From 94aceac802e51f58710300bd94d33566bb3b236f Mon Sep 17 00:00:00 2001 From: remicorniere Date: Thu, 26 Dec 2019 13:47:02 +0000 Subject: [PATCH] Changed "Disconnect" to wait for the closing stream tag. (#141) Updated example with a README.md and fixed some logs. --- _examples/xmpp_chat_client/README.md | 51 ++++++++++++++ _examples/xmpp_chat_client/config.yml | 4 +- _examples/xmpp_chat_client/interface.go | 17 +++-- .../xmpp_chat_client/xmpp_chat_client.go | 10 ++- client.go | 40 ++++++++--- client_test.go | 40 +++++++++-- component.go | 12 +++- stanza/parser.go | 66 +++++++++++++++---- stanza/stream_features.go | 18 +++++ stream_manager.go | 2 +- transport.go | 3 + websocket_transport.go | 8 ++- xmpp_transport.go | 33 +++++++--- 13 files changed, 252 insertions(+), 52 deletions(-) create mode 100644 _examples/xmpp_chat_client/README.md diff --git a/_examples/xmpp_chat_client/README.md b/_examples/xmpp_chat_client/README.md new file mode 100644 index 0000000..d4db360 --- /dev/null +++ b/_examples/xmpp_chat_client/README.md @@ -0,0 +1,51 @@ +# Chat TUI example +This is a simple chat example, with a TUI. +It shows the library usage and a few of its capabilities. +## How to run +### Build +You can build the client using : +``` + go build -o example_client +``` +and then run with (on unix for example): +``` + ./example_client +``` +or you can simply build + run in one command while at the example directory root, like this: +``` + go run xmpp_chat_client.go interface.go +``` + +### Configuration +The example needs a configuration file to run. A sample file is provided. +By default, the example will look for a file named "config" in the current directory. +To provide a different configuration file, pass the following argument to the example : +``` + go run xmpp_chat_client.go interface.go -c /path/to/config +``` +where /path/to/config is the path to the directory containing the configuration file. The configuration file must be named +"config" and be using the yaml format. + +Required fields are : +```yaml +Server : + - full_address: "localhost:5222" +Client : # This is you + - jid: "testuser2@localhost" + - pass: "pass123" #Password in a config file yay + +# Contacts list, ";" separated +Contacts : "testuser1@localhost;testuser3@localhost" +# Should we log stanzas ? +LogStanzas: + - logger_on: "true" + - logfile_path: "./logs" # Path to directory, not file. +``` + +## How to use +Shortcuts : + - ctrl+space : switch between input window and menu window. + - While in input window : + - enter : sends a message if in message mode (see menu options) + - ctrl+e : sends a raw stanza when in raw mode (see menu options) + - ctrl+c : quit \ No newline at end of file diff --git a/_examples/xmpp_chat_client/config.yml b/_examples/xmpp_chat_client/config.yml index ed6e902..6c6498b 100644 --- a/_examples/xmpp_chat_client/config.yml +++ b/_examples/xmpp_chat_client/config.yml @@ -1,9 +1,7 @@ -# Default config for the client +# Sample config for the client Server : - full_address: "localhost:5222" - - port: 5222 Client : - - name: "testuser2" - jid: "testuser2@localhost" - pass: "pass123" #Password in a config file yay diff --git a/_examples/xmpp_chat_client/interface.go b/_examples/xmpp_chat_client/interface.go index 0919709..0c05edd 100644 --- a/_examples/xmpp_chat_client/interface.go +++ b/_examples/xmpp_chat_client/interface.go @@ -17,6 +17,13 @@ const ( menuWindow = "mw" // Where the menu is shown disconnectMsg = "msg" + // Windows titles + chatLogWindowTitle = "Chat log" + menuWindowTitle = "Menu" + chatInputWindowTitle = "Write a message :" + rawInputWindowTitle = "Write or paste a raw stanza. Press \"Ctrl+E\" to send :" + contactsListWindowTitle = "Contacts" + // Menu options disconnect = "Disconnect" askServerForRoster = "Ask server for roster" @@ -60,7 +67,7 @@ func layout(g *gocui.Gui) error { if !gocui.IsUnknownView(err) { return err } - v.Title = "Chat log" + v.Title = chatLogWindowTitle v.Wrap = true v.Autoscroll = true } @@ -69,7 +76,7 @@ func layout(g *gocui.Gui) error { if !gocui.IsUnknownView(err) { return err } - v.Title = "Contacts" + v.Title = contactsListWindowTitle v.Wrap = true // If we set this to true, the contacts list will "fit" in the window but if the number // of contacts exceeds the maximum height, some contacts will be hidden... @@ -82,7 +89,7 @@ func layout(g *gocui.Gui) error { if !gocui.IsUnknownView(err) { return err } - v.Title = "Menu" + v.Title = menuWindowTitle v.Wrap = true v.Autoscroll = true fmt.Fprint(v, strings.Join(menuOptions, "\n")) @@ -95,7 +102,7 @@ func layout(g *gocui.Gui) error { if !gocui.IsUnknownView(err) { return err } - v.Title = "Write or paste a raw stanza. Press \"Ctrl+E\" to send :" + v.Title = rawInputWindowTitle v.Editable = true v.Wrap = true } @@ -104,7 +111,7 @@ func layout(g *gocui.Gui) error { if !gocui.IsUnknownView(err) { return err } - v.Title = "Write a message :" + v.Title = chatInputWindowTitle v.Editable = true v.Wrap = true diff --git a/_examples/xmpp_chat_client/xmpp_chat_client.go b/_examples/xmpp_chat_client/xmpp_chat_client.go index b13e398..0d4b94b 100644 --- a/_examples/xmpp_chat_client/xmpp_chat_client.go +++ b/_examples/xmpp_chat_client/xmpp_chat_client.go @@ -63,7 +63,7 @@ func main() { // ============================================================ // Parse the flag with the config directory path as argument flag.String("c", defaultConfigFilePath, "Provide a path to the directory that contains the configuration"+ - " file you want to use. Config file should be named \"config\" and be of YAML format..") + " file you want to use. Config file should be named \"config\" and be in YAML format..") pflag.CommandLine.AddGoFlagSet(flag.CommandLine) pflag.Parse() @@ -139,7 +139,8 @@ func startClient(g *gocui.Gui, config *config) { handlerWithGui := func(_ xmpp.Sender, p stanza.Packet) { msg, ok := p.(stanza.Message) if logger != nil { - logger.Println(msg) + m, _ := xml.Marshal(msg) + logger.Println(string(m)) } v, err := g.View(chatLogWindow) @@ -209,7 +210,7 @@ func startMessaging(client xmpp.Sender, config *config, g *gocui.Gui) { } return case text = <-textChan: - reply := stanza.Message{Attrs: stanza.Attrs{To: correspondent, From: config.Client[clientJid], Type: stanza.MessageTypeChat}, Body: text} + reply := stanza.Message{Attrs: stanza.Attrs{To: correspondent, Type: stanza.MessageTypeChat}, Body: text} if logger != nil { raw, _ := xml.Marshal(reply) logger.Println(string(raw)) @@ -284,6 +285,8 @@ func errorHandler(err error) { // If user tries to send a message to someone not registered with the server, the server will return an error. func updateRosterFromConfig(g *gocui.Gui, config *config) { viewState.contacts = append(strings.Split(config.Contacts, configContactSep), backFromContacts) + // Put a "go back" button at the end of the list + viewState.contacts = append(viewState.contacts, backFromContacts) } // Updates the menu panel of the view with the current user's roster, by asking the server. @@ -318,6 +321,7 @@ func askForRoster(client xmpp.Sender, g *gocui.Gui, config *config) { for _, item := range rosterItems.Items { viewState.contacts = append(viewState.contacts, item.Jid) } + // Put a "go back" button at the end of the list viewState.contacts = append(viewState.contacts, backFromContacts) fmt.Fprintln(chlw, infoFormat+"Contacts list updated !") return diff --git a/client.go b/client.go index 30c9d7e..be15540 100644 --- a/client.go +++ b/client.go @@ -154,7 +154,8 @@ func NewClient(config Config, r *Router, errorHandler func(error)) (c *Client, e if config.TransportConfiguration.Domain == "" { config.TransportConfiguration.Domain = config.parsedJid.Domain } - c.transport = NewClientTransport(config.TransportConfiguration) + c.config.TransportConfiguration.ConnectTimeout = c.config.ConnectTimeout + c.transport = NewClientTransport(c.config.TransportConfiguration) if config.StreamLogger != nil { c.transport.LogTraffic(config.StreamLogger) @@ -183,7 +184,24 @@ func (c *Client) Resume(state SMState) error { // Client is ok, we now open XMPP session if c.Session, err = NewSession(c.transport, c.config, state); err != nil { - c.transport.Close() + // Try to get the stream close tag from the server. + go func() { + for { + val, err := stanza.NextPacket(c.transport.GetDecoder()) + if err != nil { + c.ErrorHandler(err) + c.disconnected(state) + return + } + switch val.(type) { + case stanza.StreamClosePacket: + // TCP messages should arrive in order, so we can expect to get nothing more after this occurs + c.transport.ReceivedStreamClose() + return + } + } + }() + c.Disconnect() return err } c.Session.StreamId = streamId @@ -205,15 +223,12 @@ func (c *Client) Resume(state SMState) error { return err } -func (c *Client) Disconnect() { - // TODO : Wait for server response for clean disconnect - presence := stanza.NewPresence(stanza.Attrs{From: c.config.Jid}) - presence.Type = stanza.PresenceTypeUnavailable - c.Send(presence) - c.SendRaw(stanza.StreamClose) +func (c *Client) Disconnect() error { if c.transport != nil { - _ = c.transport.Close() + return c.transport.Close() } + // No transport so no connection. + return nil } func (c *Client) SetHandler(handler EventHandler) { @@ -294,7 +309,8 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) { close(keepaliveQuit) c.streamError(packet.Error.Local, packet.Text) c.ErrorHandler(errors.New("stream error: " + packet.Error.Local)) - return + // We don't return here, because we want to wait for the stream close tag from the server, or timeout. + c.Disconnect() // Process Stream management nonzas case stanza.SMRequest: answer := stanza.SMAnswer{XMLName: xml.Name{ @@ -306,6 +322,10 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) { c.ErrorHandler(err) return } + case stanza.StreamClosePacket: + // TCP messages should arrive in order, so we can expect to get nothing more after this occurs + c.transport.ReceivedStreamClose() + return default: state.Inbound++ } diff --git a/client_test.go b/client_test.go index 8d109d0..f455fdf 100644 --- a/client_test.go +++ b/client_test.go @@ -67,7 +67,10 @@ func TestClient_Connect(t *testing.T) { func TestClient_NoInsecure(t *testing.T) { // Setup Mock server mock := ServerMock{} - mock.Start(t, testXMPPAddress, handlerAbortTLS) + mock.Start(t, testXMPPAddress, func(t *testing.T, sc *ServerConn) { + handlerAbortTLS(t, sc) + closeConn(t, sc) + }) // Test / Check result config := Config{ @@ -97,7 +100,10 @@ func TestClient_NoInsecure(t *testing.T) { func TestClient_FeaturesTracking(t *testing.T) { // Setup Mock server mock := ServerMock{} - mock.Start(t, testXMPPAddress, handlerAbortTLS) + mock.Start(t, testXMPPAddress, func(t *testing.T, sc *ServerConn) { + handlerAbortTLS(t, sc) + closeConn(t, sc) + }) // Test / Check result config := Config{ @@ -247,6 +253,7 @@ func TestClient_SendRaw(t *testing.T) { handlerClientConnectSuccess(t, sc) discardPresence(t, sc) respondToIQ(t, sc) + closeConn(t, sc) done <- struct{}{} } type testCase struct { @@ -290,6 +297,7 @@ func TestClient_SendRaw(t *testing.T) { select { // We don't use the default "long" timeout here because waiting it out means passing the test. case <-time.After(100 * time.Millisecond): + c.Disconnect() case err = <-errChan: if err == nil && tcase.shouldErr { t.Errorf("Failed to get closing stream err") @@ -297,7 +305,6 @@ func TestClient_SendRaw(t *testing.T) { t.Errorf("This test is not supposed to err !") } } - c.transport.Close() select { case <-done: m.Stop() @@ -309,7 +316,10 @@ func TestClient_SendRaw(t *testing.T) { } func TestClient_Disconnect(t *testing.T) { - c, m := mockClientConnection(t, handlerClientConnectSuccess, testClientBasePort) + c, m := mockClientConnection(t, func(t *testing.T, sc *ServerConn) { + handlerClientConnectSuccess(t, sc) + closeConn(t, sc) + }, testClientBasePort) err := c.transport.Ping() if err != nil { t.Errorf("Could not ping but not disconnected yet") @@ -326,7 +336,10 @@ func TestClient_DisconnectStreamManager(t *testing.T) { // Init mock server // Setup Mock server mock := ServerMock{} - mock.Start(t, testXMPPAddress, handlerAbortTLS) + mock.Start(t, testXMPPAddress, func(t *testing.T, sc *ServerConn) { + handlerAbortTLS(t, sc) + closeConn(t, sc) + }) // Test / Check result config := Config{ @@ -375,6 +388,23 @@ func handlerClientConnectSuccess(t *testing.T, sc *ServerConn) { bind(t, sc) } +// closeConn closes the connection on request from the client +func closeConn(t *testing.T, sc *ServerConn) { + for { + cls, err := stanza.NextPacket(sc.decoder) + if err != nil { + t.Errorf("cannot read from socket: %s", err) + return + } + switch cls.(type) { + case stanza.StreamClosePacket: + fmt.Fprintf(sc.connection, stanza.StreamClose) + return + } + } + +} + // We expect client will abort on TLS func handlerAbortTLS(t *testing.T, sc *ServerConn) { checkClientOpenStream(t, sc) diff --git a/component.go b/component.go index 828ba07..bd85aa2 100644 --- a/component.go +++ b/component.go @@ -113,11 +113,13 @@ func (c *Component) Resume(sm SMState) error { } } -func (c *Component) Disconnect() { +func (c *Component) Disconnect() error { // TODO: Add a way to wait for stream close acknowledgement from the server for clean disconnect if c.transport != nil { - _ = c.transport.Close() + return c.transport.Close() } + // No transport so no connection. + return nil } func (c *Component) SetHandler(handler EventHandler) { @@ -126,7 +128,6 @@ func (c *Component) SetHandler(handler EventHandler) { // Receiver Go routine receiver func (c *Component) recv() { - for { val, err := stanza.NextPacket(c.transport.GetDecoder()) if err != nil { @@ -140,6 +141,11 @@ func (c *Component) recv() { c.router.route(c, val) c.streamError(p.Error.Local, p.Text) c.ErrorHandler(errors.New("stream error: " + p.Error.Local)) + // We don't return here, because we want to wait for the stream close tag from the server, or timeout. + c.Disconnect() + case stanza.StreamClosePacket: + // TCP messages should arrive in order, so we can expect to get nothing more after this occurs + c.transport.ReceivedStreamClose() return } c.router.route(c, val) diff --git a/stanza/parser.go b/stanza/parser.go index 75f78e7..b5f11cf 100644 --- a/stanza/parser.go +++ b/stanza/parser.go @@ -50,11 +50,20 @@ func InitStream(p *xml.Decoder) (sessionID string, err error) { // TODO make auth and bind use NextPacket instead of directly NextStart func NextPacket(p *xml.Decoder) (Packet, error) { // Read start element to find out how we want to parse the XMPP packet - se, err := NextStart(p) + t, err := NextXmppToken(p) if err != nil { return nil, err } + if ee, ok := t.(xml.EndElement); ok { + return decodeStream(p, ee) + } + + // If not an end element, then must be a start + se, ok := t.(xml.StartElement) + if !ok { + return nil, errors.New("unknown token ") + } // Decode one of the top level XMPP namespace switch se.Name.Space { case NSStream: @@ -73,7 +82,29 @@ func NextPacket(p *xml.Decoder) (Packet, error) { } } -// Scan XML token stream to find next StartElement. +// NextXmppToken scans XML token stream to find next StartElement or stream EndElement. +// We need the EndElement scan, because we must register stream close tags +func NextXmppToken(p *xml.Decoder) (xml.Token, error) { + for { + t, err := p.Token() + if err == io.EOF { + return xml.StartElement{}, errors.New("connection closed") + } + if err != nil { + return xml.StartElement{}, fmt.Errorf("NextStart %s", err) + } + switch t := t.(type) { + case xml.StartElement: + return t, nil + case xml.EndElement: + if t.Name.Space == NSStream && t.Name.Local == "stream" { + return t, nil + } + } + } +} + +// NextStart scans XML token stream to find next StartElement. func NextStart(p *xml.Decoder) (xml.StartElement, error) { for { t, err := p.Token() @@ -97,16 +128,29 @@ TODO: From all the decoder, we can return a pointer to the actual concrete type, */ // decodeStream will fully decode a stream packet -func decodeStream(p *xml.Decoder, se xml.StartElement) (Packet, error) { - switch se.Name.Local { - case "error": - return streamError.decode(p, se) - case "features": - return streamFeatures.decode(p, se) - default: - return nil, errors.New("unexpected XMPP packet " + - se.Name.Space + " <" + se.Name.Local + "/>") +func decodeStream(p *xml.Decoder, t xml.Token) (Packet, error) { + if se, ok := t.(xml.StartElement); ok { + switch se.Name.Local { + case "error": + return streamError.decode(p, se) + case "features": + return streamFeatures.decode(p, se) + default: + return nil, errors.New("unexpected XMPP packet " + + se.Name.Space + " <" + se.Name.Local + "/>") + } } + + if ee, ok := t.(xml.EndElement); ok { + if ee.Name.Local == "stream" { + return streamClose.decode(ee), nil + } + return nil, errors.New("unexpected XMPP packet " + + ee.Name.Space + " <" + ee.Name.Local + "/>") + } + + // Should not happen + return nil, errors.New("unexpected XML token ") } // decodeSASL decodes a packet related to SASL authentication. diff --git a/stanza/stream_features.go b/stanza/stream_features.go index 14358f0..d5bed5c 100644 --- a/stanza/stream_features.go +++ b/stanza/stream_features.go @@ -165,3 +165,21 @@ func (streamErrorDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamErr err := p.DecodeElement(&packet, &se) return packet, err } + +// ============================================================================ +// StreamClose "Packet" + +// This is just a closing tag and hold no information +type StreamClosePacket struct{} + +func (StreamClosePacket) Name() string { + return "stream:stream" +} + +type streamCloseDecoder struct{} + +var streamClose streamCloseDecoder + +func (streamCloseDecoder) decode(_ xml.EndElement) StreamClosePacket { + return StreamClosePacket{} +} diff --git a/stream_manager.go b/stream_manager.go index aebd8a4..18e1434 100644 --- a/stream_manager.go +++ b/stream_manager.go @@ -29,7 +29,7 @@ type StreamClient interface { Send(packet stanza.Packet) error SendIQ(ctx context.Context, iq stanza.IQ) (chan stanza.IQ, error) SendRaw(packet string) error - Disconnect() + Disconnect() error SetHandler(handler EventHandler) } diff --git a/transport.go b/transport.go index c6134fb..abf7f4a 100644 --- a/transport.go +++ b/transport.go @@ -40,6 +40,9 @@ type Transport interface { Read(p []byte) (n int, err error) Write(p []byte) (n int, err error) Close() error + // ReceivedStreamClose signals to the transport that a has been received and that the tcp connection + // should be closed. + ReceivedStreamClose() } // NewClientTransport creates a new Transport instance for clients. diff --git a/websocket_transport.go b/websocket_transport.go index 69c0183..7631fc8 100644 --- a/websocket_transport.go +++ b/websocket_transport.go @@ -18,7 +18,7 @@ const maxPacketSize = 32768 const pingTimeout = time.Duration(5) * time.Second -var ServerDoesNotSupportXmppOverWebsocket = errors.New("The websocket server does not support the xmpp subprotocol") +var ServerDoesNotSupportXmppOverWebsocket = errors.New("the websocket server does not support the xmpp subprotocol") // The decoder is expected to be initialized after connecting to a server. type WebsocketTransport struct { @@ -47,6 +47,7 @@ func (t *WebsocketTransport) Connect() (string, error) { wsConn, response, err := websocket.Dial(ctx, t.Config.Address, &websocket.DialOptions{ Subprotocols: []string{"xmpp"}, }) + if err != nil { return "", NewConnError(err, true) } @@ -177,3 +178,8 @@ func (t *WebsocketTransport) cleanup(code websocket.StatusCode) error { } return err } + +// ReceivedStreamClose is not used for websockets for now +func (t *WebsocketTransport) ReceivedStreamClose() { + return +} diff --git a/xmpp_transport.go b/xmpp_transport.go index 34b0d3e..6e1209f 100644 --- a/xmpp_transport.go +++ b/xmpp_transport.go @@ -24,6 +24,7 @@ type XMPPTransport struct { readWriter io.ReadWriter logFile io.Writer isSecure bool + closeChan chan stanza.StreamClosePacket } var componentStreamOpen = fmt.Sprintf("", stanza.NSComponent, stanza.NSStream) @@ -38,13 +39,14 @@ func (t *XMPPTransport) Connect() (string, error) { return "", NewConnError(err, true) } + t.closeChan = make(chan stanza.StreamClosePacket) t.readWriter = newStreamLogger(t.conn, t.logFile) t.decoder = xml.NewDecoder(bufio.NewReaderSize(t.readWriter, maxPacketSize)) t.decoder.CharsetReader = t.Config.CharsetReader return t.StartStream() } -func (t XMPPTransport) StartStream() (string, error) { +func (t *XMPPTransport) StartStream() (string, error) { if _, err := fmt.Fprintf(t, t.openStatement, t.Config.Domain); err != nil { t.Close() return "", NewConnError(err, true) @@ -58,19 +60,19 @@ func (t XMPPTransport) StartStream() (string, error) { return sessionID, nil } -func (t XMPPTransport) DoesStartTLS() bool { +func (t *XMPPTransport) DoesStartTLS() bool { return true } -func (t XMPPTransport) GetDomain() string { +func (t *XMPPTransport) GetDomain() string { return t.Config.Domain } -func (t XMPPTransport) GetDecoder() *xml.Decoder { +func (t *XMPPTransport) GetDecoder() *xml.Decoder { return t.decoder } -func (t XMPPTransport) IsSecure() bool { +func (t *XMPPTransport) IsSecure() bool { return t.isSecure } @@ -105,7 +107,7 @@ func (t *XMPPTransport) StartTLS() error { return nil } -func (t XMPPTransport) Ping() error { +func (t *XMPPTransport) Ping() error { n, err := t.conn.Write([]byte("\n")) if err != nil { return err @@ -116,24 +118,31 @@ func (t XMPPTransport) Ping() error { return nil } -func (t XMPPTransport) Read(p []byte) (n int, err error) { +func (t *XMPPTransport) Read(p []byte) (n int, err error) { if t.readWriter == nil { return 0, errors.New("cannot read: not connected, no readwriter") } return t.readWriter.Read(p) } -func (t XMPPTransport) Write(p []byte) (n int, err error) { +func (t *XMPPTransport) Write(p []byte) (n int, err error) { if t.readWriter == nil { return 0, errors.New("cannot write: not connected, no readwriter") } return t.readWriter.Write(p) } -func (t XMPPTransport) Close() error { +func (t *XMPPTransport) Close() error { if t.readWriter != nil { - _, _ = t.readWriter.Write([]byte("")) + _, _ = t.readWriter.Write([]byte(stanza.StreamClose)) } + + // Try to wait for the stream close tag from the server. After a timeout, disconnect anyway. + select { + case <-t.closeChan: + case <-time.After(time.Duration(t.Config.ConnectTimeout) * time.Second): + } + if t.conn != nil { return t.conn.Close() } @@ -143,3 +152,7 @@ func (t XMPPTransport) Close() error { func (t *XMPPTransport) LogTraffic(logFile io.Writer) { t.logFile = logFile } + +func (t *XMPPTransport) ReceivedStreamClose() { + t.closeChan <- stanza.StreamClosePacket{} +}