Introduce Transport interface

This commit is contained in:
Wichert Akkerman 2019-10-06 19:37:56 +02:00 committed by Mickaël Rémond
parent 2781563ea7
commit a3c62e515e
4 changed files with 115 additions and 57 deletions

View file

@ -90,7 +90,7 @@ type Client struct {
// Session gather data that can be accessed by users of this library
Session *Session
// TCP level connection / can be replaced by a TLS session after starttls
conn net.Conn
transport Transport
// Router is used to dispatch packets
router *Router
// Track and broadcast connection state
@ -139,6 +139,7 @@ func NewClient(config Config, r *Router) (c *Client, err error) {
c = new(Client)
c.config = config
c.router = r
c.transport = &XMPPTransport{}
if c.config.ConnectTimeout == 0 {
c.config.ConnectTimeout = 15 // 15 second as default
@ -159,21 +160,21 @@ func (c *Client) Connect() error {
func (c *Client) Resume(state SMState) error {
var err error
c.conn, err = net.DialTimeout("tcp", c.config.Address, time.Duration(c.config.ConnectTimeout)*time.Second)
err = c.transport.Connect(c.config.Address, c.config)
if err != nil {
return err
}
c.updateState(StateConnected)
// Client is ok, we now open XMPP session
if c.conn, c.Session, err = NewSession(c.conn, c.config, state); err != nil {
if c.Session, err = NewSession(c.transport, c.config, state); err != nil {
return err
}
c.updateState(StateSessionEstablished)
// Start the keepalive go routine
keepaliveQuit := make(chan struct{})
go keepalive(c.conn, keepaliveQuit)
go keepalive(c.transport, keepaliveQuit)
// Start the receiver go routine
state = c.Session.SMState
go c.recv(state, keepaliveQuit)
@ -190,7 +191,7 @@ func (c *Client) Resume(state SMState) error {
func (c *Client) Disconnect() {
_ = c.SendRaw("</stream:stream>")
// TODO: Add a way to wait for stream close acknowledgement from the server for clean disconnect
conn := c.conn
conn := c.transport
if conn != nil {
_ = conn.Close()
}
@ -202,7 +203,7 @@ func (c *Client) SetHandler(handler EventHandler) {
// Send marshals XMPP stanza and sends it to the server.
func (c *Client) Send(packet stanza.Packet) error {
conn := c.conn
conn := c.transport
if conn == nil {
return errors.New("client is not connected")
}
@ -220,7 +221,7 @@ func (c *Client) Send(packet stanza.Packet) error {
// disconnect the client. It is up to the user of this method to
// carefully craft the XML content to produce valid XMPP.
func (c *Client) SendRaw(packet string) error {
conn := c.conn
conn := c.transport
if conn == nil {
return errors.New("client is not connected")
}
@ -272,16 +273,16 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) (err error)
// Loop: send whitespace keepalive to server
// This is use to keep the connection open, but also to detect connection loss
// and trigger proper client connection shutdown.
func keepalive(conn net.Conn, quit <-chan struct{}) {
func keepalive(transport Transport, quit <-chan struct{}) {
// TODO: Make keepalive interval configurable
ticker := time.NewTicker(30 * time.Second)
for {
select {
case <-ticker.C:
if n, err := fmt.Fprintf(conn, "\n"); err != nil || n != 1 {
// When keep alive fails, we force close the connection. In all cases, the recv will also fail.
if n, err := fmt.Fprintf(transport, "\n"); err != nil || n != 1 {
// When keep alive fails, we force close the transportection. In all cases, the recv will also fail.
ticker.Stop()
_ = conn.Close()
_ = transport.Close()
return
}
case <-quit:

View file

@ -14,8 +14,8 @@ type Config struct {
StreamLogger *os.File // Used for debugging
Lang string // TODO: should default to 'en'
ConnectTimeout int // Client timeout in seconds. Default to 15
// tls.Config must not be modified after having been passed to NewClient. The
// Client connect method may override the tls.Config.ServerName if it was not set.
// tls.Config must not be modified after having been passed to NewClient. Any
// changes made after connecting are ignored.
TLSConfig *tls.Config
// Insecure can be set to true to allow to open a session without TLS. If TLS
// is supported on the server, we will still try to use it.

View file

@ -1,12 +1,10 @@
package xmpp
import (
"crypto/tls"
"encoding/xml"
"errors"
"fmt"
"io"
"net"
"gosrc.io/xmpp/stanza"
)
@ -30,35 +28,33 @@ type Session struct {
err error
}
func NewSession(conn net.Conn, o Config, state SMState) (net.Conn, *Session, error) {
func NewSession(transport Transport, o Config, state SMState) (*Session, error) {
s := new(Session)
s.SMState = state
s.init(conn, o)
s.init(transport, o)
// starttls
var tlsConn net.Conn
tlsConn = s.startTlsIfSupported(conn, o.parsedJid.Domain, o)
s.startTlsIfSupported(transport, o.parsedJid.Domain, o)
if s.err != nil {
return nil, nil, NewConnError(s.err, true)
return nil, NewConnError(s.err, true)
}
if !s.TlsEnabled && !o.Insecure {
err := fmt.Errorf("failed to negotiate TLS session : %s", s.err)
return nil, nil, NewConnError(err, true)
return nil, NewConnError(err, true)
}
if s.TlsEnabled {
s.reset(conn, tlsConn, o)
s.reset(transport, o)
}
// auth
s.auth(o)
s.reset(tlsConn, tlsConn, o)
s.reset(transport, o)
// attempt resumption
if s.resume(o) {
return tlsConn, s, s.err
return s, s.err
}
// otherwise, bind resource and 'start' XMPP session
@ -68,7 +64,7 @@ func NewSession(conn net.Conn, o Config, state SMState) (net.Conn, *Session, err
// Enable stream management if supported
s.EnableStreamManagement(o)
return tlsConn, s, s.err
return s, s.err
}
func (s *Session) PacketId() string {
@ -76,24 +72,22 @@ func (s *Session) PacketId() string {
return fmt.Sprintf("%x", s.lastPacketId)
}
func (s *Session) init(conn net.Conn, o Config) {
s.setStreamLogger(nil, conn, o)
func (s *Session) init(transport Transport, o Config) {
s.setStreamLogger(transport, o)
s.Features = s.open(o.parsedJid.Domain)
}
func (s *Session) reset(conn net.Conn, newConn net.Conn, o Config) {
func (s *Session) reset(transport Transport, o Config) {
if s.err != nil {
return
}
s.setStreamLogger(conn, newConn, o)
s.setStreamLogger(transport, o)
s.Features = s.open(o.parsedJid.Domain)
}
func (s *Session) setStreamLogger(conn net.Conn, newConn net.Conn, o Config) {
if newConn != conn {
s.streamLogger = newStreamLogger(newConn, o.StreamLogger)
}
func (s *Session) setStreamLogger(transport Transport, o Config) {
s.streamLogger = newStreamLogger(transport, o.StreamLogger)
s.decoder = xml.NewDecoder(s.streamLogger)
s.decoder.CharsetReader = o.CharsetReader
}
@ -117,9 +111,16 @@ func (s *Session) open(domain string) (f stanza.StreamFeatures) {
return
}
func (s *Session) startTlsIfSupported(conn net.Conn, domain string, o Config) net.Conn {
func (s *Session) startTlsIfSupported(transport Transport, domain string, o Config) {
if s.err != nil {
return conn
return
}
if !transport.DoesStartTLS() {
if !o.Insecure {
s.err = errors.New("Transport does not support starttls")
}
return
}
if _, ok := s.Features.DoesStartTLS(); ok {
@ -128,39 +129,21 @@ func (s *Session) startTlsIfSupported(conn net.Conn, domain string, o Config) ne
var k stanza.TLSProceed
if s.err = s.decoder.DecodeElement(&k, nil); s.err != nil {
s.err = errors.New("expecting starttls proceed: " + s.err.Error())
return conn
return
}
if o.TLSConfig == nil {
o.TLSConfig = &tls.Config{}
}
if o.TLSConfig.ServerName == "" {
o.TLSConfig.ServerName = domain
}
tlsConn := tls.Client(conn, o.TLSConfig)
// We convert existing connection to TLS
if s.err = tlsConn.Handshake(); s.err != nil {
return tlsConn
}
if !o.TLSConfig.InsecureSkipVerify {
s.err = tlsConn.VerifyHostname(domain)
}
s.err = transport.StartTLS(domain, o)
if s.err == nil {
s.TlsEnabled = true
}
return tlsConn
return
}
// If we do not allow cleartext connections, make it explicit that server do not support starttls
if !o.Insecure {
s.err = errors.New("XMPP server does not advertise support for starttls")
}
// starttls is not supported => we do not upgrade the connection:
return conn
}
func (s *Session) auth(o Config) {

74
transport.go Normal file
View file

@ -0,0 +1,74 @@
package xmpp
import (
"crypto/tls"
"net"
"time"
)
type Transport interface {
Connect(address string, c Config) error
DoesStartTLS() bool
StartTLS(domain string, c Config) error
Read(p []byte) (n int, err error)
Write(p []byte) (n int, err error)
Close() error
}
// XMPPTransport implements the XMPP native TCP transport
type XMPPTransport struct {
TLSConfig *tls.Config
// TCP level connection / can be replaced by a TLS session after starttls
conn net.Conn
}
func (t *XMPPTransport) Connect(address string, c Config) error {
var err error
t.conn, err = net.DialTimeout("tcp", address, time.Duration(c.ConnectTimeout)*time.Second)
return err
}
func (t XMPPTransport) DoesStartTLS() bool {
return true
}
func (t *XMPPTransport) StartTLS(domain string, c Config) error {
if t.TLSConfig == nil {
if c.TLSConfig != nil {
t.TLSConfig = c.TLSConfig
} else {
t.TLSConfig = &tls.Config{}
}
}
if t.TLSConfig.ServerName == "" {
t.TLSConfig.ServerName = domain
}
tlsConn := tls.Client(t.conn, t.TLSConfig)
// We convert existing connection to TLS
if err := tlsConn.Handshake(); err != nil {
return err
}
if !t.TLSConfig.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(domain); err != nil {
return err
}
}
return nil
}
func (t XMPPTransport) Read(p []byte) (n int, err error) {
return t.conn.Read(p)
}
func (t XMPPTransport) Write(p []byte) (n int, err error) {
return t.conn.Write(p)
}
func (t XMPPTransport) Close() error {
return t.conn.Close()
}