diff --git a/auth.go b/auth.go
index 8569297..e69f82e 100644
--- a/auth.go
+++ b/auth.go
@@ -50,11 +50,6 @@ func authPlain(socket io.ReadWriter, decoder *xml.Decoder, user string, password
return err
}
-type saslMechanisms struct {
- XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"`
- Mechanism []string `xml:"mechanism"`
-}
-
// ============================================================================
// SASLSuccess
diff --git a/check_cert.go b/check_cert.go
index 6190d96..074676e 100644
--- a/check_cert.go
+++ b/check_cert.go
@@ -76,8 +76,7 @@ func (c *ServerCheck) Check() error {
return errors.New("expected packet received while expecting features, got " + p.Name())
}
- startTLSFeature := f.StartTLS.XMLName.Space + " " + f.StartTLS.XMLName.Local
- if startTLSFeature == nsTLS+" starttls" {
+ if _, ok := f.DoesStartTLS(); ok {
fmt.Fprintf(tcpconn, "")
var k tlsProceed
diff --git a/client_test.go b/client_test.go
index b7fb1ac..7d68717 100644
--- a/client_test.go
+++ b/client_test.go
@@ -60,6 +60,30 @@ func TestClient_NoInsecure(t *testing.T) {
mock.Stop()
}
+// Check that the client is properly tracking features, as session negotiation progresses.
+func TestClient_FeaturesTracking(t *testing.T) {
+ // Setup Mock server
+ mock := ServerMock{}
+ mock.Start(t, testXMPPAddress, handlerAbortTLS)
+
+ // Test / Check result
+ config := Config{Address: testXMPPAddress, Jid: "test@localhost", Password: "test"}
+
+ var client *Client
+ var err error
+ if client, err = NewClient(config); err != nil {
+ t.Errorf("cannot create XMPP client: %s", err)
+ }
+
+ if err = client.Connect(); err == nil {
+ // When insecure is not allowed:
+ t.Errorf("should fail as insecure connection is not allowed and server does not support TLS")
+ }
+
+ mock.Stop()
+
+}
+
//=============================================================================
// Basic XMPP Server Mock Handlers.
diff --git a/session.go b/session.go
index 0c0b278..3bce3cd 100644
--- a/session.go
+++ b/session.go
@@ -109,7 +109,7 @@ func (s *Session) startTlsIfSupported(conn net.Conn, domain string) net.Conn {
return conn
}
- if s.Features.StartTLS.XMLName.Space+" "+s.Features.StartTLS.XMLName.Local == nsTLS+" starttls" {
+ if _, ok := s.Features.DoesStartTLS(); ok {
fmt.Fprintf(s.socketProxy, "")
var k tlsProceed
diff --git a/starttls.go b/starttls.go
index 28149b7..8c36222 100644
--- a/starttls.go
+++ b/starttls.go
@@ -7,12 +7,7 @@ import (
var DefaultTlsConfig tls.Config
-// XMPP Packet Parsing
-type tlsStartTLS struct {
- XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"`
- Required bool
-}
-
+// Used during stream initiation / session establishment
type tlsProceed struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls proceed"`
}
diff --git a/stream.go b/stream.go
index a66ac85..b887c55 100644
--- a/stream.go
+++ b/stream.go
@@ -6,11 +6,14 @@ import (
// ============================================================================
// StreamFeatures Packet
+// Reference: https://xmpp.org/registrar/stream-features.html
type StreamFeatures struct {
- XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"`
+ XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"`
+ // Server capabilities hash
+ Caps Caps
+ // Stream features
StartTLS tlsStartTLS
- Caps Caps
Mechanisms saslMechanisms
Bind BindBind
Session sessionSession
@@ -31,6 +34,76 @@ func (streamFeatureDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamF
return packet, err
}
+// Capabilities
+// Reference: https://xmpp.org/extensions/xep-0115.html#stream
+// "A server MAY include its entity capabilities in a stream feature element so that connecting clients
+// and peer servers do not need to send service discovery requests each time they connect."
+// This is not a stream feature but a way to let client cache server disco info.
+type Caps struct {
+ XMLName xml.Name `xml:"http://jabber.org/protocol/caps c"`
+ Hash string `xml:"hash,attr"`
+ Node string `xml:"node,attr"`
+ Ver string `xml:"ver,attr"`
+ Ext string `xml:"ext,attr,omitempty"`
+}
+
+// ============================================================================
+// Supported Stream Features
+
+// StartTLS feature
+// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-5.4
+type tlsStartTLS struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"`
+ Required bool
+}
+
+// UnmarshalXML implements custom parsing startTLS required flag
+func (stls *tlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
+ stls.XMLName = start.Name
+
+ // Check subelements to extract required field as boolean
+ for {
+ t, err := d.Token()
+ if err != nil {
+ return err
+ }
+
+ switch tt := t.(type) {
+
+ case xml.StartElement:
+ elt := new(Node)
+
+ err = d.DecodeElement(elt, &tt)
+ if err != nil {
+ return err
+ }
+
+ if elt.XMLName.Local == "required" {
+ stls.Required = true
+ }
+
+ case xml.EndElement:
+ if tt == start.End() {
+ return nil
+ }
+ }
+ }
+}
+
+func (sf *StreamFeatures) DoesStartTLS() (feature tlsStartTLS, isSupported bool) {
+ if sf.StartTLS.XMLName.Space+" "+sf.StartTLS.XMLName.Local == nsTLS+" starttls" {
+ return sf.StartTLS, true
+ }
+ return feature, false
+}
+
+// Mechanisms
+// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-6.4.1
+type saslMechanisms struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"`
+ Mechanism []string `xml:"mechanism"`
+}
+
// ============================================================================
// StreamError Packet
@@ -53,14 +126,3 @@ func (streamErrorDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamErr
err := p.DecodeElement(&packet, &se)
return packet, err
}
-
-// ============================================================================
-// Caps subElement
-
-type Caps struct {
- XMLName xml.Name `xml:"http://jabber.org/protocol/caps c"`
- Hash string `xml:"hash,attr"`
- Node string `xml:"node,attr"`
- Ver string `xml:"ver,attr"`
- Ext string `xml:"ext,attr,omitempty"`
-}
diff --git a/stream_test.go b/stream_test.go
new file mode 100644
index 0000000..f10d1de
--- /dev/null
+++ b/stream_test.go
@@ -0,0 +1,47 @@
+package xmpp_test
+
+import (
+ "encoding/xml"
+ "testing"
+
+ "gosrc.io/xmpp"
+)
+
+func TestNoStartTLS(t *testing.T) {
+ streamFeatures := `
+`
+
+ var parsedSF xmpp.StreamFeatures
+ if err := xml.Unmarshal([]byte(streamFeatures), &parsedSF); err != nil {
+ t.Errorf("Unmarshal(%s) returned error: %v", streamFeatures, err)
+ }
+
+ startTLS, ok := parsedSF.DoesStartTLS()
+ if ok {
+ t.Error("StartTLS feature should not be enabled")
+ }
+ if startTLS.Required {
+ t.Error("StartTLS cannot be required as default")
+ }
+}
+
+func TestStartTLS(t *testing.T) {
+ streamFeatures := `
+
+
+
+`
+
+ var parsedSF xmpp.StreamFeatures
+ if err := xml.Unmarshal([]byte(streamFeatures), &parsedSF); err != nil {
+ t.Errorf("Unmarshal(%s) returned error: %v", streamFeatures, err)
+ }
+
+ startTLS, ok := parsedSF.DoesStartTLS()
+ if !ok {
+ t.Error("StartTLS feature should be enabled")
+ }
+ if !startTLS.Required {
+ t.Error("StartTLS feature should be required")
+ }
+}