Introduce Transport interface
This commit is contained in:
parent
2781563ea7
commit
a3c62e515e
23
client.go
23
client.go
|
@ -90,7 +90,7 @@ type Client struct {
|
||||||
// Session gather data that can be accessed by users of this library
|
// Session gather data that can be accessed by users of this library
|
||||||
Session *Session
|
Session *Session
|
||||||
// TCP level connection / can be replaced by a TLS session after starttls
|
// TCP level connection / can be replaced by a TLS session after starttls
|
||||||
conn net.Conn
|
transport Transport
|
||||||
// Router is used to dispatch packets
|
// Router is used to dispatch packets
|
||||||
router *Router
|
router *Router
|
||||||
// Track and broadcast connection state
|
// Track and broadcast connection state
|
||||||
|
@ -139,6 +139,7 @@ func NewClient(config Config, r *Router) (c *Client, err error) {
|
||||||
c = new(Client)
|
c = new(Client)
|
||||||
c.config = config
|
c.config = config
|
||||||
c.router = r
|
c.router = r
|
||||||
|
c.transport = &XMPPTransport{}
|
||||||
|
|
||||||
if c.config.ConnectTimeout == 0 {
|
if c.config.ConnectTimeout == 0 {
|
||||||
c.config.ConnectTimeout = 15 // 15 second as default
|
c.config.ConnectTimeout = 15 // 15 second as default
|
||||||
|
@ -159,21 +160,21 @@ func (c *Client) Connect() error {
|
||||||
func (c *Client) Resume(state SMState) error {
|
func (c *Client) Resume(state SMState) error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
c.conn, err = net.DialTimeout("tcp", c.config.Address, time.Duration(c.config.ConnectTimeout)*time.Second)
|
err = c.transport.Connect(c.config.Address, c.config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
c.updateState(StateConnected)
|
c.updateState(StateConnected)
|
||||||
|
|
||||||
// Client is ok, we now open XMPP session
|
// Client is ok, we now open XMPP session
|
||||||
if c.conn, c.Session, err = NewSession(c.conn, c.config, state); err != nil {
|
if c.Session, err = NewSession(c.transport, c.config, state); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
c.updateState(StateSessionEstablished)
|
c.updateState(StateSessionEstablished)
|
||||||
|
|
||||||
// Start the keepalive go routine
|
// Start the keepalive go routine
|
||||||
keepaliveQuit := make(chan struct{})
|
keepaliveQuit := make(chan struct{})
|
||||||
go keepalive(c.conn, keepaliveQuit)
|
go keepalive(c.transport, keepaliveQuit)
|
||||||
// Start the receiver go routine
|
// Start the receiver go routine
|
||||||
state = c.Session.SMState
|
state = c.Session.SMState
|
||||||
go c.recv(state, keepaliveQuit)
|
go c.recv(state, keepaliveQuit)
|
||||||
|
@ -190,7 +191,7 @@ func (c *Client) Resume(state SMState) error {
|
||||||
func (c *Client) Disconnect() {
|
func (c *Client) Disconnect() {
|
||||||
_ = c.SendRaw("</stream:stream>")
|
_ = c.SendRaw("</stream:stream>")
|
||||||
// TODO: Add a way to wait for stream close acknowledgement from the server for clean disconnect
|
// TODO: Add a way to wait for stream close acknowledgement from the server for clean disconnect
|
||||||
conn := c.conn
|
conn := c.transport
|
||||||
if conn != nil {
|
if conn != nil {
|
||||||
_ = conn.Close()
|
_ = conn.Close()
|
||||||
}
|
}
|
||||||
|
@ -202,7 +203,7 @@ func (c *Client) SetHandler(handler EventHandler) {
|
||||||
|
|
||||||
// Send marshals XMPP stanza and sends it to the server.
|
// Send marshals XMPP stanza and sends it to the server.
|
||||||
func (c *Client) Send(packet stanza.Packet) error {
|
func (c *Client) Send(packet stanza.Packet) error {
|
||||||
conn := c.conn
|
conn := c.transport
|
||||||
if conn == nil {
|
if conn == nil {
|
||||||
return errors.New("client is not connected")
|
return errors.New("client is not connected")
|
||||||
}
|
}
|
||||||
|
@ -220,7 +221,7 @@ func (c *Client) Send(packet stanza.Packet) error {
|
||||||
// disconnect the client. It is up to the user of this method to
|
// disconnect the client. It is up to the user of this method to
|
||||||
// carefully craft the XML content to produce valid XMPP.
|
// carefully craft the XML content to produce valid XMPP.
|
||||||
func (c *Client) SendRaw(packet string) error {
|
func (c *Client) SendRaw(packet string) error {
|
||||||
conn := c.conn
|
conn := c.transport
|
||||||
if conn == nil {
|
if conn == nil {
|
||||||
return errors.New("client is not connected")
|
return errors.New("client is not connected")
|
||||||
}
|
}
|
||||||
|
@ -272,16 +273,16 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) (err error)
|
||||||
// Loop: send whitespace keepalive to server
|
// Loop: send whitespace keepalive to server
|
||||||
// This is use to keep the connection open, but also to detect connection loss
|
// This is use to keep the connection open, but also to detect connection loss
|
||||||
// and trigger proper client connection shutdown.
|
// and trigger proper client connection shutdown.
|
||||||
func keepalive(conn net.Conn, quit <-chan struct{}) {
|
func keepalive(transport Transport, quit <-chan struct{}) {
|
||||||
// TODO: Make keepalive interval configurable
|
// TODO: Make keepalive interval configurable
|
||||||
ticker := time.NewTicker(30 * time.Second)
|
ticker := time.NewTicker(30 * time.Second)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
if n, err := fmt.Fprintf(conn, "\n"); err != nil || n != 1 {
|
if n, err := fmt.Fprintf(transport, "\n"); err != nil || n != 1 {
|
||||||
// When keep alive fails, we force close the connection. In all cases, the recv will also fail.
|
// When keep alive fails, we force close the transportection. In all cases, the recv will also fail.
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
_ = conn.Close()
|
_ = transport.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case <-quit:
|
case <-quit:
|
||||||
|
|
|
@ -14,8 +14,8 @@ type Config struct {
|
||||||
StreamLogger *os.File // Used for debugging
|
StreamLogger *os.File // Used for debugging
|
||||||
Lang string // TODO: should default to 'en'
|
Lang string // TODO: should default to 'en'
|
||||||
ConnectTimeout int // Client timeout in seconds. Default to 15
|
ConnectTimeout int // Client timeout in seconds. Default to 15
|
||||||
// tls.Config must not be modified after having been passed to NewClient. The
|
// tls.Config must not be modified after having been passed to NewClient. Any
|
||||||
// Client connect method may override the tls.Config.ServerName if it was not set.
|
// changes made after connecting are ignored.
|
||||||
TLSConfig *tls.Config
|
TLSConfig *tls.Config
|
||||||
// Insecure can be set to true to allow to open a session without TLS. If TLS
|
// Insecure can be set to true to allow to open a session without TLS. If TLS
|
||||||
// is supported on the server, we will still try to use it.
|
// is supported on the server, we will still try to use it.
|
||||||
|
|
71
session.go
71
session.go
|
@ -1,12 +1,10 @@
|
||||||
package xmpp
|
package xmpp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
|
||||||
"encoding/xml"
|
"encoding/xml"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
|
|
||||||
"gosrc.io/xmpp/stanza"
|
"gosrc.io/xmpp/stanza"
|
||||||
)
|
)
|
||||||
|
@ -30,35 +28,33 @@ type Session struct {
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSession(conn net.Conn, o Config, state SMState) (net.Conn, *Session, error) {
|
func NewSession(transport Transport, o Config, state SMState) (*Session, error) {
|
||||||
s := new(Session)
|
s := new(Session)
|
||||||
s.SMState = state
|
s.SMState = state
|
||||||
s.init(conn, o)
|
s.init(transport, o)
|
||||||
|
|
||||||
// starttls
|
s.startTlsIfSupported(transport, o.parsedJid.Domain, o)
|
||||||
var tlsConn net.Conn
|
|
||||||
tlsConn = s.startTlsIfSupported(conn, o.parsedJid.Domain, o)
|
|
||||||
|
|
||||||
if s.err != nil {
|
if s.err != nil {
|
||||||
return nil, nil, NewConnError(s.err, true)
|
return nil, NewConnError(s.err, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !s.TlsEnabled && !o.Insecure {
|
if !s.TlsEnabled && !o.Insecure {
|
||||||
err := fmt.Errorf("failed to negotiate TLS session : %s", s.err)
|
err := fmt.Errorf("failed to negotiate TLS session : %s", s.err)
|
||||||
return nil, nil, NewConnError(err, true)
|
return nil, NewConnError(err, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.TlsEnabled {
|
if s.TlsEnabled {
|
||||||
s.reset(conn, tlsConn, o)
|
s.reset(transport, o)
|
||||||
}
|
}
|
||||||
|
|
||||||
// auth
|
// auth
|
||||||
s.auth(o)
|
s.auth(o)
|
||||||
s.reset(tlsConn, tlsConn, o)
|
s.reset(transport, o)
|
||||||
|
|
||||||
// attempt resumption
|
// attempt resumption
|
||||||
if s.resume(o) {
|
if s.resume(o) {
|
||||||
return tlsConn, s, s.err
|
return s, s.err
|
||||||
}
|
}
|
||||||
|
|
||||||
// otherwise, bind resource and 'start' XMPP session
|
// otherwise, bind resource and 'start' XMPP session
|
||||||
|
@ -68,7 +64,7 @@ func NewSession(conn net.Conn, o Config, state SMState) (net.Conn, *Session, err
|
||||||
// Enable stream management if supported
|
// Enable stream management if supported
|
||||||
s.EnableStreamManagement(o)
|
s.EnableStreamManagement(o)
|
||||||
|
|
||||||
return tlsConn, s, s.err
|
return s, s.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) PacketId() string {
|
func (s *Session) PacketId() string {
|
||||||
|
@ -76,24 +72,22 @@ func (s *Session) PacketId() string {
|
||||||
return fmt.Sprintf("%x", s.lastPacketId)
|
return fmt.Sprintf("%x", s.lastPacketId)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) init(conn net.Conn, o Config) {
|
func (s *Session) init(transport Transport, o Config) {
|
||||||
s.setStreamLogger(nil, conn, o)
|
s.setStreamLogger(transport, o)
|
||||||
s.Features = s.open(o.parsedJid.Domain)
|
s.Features = s.open(o.parsedJid.Domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) reset(conn net.Conn, newConn net.Conn, o Config) {
|
func (s *Session) reset(transport Transport, o Config) {
|
||||||
if s.err != nil {
|
if s.err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s.setStreamLogger(conn, newConn, o)
|
s.setStreamLogger(transport, o)
|
||||||
s.Features = s.open(o.parsedJid.Domain)
|
s.Features = s.open(o.parsedJid.Domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) setStreamLogger(conn net.Conn, newConn net.Conn, o Config) {
|
func (s *Session) setStreamLogger(transport Transport, o Config) {
|
||||||
if newConn != conn {
|
s.streamLogger = newStreamLogger(transport, o.StreamLogger)
|
||||||
s.streamLogger = newStreamLogger(newConn, o.StreamLogger)
|
|
||||||
}
|
|
||||||
s.decoder = xml.NewDecoder(s.streamLogger)
|
s.decoder = xml.NewDecoder(s.streamLogger)
|
||||||
s.decoder.CharsetReader = o.CharsetReader
|
s.decoder.CharsetReader = o.CharsetReader
|
||||||
}
|
}
|
||||||
|
@ -117,9 +111,16 @@ func (s *Session) open(domain string) (f stanza.StreamFeatures) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) startTlsIfSupported(conn net.Conn, domain string, o Config) net.Conn {
|
func (s *Session) startTlsIfSupported(transport Transport, domain string, o Config) {
|
||||||
if s.err != nil {
|
if s.err != nil {
|
||||||
return conn
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !transport.DoesStartTLS() {
|
||||||
|
if !o.Insecure {
|
||||||
|
s.err = errors.New("Transport does not support starttls")
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := s.Features.DoesStartTLS(); ok {
|
if _, ok := s.Features.DoesStartTLS(); ok {
|
||||||
|
@ -128,39 +129,21 @@ func (s *Session) startTlsIfSupported(conn net.Conn, domain string, o Config) ne
|
||||||
var k stanza.TLSProceed
|
var k stanza.TLSProceed
|
||||||
if s.err = s.decoder.DecodeElement(&k, nil); s.err != nil {
|
if s.err = s.decoder.DecodeElement(&k, nil); s.err != nil {
|
||||||
s.err = errors.New("expecting starttls proceed: " + s.err.Error())
|
s.err = errors.New("expecting starttls proceed: " + s.err.Error())
|
||||||
return conn
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if o.TLSConfig == nil {
|
s.err = transport.StartTLS(domain, o)
|
||||||
o.TLSConfig = &tls.Config{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if o.TLSConfig.ServerName == "" {
|
|
||||||
o.TLSConfig.ServerName = domain
|
|
||||||
}
|
|
||||||
tlsConn := tls.Client(conn, o.TLSConfig)
|
|
||||||
// We convert existing connection to TLS
|
|
||||||
if s.err = tlsConn.Handshake(); s.err != nil {
|
|
||||||
return tlsConn
|
|
||||||
}
|
|
||||||
|
|
||||||
if !o.TLSConfig.InsecureSkipVerify {
|
|
||||||
s.err = tlsConn.VerifyHostname(domain)
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.err == nil {
|
if s.err == nil {
|
||||||
s.TlsEnabled = true
|
s.TlsEnabled = true
|
||||||
}
|
}
|
||||||
return tlsConn
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we do not allow cleartext connections, make it explicit that server do not support starttls
|
// If we do not allow cleartext connections, make it explicit that server do not support starttls
|
||||||
if !o.Insecure {
|
if !o.Insecure {
|
||||||
s.err = errors.New("XMPP server does not advertise support for starttls")
|
s.err = errors.New("XMPP server does not advertise support for starttls")
|
||||||
}
|
}
|
||||||
|
|
||||||
// starttls is not supported => we do not upgrade the connection:
|
|
||||||
return conn
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) auth(o Config) {
|
func (s *Session) auth(o Config) {
|
||||||
|
|
74
transport.go
Normal file
74
transport.go
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
package xmpp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Transport interface {
|
||||||
|
Connect(address string, c Config) error
|
||||||
|
DoesStartTLS() bool
|
||||||
|
StartTLS(domain string, c Config) error
|
||||||
|
|
||||||
|
Read(p []byte) (n int, err error)
|
||||||
|
Write(p []byte) (n int, err error)
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// XMPPTransport implements the XMPP native TCP transport
|
||||||
|
type XMPPTransport struct {
|
||||||
|
TLSConfig *tls.Config
|
||||||
|
// TCP level connection / can be replaced by a TLS session after starttls
|
||||||
|
conn net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *XMPPTransport) Connect(address string, c Config) error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
t.conn, err = net.DialTimeout("tcp", address, time.Duration(c.ConnectTimeout)*time.Second)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t XMPPTransport) DoesStartTLS() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *XMPPTransport) StartTLS(domain string, c Config) error {
|
||||||
|
if t.TLSConfig == nil {
|
||||||
|
if c.TLSConfig != nil {
|
||||||
|
t.TLSConfig = c.TLSConfig
|
||||||
|
} else {
|
||||||
|
t.TLSConfig = &tls.Config{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.TLSConfig.ServerName == "" {
|
||||||
|
t.TLSConfig.ServerName = domain
|
||||||
|
}
|
||||||
|
tlsConn := tls.Client(t.conn, t.TLSConfig)
|
||||||
|
// We convert existing connection to TLS
|
||||||
|
if err := tlsConn.Handshake(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !t.TLSConfig.InsecureSkipVerify {
|
||||||
|
if err := tlsConn.VerifyHostname(domain); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t XMPPTransport) Read(p []byte) (n int, err error) {
|
||||||
|
return t.conn.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t XMPPTransport) Write(p []byte) (n int, err error) {
|
||||||
|
return t.conn.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t XMPPTransport) Close() error {
|
||||||
|
return t.conn.Close()
|
||||||
|
}
|
Loading…
Reference in a new issue