diff --git a/client.go b/client.go
index 7bb7a80..7e40cb4 100644
--- a/client.go
+++ b/client.go
@@ -31,6 +31,18 @@ type Event struct {
State ConnState
Description string
StreamError string
+ SMState SMState
+}
+
+// SMState holds Stream Management information regarding the session that can be
+// used to resume session after disconnect
+type SMState struct {
+ // Stream Management ID
+ Id string
+ // Inbound stanza count
+ Inbound uint
+ // TODO Store location for IP affinity
+ // TODO Store max and timestamp, to check if we should retry resumption or not
}
// EventHandler is use to pass events about state of the connection to
@@ -52,6 +64,13 @@ func (em EventManager) updateState(state ConnState) {
}
}
+func (em EventManager) disconnected(state SMState) {
+ em.CurrentState = StateDisconnected
+ if em.Handler != nil {
+ em.Handler(Event{State: em.CurrentState, SMState: state})
+ }
+}
+
func (em EventManager) streamError(error, desc string) {
em.CurrentState = StateStreamError
if em.Handler != nil {
@@ -128,7 +147,15 @@ func NewClient(config Config, r *Router) (c *Client, err error) {
}
// Connect triggers actual TCP connection, based on previously defined parameters.
+// Connect simply triggers resumption, with an empty session state.
func (c *Client) Connect() error {
+ var state SMState
+ return c.Resume(state)
+}
+
+// Resume attempts resuming a Stream Managed session, based on the provided stream management
+// state.
+func (c *Client) Resume(state SMState) error {
var err error
c.conn, err = net.DialTimeout("tcp", c.config.Address, time.Duration(c.config.ConnectTimeout)*time.Second)
@@ -138,23 +165,24 @@ func (c *Client) Connect() error {
c.updateState(StateConnected)
// Client is ok, we now open XMPP session
- if c.conn, c.Session, err = NewSession(c.conn, c.config); err != nil {
+ if c.conn, c.Session, err = NewSession(c.conn, c.config, state); err != nil {
return err
}
c.updateState(StateSessionEstablished)
+ // Start the keepalive go routine
+ keepaliveQuit := make(chan struct{})
+ go keepalive(c.conn, keepaliveQuit)
+ // Start the receiver go routine
+ state = c.Session.SMState
+ go c.recv(state, keepaliveQuit)
+
// We're connected and can now receive and send messages.
//fmt.Fprintf(client.conn, "%s%s", "chat", "Online")
// TODO: Do we always want to send initial presence automatically ?
// Do we need an option to avoid that or do we rely on client to send the presence itself ?
fmt.Fprintf(c.Session.streamLogger, "")
- // Start the keepalive go routine
- keepaliveQuit := make(chan struct{})
- go keepalive(c.conn, keepaliveQuit)
- // Start the receiver go routine
- go c.recv(keepaliveQuit)
-
return err
}
@@ -206,12 +234,12 @@ func (c *Client) sendWithLogger(packet string) error {
// Go routines
// Loop: Receive data from server
-func (c *Client) recv(keepaliveQuit chan<- struct{}) (err error) {
+func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) (err error) {
for {
val, err := stanza.NextPacket(c.Session.decoder)
if err != nil {
close(keepaliveQuit)
- c.updateState(StateDisconnected)
+ c.disconnected(state)
return err
}
@@ -222,6 +250,17 @@ func (c *Client) recv(keepaliveQuit chan<- struct{}) (err error) {
close(keepaliveQuit)
c.streamError(packet.Error.Local, packet.Text)
return errors.New("stream error: " + packet.Error.Local)
+ // Process Stream management nonzas
+ case stanza.SMRequest:
+ fmt.Println("MREMOND: inbound: ", state.Inbound)
+ answer := stanza.SMAnswer{XMLName: xml.Name{
+ Space: stanza.NSStreamManagement,
+ Local: "a",
+ }, H: state.Inbound}
+ c.Send(answer)
+ default:
+ fmt.Println(packet)
+ state.Inbound++
}
c.router.route(c, val)
@@ -243,6 +282,9 @@ func keepalive(conn net.Conn, quit <-chan struct{}) {
_ = conn.Close()
return
}
+ case <-time.After(3 * time.Second):
+ _ = conn.Close()
+ return
case <-quit:
ticker.Stop()
return
diff --git a/component.go b/component.go
index 0176371..af424a2 100644
--- a/component.go
+++ b/component.go
@@ -108,6 +108,10 @@ func (c *Component) Connect() error {
}
}
+func (c *Component) Resume() error {
+ return errors.New("components do not support stream management")
+}
+
func (c *Component) Disconnect() {
_ = c.SendRaw("")
// TODO: Add a way to wait for stream close acknowledgement from the server for clean disconnect
diff --git a/session.go b/session.go
index 1092569..8253ab6 100644
--- a/session.go
+++ b/session.go
@@ -17,6 +17,7 @@ type Session struct {
// Session info
BindJid string // Jabber ID as provided by XMPP server
StreamId string
+ SMState SMState
Features stanza.StreamFeatures
TlsEnabled bool
lastPacketId int
@@ -29,8 +30,9 @@ type Session struct {
err error
}
-func NewSession(conn net.Conn, o Config) (net.Conn, *Session, error) {
+func NewSession(conn net.Conn, o Config, state SMState) (net.Conn, *Session, error) {
s := new(Session)
+ s.SMState = state
s.init(conn, o)
// starttls
@@ -54,10 +56,18 @@ func NewSession(conn net.Conn, o Config) (net.Conn, *Session, error) {
s.auth(o)
s.reset(tlsConn, tlsConn, o)
- // bind resource and 'start' XMPP session
+ // attempt resumption
+ if s.resume(o) {
+ return tlsConn, s, s.err
+ }
+
+ // otherwise, bind resource and 'start' XMPP session
s.bind(o)
s.rfc3921Session(o)
+ // Enable stream management if supported
+ s.EnableStreamManagement(o)
+
return tlsConn, s, s.err
}
@@ -161,6 +171,39 @@ func (s *Session) auth(o Config) {
s.err = authSASL(s.streamLogger, s.decoder, s.Features, o.parsedJid.Node, o.Password)
}
+// Attempt to resume session using stream management
+func (s *Session) resume(o Config) bool {
+ if !s.Features.DoesStreamManagement() {
+ return false
+ }
+ if s.SMState.Id == "" {
+ return false
+ }
+
+ fmt.Fprintf(s.streamLogger, "",
+ stanza.NSStreamManagement, s.SMState.Inbound, s.SMState.Id)
+
+ var packet stanza.Packet
+ packet, s.err = stanza.NextPacket(s.decoder)
+ if s.err == nil {
+ switch p := packet.(type) {
+ case stanza.SMResumed:
+ if p.PrevId != s.SMState.Id {
+ s.err = errors.New("session resumption: mismatched id")
+ s.SMState = SMState{}
+ return false
+ }
+ return true
+ case stanza.SMFailed:
+ fmt.Println("MREMOND SM Failed")
+ default:
+ s.err = errors.New("unexpected reply to SM resume")
+ }
+ }
+ s.SMState = SMState{}
+ return false
+}
+
func (s *Session) bind(o Config) {
if s.err != nil {
return
@@ -208,3 +251,31 @@ func (s *Session) rfc3921Session(o Config) {
}
}
}
+
+// Enable stream management, with session resumption, if supported.
+func (s *Session) EnableStreamManagement(o Config) {
+ if s.err != nil {
+ return
+ }
+ if !s.Features.DoesStreamManagement() {
+ return
+ }
+
+ fmt.Fprintf(s.streamLogger, "", stanza.NSStreamManagement)
+
+ var packet stanza.Packet
+ packet, s.err = stanza.NextPacket(s.decoder)
+ if s.err == nil {
+ switch p := packet.(type) {
+ case stanza.SMEnabled:
+ s.SMState = SMState{Id: p.Id}
+ case stanza.SMFailed:
+ // TODO: Store error in SMState
+ default:
+ fmt.Println(p)
+ s.err = errors.New("unexpected reply to SM enable")
+ }
+ }
+
+ return
+}
diff --git a/stanza/parser.go b/stanza/parser.go
index c83e17e..cdd8b70 100644
--- a/stanza/parser.go
+++ b/stanza/parser.go
@@ -63,6 +63,8 @@ func NextPacket(p *xml.Decoder) (Packet, error) {
return decodeClient(p, se)
case NSComponent:
return decodeComponent(p, se)
+ case NSStreamManagement:
+ return sm.decode(p, se)
default:
return nil, errors.New("unknown namespace " +
se.Name.Space + " <" + se.Name.Local + "/>")
@@ -133,7 +135,7 @@ func decodeClient(p *xml.Decoder, se xml.StartElement) (Packet, error) {
}
}
-// decodeClient decodes all known packets in the component namespace.
+// decodeComponent decodes all known packets in the component namespace.
func decodeComponent(p *xml.Decoder, se xml.StartElement) (Packet, error) {
switch se.Name.Local {
case "handshake": // handshake is used to authenticate components
diff --git a/stanza/stream_management.go b/stanza/stream_management.go
new file mode 100644
index 0000000..ddbe9cd
--- /dev/null
+++ b/stanza/stream_management.go
@@ -0,0 +1,121 @@
+package stanza
+
+import (
+ "encoding/xml"
+ "errors"
+)
+
+const (
+ NSStreamManagement = "urn:xmpp:sm:3"
+)
+
+// Enabled as defined in Stream Management spec
+// Reference: https://xmpp.org/extensions/xep-0198.html#enable
+type SMEnabled struct {
+ XMLName xml.Name `xml:"urn:xmpp:sm:3 enabled"`
+ Id string `xml:"id,attr,omitempty"`
+ Location string `xml:"location,attr,omitempty"`
+ Resume string `xml:"resume,attr,omitempty"`
+ Max uint `xml:"max,attr,omitempty"`
+}
+
+func (SMEnabled) Name() string {
+ return "Stream Management: enabled"
+}
+
+// Request as defined in Stream Management spec
+// Reference: https://xmpp.org/extensions/xep-0198.html#acking
+type SMRequest struct {
+ XMLName xml.Name `xml:"urn:xmpp:sm:3 r"`
+}
+
+func (SMRequest) Name() string {
+ return "Stream Management: request"
+}
+
+// Answer as defined in Stream Management spec
+// Reference: https://xmpp.org/extensions/xep-0198.html#acking
+type SMAnswer struct {
+ XMLName xml.Name `xml:"urn:xmpp:sm:3 a"`
+ H uint `xml:"h,attr,omitempty"`
+}
+
+func (SMAnswer) Name() string {
+ return "Stream Management: answer"
+}
+
+// Resumed as defined in Stream Management spec
+// Reference: https://xmpp.org/extensions/xep-0198.html#acking
+type SMResumed struct {
+ XMLName xml.Name `xml:"urn:xmpp:sm:3 resumed"`
+ PrevId string `xml:"previd,attr,omitempty"`
+ H uint `xml:"h,attr,omitempty"`
+}
+
+func (SMResumed) Name() string {
+ return "Stream Management: resumed"
+}
+
+// Failed as defined in Stream Management spec
+// Reference: https://xmpp.org/extensions/xep-0198.html#acking
+type SMFailed struct {
+ XMLName xml.Name `xml:"urn:xmpp:sm:3 failed"`
+ // TODO: Handle decoding error cause (need custom parsing).
+}
+
+func (SMFailed) Name() string {
+ return "Stream Management: failed"
+}
+
+type smDecoder struct{}
+
+var sm smDecoder
+
+// decode decodes all known nonza in the stream management namespace.
+func (s smDecoder) decode(p *xml.Decoder, se xml.StartElement) (Packet, error) {
+ switch se.Name.Local {
+ case "enabled":
+ return s.decodeEnabled(p, se)
+ case "resumed":
+ return s.decodeResumed(p, se)
+ case "r":
+ return s.decodeRequest(p, se)
+ case "h":
+ return s.decodeAnswer(p, se)
+ case "failed":
+ return s.decodeFailed(p, se)
+ default:
+ return nil, errors.New("unexpected XMPP packet " +
+ se.Name.Space + " <" + se.Name.Local + "/>")
+ }
+}
+
+func (smDecoder) decodeEnabled(p *xml.Decoder, se xml.StartElement) (SMEnabled, error) {
+ var packet SMEnabled
+ err := p.DecodeElement(&packet, &se)
+ return packet, err
+}
+
+func (smDecoder) decodeResumed(p *xml.Decoder, se xml.StartElement) (SMResumed, error) {
+ var packet SMResumed
+ err := p.DecodeElement(&packet, &se)
+ return packet, err
+}
+
+func (smDecoder) decodeRequest(p *xml.Decoder, se xml.StartElement) (SMRequest, error) {
+ var packet SMRequest
+ err := p.DecodeElement(&packet, &se)
+ return packet, err
+}
+
+func (smDecoder) decodeAnswer(p *xml.Decoder, se xml.StartElement) (SMAnswer, error) {
+ var packet SMAnswer
+ err := p.DecodeElement(&packet, &se)
+ return packet, err
+}
+
+func (smDecoder) decodeFailed(p *xml.Decoder, se xml.StartElement) (SMFailed, error) {
+ var packet SMFailed
+ err := p.DecodeElement(&packet, &se)
+ return packet, err
+}
diff --git a/stream_manager.go b/stream_manager.go
index 1aaf164..b81a783 100644
--- a/stream_manager.go
+++ b/stream_manager.go
@@ -24,6 +24,7 @@ import (
// set callback and trigger reconnection.
type StreamClient interface {
Connect() error
+ Resume(state SMState) error
Send(packet stanza.Packet) error
SendRaw(packet string) error
Disconnect()
@@ -78,7 +79,7 @@ func (sm *StreamManager) Run() error {
sm.Metrics.setLoginTime()
case StateDisconnected:
// Reconnect on disconnection
- sm.connect()
+ sm.resume(e.SMState)
case StateStreamError:
sm.client.Disconnect()
// Only try reconnecting if we have not been kicked by another session to avoid connection loop.
@@ -106,8 +107,13 @@ func (sm *StreamManager) Stop() {
sm.wg.Done()
}
-// connect manages the reconnection loop and apply the define backoff to avoid overloading the server.
func (sm *StreamManager) connect() error {
+ var state SMState
+ return sm.resume(state)
+}
+
+// resume manages the reconnection loop and apply the define backoff to avoid overloading the server.
+func (sm *StreamManager) resume(state SMState) error {
var backoff backoff // TODO: Group backoff calculation features with connection manager?
for {
@@ -115,7 +121,7 @@ func (sm *StreamManager) connect() error {
// TODO: Make it possible to define logger to log disconnect and reconnection attempts
sm.Metrics = initMetrics()
- if err = sm.client.Connect(); err != nil {
+ if err = sm.client.Resume(state); err != nil {
var actualErr ConnError
if xerrors.As(err, &actualErr) {
if actualErr.Permanent {