From 3de99e0e0e6582486ad052edcc41bbce12205c5c Mon Sep 17 00:00:00 2001 From: Mickael Remond Date: Wed, 31 Jul 2019 18:47:30 +0200 Subject: [PATCH] Add initial support for stream management For now it support enabling SM, replying to ack requests from server, and trying resuming the session with existing Stream Management state. --- client.go | 60 +++++++++++++++--- component.go | 4 ++ session.go | 75 +++++++++++++++++++++- stanza/parser.go | 4 +- stanza/stream_management.go | 121 ++++++++++++++++++++++++++++++++++++ stream_manager.go | 12 +++- 6 files changed, 261 insertions(+), 15 deletions(-) create mode 100644 stanza/stream_management.go diff --git a/client.go b/client.go index 7bb7a80..7e40cb4 100644 --- a/client.go +++ b/client.go @@ -31,6 +31,18 @@ type Event struct { State ConnState Description string StreamError string + SMState SMState +} + +// SMState holds Stream Management information regarding the session that can be +// used to resume session after disconnect +type SMState struct { + // Stream Management ID + Id string + // Inbound stanza count + Inbound uint + // TODO Store location for IP affinity + // TODO Store max and timestamp, to check if we should retry resumption or not } // EventHandler is use to pass events about state of the connection to @@ -52,6 +64,13 @@ func (em EventManager) updateState(state ConnState) { } } +func (em EventManager) disconnected(state SMState) { + em.CurrentState = StateDisconnected + if em.Handler != nil { + em.Handler(Event{State: em.CurrentState, SMState: state}) + } +} + func (em EventManager) streamError(error, desc string) { em.CurrentState = StateStreamError if em.Handler != nil { @@ -128,7 +147,15 @@ func NewClient(config Config, r *Router) (c *Client, err error) { } // Connect triggers actual TCP connection, based on previously defined parameters. +// Connect simply triggers resumption, with an empty session state. func (c *Client) Connect() error { + var state SMState + return c.Resume(state) +} + +// Resume attempts resuming a Stream Managed session, based on the provided stream management +// state. +func (c *Client) Resume(state SMState) error { var err error c.conn, err = net.DialTimeout("tcp", c.config.Address, time.Duration(c.config.ConnectTimeout)*time.Second) @@ -138,23 +165,24 @@ func (c *Client) Connect() error { c.updateState(StateConnected) // Client is ok, we now open XMPP session - if c.conn, c.Session, err = NewSession(c.conn, c.config); err != nil { + if c.conn, c.Session, err = NewSession(c.conn, c.config, state); err != nil { return err } c.updateState(StateSessionEstablished) + // Start the keepalive go routine + keepaliveQuit := make(chan struct{}) + go keepalive(c.conn, keepaliveQuit) + // Start the receiver go routine + state = c.Session.SMState + go c.recv(state, keepaliveQuit) + // We're connected and can now receive and send messages. //fmt.Fprintf(client.conn, "%s%s", "chat", "Online") // TODO: Do we always want to send initial presence automatically ? // Do we need an option to avoid that or do we rely on client to send the presence itself ? fmt.Fprintf(c.Session.streamLogger, "") - // Start the keepalive go routine - keepaliveQuit := make(chan struct{}) - go keepalive(c.conn, keepaliveQuit) - // Start the receiver go routine - go c.recv(keepaliveQuit) - return err } @@ -206,12 +234,12 @@ func (c *Client) sendWithLogger(packet string) error { // Go routines // Loop: Receive data from server -func (c *Client) recv(keepaliveQuit chan<- struct{}) (err error) { +func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) (err error) { for { val, err := stanza.NextPacket(c.Session.decoder) if err != nil { close(keepaliveQuit) - c.updateState(StateDisconnected) + c.disconnected(state) return err } @@ -222,6 +250,17 @@ func (c *Client) recv(keepaliveQuit chan<- struct{}) (err error) { close(keepaliveQuit) c.streamError(packet.Error.Local, packet.Text) return errors.New("stream error: " + packet.Error.Local) + // Process Stream management nonzas + case stanza.SMRequest: + fmt.Println("MREMOND: inbound: ", state.Inbound) + answer := stanza.SMAnswer{XMLName: xml.Name{ + Space: stanza.NSStreamManagement, + Local: "a", + }, H: state.Inbound} + c.Send(answer) + default: + fmt.Println(packet) + state.Inbound++ } c.router.route(c, val) @@ -243,6 +282,9 @@ func keepalive(conn net.Conn, quit <-chan struct{}) { _ = conn.Close() return } + case <-time.After(3 * time.Second): + _ = conn.Close() + return case <-quit: ticker.Stop() return diff --git a/component.go b/component.go index 0176371..af424a2 100644 --- a/component.go +++ b/component.go @@ -108,6 +108,10 @@ func (c *Component) Connect() error { } } +func (c *Component) Resume() error { + return errors.New("components do not support stream management") +} + func (c *Component) Disconnect() { _ = c.SendRaw("") // TODO: Add a way to wait for stream close acknowledgement from the server for clean disconnect diff --git a/session.go b/session.go index 1092569..8253ab6 100644 --- a/session.go +++ b/session.go @@ -17,6 +17,7 @@ type Session struct { // Session info BindJid string // Jabber ID as provided by XMPP server StreamId string + SMState SMState Features stanza.StreamFeatures TlsEnabled bool lastPacketId int @@ -29,8 +30,9 @@ type Session struct { err error } -func NewSession(conn net.Conn, o Config) (net.Conn, *Session, error) { +func NewSession(conn net.Conn, o Config, state SMState) (net.Conn, *Session, error) { s := new(Session) + s.SMState = state s.init(conn, o) // starttls @@ -54,10 +56,18 @@ func NewSession(conn net.Conn, o Config) (net.Conn, *Session, error) { s.auth(o) s.reset(tlsConn, tlsConn, o) - // bind resource and 'start' XMPP session + // attempt resumption + if s.resume(o) { + return tlsConn, s, s.err + } + + // otherwise, bind resource and 'start' XMPP session s.bind(o) s.rfc3921Session(o) + // Enable stream management if supported + s.EnableStreamManagement(o) + return tlsConn, s, s.err } @@ -161,6 +171,39 @@ func (s *Session) auth(o Config) { s.err = authSASL(s.streamLogger, s.decoder, s.Features, o.parsedJid.Node, o.Password) } +// Attempt to resume session using stream management +func (s *Session) resume(o Config) bool { + if !s.Features.DoesStreamManagement() { + return false + } + if s.SMState.Id == "" { + return false + } + + fmt.Fprintf(s.streamLogger, "", + stanza.NSStreamManagement, s.SMState.Inbound, s.SMState.Id) + + var packet stanza.Packet + packet, s.err = stanza.NextPacket(s.decoder) + if s.err == nil { + switch p := packet.(type) { + case stanza.SMResumed: + if p.PrevId != s.SMState.Id { + s.err = errors.New("session resumption: mismatched id") + s.SMState = SMState{} + return false + } + return true + case stanza.SMFailed: + fmt.Println("MREMOND SM Failed") + default: + s.err = errors.New("unexpected reply to SM resume") + } + } + s.SMState = SMState{} + return false +} + func (s *Session) bind(o Config) { if s.err != nil { return @@ -208,3 +251,31 @@ func (s *Session) rfc3921Session(o Config) { } } } + +// Enable stream management, with session resumption, if supported. +func (s *Session) EnableStreamManagement(o Config) { + if s.err != nil { + return + } + if !s.Features.DoesStreamManagement() { + return + } + + fmt.Fprintf(s.streamLogger, "", stanza.NSStreamManagement) + + var packet stanza.Packet + packet, s.err = stanza.NextPacket(s.decoder) + if s.err == nil { + switch p := packet.(type) { + case stanza.SMEnabled: + s.SMState = SMState{Id: p.Id} + case stanza.SMFailed: + // TODO: Store error in SMState + default: + fmt.Println(p) + s.err = errors.New("unexpected reply to SM enable") + } + } + + return +} diff --git a/stanza/parser.go b/stanza/parser.go index c83e17e..cdd8b70 100644 --- a/stanza/parser.go +++ b/stanza/parser.go @@ -63,6 +63,8 @@ func NextPacket(p *xml.Decoder) (Packet, error) { return decodeClient(p, se) case NSComponent: return decodeComponent(p, se) + case NSStreamManagement: + return sm.decode(p, se) default: return nil, errors.New("unknown namespace " + se.Name.Space + " <" + se.Name.Local + "/>") @@ -133,7 +135,7 @@ func decodeClient(p *xml.Decoder, se xml.StartElement) (Packet, error) { } } -// decodeClient decodes all known packets in the component namespace. +// decodeComponent decodes all known packets in the component namespace. func decodeComponent(p *xml.Decoder, se xml.StartElement) (Packet, error) { switch se.Name.Local { case "handshake": // handshake is used to authenticate components diff --git a/stanza/stream_management.go b/stanza/stream_management.go new file mode 100644 index 0000000..ddbe9cd --- /dev/null +++ b/stanza/stream_management.go @@ -0,0 +1,121 @@ +package stanza + +import ( + "encoding/xml" + "errors" +) + +const ( + NSStreamManagement = "urn:xmpp:sm:3" +) + +// Enabled as defined in Stream Management spec +// Reference: https://xmpp.org/extensions/xep-0198.html#enable +type SMEnabled struct { + XMLName xml.Name `xml:"urn:xmpp:sm:3 enabled"` + Id string `xml:"id,attr,omitempty"` + Location string `xml:"location,attr,omitempty"` + Resume string `xml:"resume,attr,omitempty"` + Max uint `xml:"max,attr,omitempty"` +} + +func (SMEnabled) Name() string { + return "Stream Management: enabled" +} + +// Request as defined in Stream Management spec +// Reference: https://xmpp.org/extensions/xep-0198.html#acking +type SMRequest struct { + XMLName xml.Name `xml:"urn:xmpp:sm:3 r"` +} + +func (SMRequest) Name() string { + return "Stream Management: request" +} + +// Answer as defined in Stream Management spec +// Reference: https://xmpp.org/extensions/xep-0198.html#acking +type SMAnswer struct { + XMLName xml.Name `xml:"urn:xmpp:sm:3 a"` + H uint `xml:"h,attr,omitempty"` +} + +func (SMAnswer) Name() string { + return "Stream Management: answer" +} + +// Resumed as defined in Stream Management spec +// Reference: https://xmpp.org/extensions/xep-0198.html#acking +type SMResumed struct { + XMLName xml.Name `xml:"urn:xmpp:sm:3 resumed"` + PrevId string `xml:"previd,attr,omitempty"` + H uint `xml:"h,attr,omitempty"` +} + +func (SMResumed) Name() string { + return "Stream Management: resumed" +} + +// Failed as defined in Stream Management spec +// Reference: https://xmpp.org/extensions/xep-0198.html#acking +type SMFailed struct { + XMLName xml.Name `xml:"urn:xmpp:sm:3 failed"` + // TODO: Handle decoding error cause (need custom parsing). +} + +func (SMFailed) Name() string { + return "Stream Management: failed" +} + +type smDecoder struct{} + +var sm smDecoder + +// decode decodes all known nonza in the stream management namespace. +func (s smDecoder) decode(p *xml.Decoder, se xml.StartElement) (Packet, error) { + switch se.Name.Local { + case "enabled": + return s.decodeEnabled(p, se) + case "resumed": + return s.decodeResumed(p, se) + case "r": + return s.decodeRequest(p, se) + case "h": + return s.decodeAnswer(p, se) + case "failed": + return s.decodeFailed(p, se) + default: + return nil, errors.New("unexpected XMPP packet " + + se.Name.Space + " <" + se.Name.Local + "/>") + } +} + +func (smDecoder) decodeEnabled(p *xml.Decoder, se xml.StartElement) (SMEnabled, error) { + var packet SMEnabled + err := p.DecodeElement(&packet, &se) + return packet, err +} + +func (smDecoder) decodeResumed(p *xml.Decoder, se xml.StartElement) (SMResumed, error) { + var packet SMResumed + err := p.DecodeElement(&packet, &se) + return packet, err +} + +func (smDecoder) decodeRequest(p *xml.Decoder, se xml.StartElement) (SMRequest, error) { + var packet SMRequest + err := p.DecodeElement(&packet, &se) + return packet, err +} + +func (smDecoder) decodeAnswer(p *xml.Decoder, se xml.StartElement) (SMAnswer, error) { + var packet SMAnswer + err := p.DecodeElement(&packet, &se) + return packet, err +} + +func (smDecoder) decodeFailed(p *xml.Decoder, se xml.StartElement) (SMFailed, error) { + var packet SMFailed + err := p.DecodeElement(&packet, &se) + return packet, err +} diff --git a/stream_manager.go b/stream_manager.go index 1aaf164..b81a783 100644 --- a/stream_manager.go +++ b/stream_manager.go @@ -24,6 +24,7 @@ import ( // set callback and trigger reconnection. type StreamClient interface { Connect() error + Resume(state SMState) error Send(packet stanza.Packet) error SendRaw(packet string) error Disconnect() @@ -78,7 +79,7 @@ func (sm *StreamManager) Run() error { sm.Metrics.setLoginTime() case StateDisconnected: // Reconnect on disconnection - sm.connect() + sm.resume(e.SMState) case StateStreamError: sm.client.Disconnect() // Only try reconnecting if we have not been kicked by another session to avoid connection loop. @@ -106,8 +107,13 @@ func (sm *StreamManager) Stop() { sm.wg.Done() } -// connect manages the reconnection loop and apply the define backoff to avoid overloading the server. func (sm *StreamManager) connect() error { + var state SMState + return sm.resume(state) +} + +// resume manages the reconnection loop and apply the define backoff to avoid overloading the server. +func (sm *StreamManager) resume(state SMState) error { var backoff backoff // TODO: Group backoff calculation features with connection manager? for { @@ -115,7 +121,7 @@ func (sm *StreamManager) connect() error { // TODO: Make it possible to define logger to log disconnect and reconnection attempts sm.Metrics = initMetrics() - if err = sm.client.Connect(); err != nil { + if err = sm.client.Resume(state); err != nil { var actualErr ConnError if xerrors.As(err, &actualErr) { if actualErr.Permanent {