Create a new stream after StartTLS

This commit is contained in:
Wichert Akkerman 2019-10-25 15:22:01 +02:00 committed by Mickaël Rémond
parent 390f9b065e
commit 33446ad0ba
7 changed files with 55 additions and 26 deletions

View file

@ -51,7 +51,7 @@ func (c *ServerCheck) Check() error {
decoder := xml.NewDecoder(tcpconn) decoder := xml.NewDecoder(tcpconn)
// Send stream open tag // Send stream open tag
if _, err = fmt.Fprintf(tcpconn, xmppStreamOpen, c.domain, stanza.NSClient, stanza.NSStream); err != nil { if _, err = fmt.Fprintf(tcpconn, clientStreamOpen, c.domain); err != nil {
return err return err
} }

View file

@ -144,7 +144,7 @@ func NewClient(config Config, r *Router) (c *Client, err error) {
if config.TransportConfiguration.Domain == "" { if config.TransportConfiguration.Domain == "" {
config.TransportConfiguration.Domain = config.parsedJid.Domain config.TransportConfiguration.Domain = config.parsedJid.Domain
} }
c.transport = NewTransport(config.TransportConfiguration) c.transport = NewClientTransport(config.TransportConfiguration)
if config.StreamLogger != nil { if config.StreamLogger != nil {
c.transport.LogTraffic(config.StreamLogger) c.transport.LogTraffic(config.StreamLogger)

View file

@ -11,8 +11,6 @@ import (
"gosrc.io/xmpp/stanza" "gosrc.io/xmpp/stanza"
) )
const componentStreamOpen = "<?xml version='1.0'?><stream:stream to='%s' xmlns='%s' xmlns:stream='%s'>"
type ComponentOptions struct { type ComponentOptions struct {
TransportConfiguration TransportConfiguration
@ -71,7 +69,11 @@ func (c *Component) Resume(sm SMState) error {
if c.ComponentOptions.TransportConfiguration.Domain == "" { if c.ComponentOptions.TransportConfiguration.Domain == "" {
c.ComponentOptions.TransportConfiguration.Domain = c.ComponentOptions.Domain c.ComponentOptions.TransportConfiguration.Domain = c.ComponentOptions.Domain
} }
c.transport = NewTransport(c.ComponentOptions.TransportConfiguration) c.transport, err = NewComponentTransport(c.ComponentOptions.TransportConfiguration)
if err != nil {
c.updateState(StateStreamError)
return err
}
if streamId, err = c.transport.Connect(); err != nil { if streamId, err = c.transport.Connect(); err != nil {
c.updateState(StateStreamError) c.updateState(StateStreamError)

View file

@ -108,7 +108,7 @@ func (s *Session) startTlsIfSupported(o Config) {
return return
} }
s.err = s.transport.StartTLS() s.StreamId, s.err = s.transport.StartTLS()
if s.err == nil { if s.err == nil {
s.TlsEnabled = true s.TlsEnabled = true

View file

@ -4,11 +4,13 @@ import (
"crypto/tls" "crypto/tls"
"encoding/xml" "encoding/xml"
"errors" "errors"
"fmt"
"io" "io"
"strings" "strings"
) )
var TLSNotSupported = errors.New("Transport does not support StartTLS") var ErrTransportProtocolNotSupported = errors.New("Transport protocol not supported")
var ErrTLSNotSupported = errors.New("Transport does not support StartTLS")
type TransportConfiguration struct { type TransportConfiguration struct {
// Address is the XMPP Host and port to connect to. Host is of // Address is the XMPP Host and port to connect to. Host is of
@ -25,7 +27,7 @@ type TransportConfiguration struct {
type Transport interface { type Transport interface {
Connect() (string, error) Connect() (string, error)
DoesStartTLS() bool DoesStartTLS() bool
StartTLS() error StartTLS() (string, error)
LogTraffic(logFile io.Writer) LogTraffic(logFile io.Writer)
@ -38,16 +40,34 @@ type Transport interface {
Close() error Close() error
} }
// NewTransport creates a new Transport instance. // NewClientTransport creates a new Transport instance for clients.
// The type of transport is determined by the address in the configuration: // The type of transport is determined by the address in the configuration:
// - if the address is a URL with the `ws` or `wss` scheme WebsocketTransport is used // - if the address is a URL with the `ws` or `wss` scheme WebsocketTransport is used
// - in all other cases a XMPPTransport is used // - in all other cases a XMPPTransport is used
// For XMPPTransport it is mandatory for the address to have a port specified. // For XMPPTransport it is mandatory for the address to have a port specified.
func NewTransport(config TransportConfiguration) Transport { func NewClientTransport(config TransportConfiguration) Transport {
if strings.HasPrefix(config.Address, "ws:") || strings.HasPrefix(config.Address, "wss:") { if strings.HasPrefix(config.Address, "ws:") || strings.HasPrefix(config.Address, "wss:") {
return &WebsocketTransport{Config: config} return &WebsocketTransport{Config: config}
} }
config.Address = ensurePort(config.Address, 5222) config.Address = ensurePort(config.Address, 5222)
return &XMPPTransport{Config: config} return &XMPPTransport{
Config: config,
openStatement: clientStreamOpen,
}
}
// NewComponentTransport creates a new Transport instance for components.
// Only XMPP transports are allowed. If you try to use any other protocol an error
// will be returned.
func NewComponentTransport(config TransportConfiguration) (Transport, error) {
if strings.HasPrefix(config.Address, "ws:") || strings.HasPrefix(config.Address, "wss:") {
return nil, fmt.Errorf("Components only support XMPP transport: %w", ErrTransportProtocolNotSupported)
}
config.Address = ensurePort(config.Address, 5222)
return &XMPPTransport{
Config: config,
openStatement: componentStreamOpen,
}, nil
} }

View file

@ -108,8 +108,8 @@ func (t WebsocketTransport) startReader() {
}() }()
} }
func (t WebsocketTransport) StartTLS() error { func (t WebsocketTransport) StartTLS() (string, error) {
return TLSNotSupported return "", ErrTLSNotSupported
} }
func (t WebsocketTransport) DoesStartTLS() bool { func (t WebsocketTransport) DoesStartTLS() bool {

View file

@ -14,16 +14,19 @@ import (
// XMPPTransport implements the XMPP native TCP transport // XMPPTransport implements the XMPP native TCP transport
type XMPPTransport struct { type XMPPTransport struct {
Config TransportConfiguration openStatement string
TLSConfig *tls.Config Config TransportConfiguration
decoder *xml.Decoder TLSConfig *tls.Config
conn net.Conn decoder *xml.Decoder
readWriter io.ReadWriter conn net.Conn
logFile io.Writer readWriter io.ReadWriter
isSecure bool logFile io.Writer
isSecure bool
} }
const xmppStreamOpen = "<?xml version='1.0'?><stream:stream to='%s' xmlns='%s' xmlns:stream='%s' version='1.0'>" var componentStreamOpen = fmt.Sprintf("<?xml version='1.0'?><stream:stream to='%%s' xmlns='%s' xmlns:stream='%s'>", stanza.NSComponent, stanza.NSStream)
var clientStreamOpen = fmt.Sprintf("<?xml version='1.0'?><stream:stream to='%%s' xmlns='%s' xmlns:stream='%s' version='1.0'>", stanza.NSClient, stanza.NSStream)
func (t *XMPPTransport) Connect() (string, error) { func (t *XMPPTransport) Connect() (string, error) {
var err error var err error
@ -34,8 +37,11 @@ func (t *XMPPTransport) Connect() (string, error) {
} }
t.readWriter = newStreamLogger(t.conn, t.logFile) t.readWriter = newStreamLogger(t.conn, t.logFile)
return t.startStream()
}
if _, err = fmt.Fprintf(t.readWriter, xmppStreamOpen, t.Config.Domain, stanza.NSClient, stanza.NSStream); err != nil { func (t *XMPPTransport) startStream() (string, error) {
if _, err := fmt.Fprintf(t.readWriter, t.openStatement, t.Config.Domain); err != nil {
t.conn.Close() t.conn.Close()
return "", NewConnError(err, true) return "", NewConnError(err, true)
} }
@ -62,7 +68,7 @@ func (t XMPPTransport) IsSecure() bool {
return t.isSecure return t.isSecure
} }
func (t *XMPPTransport) StartTLS() error { func (t *XMPPTransport) StartTLS() (string, error) {
if t.Config.TLSConfig == nil { if t.Config.TLSConfig == nil {
t.TLSConfig = &tls.Config{} t.TLSConfig = &tls.Config{}
} else { } else {
@ -75,7 +81,7 @@ func (t *XMPPTransport) StartTLS() error {
tlsConn := tls.Client(t.conn, t.TLSConfig) tlsConn := tls.Client(t.conn, t.TLSConfig)
// We convert existing connection to TLS // We convert existing connection to TLS
if err := tlsConn.Handshake(); err != nil { if err := tlsConn.Handshake(); err != nil {
return err return "", err
} }
t.conn = tlsConn t.conn = tlsConn
@ -85,12 +91,13 @@ func (t *XMPPTransport) StartTLS() error {
if !t.TLSConfig.InsecureSkipVerify { if !t.TLSConfig.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(t.Config.Domain); err != nil { if err := tlsConn.VerifyHostname(t.Config.Domain); err != nil {
return err return "", err
} }
} }
t.isSecure = true t.isSecure = true
return nil
return t.startStream()
} }
func (t XMPPTransport) Ping() error { func (t XMPPTransport) Ping() error {