Add IsSecure() to Transport

This commit is contained in:
Wichert Akkerman 2019-10-11 07:15:47 +02:00 committed by Mickaël Rémond
parent 7fa4b06705
commit 8db608ccc1
4 changed files with 21 additions and 6 deletions

View file

@ -39,7 +39,7 @@ func NewSession(transport Transport, o Config, state SMState) (*Session, error)
return nil, NewConnError(s.err, true) return nil, NewConnError(s.err, true)
} }
if !s.TlsEnabled && !o.Insecure { if !transport.IsSecure() && !o.Insecure {
err := fmt.Errorf("failed to negotiate TLS session : %s", s.err) err := fmt.Errorf("failed to negotiate TLS session : %s", s.err)
return nil, NewConnError(err, true) return nil, NewConnError(err, true)
} }

View file

@ -2,8 +2,11 @@ package xmpp
import ( import (
"crypto/tls" "crypto/tls"
"errors"
) )
var TLSNotSupported = errors.New("Transport does not support StartTLS")
type TransportConfiguration struct { type TransportConfiguration struct {
// Address is the XMPP Host and port to connect to. Host is of // Address is the XMPP Host and port to connect to. Host is of
// the form 'serverhost:port' i.e "localhost:8888" // the form 'serverhost:port' i.e "localhost:8888"
@ -19,6 +22,8 @@ type Transport interface {
DoesStartTLS() bool DoesStartTLS() bool
StartTLS(domain string) error StartTLS(domain string) error
IsSecure() bool
Read(p []byte) (n int, err error) Read(p []byte) (n int, err error)
Write(p []byte) (n int, err error) Write(p []byte) (n int, err error)
Close() error Close() error

View file

@ -2,7 +2,6 @@ package xmpp
import ( import (
"context" "context"
"errors"
"net" "net"
"strings" "strings"
"time" "time"
@ -23,9 +22,6 @@ func (t *WebsocketTransport) Connect() error {
ctx, cancel := context.WithTimeout(t.ctx, time.Duration(t.Config.ConnectTimeout)*time.Second) ctx, cancel := context.WithTimeout(t.ctx, time.Duration(t.Config.ConnectTimeout)*time.Second)
defer cancel() defer cancel()
if !c.Insecure && strings.HasPrefix(address, "wss:") {
return errors.New("Websocket address is not secure")
}
wsConn, _, err := websocket.Dial(ctx, t.Config.Address, nil) wsConn, _, err := websocket.Dial(ctx, t.Config.Address, nil)
if err != nil { if err != nil {
t.wsConn = wsConn t.wsConn = wsConn
@ -34,10 +30,18 @@ func (t *WebsocketTransport) Connect() error {
return err return err
} }
func (t WebsocketTransport) StartTLS(domain string) error {
return TLSNotSupported
}
func (t WebsocketTransport) DoesStartTLS() bool { func (t WebsocketTransport) DoesStartTLS() bool {
return false return false
} }
func (t WebsocketTransport) IsSecure() bool {
return strings.HasPrefix(t.Config.Address, "wss:")
}
func (t WebsocketTransport) Read(p []byte) (n int, err error) { func (t WebsocketTransport) Read(p []byte) (n int, err error) {
return t.netConn.Read(p) return t.netConn.Read(p)
} }

View file

@ -11,7 +11,8 @@ type XMPPTransport struct {
Config TransportConfiguration Config TransportConfiguration
TLSConfig *tls.Config TLSConfig *tls.Config
// TCP level connection / can be replaced by a TLS session after starttls // TCP level connection / can be replaced by a TLS session after starttls
conn net.Conn conn net.Conn
isSecure bool
} }
func (t *XMPPTransport) Connect() error { func (t *XMPPTransport) Connect() error {
@ -25,6 +26,10 @@ func (t XMPPTransport) DoesStartTLS() bool {
return true return true
} }
func (t XMPPTransport) IsSecure() bool {
return t.isSecure
}
func (t *XMPPTransport) StartTLS(domain string) error { func (t *XMPPTransport) StartTLS(domain string) error {
if t.Config.TLSConfig == nil { if t.Config.TLSConfig == nil {
t.Config.TLSConfig = &tls.Config{} t.Config.TLSConfig = &tls.Config{}
@ -45,6 +50,7 @@ func (t *XMPPTransport) StartTLS(domain string) error {
} }
} }
t.isSecure = true
return nil return nil
} }