Add initial support for stream management

For now it support enabling SM, replying to ack requests from server,
and trying resuming the session with existing Stream Management state.
This commit is contained in:
Mickael Remond 2019-07-31 18:47:30 +02:00 committed by Mickaël Rémond
parent e531370dc9
commit 3de99e0e0e
6 changed files with 261 additions and 15 deletions

View file

@ -31,6 +31,18 @@ type Event struct {
State ConnState State ConnState
Description string Description string
StreamError string StreamError string
SMState SMState
}
// SMState holds Stream Management information regarding the session that can be
// used to resume session after disconnect
type SMState struct {
// Stream Management ID
Id string
// Inbound stanza count
Inbound uint
// TODO Store location for IP affinity
// TODO Store max and timestamp, to check if we should retry resumption or not
} }
// EventHandler is use to pass events about state of the connection to // EventHandler is use to pass events about state of the connection to
@ -52,6 +64,13 @@ func (em EventManager) updateState(state ConnState) {
} }
} }
func (em EventManager) disconnected(state SMState) {
em.CurrentState = StateDisconnected
if em.Handler != nil {
em.Handler(Event{State: em.CurrentState, SMState: state})
}
}
func (em EventManager) streamError(error, desc string) { func (em EventManager) streamError(error, desc string) {
em.CurrentState = StateStreamError em.CurrentState = StateStreamError
if em.Handler != nil { if em.Handler != nil {
@ -128,7 +147,15 @@ func NewClient(config Config, r *Router) (c *Client, err error) {
} }
// Connect triggers actual TCP connection, based on previously defined parameters. // Connect triggers actual TCP connection, based on previously defined parameters.
// Connect simply triggers resumption, with an empty session state.
func (c *Client) Connect() error { func (c *Client) Connect() error {
var state SMState
return c.Resume(state)
}
// Resume attempts resuming a Stream Managed session, based on the provided stream management
// state.
func (c *Client) Resume(state SMState) error {
var err error var err error
c.conn, err = net.DialTimeout("tcp", c.config.Address, time.Duration(c.config.ConnectTimeout)*time.Second) c.conn, err = net.DialTimeout("tcp", c.config.Address, time.Duration(c.config.ConnectTimeout)*time.Second)
@ -138,23 +165,24 @@ func (c *Client) Connect() error {
c.updateState(StateConnected) c.updateState(StateConnected)
// Client is ok, we now open XMPP session // Client is ok, we now open XMPP session
if c.conn, c.Session, err = NewSession(c.conn, c.config); err != nil { if c.conn, c.Session, err = NewSession(c.conn, c.config, state); err != nil {
return err return err
} }
c.updateState(StateSessionEstablished) c.updateState(StateSessionEstablished)
// Start the keepalive go routine
keepaliveQuit := make(chan struct{})
go keepalive(c.conn, keepaliveQuit)
// Start the receiver go routine
state = c.Session.SMState
go c.recv(state, keepaliveQuit)
// We're connected and can now receive and send messages. // We're connected and can now receive and send messages.
//fmt.Fprintf(client.conn, "<presence xml:lang='en'><show>%s</show><status>%s</status></presence>", "chat", "Online") //fmt.Fprintf(client.conn, "<presence xml:lang='en'><show>%s</show><status>%s</status></presence>", "chat", "Online")
// TODO: Do we always want to send initial presence automatically ? // TODO: Do we always want to send initial presence automatically ?
// Do we need an option to avoid that or do we rely on client to send the presence itself ? // Do we need an option to avoid that or do we rely on client to send the presence itself ?
fmt.Fprintf(c.Session.streamLogger, "<presence/>") fmt.Fprintf(c.Session.streamLogger, "<presence/>")
// Start the keepalive go routine
keepaliveQuit := make(chan struct{})
go keepalive(c.conn, keepaliveQuit)
// Start the receiver go routine
go c.recv(keepaliveQuit)
return err return err
} }
@ -206,12 +234,12 @@ func (c *Client) sendWithLogger(packet string) error {
// Go routines // Go routines
// Loop: Receive data from server // Loop: Receive data from server
func (c *Client) recv(keepaliveQuit chan<- struct{}) (err error) { func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) (err error) {
for { for {
val, err := stanza.NextPacket(c.Session.decoder) val, err := stanza.NextPacket(c.Session.decoder)
if err != nil { if err != nil {
close(keepaliveQuit) close(keepaliveQuit)
c.updateState(StateDisconnected) c.disconnected(state)
return err return err
} }
@ -222,6 +250,17 @@ func (c *Client) recv(keepaliveQuit chan<- struct{}) (err error) {
close(keepaliveQuit) close(keepaliveQuit)
c.streamError(packet.Error.Local, packet.Text) c.streamError(packet.Error.Local, packet.Text)
return errors.New("stream error: " + packet.Error.Local) return errors.New("stream error: " + packet.Error.Local)
// Process Stream management nonzas
case stanza.SMRequest:
fmt.Println("MREMOND: inbound: ", state.Inbound)
answer := stanza.SMAnswer{XMLName: xml.Name{
Space: stanza.NSStreamManagement,
Local: "a",
}, H: state.Inbound}
c.Send(answer)
default:
fmt.Println(packet)
state.Inbound++
} }
c.router.route(c, val) c.router.route(c, val)
@ -243,6 +282,9 @@ func keepalive(conn net.Conn, quit <-chan struct{}) {
_ = conn.Close() _ = conn.Close()
return return
} }
case <-time.After(3 * time.Second):
_ = conn.Close()
return
case <-quit: case <-quit:
ticker.Stop() ticker.Stop()
return return

View file

@ -108,6 +108,10 @@ func (c *Component) Connect() error {
} }
} }
func (c *Component) Resume() error {
return errors.New("components do not support stream management")
}
func (c *Component) Disconnect() { func (c *Component) Disconnect() {
_ = c.SendRaw("</stream:stream>") _ = c.SendRaw("</stream:stream>")
// TODO: Add a way to wait for stream close acknowledgement from the server for clean disconnect // TODO: Add a way to wait for stream close acknowledgement from the server for clean disconnect

View file

@ -17,6 +17,7 @@ type Session struct {
// Session info // Session info
BindJid string // Jabber ID as provided by XMPP server BindJid string // Jabber ID as provided by XMPP server
StreamId string StreamId string
SMState SMState
Features stanza.StreamFeatures Features stanza.StreamFeatures
TlsEnabled bool TlsEnabled bool
lastPacketId int lastPacketId int
@ -29,8 +30,9 @@ type Session struct {
err error err error
} }
func NewSession(conn net.Conn, o Config) (net.Conn, *Session, error) { func NewSession(conn net.Conn, o Config, state SMState) (net.Conn, *Session, error) {
s := new(Session) s := new(Session)
s.SMState = state
s.init(conn, o) s.init(conn, o)
// starttls // starttls
@ -54,10 +56,18 @@ func NewSession(conn net.Conn, o Config) (net.Conn, *Session, error) {
s.auth(o) s.auth(o)
s.reset(tlsConn, tlsConn, o) s.reset(tlsConn, tlsConn, o)
// bind resource and 'start' XMPP session // attempt resumption
if s.resume(o) {
return tlsConn, s, s.err
}
// otherwise, bind resource and 'start' XMPP session
s.bind(o) s.bind(o)
s.rfc3921Session(o) s.rfc3921Session(o)
// Enable stream management if supported
s.EnableStreamManagement(o)
return tlsConn, s, s.err return tlsConn, s, s.err
} }
@ -161,6 +171,39 @@ func (s *Session) auth(o Config) {
s.err = authSASL(s.streamLogger, s.decoder, s.Features, o.parsedJid.Node, o.Password) s.err = authSASL(s.streamLogger, s.decoder, s.Features, o.parsedJid.Node, o.Password)
} }
// Attempt to resume session using stream management
func (s *Session) resume(o Config) bool {
if !s.Features.DoesStreamManagement() {
return false
}
if s.SMState.Id == "" {
return false
}
fmt.Fprintf(s.streamLogger, "<resume xmlns='%s' h='%d' previd='%s'/>",
stanza.NSStreamManagement, s.SMState.Inbound, s.SMState.Id)
var packet stanza.Packet
packet, s.err = stanza.NextPacket(s.decoder)
if s.err == nil {
switch p := packet.(type) {
case stanza.SMResumed:
if p.PrevId != s.SMState.Id {
s.err = errors.New("session resumption: mismatched id")
s.SMState = SMState{}
return false
}
return true
case stanza.SMFailed:
fmt.Println("MREMOND SM Failed")
default:
s.err = errors.New("unexpected reply to SM resume")
}
}
s.SMState = SMState{}
return false
}
func (s *Session) bind(o Config) { func (s *Session) bind(o Config) {
if s.err != nil { if s.err != nil {
return return
@ -208,3 +251,31 @@ func (s *Session) rfc3921Session(o Config) {
} }
} }
} }
// Enable stream management, with session resumption, if supported.
func (s *Session) EnableStreamManagement(o Config) {
if s.err != nil {
return
}
if !s.Features.DoesStreamManagement() {
return
}
fmt.Fprintf(s.streamLogger, "<enable xmlns='%s' resume='true'/>", stanza.NSStreamManagement)
var packet stanza.Packet
packet, s.err = stanza.NextPacket(s.decoder)
if s.err == nil {
switch p := packet.(type) {
case stanza.SMEnabled:
s.SMState = SMState{Id: p.Id}
case stanza.SMFailed:
// TODO: Store error in SMState
default:
fmt.Println(p)
s.err = errors.New("unexpected reply to SM enable")
}
}
return
}

View file

@ -63,6 +63,8 @@ func NextPacket(p *xml.Decoder) (Packet, error) {
return decodeClient(p, se) return decodeClient(p, se)
case NSComponent: case NSComponent:
return decodeComponent(p, se) return decodeComponent(p, se)
case NSStreamManagement:
return sm.decode(p, se)
default: default:
return nil, errors.New("unknown namespace " + return nil, errors.New("unknown namespace " +
se.Name.Space + " <" + se.Name.Local + "/>") se.Name.Space + " <" + se.Name.Local + "/>")
@ -133,7 +135,7 @@ func decodeClient(p *xml.Decoder, se xml.StartElement) (Packet, error) {
} }
} }
// decodeClient decodes all known packets in the component namespace. // decodeComponent decodes all known packets in the component namespace.
func decodeComponent(p *xml.Decoder, se xml.StartElement) (Packet, error) { func decodeComponent(p *xml.Decoder, se xml.StartElement) (Packet, error) {
switch se.Name.Local { switch se.Name.Local {
case "handshake": // handshake is used to authenticate components case "handshake": // handshake is used to authenticate components

121
stanza/stream_management.go Normal file
View file

@ -0,0 +1,121 @@
package stanza
import (
"encoding/xml"
"errors"
)
const (
NSStreamManagement = "urn:xmpp:sm:3"
)
// Enabled as defined in Stream Management spec
// Reference: https://xmpp.org/extensions/xep-0198.html#enable
type SMEnabled struct {
XMLName xml.Name `xml:"urn:xmpp:sm:3 enabled"`
Id string `xml:"id,attr,omitempty"`
Location string `xml:"location,attr,omitempty"`
Resume string `xml:"resume,attr,omitempty"`
Max uint `xml:"max,attr,omitempty"`
}
func (SMEnabled) Name() string {
return "Stream Management: enabled"
}
// Request as defined in Stream Management spec
// Reference: https://xmpp.org/extensions/xep-0198.html#acking
type SMRequest struct {
XMLName xml.Name `xml:"urn:xmpp:sm:3 r"`
}
func (SMRequest) Name() string {
return "Stream Management: request"
}
// Answer as defined in Stream Management spec
// Reference: https://xmpp.org/extensions/xep-0198.html#acking
type SMAnswer struct {
XMLName xml.Name `xml:"urn:xmpp:sm:3 a"`
H uint `xml:"h,attr,omitempty"`
}
func (SMAnswer) Name() string {
return "Stream Management: answer"
}
// Resumed as defined in Stream Management spec
// Reference: https://xmpp.org/extensions/xep-0198.html#acking
type SMResumed struct {
XMLName xml.Name `xml:"urn:xmpp:sm:3 resumed"`
PrevId string `xml:"previd,attr,omitempty"`
H uint `xml:"h,attr,omitempty"`
}
func (SMResumed) Name() string {
return "Stream Management: resumed"
}
// Failed as defined in Stream Management spec
// Reference: https://xmpp.org/extensions/xep-0198.html#acking
type SMFailed struct {
XMLName xml.Name `xml:"urn:xmpp:sm:3 failed"`
// TODO: Handle decoding error cause (need custom parsing).
}
func (SMFailed) Name() string {
return "Stream Management: failed"
}
type smDecoder struct{}
var sm smDecoder
// decode decodes all known nonza in the stream management namespace.
func (s smDecoder) decode(p *xml.Decoder, se xml.StartElement) (Packet, error) {
switch se.Name.Local {
case "enabled":
return s.decodeEnabled(p, se)
case "resumed":
return s.decodeResumed(p, se)
case "r":
return s.decodeRequest(p, se)
case "h":
return s.decodeAnswer(p, se)
case "failed":
return s.decodeFailed(p, se)
default:
return nil, errors.New("unexpected XMPP packet " +
se.Name.Space + " <" + se.Name.Local + "/>")
}
}
func (smDecoder) decodeEnabled(p *xml.Decoder, se xml.StartElement) (SMEnabled, error) {
var packet SMEnabled
err := p.DecodeElement(&packet, &se)
return packet, err
}
func (smDecoder) decodeResumed(p *xml.Decoder, se xml.StartElement) (SMResumed, error) {
var packet SMResumed
err := p.DecodeElement(&packet, &se)
return packet, err
}
func (smDecoder) decodeRequest(p *xml.Decoder, se xml.StartElement) (SMRequest, error) {
var packet SMRequest
err := p.DecodeElement(&packet, &se)
return packet, err
}
func (smDecoder) decodeAnswer(p *xml.Decoder, se xml.StartElement) (SMAnswer, error) {
var packet SMAnswer
err := p.DecodeElement(&packet, &se)
return packet, err
}
func (smDecoder) decodeFailed(p *xml.Decoder, se xml.StartElement) (SMFailed, error) {
var packet SMFailed
err := p.DecodeElement(&packet, &se)
return packet, err
}

View file

@ -24,6 +24,7 @@ import (
// set callback and trigger reconnection. // set callback and trigger reconnection.
type StreamClient interface { type StreamClient interface {
Connect() error Connect() error
Resume(state SMState) error
Send(packet stanza.Packet) error Send(packet stanza.Packet) error
SendRaw(packet string) error SendRaw(packet string) error
Disconnect() Disconnect()
@ -78,7 +79,7 @@ func (sm *StreamManager) Run() error {
sm.Metrics.setLoginTime() sm.Metrics.setLoginTime()
case StateDisconnected: case StateDisconnected:
// Reconnect on disconnection // Reconnect on disconnection
sm.connect() sm.resume(e.SMState)
case StateStreamError: case StateStreamError:
sm.client.Disconnect() sm.client.Disconnect()
// Only try reconnecting if we have not been kicked by another session to avoid connection loop. // Only try reconnecting if we have not been kicked by another session to avoid connection loop.
@ -106,8 +107,13 @@ func (sm *StreamManager) Stop() {
sm.wg.Done() sm.wg.Done()
} }
// connect manages the reconnection loop and apply the define backoff to avoid overloading the server.
func (sm *StreamManager) connect() error { func (sm *StreamManager) connect() error {
var state SMState
return sm.resume(state)
}
// resume manages the reconnection loop and apply the define backoff to avoid overloading the server.
func (sm *StreamManager) resume(state SMState) error {
var backoff backoff // TODO: Group backoff calculation features with connection manager? var backoff backoff // TODO: Group backoff calculation features with connection manager?
for { for {
@ -115,7 +121,7 @@ func (sm *StreamManager) connect() error {
// TODO: Make it possible to define logger to log disconnect and reconnection attempts // TODO: Make it possible to define logger to log disconnect and reconnection attempts
sm.Metrics = initMetrics() sm.Metrics = initMetrics()
if err = sm.client.Connect(); err != nil { if err = sm.client.Resume(state); err != nil {
var actualErr ConnError var actualErr ConnError
if xerrors.As(err, &actualErr) { if xerrors.As(err, &actualErr) {
if actualErr.Permanent { if actualErr.Permanent {