Clean up and fix StartTLS feature discovery
Required field was never set to true
This commit is contained in:
parent
44568fcf2b
commit
709a95129e
5
auth.go
5
auth.go
|
@ -50,11 +50,6 @@ func authPlain(socket io.ReadWriter, decoder *xml.Decoder, user string, password
|
|||
return err
|
||||
}
|
||||
|
||||
type saslMechanisms struct {
|
||||
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"`
|
||||
Mechanism []string `xml:"mechanism"`
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SASLSuccess
|
||||
|
||||
|
|
|
@ -76,8 +76,7 @@ func (c *ServerCheck) Check() error {
|
|||
return errors.New("expected packet received while expecting features, got " + p.Name())
|
||||
}
|
||||
|
||||
startTLSFeature := f.StartTLS.XMLName.Space + " " + f.StartTLS.XMLName.Local
|
||||
if startTLSFeature == nsTLS+" starttls" {
|
||||
if _, ok := f.DoesStartTLS(); ok {
|
||||
fmt.Fprintf(tcpconn, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")
|
||||
|
||||
var k tlsProceed
|
||||
|
|
|
@ -60,6 +60,30 @@ func TestClient_NoInsecure(t *testing.T) {
|
|||
mock.Stop()
|
||||
}
|
||||
|
||||
// Check that the client is properly tracking features, as session negotiation progresses.
|
||||
func TestClient_FeaturesTracking(t *testing.T) {
|
||||
// Setup Mock server
|
||||
mock := ServerMock{}
|
||||
mock.Start(t, testXMPPAddress, handlerAbortTLS)
|
||||
|
||||
// Test / Check result
|
||||
config := Config{Address: testXMPPAddress, Jid: "test@localhost", Password: "test"}
|
||||
|
||||
var client *Client
|
||||
var err error
|
||||
if client, err = NewClient(config); err != nil {
|
||||
t.Errorf("cannot create XMPP client: %s", err)
|
||||
}
|
||||
|
||||
if err = client.Connect(); err == nil {
|
||||
// When insecure is not allowed:
|
||||
t.Errorf("should fail as insecure connection is not allowed and server does not support TLS")
|
||||
}
|
||||
|
||||
mock.Stop()
|
||||
|
||||
}
|
||||
|
||||
//=============================================================================
|
||||
// Basic XMPP Server Mock Handlers.
|
||||
|
||||
|
|
|
@ -109,7 +109,7 @@ func (s *Session) startTlsIfSupported(conn net.Conn, domain string) net.Conn {
|
|||
return conn
|
||||
}
|
||||
|
||||
if s.Features.StartTLS.XMLName.Space+" "+s.Features.StartTLS.XMLName.Local == nsTLS+" starttls" {
|
||||
if _, ok := s.Features.DoesStartTLS(); ok {
|
||||
fmt.Fprintf(s.socketProxy, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")
|
||||
|
||||
var k tlsProceed
|
||||
|
|
|
@ -7,12 +7,7 @@ import (
|
|||
|
||||
var DefaultTlsConfig tls.Config
|
||||
|
||||
// XMPP Packet Parsing
|
||||
type tlsStartTLS struct {
|
||||
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"`
|
||||
Required bool
|
||||
}
|
||||
|
||||
// Used during stream initiation / session establishment
|
||||
type tlsProceed struct {
|
||||
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls proceed"`
|
||||
}
|
||||
|
|
86
stream.go
86
stream.go
|
@ -6,11 +6,14 @@ import (
|
|||
|
||||
// ============================================================================
|
||||
// StreamFeatures Packet
|
||||
// Reference: https://xmpp.org/registrar/stream-features.html
|
||||
|
||||
type StreamFeatures struct {
|
||||
XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"`
|
||||
StartTLS tlsStartTLS
|
||||
// Server capabilities hash
|
||||
Caps Caps
|
||||
// Stream features
|
||||
StartTLS tlsStartTLS
|
||||
Mechanisms saslMechanisms
|
||||
Bind BindBind
|
||||
Session sessionSession
|
||||
|
@ -31,6 +34,76 @@ func (streamFeatureDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamF
|
|||
return packet, err
|
||||
}
|
||||
|
||||
// Capabilities
|
||||
// Reference: https://xmpp.org/extensions/xep-0115.html#stream
|
||||
// "A server MAY include its entity capabilities in a stream feature element so that connecting clients
|
||||
// and peer servers do not need to send service discovery requests each time they connect."
|
||||
// This is not a stream feature but a way to let client cache server disco info.
|
||||
type Caps struct {
|
||||
XMLName xml.Name `xml:"http://jabber.org/protocol/caps c"`
|
||||
Hash string `xml:"hash,attr"`
|
||||
Node string `xml:"node,attr"`
|
||||
Ver string `xml:"ver,attr"`
|
||||
Ext string `xml:"ext,attr,omitempty"`
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Supported Stream Features
|
||||
|
||||
// StartTLS feature
|
||||
// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-5.4
|
||||
type tlsStartTLS struct {
|
||||
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"`
|
||||
Required bool
|
||||
}
|
||||
|
||||
// UnmarshalXML implements custom parsing startTLS required flag
|
||||
func (stls *tlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
|
||||
stls.XMLName = start.Name
|
||||
|
||||
// Check subelements to extract required field as boolean
|
||||
for {
|
||||
t, err := d.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch tt := t.(type) {
|
||||
|
||||
case xml.StartElement:
|
||||
elt := new(Node)
|
||||
|
||||
err = d.DecodeElement(elt, &tt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if elt.XMLName.Local == "required" {
|
||||
stls.Required = true
|
||||
}
|
||||
|
||||
case xml.EndElement:
|
||||
if tt == start.End() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sf *StreamFeatures) DoesStartTLS() (feature tlsStartTLS, isSupported bool) {
|
||||
if sf.StartTLS.XMLName.Space+" "+sf.StartTLS.XMLName.Local == nsTLS+" starttls" {
|
||||
return sf.StartTLS, true
|
||||
}
|
||||
return feature, false
|
||||
}
|
||||
|
||||
// Mechanisms
|
||||
// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-6.4.1
|
||||
type saslMechanisms struct {
|
||||
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"`
|
||||
Mechanism []string `xml:"mechanism"`
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// StreamError Packet
|
||||
|
||||
|
@ -53,14 +126,3 @@ func (streamErrorDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamErr
|
|||
err := p.DecodeElement(&packet, &se)
|
||||
return packet, err
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Caps subElement
|
||||
|
||||
type Caps struct {
|
||||
XMLName xml.Name `xml:"http://jabber.org/protocol/caps c"`
|
||||
Hash string `xml:"hash,attr"`
|
||||
Node string `xml:"node,attr"`
|
||||
Ver string `xml:"ver,attr"`
|
||||
Ext string `xml:"ext,attr,omitempty"`
|
||||
}
|
||||
|
|
47
stream_test.go
Normal file
47
stream_test.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
package xmpp_test
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"testing"
|
||||
|
||||
"gosrc.io/xmpp"
|
||||
)
|
||||
|
||||
func TestNoStartTLS(t *testing.T) {
|
||||
streamFeatures := `<stream:features xmlns:stream='http://etherx.jabber.org/streams'>
|
||||
</stream:features>`
|
||||
|
||||
var parsedSF xmpp.StreamFeatures
|
||||
if err := xml.Unmarshal([]byte(streamFeatures), &parsedSF); err != nil {
|
||||
t.Errorf("Unmarshal(%s) returned error: %v", streamFeatures, err)
|
||||
}
|
||||
|
||||
startTLS, ok := parsedSF.DoesStartTLS()
|
||||
if ok {
|
||||
t.Error("StartTLS feature should not be enabled")
|
||||
}
|
||||
if startTLS.Required {
|
||||
t.Error("StartTLS cannot be required as default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartTLS(t *testing.T) {
|
||||
streamFeatures := `<stream:features xmlns:stream='http://etherx.jabber.org/streams'>
|
||||
<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'>
|
||||
<required/>
|
||||
</starttls>
|
||||
</stream:features>`
|
||||
|
||||
var parsedSF xmpp.StreamFeatures
|
||||
if err := xml.Unmarshal([]byte(streamFeatures), &parsedSF); err != nil {
|
||||
t.Errorf("Unmarshal(%s) returned error: %v", streamFeatures, err)
|
||||
}
|
||||
|
||||
startTLS, ok := parsedSF.DoesStartTLS()
|
||||
if !ok {
|
||||
t.Error("StartTLS feature should be enabled")
|
||||
}
|
||||
if !startTLS.Required {
|
||||
t.Error("StartTLS feature should be required")
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue