Add IsSecure() to Transport
This commit is contained in:
parent
7fa4b06705
commit
8db608ccc1
|
@ -39,7 +39,7 @@ func NewSession(transport Transport, o Config, state SMState) (*Session, error)
|
||||||
return nil, NewConnError(s.err, true)
|
return nil, NewConnError(s.err, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !s.TlsEnabled && !o.Insecure {
|
if !transport.IsSecure() && !o.Insecure {
|
||||||
err := fmt.Errorf("failed to negotiate TLS session : %s", s.err)
|
err := fmt.Errorf("failed to negotiate TLS session : %s", s.err)
|
||||||
return nil, NewConnError(err, true)
|
return nil, NewConnError(err, true)
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,8 +2,11 @@ package xmpp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var TLSNotSupported = errors.New("Transport does not support StartTLS")
|
||||||
|
|
||||||
type TransportConfiguration struct {
|
type TransportConfiguration struct {
|
||||||
// Address is the XMPP Host and port to connect to. Host is of
|
// Address is the XMPP Host and port to connect to. Host is of
|
||||||
// the form 'serverhost:port' i.e "localhost:8888"
|
// the form 'serverhost:port' i.e "localhost:8888"
|
||||||
|
@ -19,6 +22,8 @@ type Transport interface {
|
||||||
DoesStartTLS() bool
|
DoesStartTLS() bool
|
||||||
StartTLS(domain string) error
|
StartTLS(domain string) error
|
||||||
|
|
||||||
|
IsSecure() bool
|
||||||
|
|
||||||
Read(p []byte) (n int, err error)
|
Read(p []byte) (n int, err error)
|
||||||
Write(p []byte) (n int, err error)
|
Write(p []byte) (n int, err error)
|
||||||
Close() error
|
Close() error
|
||||||
|
|
|
@ -2,7 +2,6 @@ package xmpp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -23,9 +22,6 @@ func (t *WebsocketTransport) Connect() error {
|
||||||
ctx, cancel := context.WithTimeout(t.ctx, time.Duration(t.Config.ConnectTimeout)*time.Second)
|
ctx, cancel := context.WithTimeout(t.ctx, time.Duration(t.Config.ConnectTimeout)*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if !c.Insecure && strings.HasPrefix(address, "wss:") {
|
|
||||||
return errors.New("Websocket address is not secure")
|
|
||||||
}
|
|
||||||
wsConn, _, err := websocket.Dial(ctx, t.Config.Address, nil)
|
wsConn, _, err := websocket.Dial(ctx, t.Config.Address, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.wsConn = wsConn
|
t.wsConn = wsConn
|
||||||
|
@ -34,10 +30,18 @@ func (t *WebsocketTransport) Connect() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t WebsocketTransport) StartTLS(domain string) error {
|
||||||
|
return TLSNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
func (t WebsocketTransport) DoesStartTLS() bool {
|
func (t WebsocketTransport) DoesStartTLS() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t WebsocketTransport) IsSecure() bool {
|
||||||
|
return strings.HasPrefix(t.Config.Address, "wss:")
|
||||||
|
}
|
||||||
|
|
||||||
func (t WebsocketTransport) Read(p []byte) (n int, err error) {
|
func (t WebsocketTransport) Read(p []byte) (n int, err error) {
|
||||||
return t.netConn.Read(p)
|
return t.netConn.Read(p)
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,7 +11,8 @@ type XMPPTransport struct {
|
||||||
Config TransportConfiguration
|
Config TransportConfiguration
|
||||||
TLSConfig *tls.Config
|
TLSConfig *tls.Config
|
||||||
// TCP level connection / can be replaced by a TLS session after starttls
|
// TCP level connection / can be replaced by a TLS session after starttls
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
|
isSecure bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *XMPPTransport) Connect() error {
|
func (t *XMPPTransport) Connect() error {
|
||||||
|
@ -25,6 +26,10 @@ func (t XMPPTransport) DoesStartTLS() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t XMPPTransport) IsSecure() bool {
|
||||||
|
return t.isSecure
|
||||||
|
}
|
||||||
|
|
||||||
func (t *XMPPTransport) StartTLS(domain string) error {
|
func (t *XMPPTransport) StartTLS(domain string) error {
|
||||||
if t.Config.TLSConfig == nil {
|
if t.Config.TLSConfig == nil {
|
||||||
t.Config.TLSConfig = &tls.Config{}
|
t.Config.TLSConfig = &tls.Config{}
|
||||||
|
@ -45,6 +50,7 @@ func (t *XMPPTransport) StartTLS(domain string) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.isSecure = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue