Introduce Transport interface
This commit is contained in:
parent
2781563ea7
commit
a3c62e515e
23
client.go
23
client.go
|
@ -90,7 +90,7 @@ type Client struct {
|
|||
// Session gather data that can be accessed by users of this library
|
||||
Session *Session
|
||||
// TCP level connection / can be replaced by a TLS session after starttls
|
||||
conn net.Conn
|
||||
transport Transport
|
||||
// Router is used to dispatch packets
|
||||
router *Router
|
||||
// Track and broadcast connection state
|
||||
|
@ -139,6 +139,7 @@ func NewClient(config Config, r *Router) (c *Client, err error) {
|
|||
c = new(Client)
|
||||
c.config = config
|
||||
c.router = r
|
||||
c.transport = &XMPPTransport{}
|
||||
|
||||
if c.config.ConnectTimeout == 0 {
|
||||
c.config.ConnectTimeout = 15 // 15 second as default
|
||||
|
@ -159,21 +160,21 @@ func (c *Client) Connect() error {
|
|||
func (c *Client) Resume(state SMState) error {
|
||||
var err error
|
||||
|
||||
c.conn, err = net.DialTimeout("tcp", c.config.Address, time.Duration(c.config.ConnectTimeout)*time.Second)
|
||||
err = c.transport.Connect(c.config.Address, c.config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.updateState(StateConnected)
|
||||
|
||||
// Client is ok, we now open XMPP session
|
||||
if c.conn, c.Session, err = NewSession(c.conn, c.config, state); err != nil {
|
||||
if c.Session, err = NewSession(c.transport, c.config, state); err != nil {
|
||||
return err
|
||||
}
|
||||
c.updateState(StateSessionEstablished)
|
||||
|
||||
// Start the keepalive go routine
|
||||
keepaliveQuit := make(chan struct{})
|
||||
go keepalive(c.conn, keepaliveQuit)
|
||||
go keepalive(c.transport, keepaliveQuit)
|
||||
// Start the receiver go routine
|
||||
state = c.Session.SMState
|
||||
go c.recv(state, keepaliveQuit)
|
||||
|
@ -190,7 +191,7 @@ func (c *Client) Resume(state SMState) error {
|
|||
func (c *Client) Disconnect() {
|
||||
_ = c.SendRaw("</stream:stream>")
|
||||
// TODO: Add a way to wait for stream close acknowledgement from the server for clean disconnect
|
||||
conn := c.conn
|
||||
conn := c.transport
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
@ -202,7 +203,7 @@ func (c *Client) SetHandler(handler EventHandler) {
|
|||
|
||||
// Send marshals XMPP stanza and sends it to the server.
|
||||
func (c *Client) Send(packet stanza.Packet) error {
|
||||
conn := c.conn
|
||||
conn := c.transport
|
||||
if conn == nil {
|
||||
return errors.New("client is not connected")
|
||||
}
|
||||
|
@ -220,7 +221,7 @@ func (c *Client) Send(packet stanza.Packet) error {
|
|||
// disconnect the client. It is up to the user of this method to
|
||||
// carefully craft the XML content to produce valid XMPP.
|
||||
func (c *Client) SendRaw(packet string) error {
|
||||
conn := c.conn
|
||||
conn := c.transport
|
||||
if conn == nil {
|
||||
return errors.New("client is not connected")
|
||||
}
|
||||
|
@ -272,16 +273,16 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) (err error)
|
|||
// Loop: send whitespace keepalive to server
|
||||
// This is use to keep the connection open, but also to detect connection loss
|
||||
// and trigger proper client connection shutdown.
|
||||
func keepalive(conn net.Conn, quit <-chan struct{}) {
|
||||
func keepalive(transport Transport, quit <-chan struct{}) {
|
||||
// TODO: Make keepalive interval configurable
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if n, err := fmt.Fprintf(conn, "\n"); err != nil || n != 1 {
|
||||
// When keep alive fails, we force close the connection. In all cases, the recv will also fail.
|
||||
if n, err := fmt.Fprintf(transport, "\n"); err != nil || n != 1 {
|
||||
// When keep alive fails, we force close the transportection. In all cases, the recv will also fail.
|
||||
ticker.Stop()
|
||||
_ = conn.Close()
|
||||
_ = transport.Close()
|
||||
return
|
||||
}
|
||||
case <-quit:
|
||||
|
|
|
@ -14,8 +14,8 @@ type Config struct {
|
|||
StreamLogger *os.File // Used for debugging
|
||||
Lang string // TODO: should default to 'en'
|
||||
ConnectTimeout int // Client timeout in seconds. Default to 15
|
||||
// tls.Config must not be modified after having been passed to NewClient. The
|
||||
// Client connect method may override the tls.Config.ServerName if it was not set.
|
||||
// tls.Config must not be modified after having been passed to NewClient. Any
|
||||
// changes made after connecting are ignored.
|
||||
TLSConfig *tls.Config
|
||||
// Insecure can be set to true to allow to open a session without TLS. If TLS
|
||||
// is supported on the server, we will still try to use it.
|
||||
|
|
71
session.go
71
session.go
|
@ -1,12 +1,10 @@
|
|||
package xmpp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"gosrc.io/xmpp/stanza"
|
||||
)
|
||||
|
@ -30,35 +28,33 @@ type Session struct {
|
|||
err error
|
||||
}
|
||||
|
||||
func NewSession(conn net.Conn, o Config, state SMState) (net.Conn, *Session, error) {
|
||||
func NewSession(transport Transport, o Config, state SMState) (*Session, error) {
|
||||
s := new(Session)
|
||||
s.SMState = state
|
||||
s.init(conn, o)
|
||||
s.init(transport, o)
|
||||
|
||||
// starttls
|
||||
var tlsConn net.Conn
|
||||
tlsConn = s.startTlsIfSupported(conn, o.parsedJid.Domain, o)
|
||||
s.startTlsIfSupported(transport, o.parsedJid.Domain, o)
|
||||
|
||||
if s.err != nil {
|
||||
return nil, nil, NewConnError(s.err, true)
|
||||
return nil, NewConnError(s.err, true)
|
||||
}
|
||||
|
||||
if !s.TlsEnabled && !o.Insecure {
|
||||
err := fmt.Errorf("failed to negotiate TLS session : %s", s.err)
|
||||
return nil, nil, NewConnError(err, true)
|
||||
return nil, NewConnError(err, true)
|
||||
}
|
||||
|
||||
if s.TlsEnabled {
|
||||
s.reset(conn, tlsConn, o)
|
||||
s.reset(transport, o)
|
||||
}
|
||||
|
||||
// auth
|
||||
s.auth(o)
|
||||
s.reset(tlsConn, tlsConn, o)
|
||||
s.reset(transport, o)
|
||||
|
||||
// attempt resumption
|
||||
if s.resume(o) {
|
||||
return tlsConn, s, s.err
|
||||
return s, s.err
|
||||
}
|
||||
|
||||
// otherwise, bind resource and 'start' XMPP session
|
||||
|
@ -68,7 +64,7 @@ func NewSession(conn net.Conn, o Config, state SMState) (net.Conn, *Session, err
|
|||
// Enable stream management if supported
|
||||
s.EnableStreamManagement(o)
|
||||
|
||||
return tlsConn, s, s.err
|
||||
return s, s.err
|
||||
}
|
||||
|
||||
func (s *Session) PacketId() string {
|
||||
|
@ -76,24 +72,22 @@ func (s *Session) PacketId() string {
|
|||
return fmt.Sprintf("%x", s.lastPacketId)
|
||||
}
|
||||
|
||||
func (s *Session) init(conn net.Conn, o Config) {
|
||||
s.setStreamLogger(nil, conn, o)
|
||||
func (s *Session) init(transport Transport, o Config) {
|
||||
s.setStreamLogger(transport, o)
|
||||
s.Features = s.open(o.parsedJid.Domain)
|
||||
}
|
||||
|
||||
func (s *Session) reset(conn net.Conn, newConn net.Conn, o Config) {
|
||||
func (s *Session) reset(transport Transport, o Config) {
|
||||
if s.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.setStreamLogger(conn, newConn, o)
|
||||
s.setStreamLogger(transport, o)
|
||||
s.Features = s.open(o.parsedJid.Domain)
|
||||
}
|
||||
|
||||
func (s *Session) setStreamLogger(conn net.Conn, newConn net.Conn, o Config) {
|
||||
if newConn != conn {
|
||||
s.streamLogger = newStreamLogger(newConn, o.StreamLogger)
|
||||
}
|
||||
func (s *Session) setStreamLogger(transport Transport, o Config) {
|
||||
s.streamLogger = newStreamLogger(transport, o.StreamLogger)
|
||||
s.decoder = xml.NewDecoder(s.streamLogger)
|
||||
s.decoder.CharsetReader = o.CharsetReader
|
||||
}
|
||||
|
@ -117,9 +111,16 @@ func (s *Session) open(domain string) (f stanza.StreamFeatures) {
|
|||
return
|
||||
}
|
||||
|
||||
func (s *Session) startTlsIfSupported(conn net.Conn, domain string, o Config) net.Conn {
|
||||
func (s *Session) startTlsIfSupported(transport Transport, domain string, o Config) {
|
||||
if s.err != nil {
|
||||
return conn
|
||||
return
|
||||
}
|
||||
|
||||
if !transport.DoesStartTLS() {
|
||||
if !o.Insecure {
|
||||
s.err = errors.New("Transport does not support starttls")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := s.Features.DoesStartTLS(); ok {
|
||||
|
@ -128,39 +129,21 @@ func (s *Session) startTlsIfSupported(conn net.Conn, domain string, o Config) ne
|
|||
var k stanza.TLSProceed
|
||||
if s.err = s.decoder.DecodeElement(&k, nil); s.err != nil {
|
||||
s.err = errors.New("expecting starttls proceed: " + s.err.Error())
|
||||
return conn
|
||||
return
|
||||
}
|
||||
|
||||
if o.TLSConfig == nil {
|
||||
o.TLSConfig = &tls.Config{}
|
||||
}
|
||||
|
||||
if o.TLSConfig.ServerName == "" {
|
||||
o.TLSConfig.ServerName = domain
|
||||
}
|
||||
tlsConn := tls.Client(conn, o.TLSConfig)
|
||||
// We convert existing connection to TLS
|
||||
if s.err = tlsConn.Handshake(); s.err != nil {
|
||||
return tlsConn
|
||||
}
|
||||
|
||||
if !o.TLSConfig.InsecureSkipVerify {
|
||||
s.err = tlsConn.VerifyHostname(domain)
|
||||
}
|
||||
s.err = transport.StartTLS(domain, o)
|
||||
|
||||
if s.err == nil {
|
||||
s.TlsEnabled = true
|
||||
}
|
||||
return tlsConn
|
||||
return
|
||||
}
|
||||
|
||||
// If we do not allow cleartext connections, make it explicit that server do not support starttls
|
||||
if !o.Insecure {
|
||||
s.err = errors.New("XMPP server does not advertise support for starttls")
|
||||
}
|
||||
|
||||
// starttls is not supported => we do not upgrade the connection:
|
||||
return conn
|
||||
}
|
||||
|
||||
func (s *Session) auth(o Config) {
|
||||
|
|
74
transport.go
Normal file
74
transport.go
Normal file
|
@ -0,0 +1,74 @@
|
|||
package xmpp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Transport interface {
|
||||
Connect(address string, c Config) error
|
||||
DoesStartTLS() bool
|
||||
StartTLS(domain string, c Config) error
|
||||
|
||||
Read(p []byte) (n int, err error)
|
||||
Write(p []byte) (n int, err error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// XMPPTransport implements the XMPP native TCP transport
|
||||
type XMPPTransport struct {
|
||||
TLSConfig *tls.Config
|
||||
// TCP level connection / can be replaced by a TLS session after starttls
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func (t *XMPPTransport) Connect(address string, c Config) error {
|
||||
var err error
|
||||
|
||||
t.conn, err = net.DialTimeout("tcp", address, time.Duration(c.ConnectTimeout)*time.Second)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t XMPPTransport) DoesStartTLS() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t *XMPPTransport) StartTLS(domain string, c Config) error {
|
||||
if t.TLSConfig == nil {
|
||||
if c.TLSConfig != nil {
|
||||
t.TLSConfig = c.TLSConfig
|
||||
} else {
|
||||
t.TLSConfig = &tls.Config{}
|
||||
}
|
||||
}
|
||||
|
||||
if t.TLSConfig.ServerName == "" {
|
||||
t.TLSConfig.ServerName = domain
|
||||
}
|
||||
tlsConn := tls.Client(t.conn, t.TLSConfig)
|
||||
// We convert existing connection to TLS
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !t.TLSConfig.InsecureSkipVerify {
|
||||
if err := tlsConn.VerifyHostname(domain); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t XMPPTransport) Read(p []byte) (n int, err error) {
|
||||
return t.conn.Read(p)
|
||||
}
|
||||
|
||||
func (t XMPPTransport) Write(p []byte) (n int, err error) {
|
||||
return t.conn.Write(p)
|
||||
}
|
||||
|
||||
func (t XMPPTransport) Close() error {
|
||||
return t.conn.Close()
|
||||
}
|
Loading…
Reference in a new issue