Create a new stream after StartTLS
This commit is contained in:
parent
390f9b065e
commit
33446ad0ba
|
@ -51,7 +51,7 @@ func (c *ServerCheck) Check() error {
|
||||||
decoder := xml.NewDecoder(tcpconn)
|
decoder := xml.NewDecoder(tcpconn)
|
||||||
|
|
||||||
// Send stream open tag
|
// Send stream open tag
|
||||||
if _, err = fmt.Fprintf(tcpconn, xmppStreamOpen, c.domain, stanza.NSClient, stanza.NSStream); err != nil {
|
if _, err = fmt.Fprintf(tcpconn, clientStreamOpen, c.domain); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -144,7 +144,7 @@ func NewClient(config Config, r *Router) (c *Client, err error) {
|
||||||
if config.TransportConfiguration.Domain == "" {
|
if config.TransportConfiguration.Domain == "" {
|
||||||
config.TransportConfiguration.Domain = config.parsedJid.Domain
|
config.TransportConfiguration.Domain = config.parsedJid.Domain
|
||||||
}
|
}
|
||||||
c.transport = NewTransport(config.TransportConfiguration)
|
c.transport = NewClientTransport(config.TransportConfiguration)
|
||||||
|
|
||||||
if config.StreamLogger != nil {
|
if config.StreamLogger != nil {
|
||||||
c.transport.LogTraffic(config.StreamLogger)
|
c.transport.LogTraffic(config.StreamLogger)
|
||||||
|
|
|
@ -11,8 +11,6 @@ import (
|
||||||
"gosrc.io/xmpp/stanza"
|
"gosrc.io/xmpp/stanza"
|
||||||
)
|
)
|
||||||
|
|
||||||
const componentStreamOpen = "<?xml version='1.0'?><stream:stream to='%s' xmlns='%s' xmlns:stream='%s'>"
|
|
||||||
|
|
||||||
type ComponentOptions struct {
|
type ComponentOptions struct {
|
||||||
TransportConfiguration
|
TransportConfiguration
|
||||||
|
|
||||||
|
@ -71,7 +69,11 @@ func (c *Component) Resume(sm SMState) error {
|
||||||
if c.ComponentOptions.TransportConfiguration.Domain == "" {
|
if c.ComponentOptions.TransportConfiguration.Domain == "" {
|
||||||
c.ComponentOptions.TransportConfiguration.Domain = c.ComponentOptions.Domain
|
c.ComponentOptions.TransportConfiguration.Domain = c.ComponentOptions.Domain
|
||||||
}
|
}
|
||||||
c.transport = NewTransport(c.ComponentOptions.TransportConfiguration)
|
c.transport, err = NewComponentTransport(c.ComponentOptions.TransportConfiguration)
|
||||||
|
if err != nil {
|
||||||
|
c.updateState(StateStreamError)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if streamId, err = c.transport.Connect(); err != nil {
|
if streamId, err = c.transport.Connect(); err != nil {
|
||||||
c.updateState(StateStreamError)
|
c.updateState(StateStreamError)
|
||||||
|
|
|
@ -108,7 +108,7 @@ func (s *Session) startTlsIfSupported(o Config) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s.err = s.transport.StartTLS()
|
s.StreamId, s.err = s.transport.StartTLS()
|
||||||
|
|
||||||
if s.err == nil {
|
if s.err == nil {
|
||||||
s.TlsEnabled = true
|
s.TlsEnabled = true
|
||||||
|
|
30
transport.go
30
transport.go
|
@ -4,11 +4,13 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/xml"
|
"encoding/xml"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
var TLSNotSupported = errors.New("Transport does not support StartTLS")
|
var ErrTransportProtocolNotSupported = errors.New("Transport protocol not supported")
|
||||||
|
var ErrTLSNotSupported = errors.New("Transport does not support StartTLS")
|
||||||
|
|
||||||
type TransportConfiguration struct {
|
type TransportConfiguration struct {
|
||||||
// Address is the XMPP Host and port to connect to. Host is of
|
// Address is the XMPP Host and port to connect to. Host is of
|
||||||
|
@ -25,7 +27,7 @@ type TransportConfiguration struct {
|
||||||
type Transport interface {
|
type Transport interface {
|
||||||
Connect() (string, error)
|
Connect() (string, error)
|
||||||
DoesStartTLS() bool
|
DoesStartTLS() bool
|
||||||
StartTLS() error
|
StartTLS() (string, error)
|
||||||
|
|
||||||
LogTraffic(logFile io.Writer)
|
LogTraffic(logFile io.Writer)
|
||||||
|
|
||||||
|
@ -38,16 +40,34 @@ type Transport interface {
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTransport creates a new Transport instance.
|
// NewClientTransport creates a new Transport instance for clients.
|
||||||
// The type of transport is determined by the address in the configuration:
|
// The type of transport is determined by the address in the configuration:
|
||||||
// - if the address is a URL with the `ws` or `wss` scheme WebsocketTransport is used
|
// - if the address is a URL with the `ws` or `wss` scheme WebsocketTransport is used
|
||||||
// - in all other cases a XMPPTransport is used
|
// - in all other cases a XMPPTransport is used
|
||||||
// For XMPPTransport it is mandatory for the address to have a port specified.
|
// For XMPPTransport it is mandatory for the address to have a port specified.
|
||||||
func NewTransport(config TransportConfiguration) Transport {
|
func NewClientTransport(config TransportConfiguration) Transport {
|
||||||
if strings.HasPrefix(config.Address, "ws:") || strings.HasPrefix(config.Address, "wss:") {
|
if strings.HasPrefix(config.Address, "ws:") || strings.HasPrefix(config.Address, "wss:") {
|
||||||
return &WebsocketTransport{Config: config}
|
return &WebsocketTransport{Config: config}
|
||||||
}
|
}
|
||||||
|
|
||||||
config.Address = ensurePort(config.Address, 5222)
|
config.Address = ensurePort(config.Address, 5222)
|
||||||
return &XMPPTransport{Config: config}
|
return &XMPPTransport{
|
||||||
|
Config: config,
|
||||||
|
openStatement: clientStreamOpen,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewComponentTransport creates a new Transport instance for components.
|
||||||
|
// Only XMPP transports are allowed. If you try to use any other protocol an error
|
||||||
|
// will be returned.
|
||||||
|
func NewComponentTransport(config TransportConfiguration) (Transport, error) {
|
||||||
|
if strings.HasPrefix(config.Address, "ws:") || strings.HasPrefix(config.Address, "wss:") {
|
||||||
|
return nil, fmt.Errorf("Components only support XMPP transport: %w", ErrTransportProtocolNotSupported)
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Address = ensurePort(config.Address, 5222)
|
||||||
|
return &XMPPTransport{
|
||||||
|
Config: config,
|
||||||
|
openStatement: componentStreamOpen,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -108,8 +108,8 @@ func (t WebsocketTransport) startReader() {
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t WebsocketTransport) StartTLS() error {
|
func (t WebsocketTransport) StartTLS() (string, error) {
|
||||||
return TLSNotSupported
|
return "", ErrTLSNotSupported
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t WebsocketTransport) DoesStartTLS() bool {
|
func (t WebsocketTransport) DoesStartTLS() bool {
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
|
|
||||||
// XMPPTransport implements the XMPP native TCP transport
|
// XMPPTransport implements the XMPP native TCP transport
|
||||||
type XMPPTransport struct {
|
type XMPPTransport struct {
|
||||||
|
openStatement string
|
||||||
Config TransportConfiguration
|
Config TransportConfiguration
|
||||||
TLSConfig *tls.Config
|
TLSConfig *tls.Config
|
||||||
decoder *xml.Decoder
|
decoder *xml.Decoder
|
||||||
|
@ -23,7 +24,9 @@ type XMPPTransport struct {
|
||||||
isSecure bool
|
isSecure bool
|
||||||
}
|
}
|
||||||
|
|
||||||
const xmppStreamOpen = "<?xml version='1.0'?><stream:stream to='%s' xmlns='%s' xmlns:stream='%s' version='1.0'>"
|
var componentStreamOpen = fmt.Sprintf("<?xml version='1.0'?><stream:stream to='%%s' xmlns='%s' xmlns:stream='%s'>", stanza.NSComponent, stanza.NSStream)
|
||||||
|
|
||||||
|
var clientStreamOpen = fmt.Sprintf("<?xml version='1.0'?><stream:stream to='%%s' xmlns='%s' xmlns:stream='%s' version='1.0'>", stanza.NSClient, stanza.NSStream)
|
||||||
|
|
||||||
func (t *XMPPTransport) Connect() (string, error) {
|
func (t *XMPPTransport) Connect() (string, error) {
|
||||||
var err error
|
var err error
|
||||||
|
@ -34,8 +37,11 @@ func (t *XMPPTransport) Connect() (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
t.readWriter = newStreamLogger(t.conn, t.logFile)
|
t.readWriter = newStreamLogger(t.conn, t.logFile)
|
||||||
|
return t.startStream()
|
||||||
|
}
|
||||||
|
|
||||||
if _, err = fmt.Fprintf(t.readWriter, xmppStreamOpen, t.Config.Domain, stanza.NSClient, stanza.NSStream); err != nil {
|
func (t *XMPPTransport) startStream() (string, error) {
|
||||||
|
if _, err := fmt.Fprintf(t.readWriter, t.openStatement, t.Config.Domain); err != nil {
|
||||||
t.conn.Close()
|
t.conn.Close()
|
||||||
return "", NewConnError(err, true)
|
return "", NewConnError(err, true)
|
||||||
}
|
}
|
||||||
|
@ -62,7 +68,7 @@ func (t XMPPTransport) IsSecure() bool {
|
||||||
return t.isSecure
|
return t.isSecure
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *XMPPTransport) StartTLS() error {
|
func (t *XMPPTransport) StartTLS() (string, error) {
|
||||||
if t.Config.TLSConfig == nil {
|
if t.Config.TLSConfig == nil {
|
||||||
t.TLSConfig = &tls.Config{}
|
t.TLSConfig = &tls.Config{}
|
||||||
} else {
|
} else {
|
||||||
|
@ -75,7 +81,7 @@ func (t *XMPPTransport) StartTLS() error {
|
||||||
tlsConn := tls.Client(t.conn, t.TLSConfig)
|
tlsConn := tls.Client(t.conn, t.TLSConfig)
|
||||||
// We convert existing connection to TLS
|
// We convert existing connection to TLS
|
||||||
if err := tlsConn.Handshake(); err != nil {
|
if err := tlsConn.Handshake(); err != nil {
|
||||||
return err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
t.conn = tlsConn
|
t.conn = tlsConn
|
||||||
|
@ -85,12 +91,13 @@ func (t *XMPPTransport) StartTLS() error {
|
||||||
|
|
||||||
if !t.TLSConfig.InsecureSkipVerify {
|
if !t.TLSConfig.InsecureSkipVerify {
|
||||||
if err := tlsConn.VerifyHostname(t.Config.Domain); err != nil {
|
if err := tlsConn.VerifyHostname(t.Config.Domain); err != nil {
|
||||||
return err
|
return "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
t.isSecure = true
|
t.isSecure = true
|
||||||
return nil
|
|
||||||
|
return t.startStream()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t XMPPTransport) Ping() error {
|
func (t XMPPTransport) Ping() error {
|
||||||
|
|
Loading…
Reference in a new issue