diff --git a/_examples/xmpp_chat_client/xmpp_chat_client.go b/_examples/xmpp_chat_client/xmpp_chat_client.go new file mode 100644 index 0000000..2b2d2e7 --- /dev/null +++ b/_examples/xmpp_chat_client/xmpp_chat_client.go @@ -0,0 +1,95 @@ +package main + +/* +xmpp_chat_client is a demo client that connect on an XMPP server to chat with other members +Note that this example sends to a very specific user. User logic is not implemented here. +*/ + +import ( + . "bufio" + "fmt" + "os" + + "gosrc.io/xmpp" + "gosrc.io/xmpp/stanza" +) + +const ( + currentUserAddress = "localhost:5222" + currentUserJid = "testuser@localhost" + currentUserPass = "testpass" + correspondantJid = "testuser2@localhost" +) + +func main() { + config := xmpp.Config{ + TransportConfiguration: xmpp.TransportConfiguration{ + Address: currentUserAddress, + }, + Jid: currentUserJid, + Credential: xmpp.Password(currentUserPass), + Insecure: true} + + var client *xmpp.Client + var err error + router := xmpp.NewRouter() + router.HandleFunc("message", handleMessage) + if client, err = xmpp.NewClient(config, router, errorHandler); err != nil { + fmt.Println("Error new client") + } + + // Connecting client and handling messages + // To use a stream manager, just write something like this instead : + //cm := xmpp.NewStreamManager(client, startMessaging) + //log.Fatal(cm.Run()) //=> this will lock the calling goroutine + + if err = client.Connect(); err != nil { + fmt.Printf("XMPP connection failed: %s", err) + return + } + startMessaging(client) + +} + +func startMessaging(client xmpp.Sender) { + reader := NewReader(os.Stdin) + textChan := make(chan string) + var text string + for { + fmt.Print("Enter text: ") + go readInput(reader, textChan) + select { + case <-killChan: + return + case text = <-textChan: + reply := stanza.Message{Attrs: stanza.Attrs{To: correspondantJid}, Body: text} + err := client.Send(reply) + if err != nil { + fmt.Printf("There was a problem sending the message : %v", reply) + return + } + } + } +} + +func readInput(reader *Reader, textChan chan string) { + text, _ := reader.ReadString('\n') + textChan <- text +} + +var killChan = make(chan struct{}) + +// If an error occurs, this is used +func errorHandler(err error) { + fmt.Printf("%v", err) + killChan <- struct{}{} +} + +func handleMessage(s xmpp.Sender, p stanza.Packet) { + msg, ok := p.(stanza.Message) + if !ok { + _, _ = fmt.Fprintf(os.Stdout, "Ignoring packet: %T\n", p) + return + } + _, _ = fmt.Fprintf(os.Stdout, "Body = %s - from = %s\n", msg.Body, msg.From) +} diff --git a/_examples/xmpp_component/xmpp_component.go b/_examples/xmpp_component/xmpp_component.go index 0452888..7f676cb 100644 --- a/_examples/xmpp_component/xmpp_component.go +++ b/_examples/xmpp_component/xmpp_component.go @@ -35,7 +35,7 @@ func main() { IQNamespaces("jabber:iq:version"). HandlerFunc(handleVersion) - component, err := xmpp.NewComponent(opts, router) + component, err := xmpp.NewComponent(opts, router, handleError) if err != nil { log.Fatalf("%+v", err) } @@ -47,6 +47,10 @@ func main() { log.Fatal(cm.Run()) } +func handleError(err error) { + fmt.Println(err.Error()) +} + func handleMessage(_ xmpp.Sender, p stanza.Packet) { msg, ok := p.(stanza.Message) if !ok { diff --git a/_examples/xmpp_jukebox/xmpp_jukebox.go b/_examples/xmpp_jukebox/xmpp_jukebox.go index 91f453c..ce7ebc9 100644 --- a/_examples/xmpp_jukebox/xmpp_jukebox.go +++ b/_examples/xmpp_jukebox/xmpp_jukebox.go @@ -53,7 +53,7 @@ func main() { handleIQ(s, p, player) }) - client, err := xmpp.NewClient(config, router) + client, err := xmpp.NewClient(config, router, errorHandler) if err != nil { log.Fatalf("%+v", err) } @@ -61,6 +61,9 @@ func main() { cm := xmpp.NewStreamManager(client, nil) log.Fatal(cm.Run()) } +func errorHandler(err error) { + fmt.Println(err.Error()) +} func handleMessage(s xmpp.Sender, p stanza.Packet, player *mpg123.Player) { msg, ok := p.(stanza.Message) diff --git a/_examples/xmpp_oauth2/xmpp_oauth2.go b/_examples/xmpp_oauth2/xmpp_oauth2.go index f322447..89b2639 100644 --- a/_examples/xmpp_oauth2/xmpp_oauth2.go +++ b/_examples/xmpp_oauth2/xmpp_oauth2.go @@ -28,7 +28,7 @@ func main() { router := xmpp.NewRouter() router.HandleFunc("message", handleMessage) - client, err := xmpp.NewClient(config, router) + client, err := xmpp.NewClient(config, router, errorHandler) if err != nil { log.Fatalf("%+v", err) } @@ -39,6 +39,10 @@ func main() { log.Fatal(cm.Run()) } +func errorHandler(err error) { + fmt.Println(err.Error()) +} + func handleMessage(s xmpp.Sender, p stanza.Packet) { msg, ok := p.(stanza.Message) if !ok { diff --git a/_examples/xmpp_websocket/xmpp_websocket.go b/_examples/xmpp_websocket/xmpp_websocket.go index 428a1d1..c8c0620 100644 --- a/_examples/xmpp_websocket/xmpp_websocket.go +++ b/_examples/xmpp_websocket/xmpp_websocket.go @@ -26,7 +26,7 @@ func main() { router := xmpp.NewRouter() router.HandleFunc("message", handleMessage) - client, err := xmpp.NewClient(config, router) + client, err := xmpp.NewClient(config, router, errorHandler) if err != nil { log.Fatalf("%+v", err) } @@ -37,6 +37,10 @@ func main() { log.Fatal(cm.Run()) } +func errorHandler(err error) { + fmt.Println(err.Error()) +} + func handleMessage(s xmpp.Sender, p stanza.Packet) { msg, ok := p.(stanza.Message) if !ok { diff --git a/client.go b/client.go index 14537db..cc152f3 100644 --- a/client.go +++ b/client.go @@ -98,6 +98,8 @@ type Client struct { router *Router // Track and broadcast connection state EventManager + // Handle errors from client execution + ErrorHandler func(error) } /* @@ -107,7 +109,7 @@ Setting up the client / Checking the parameters // NewClient generates a new XMPP client, based on Config passed as parameters. // If host is not specified, the DNS SRV should be used to find the host from the domainpart of the JID. // Default the port to 5222. -func NewClient(config Config, r *Router) (c *Client, err error) { +func NewClient(config Config, r *Router, errorHandler func(error)) (c *Client, err error) { // Parse JID if config.parsedJid, err = NewJid(config.Jid); err != nil { err = errors.New("missing jid") @@ -140,6 +142,7 @@ func NewClient(config Config, r *Router) (c *Client, err error) { c = new(Client) c.config = config c.router = r + c.ErrorHandler = errorHandler if c.config.ConnectTimeout == 0 { c.config.ConnectTimeout = 15 // 15 second as default @@ -185,13 +188,10 @@ func (c *Client) Resume(state SMState) error { // Start the keepalive go routine keepaliveQuit := make(chan struct{}) - go keepalive(c.transport, keepaliveQuit) + go keepalive(c, keepaliveQuit) // Start the receiver go routine state = c.Session.SMState - // Leaving this channel here for later. Not used atm. We should return this instead of an error because right - // now the returned error is lost in limbo. - errChan := make(chan error) - go c.recv(state, keepaliveQuit, errChan) + go c.recv(state, keepaliveQuit) // We're connected and can now receive and send messages. //fmt.Fprintf(client.conn, "%s%s", "chat", "Online") @@ -270,11 +270,11 @@ func (c *Client) sendWithWriter(writer io.Writer, packet []byte) error { // Go routines // Loop: Receive data from server -func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}, errChan chan<- error) { +func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) { for { val, err := stanza.NextPacket(c.transport.GetDecoder()) if err != nil { - errChan <- err + c.ErrorHandler(err) close(keepaliveQuit) c.disconnected(state) return @@ -286,7 +286,7 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}, errChan chan c.router.route(c, val) close(keepaliveQuit) c.streamError(packet.Error.Local, packet.Text) - errChan <- errors.New("stream error: " + packet.Error.Local) + c.ErrorHandler(errors.New("stream error: " + packet.Error.Local)) return // Process Stream management nonzas case stanza.SMRequest: @@ -296,7 +296,7 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}, errChan chan }, H: state.Inbound} err = c.Send(answer) if err != nil { - errChan <- err + c.ErrorHandler(err) return } default: @@ -312,8 +312,9 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}, errChan chan // Loop: send whitespace keepalive to server // This is use to keep the connection open, but also to detect connection loss // and trigger proper client connection shutdown. -func keepalive(transport Transport, quit <-chan struct{}) { +func keepalive(c *Client, quit <-chan struct{}) { // TODO: Make keepalive interval configurable + transport := c.transport ticker := time.NewTicker(30 * time.Second) for { select { diff --git a/client_test.go b/client_test.go index 2636f29..15e104f 100644 --- a/client_test.go +++ b/client_test.go @@ -1,6 +1,7 @@ package xmpp import ( + "context" "encoding/xml" "errors" "fmt" @@ -14,15 +15,14 @@ import ( const ( // Default port is not standard XMPP port to avoid interfering // with local running XMPP server - testXMPPAddress = "localhost:15222" - - defaultTimeout = 2 * time.Second + testXMPPAddress = "localhost:15222" + testClientDomain = "localhost" ) func TestClient_Connect(t *testing.T) { // Setup Mock server mock := ServerMock{} - mock.Start(t, testXMPPAddress, handlerConnectSuccess) + mock.Start(t, testXMPPAddress, handlerClientConnectSuccess) // Test / Check result config := Config{ @@ -36,7 +36,7 @@ func TestClient_Connect(t *testing.T) { var client *Client var err error router := NewRouter() - if client, err = NewClient(config, router); err != nil { + if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil { t.Errorf("connect create XMPP client: %s", err) } @@ -64,7 +64,7 @@ func TestClient_NoInsecure(t *testing.T) { var client *Client var err error router := NewRouter() - if client, err = NewClient(config, router); err != nil { + if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil { t.Errorf("cannot create XMPP client: %s", err) } @@ -94,7 +94,7 @@ func TestClient_FeaturesTracking(t *testing.T) { var client *Client var err error router := NewRouter() - if client, err = NewClient(config, router); err != nil { + if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil { t.Errorf("cannot create XMPP client: %s", err) } @@ -109,7 +109,7 @@ func TestClient_FeaturesTracking(t *testing.T) { func TestClient_RFC3921Session(t *testing.T) { // Setup Mock server mock := ServerMock{} - mock.Start(t, testXMPPAddress, handlerConnectWithSession) + mock.Start(t, testXMPPAddress, handlerClientConnectWithSession) // Test / Check result config := Config{ @@ -124,7 +124,7 @@ func TestClient_RFC3921Session(t *testing.T) { var client *Client var err error router := NewRouter() - if client, err = NewClient(config, router); err != nil { + if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil { t.Errorf("connect create XMPP client: %s", err) } @@ -135,48 +135,254 @@ func TestClient_RFC3921Session(t *testing.T) { mock.Stop() } +// Testing sending an IQ to the mock server and reading its response. +func TestClient_SendIQ(t *testing.T) { + done := make(chan struct{}) + // Handler for Mock server + h := func(t *testing.T, c net.Conn) { + handlerClientConnectSuccess(t, c) + discardPresence(t, c) + respondToIQ(t, c) + done <- struct{}{} + } + client, mock := mockClientConnection(t, h, testClientIqPort) + + ctx, _ := context.WithTimeout(context.Background(), 30*time.Second) + iqReq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "test1@localhost/mremond-mbp", To: defaultServerName, Id: defaultStreamID, Lang: "en"}) + disco := iqReq.DiscoInfo() + iqReq.Payload = disco + + // Handle a possible error + errChan := make(chan error) + errorHandler := func(err error) { + errChan <- err + } + client.ErrorHandler = errorHandler + res, err := client.SendIQ(ctx, iqReq) + if err != nil { + t.Errorf(err.Error()) + } + + select { + case <-res: // If the server responds with an IQ, we pass the test + case err := <-errChan: // If the server sends an error, or there is a connection error + t.Errorf(err.Error()) + case <-time.After(defaultChannelTimeout): // If we timeout + t.Errorf("Failed to receive response, to sent IQ, from mock server") + } + select { + case <-done: + mock.Stop() + case <-time.After(defaultChannelTimeout): + t.Errorf("The mock server failed to finish its job !") + } +} + +func TestClient_SendIQFail(t *testing.T) { + done := make(chan struct{}) + // Handler for Mock server + h := func(t *testing.T, c net.Conn) { + handlerClientConnectSuccess(t, c) + discardPresence(t, c) + respondToIQ(t, c) + done <- struct{}{} + } + client, mock := mockClientConnection(t, h, testClientIqFailPort) + + //================== + // Create an IQ to send + ctx, _ := context.WithTimeout(context.Background(), 30*time.Second) + iqReq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "test1@localhost/mremond-mbp", To: defaultServerName, Id: defaultStreamID, Lang: "en"}) + disco := iqReq.DiscoInfo() + iqReq.Payload = disco + // Removing the id to make the stanza invalid. The IQ constructor makes a random one if none is specified + // so we need to overwrite it. + iqReq.Id = "" + + // Handle a possible error + errChan := make(chan error) + errorHandler := func(err error) { + errChan <- err + } + client.ErrorHandler = errorHandler + res, _ := client.SendIQ(ctx, iqReq) + + // Test + select { + case <-res: // If the server responds with an IQ + t.Errorf("Server should not respond with an IQ since the request is expected to be invalid !") + case <-errChan: // If the server sends an error, the test passes + case <-time.After(defaultChannelTimeout): // If we timeout + t.Errorf("Failed to receive response, to sent IQ, from mock server") + } + select { + case <-done: + mock.Stop() + case <-time.After(defaultChannelTimeout): + t.Errorf("The mock server failed to finish its job !") + } +} + +func TestClient_SendRaw(t *testing.T) { + done := make(chan struct{}) + // Handler for Mock server + h := func(t *testing.T, c net.Conn) { + handlerClientConnectSuccess(t, c) + discardPresence(t, c) + respondToIQ(t, c) + done <- struct{}{} + } + type testCase struct { + req string + shouldErr bool + port int + } + testRequests := make(map[string]testCase) + // Sending a correct IQ of type get. Not supposed to err + testRequests["Correct IQ"] = testCase{ + req: ``, + shouldErr: false, + port: testClientRawPort + 100, + } + // Sending an IQ with a missing ID. Should err + testRequests["IQ with missing ID"] = testCase{ + req: ``, + shouldErr: true, + port: testClientRawPort, + } + + // A handler for the client. + // In the failing test, the server returns a stream error, which triggers this handler, client side. + errChan := make(chan error) + errHandler := func(err error) { + errChan <- err + } + + // Tests for all the IQs + for name, tcase := range testRequests { + t.Run(name, func(st *testing.T) { + //Connecting to a mock server, initialized with given port and handler function + c, m := mockClientConnection(t, h, tcase.port) + c.ErrorHandler = errHandler + // Sending raw xml from test case + err := c.SendRaw(tcase.req) + if err != nil { + t.Errorf("Error sending Raw string") + } + // Just wait a little so the message has time to arrive + select { + // We don't use the default "long" timeout here because waiting it out means passing the test. + case <-time.After(100 * time.Millisecond): + case err = <-errChan: + if err == nil && tcase.shouldErr { + t.Errorf("Failed to get closing stream err") + } else if err != nil && !tcase.shouldErr { + t.Errorf("This test is not supposed to err !") + } + } + c.transport.Close() + select { + case <-done: + m.Stop() + case <-time.After(defaultChannelTimeout): + t.Errorf("The mock server failed to finish its job !") + } + }) + } +} + +func TestClient_Disconnect(t *testing.T) { + c, m := mockClientConnection(t, handlerClientConnectSuccess, testClientBasePort) + err := c.transport.Ping() + if err != nil { + t.Errorf("Could not ping but not disconnected yet") + } + c.Disconnect() + err = c.transport.Ping() + if err == nil { + t.Errorf("Did not disconnect properly") + } + m.Stop() +} + +func TestClient_DisconnectStreamManager(t *testing.T) { + // Init mock server + // Setup Mock server + mock := ServerMock{} + mock.Start(t, testXMPPAddress, handlerAbortTLS) + + // Test / Check result + config := Config{ + TransportConfiguration: TransportConfiguration{ + Address: testXMPPAddress, + }, + Jid: "test@localhost", + Credential: Password("test"), + } + + var client *Client + var err error + router := NewRouter() + if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil { + t.Errorf("cannot create XMPP client: %s", err) + } + + sman := NewStreamManager(client, nil) + errChan := make(chan error) + runSMan := func(errChan chan error) { + errChan <- sman.Run() + } + + go runSMan(errChan) + select { + case <-errChan: + case <-time.After(defaultChannelTimeout): + // When insecure is not allowed: + t.Errorf("should fail as insecure connection is not allowed and server does not support TLS") + } + mock.Stop() +} + //============================================================================= // Basic XMPP Server Mock Handlers. -const serverStreamOpen = "" - // Test connection with a basic straightforward workflow -func handlerConnectSuccess(t *testing.T, c net.Conn) { +func handlerClientConnectSuccess(t *testing.T, c net.Conn) { decoder := xml.NewDecoder(c) - checkOpenStream(t, c, decoder) + checkClientOpenStream(t, c, decoder) sendStreamFeatures(t, c, decoder) // Send initial features readAuth(t, decoder) fmt.Fprintln(c, "") - checkOpenStream(t, c, decoder) // Reset stream - sendBindFeature(t, c, decoder) // Send post auth features + checkClientOpenStream(t, c, decoder) // Reset stream + sendBindFeature(t, c, decoder) // Send post auth features bind(t, c, decoder) } // We expect client will abort on TLS func handlerAbortTLS(t *testing.T, c net.Conn) { decoder := xml.NewDecoder(c) - checkOpenStream(t, c, decoder) + checkClientOpenStream(t, c, decoder) sendStreamFeatures(t, c, decoder) // Send initial features } // Test connection with mandatory session (RFC-3921) -func handlerConnectWithSession(t *testing.T, c net.Conn) { +func handlerClientConnectWithSession(t *testing.T, c net.Conn) { decoder := xml.NewDecoder(c) - checkOpenStream(t, c, decoder) + checkClientOpenStream(t, c, decoder) sendStreamFeatures(t, c, decoder) // Send initial features readAuth(t, decoder) fmt.Fprintln(c, "") - checkOpenStream(t, c, decoder) // Reset stream - sendRFC3921Feature(t, c, decoder) // Send post auth features + checkClientOpenStream(t, c, decoder) // Reset stream + sendRFC3921Feature(t, c, decoder) // Send post auth features bind(t, c, decoder) session(t, c, decoder) } -func checkOpenStream(t *testing.T, c net.Conn, decoder *xml.Decoder) { +func checkClientOpenStream(t *testing.T, c net.Conn, decoder *xml.Decoder) { c.SetDeadline(time.Now().Add(defaultTimeout)) defer c.SetDeadline(time.Time{}) @@ -202,105 +408,35 @@ func checkOpenStream(t *testing.T, c net.Conn, decoder *xml.Decoder) { } } -func sendStreamFeatures(t *testing.T, c net.Conn, _ *xml.Decoder) { - // This is a basic server, supporting only 1 stream feature: SASL Plain Auth - features := ` - - PLAIN - -` - if _, err := fmt.Fprintln(c, features); err != nil { - t.Errorf("cannot send stream feature: %s", err) +func mockClientConnection(t *testing.T, serverHandler func(*testing.T, net.Conn), port int) (*Client, ServerMock) { + mock := ServerMock{} + testServerAddress := fmt.Sprintf("%s:%d", testClientDomain, port) + + mock.Start(t, testServerAddress, serverHandler) + + config := Config{ + TransportConfiguration: TransportConfiguration{ + Address: testServerAddress, + }, + Jid: "test@localhost", + Credential: Password("test"), + Insecure: true} + + var client *Client + var err error + router := NewRouter() + if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil { + t.Errorf("connect create XMPP client: %s", err) } + + if err = client.Connect(); err != nil { + t.Errorf("XMPP connection failed: %s", err) + } + + return client, mock } -// TODO return err in case of error reading the auth params -func readAuth(t *testing.T, decoder *xml.Decoder) string { - se, err := stanza.NextStart(decoder) - if err != nil { - t.Errorf("cannot read auth: %s", err) - return "" - } - - var nv interface{} - nv = &stanza.SASLAuth{} - // Decode element into pointer storage - if err = decoder.DecodeElement(nv, &se); err != nil { - t.Errorf("cannot decode auth: %s", err) - return "" - } - - switch v := nv.(type) { - case *stanza.SASLAuth: - return v.Value - } - return "" -} - -func sendBindFeature(t *testing.T, c net.Conn, _ *xml.Decoder) { - // This is a basic server, supporting only 1 stream feature after auth: resource binding - features := ` - -` - if _, err := fmt.Fprintln(c, features); err != nil { - t.Errorf("cannot send stream feature: %s", err) - } -} - -func sendRFC3921Feature(t *testing.T, c net.Conn, _ *xml.Decoder) { - // This is a basic server, supporting only 2 features after auth: resource & session binding - features := ` - - -` - if _, err := fmt.Fprintln(c, features); err != nil { - t.Errorf("cannot send stream feature: %s", err) - } -} - -func bind(t *testing.T, c net.Conn, decoder *xml.Decoder) { - se, err := stanza.NextStart(decoder) - if err != nil { - t.Errorf("cannot read bind: %s", err) - return - } - - iq := &stanza.IQ{} - // Decode element into pointer storage - if err = decoder.DecodeElement(&iq, &se); err != nil { - t.Errorf("cannot decode bind iq: %s", err) - return - } - - // TODO Check all elements - switch iq.Payload.(type) { - case *stanza.Bind: - result := ` - - %s - -` - fmt.Fprintf(c, result, iq.Id, "test@localhost/test") // TODO use real JID - } -} - -func session(t *testing.T, c net.Conn, decoder *xml.Decoder) { - se, err := stanza.NextStart(decoder) - if err != nil { - t.Errorf("cannot read session: %s", err) - return - } - - iq := &stanza.IQ{} - // Decode element into pointer storage - if err = decoder.DecodeElement(&iq, &se); err != nil { - t.Errorf("cannot decode session iq: %s", err) - return - } - - switch iq.Payload.(type) { - case *stanza.StreamSession: - result := `` - fmt.Fprintf(c, result, iq.Id) - } +// This really should not be used as is. +// It's just meant to be a placeholder when error handling is not needed at this level +func clientDefaultErrorHandler(err error) { } diff --git a/component.go b/component.go index 471f1db..2f61aef 100644 --- a/component.go +++ b/component.go @@ -48,11 +48,12 @@ type Component struct { transport Transport // read / write - socketProxy io.ReadWriter // TODO + socketProxy io.ReadWriter // TODO + ErrorHandler func(error) } -func NewComponent(opts ComponentOptions, r *Router) (*Component, error) { - c := Component{ComponentOptions: opts, router: r} +func NewComponent(opts ComponentOptions, r *Router, errorHandler func(error)) (*Component, error) { + c := Component{ComponentOptions: opts, router: r, ErrorHandler: errorHandler} return &c, nil } @@ -104,11 +105,8 @@ func (c *Component) Resume(sm SMState) error { case stanza.Handshake: // Start the receiver go routine c.updateState(StateSessionEstablished) - // Leaving this channel here for later. Not used atm. We should return this instead of an error because right - // now the returned error is lost in limbo. - errChan := make(chan error) - go c.recv(errChan) // Sends to errChan - return err // Should be empty at this point + go c.recv() + return err // Should be empty at this point default: c.updateState(StatePermanentError) return NewConnError(errors.New("expecting handshake result, got "+v.Name()), true) @@ -128,13 +126,13 @@ func (c *Component) SetHandler(handler EventHandler) { } // Receiver Go routine receiver -func (c *Component) recv(errChan chan<- error) { +func (c *Component) recv() { for { val, err := stanza.NextPacket(c.transport.GetDecoder()) if err != nil { c.updateState(StateDisconnected) - errChan <- err + c.ErrorHandler(err) return } // Handle stream errors @@ -142,7 +140,7 @@ func (c *Component) recv(errChan chan<- error) { case stanza.StreamError: c.router.route(c, val) c.streamError(p.Error.Local, p.Text) - errChan <- errors.New("stream error: " + p.Error.Local) + c.ErrorHandler(errors.New("stream error: " + p.Error.Local)) return } c.router.route(c, val) diff --git a/component_test.go b/component_test.go index 4e115f0..48963a5 100644 --- a/component_test.go +++ b/component_test.go @@ -5,6 +5,7 @@ import ( "encoding/xml" "errors" "fmt" + "github.com/google/uuid" "gosrc.io/xmpp/stanza" "net" "strings" @@ -15,19 +16,7 @@ import ( // Tests are ran in parallel, so each test creating a server must use a different port so we do not get any // conflict. Using iota for this should do the trick. const ( - testComponentDomain = "localhost" - defaultServerName = "testServer" - defaultStreamID = "91bd0bba-012f-4d92-bb17-5fc41e6fe545" - defaultComponentName = "Test Component" - - // Default port is not standard XMPP port to avoid interfering - // with local running XMPP server - testHandshakePort = iota + 15222 - testDecoderPort - testSendIqPort - testSendRawPort - testDisconnectPort - testSManDisconnectPort + defaultChannelTimeout = 5 * time.Second ) func TestHandshake(t *testing.T) { @@ -48,16 +37,14 @@ func TestHandshake(t *testing.T) { // Tests connection process with a handshake exchange // Tests multiple session IDs. All connections should generate a unique stream ID -func TestGenerateHandshake(t *testing.T) { +func TestGenerateHandshakeId(t *testing.T) { // Using this array with a channel to make a queue of values to test // These are stream IDs that will be used to test the connection process, mixing them with the "secret" to generate // some handshake value - var uuidsArray = [5]string{ - "cc9b3249-9582-4780-825f-4311b42f9b0e", - "bba8be3c-d98e-4e26-b9bb-9ed34578a503", - "dae72822-80e8-496b-b763-ab685f53a188", - "a45d6c06-de49-4bb0-935b-1a2201b71028", - "7dc6924f-0eca-4237-9898-18654b8d891e", + var uuidsArray = [5]string{} + for i := 1; i < len(uuidsArray); i++ { + id, _ := uuid.NewRandom() + uuidsArray[i] = id.String() } // Channel to pass stream IDs as a queue @@ -95,7 +82,7 @@ func TestGenerateHandshake(t *testing.T) { Type: "service", } router := NewRouter() - c, err := NewComponent(opts, router) + c, err := NewComponent(opts, router, componentDefaultErrorHandler) if err != nil { t.Errorf("%+v", err) } @@ -126,7 +113,7 @@ func TestStreamManager(t *testing.T) { // The decoder is expected to be built after a valid connection // Based on the xmpp_component example. func TestDecoder(t *testing.T) { - c, _ := mockConnection(t, testDecoderPort, handlerForComponentHandshakeDefaultID) + c, _ := mockComponentConnection(t, testDecoderPort, handlerForComponentHandshakeDefaultID) if c.transport.GetDecoder() == nil { t.Errorf("Failed to initialize decoder. Decoder is nil.") } @@ -134,39 +121,103 @@ func TestDecoder(t *testing.T) { // Tests sending an IQ to the server, and getting the response func TestSendIq(t *testing.T) { + done := make(chan struct{}) + h := func(t *testing.T, c net.Conn) { + handlerForComponentIQSend(t, c) + done <- struct{}{} + } + //Connecting to a mock server, initialized with given port and handler function - c, m := mockConnection(t, testSendIqPort, handlerForComponentIQSend) + c, m := mockComponentConnection(t, testSendIqPort, h) ctx, _ := context.WithTimeout(context.Background(), 30*time.Second) iqReq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "test1@localhost/mremond-mbp", To: defaultServerName, Id: defaultStreamID, Lang: "en"}) disco := iqReq.DiscoInfo() iqReq.Payload = disco + // Handle a possible error + errChan := make(chan error) + errorHandler := func(err error) { + errChan <- err + } + c.ErrorHandler = errorHandler + var res chan stanza.IQ res, _ = c.SendIQ(ctx, iqReq) select { case <-res: - case <-time.After(100 * time.Millisecond): + case err := <-errChan: + t.Errorf(err.Error()) + case <-time.After(defaultChannelTimeout): t.Errorf("Failed to receive response, to sent IQ, from mock server") } - m.Stop() + select { + case <-done: + m.Stop() + case <-time.After(defaultChannelTimeout): + t.Errorf("The mock server failed to finish its job !") + } +} + +// Checking that error handling is done properly client side when an invalid IQ is sent and the server responds in kind. +func TestSendIqFail(t *testing.T) { + done := make(chan struct{}) + h := func(t *testing.T, c net.Conn) { + handlerForComponentIQSend(t, c) + done <- struct{}{} + } + //Connecting to a mock server, initialized with given port and handler function + c, m := mockComponentConnection(t, testSendIqFailPort, h) + + ctx, _ := context.WithTimeout(context.Background(), 30*time.Second) + iqReq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "test1@localhost/mremond-mbp", To: defaultServerName, Id: defaultStreamID, Lang: "en"}) + + // Removing the id to make the stanza invalid. The IQ constructor makes a random one if none is specified + // so we need to overwrite it. + iqReq.Id = "" + disco := iqReq.DiscoInfo() + iqReq.Payload = disco + + errChan := make(chan error) + errorHandler := func(err error) { + errChan <- err + } + c.ErrorHandler = errorHandler + + var res chan stanza.IQ + res, _ = c.SendIQ(ctx, iqReq) + + select { + case r := <-res: // Do we get an IQ response from the server ? + t.Errorf("We should not be getting an IQ response here : this should fail !") + fmt.Println(r) + case <-errChan: // Do we get a stream error from the server ? + // If we get an error from the server, the test passes. + case <-time.After(defaultChannelTimeout): // Timeout ? + t.Errorf("Failed to receive response, to sent IQ, from mock server") + } + + select { + case <-done: + m.Stop() + case <-time.After(defaultChannelTimeout): + t.Errorf("The mock server failed to finish its job !") + } } // Tests sending raw xml to the mock server. -// TODO : check the server response client side ? // Right now, the server response is not checked and an err is passed in a channel if the test is supposed to err. // In this test, we use IQs func TestSendRaw(t *testing.T) { - // Error channel for the handler - errChan := make(chan error) + done := make(chan struct{}) // Handler for the mock server h := func(t *testing.T, c net.Conn) { // Completes the connection by exchanging handshakes handlerForComponentHandshakeDefaultID(t, c) - receiveRawIq(t, c, errChan) - return + receiveIq(c, xml.NewDecoder(c)) + done <- struct{}{} } type testCase struct { @@ -185,12 +236,19 @@ func TestSendRaw(t *testing.T) { shouldErr: true, } + // A handler for the component. + // In the failing test, the server returns a stream error, which triggers this handler, component side. + errChan := make(chan error) + errHandler := func(err error) { + errChan <- err + } + // Tests for all the IQs for name, tcase := range testRequests { t.Run(name, func(st *testing.T) { //Connecting to a mock server, initialized with given port and handler function - c, m := mockConnection(t, testSendRawPort, h) - + c, m := mockComponentConnection(t, testSendRawPort, h) + c.ErrorHandler = errHandler // Sending raw xml from test case err := c.SendRaw(tcase.req) if err != nil { @@ -198,21 +256,29 @@ func TestSendRaw(t *testing.T) { } // Just wait a little so the message has time to arrive select { - case <-time.After(100 * time.Millisecond): + // We don't use the default "long" timeout here because waiting it out means passing the test. + case <-time.After(200 * time.Millisecond): case err = <-errChan: if err == nil && tcase.shouldErr { t.Errorf("Failed to get closing stream err") + } else if err != nil && !tcase.shouldErr { + t.Errorf("This test is not supposed to err ! => %s", err.Error()) } } c.transport.Close() - m.Stop() + select { + case <-done: + m.Stop() + case <-time.After(defaultChannelTimeout): + t.Errorf("The mock server failed to finish its job !") + } }) } } // Tests the Disconnect method for Components func TestDisconnect(t *testing.T) { - c, m := mockConnection(t, testDisconnectPort, handlerForComponentHandshakeDefaultID) + c, m := mockComponentConnection(t, testDisconnectPort, handlerForComponentHandshakeDefaultID) err := c.transport.Ping() if err != nil { t.Errorf("Could not ping but not disconnected yet") @@ -257,14 +323,97 @@ func TestStreamManagerDisconnect(t *testing.T) { //============================================================================= // Basic XMPP Server Mock Handlers. -// Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant. -// Used in the mock server as a Handler -func handlerForComponentHandshakeDefaultID(t *testing.T, c net.Conn) { - decoder := xml.NewDecoder(c) - checkOpenStreamHandshakeDefaultID(t, c, decoder) - readHandshakeComponent(t, decoder) - fmt.Fprintln(c, "") // That's all the server needs to return (see xep-0114) - return + +//=============================== +// Init mock server and connection +// Creating a mock server and connecting a Component to it. Initialized with given port and handler function +// The Component and mock are both returned +func mockComponentConnection(t *testing.T, port int, handler func(t *testing.T, c net.Conn)) (*Component, *ServerMock) { + // Init mock server + testComponentAddress := fmt.Sprintf("%s:%d", testComponentDomain, port) + mock := ServerMock{} + mock.Start(t, testComponentAddress, handler) + + //================================== + // Create Component to connect to it + c := makeBasicComponent(defaultComponentName, testComponentAddress, t) + + //======================================== + // Connect the new Component to the server + err := c.Connect() + if err != nil { + t.Errorf("%+v", err) + } + + return c, &mock +} + +func makeBasicComponent(name string, mockServerAddr string, t *testing.T) *Component { + opts := ComponentOptions{ + TransportConfiguration: TransportConfiguration{ + Address: mockServerAddr, + Domain: "localhost", + }, + Domain: testComponentDomain, + Secret: "mypass", + Name: name, + Category: "gateway", + Type: "service", + } + router := NewRouter() + c, err := NewComponent(opts, router, componentDefaultErrorHandler) + if err != nil { + t.Errorf("%+v", err) + } + c.transport, err = NewComponentTransport(c.ComponentOptions.TransportConfiguration) + if err != nil { + t.Errorf("%+v", err) + } + return c +} + +// This really should not be used as is. +// It's just meant to be a placeholder when error handling is not needed at this level +func componentDefaultErrorHandler(err error) { + +} + +// Sends IQ response to Component request. +// No parsing of the request here. We just check that it's valid, and send the default response. +func handlerForComponentIQSend(t *testing.T, c net.Conn) { + // Completes the connection by exchanging handshakes + handlerForComponentHandshakeDefaultID(t, c) + respondToIQ(t, c) +} + +// Used for ID and handshake related tests +func checkOpenStreamHandshakeID(t *testing.T, c net.Conn, decoder *xml.Decoder, streamID string) { + c.SetDeadline(time.Now().Add(defaultTimeout)) + defer c.SetDeadline(time.Time{}) + + for { // TODO clean up. That for loop is not elegant and I prefer bounded recursion. + token, err := decoder.Token() + if err != nil { + t.Errorf("cannot read next token: %s", err) + } + + switch elem := token.(type) { + // Wait for first startElement + case xml.StartElement: + if elem.Name.Space != stanza.NSStream || elem.Name.Local != "stream" { + err = errors.New("xmpp: expected but got <" + elem.Name.Local + "> in " + elem.Name.Space) + return + } + if _, err := fmt.Fprintf(c, serverStreamOpen, "localhost", streamID, stanza.NSComponent, stanza.NSStream); err != nil { + t.Errorf("cannot write server stream open: %s", err) + } + return + } + } +} + +func checkOpenStreamHandshakeDefaultID(t *testing.T, c net.Conn, decoder *xml.Decoder) { + checkOpenStreamHandshakeID(t, c, decoder, defaultStreamID) } // Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant. @@ -303,152 +452,12 @@ func readHandshakeComponent(t *testing.T, decoder *xml.Decoder) { } } -func checkOpenStreamHandshakeDefaultID(t *testing.T, c net.Conn, decoder *xml.Decoder) { - checkOpenStreamHandshakeID(t, c, decoder, defaultStreamID) -} - -// Used for ID and handshake related tests -func checkOpenStreamHandshakeID(t *testing.T, c net.Conn, decoder *xml.Decoder, streamID string) { - c.SetDeadline(time.Now().Add(defaultTimeout)) - defer c.SetDeadline(time.Time{}) - - for { // TODO clean up. That for loop is not elegant and I prefer bounded recursion. - token, err := decoder.Token() - if err != nil { - t.Errorf("cannot read next token: %s", err) - } - - switch elem := token.(type) { - // Wait for first startElement - case xml.StartElement: - if elem.Name.Space != stanza.NSStream || elem.Name.Local != "stream" { - err = errors.New("xmpp: expected but got <" + elem.Name.Local + "> in " + elem.Name.Space) - return - } - if _, err := fmt.Fprintf(c, serverStreamOpen, "localhost", streamID, stanza.NSComponent, stanza.NSStream); err != nil { - t.Errorf("cannot write server stream open: %s", err) - } - return - } - } -} - -//============================================================================= -// Sends IQ response to Component request. -// No parsing of the request here. We just check that it's valid, and send the default response. -func handlerForComponentIQSend(t *testing.T, c net.Conn) { - // Completes the connection by exchanging handshakes - handlerForComponentHandshakeDefaultID(t, c) - - // Decoder to parse the request +// Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant. +// Used in the mock server as a Handler +func handlerForComponentHandshakeDefaultID(t *testing.T, c net.Conn) { decoder := xml.NewDecoder(c) - - iqReq, err := receiveIq(t, c, decoder) - if err != nil { - t.Errorf("Error receiving the IQ stanza : %v", err) - } else if !iqReq.IsValid() { - t.Errorf("server received an IQ stanza : %v", iqReq) - } - - // Crafting response - iqResp := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeResult, From: iqReq.To, To: iqReq.From, Id: iqReq.Id, Lang: "en"}) - disco := iqResp.DiscoInfo() - disco.AddFeatures("vcard-temp", - `http://jabber.org/protocol/address`) - - disco.AddIdentity("Multicast", "service", "multicast") - iqResp.Payload = disco - - // Sending response to the Component - mResp, err := xml.Marshal(iqResp) - _, err = fmt.Fprintln(c, string(mResp)) - if err != nil { - t.Errorf("Could not send response stanza : %s", err) - } + checkOpenStreamHandshakeDefaultID(t, c, decoder) + readHandshakeComponent(t, decoder) + fmt.Fprintln(c, "") // That's all the server needs to return (see xep-0114) return } - -// Reads next request coming from the Component. Expecting it to be an IQ request -func receiveIq(t *testing.T, c net.Conn, decoder *xml.Decoder) (stanza.IQ, error) { - c.SetDeadline(time.Now().Add(defaultTimeout)) - defer c.SetDeadline(time.Time{}) - var iqStz stanza.IQ - err := decoder.Decode(&iqStz) - if err != nil { - t.Errorf("cannot read the received IQ stanza: %s", err) - } - if !iqStz.IsValid() { - t.Errorf("received IQ stanza is invalid : %s", err) - } - return iqStz, nil -} - -func receiveRawIq(t *testing.T, c net.Conn, errChan chan error) { - c.SetDeadline(time.Now().Add(defaultTimeout)) - defer c.SetDeadline(time.Time{}) - decoder := xml.NewDecoder(c) - var iq stanza.IQ - err := decoder.Decode(&iq) - if err != nil || !iq.IsValid() { - s := stanza.StreamError{ - XMLName: xml.Name{Local: "stream:error"}, - Error: xml.Name{Local: "xml-not-well-formed"}, - Text: `XML was not well-formed`, - } - raw, _ := xml.Marshal(s) - fmt.Fprintln(c, string(raw)) - fmt.Fprintln(c, ``) // TODO : check this client side - errChan <- fmt.Errorf("invalid xml") - return - } - errChan <- nil - return -} - -//=============================== -// Init mock server and connection -// Creating a mock server and connecting a Component to it. Initialized with given port and handler function -// The Component and mock are both returned -func mockConnection(t *testing.T, port int, handler func(t *testing.T, c net.Conn)) (*Component, *ServerMock) { - // Init mock server - testComponentAddress := fmt.Sprintf("%s:%d", testComponentDomain, port) - mock := ServerMock{} - mock.Start(t, testComponentAddress, handler) - - //================================== - // Create Component to connect to it - c := makeBasicComponent(defaultComponentName, testComponentAddress, t) - - //======================================== - // Connect the new Component to the server - err := c.Connect() - if err != nil { - t.Errorf("%+v", err) - } - - return c, &mock -} - -func makeBasicComponent(name string, mockServerAddr string, t *testing.T) *Component { - opts := ComponentOptions{ - TransportConfiguration: TransportConfiguration{ - Address: mockServerAddr, - Domain: "localhost", - }, - Domain: testComponentDomain, - Secret: "mypass", - Name: name, - Category: "gateway", - Type: "service", - } - router := NewRouter() - c, err := NewComponent(opts, router) - if err != nil { - t.Errorf("%+v", err) - } - c.transport, err = NewComponentTransport(c.ComponentOptions.TransportConfiguration) - if err != nil { - t.Errorf("%+v", err) - } - return c -} diff --git a/tcp_server_mock.go b/tcp_server_mock.go index bdc4397..4afed80 100644 --- a/tcp_server_mock.go +++ b/tcp_server_mock.go @@ -1,12 +1,42 @@ package xmpp import ( + "encoding/xml" + "fmt" + "gosrc.io/xmpp/stanza" "net" "testing" + "time" ) //============================================================================= // TCP Server Mock +const ( + defaultTimeout = 2 * time.Second + testComponentDomain = "localhost" + defaultServerName = "testServer" + defaultStreamID = "91bd0bba-012f-4d92-bb17-5fc41e6fe545" + defaultComponentName = "Test Component" + serverStreamOpen = "" + + // Default port is not standard XMPP port to avoid interfering + // with local running XMPP server + + // Component tests + testHandshakePort = iota + 15222 + testDecoderPort + testSendIqPort + testSendIqFailPort + testSendRawPort + testDisconnectPort + testSManDisconnectPort + + // Client tests + testClientBasePort + testClientRawPort + testClientIqPort + testClientIqFailPort +) // ClientHandler is passed by the test client to provide custom behaviour to // the TCP server mock. This allows customizing the server behaviour to allow @@ -81,3 +111,180 @@ func (mock *ServerMock) loop() { go mock.handler(mock.t, conn) } } + +//====================================================================================================================== +// A few functions commonly used for tests. Trying to avoid duplicates in client and component test files. +//====================================================================================================================== + +func respondToIQ(t *testing.T, c net.Conn) { + // Decoder to parse the request + decoder := xml.NewDecoder(c) + + iqReq, err := receiveIq(c, decoder) + if err != nil { + t.Fatalf("failed to receive IQ : %s", err.Error()) + } + + if !iqReq.IsValid() { + mockIQError(c) + return + } + + // Crafting response + iqResp := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeResult, From: iqReq.To, To: iqReq.From, Id: iqReq.Id, Lang: "en"}) + disco := iqResp.DiscoInfo() + disco.AddFeatures("vcard-temp", + `http://jabber.org/protocol/address`) + + disco.AddIdentity("Multicast", "service", "multicast") + iqResp.Payload = disco + + // Sending response to the Component + mResp, err := xml.Marshal(iqResp) + _, err = fmt.Fprintln(c, string(mResp)) + if err != nil { + t.Errorf("Could not send response stanza : %s", err) + } + return +} + +// When a presence stanza is automatically sent (right now it's the case in the client), we may want to discard it +// and test further stanzas. +func discardPresence(t *testing.T, c net.Conn) { + decoder := xml.NewDecoder(c) + c.SetDeadline(time.Now().Add(defaultTimeout)) + defer c.SetDeadline(time.Time{}) + var presenceStz stanza.Presence + err := decoder.Decode(&presenceStz) + if err != nil { + t.Errorf("Expected presence but this happened : %s", err.Error()) + } +} + +// Reads next request coming from the Component. Expecting it to be an IQ request +func receiveIq(c net.Conn, decoder *xml.Decoder) (*stanza.IQ, error) { + c.SetDeadline(time.Now().Add(defaultTimeout)) + defer c.SetDeadline(time.Time{}) + var iqStz stanza.IQ + err := decoder.Decode(&iqStz) + if err != nil { + return nil, err + } + return &iqStz, nil +} + +// Should be used in server handlers when an IQ sent by a client or component is invalid. +// This responds as expected from a "real" server, aside from the error message. +func mockIQError(c net.Conn) { + s := stanza.StreamError{ + XMLName: xml.Name{Local: "stream:error"}, + Error: xml.Name{Local: "xml-not-well-formed"}, + Text: `XML was not well-formed`, + } + raw, _ := xml.Marshal(s) + fmt.Fprintln(c, string(raw)) + fmt.Fprintln(c, ``) +} + +func sendStreamFeatures(t *testing.T, c net.Conn, _ *xml.Decoder) { + // This is a basic server, supporting only 1 stream feature: SASL Plain Auth + features := ` + + PLAIN + +` + if _, err := fmt.Fprintln(c, features); err != nil { + t.Errorf("cannot send stream feature: %s", err) + } +} + +// TODO return err in case of error reading the auth params +func readAuth(t *testing.T, decoder *xml.Decoder) string { + se, err := stanza.NextStart(decoder) + if err != nil { + t.Errorf("cannot read auth: %s", err) + return "" + } + + var nv interface{} + nv = &stanza.SASLAuth{} + // Decode element into pointer storage + if err = decoder.DecodeElement(nv, &se); err != nil { + t.Errorf("cannot decode auth: %s", err) + return "" + } + + switch v := nv.(type) { + case *stanza.SASLAuth: + return v.Value + } + return "" +} + +func sendBindFeature(t *testing.T, c net.Conn, _ *xml.Decoder) { + // This is a basic server, supporting only 1 stream feature after auth: resource binding + features := ` + +` + if _, err := fmt.Fprintln(c, features); err != nil { + t.Errorf("cannot send stream feature: %s", err) + } +} + +func sendRFC3921Feature(t *testing.T, c net.Conn, _ *xml.Decoder) { + // This is a basic server, supporting only 2 features after auth: resource & session binding + features := ` + + +` + if _, err := fmt.Fprintln(c, features); err != nil { + t.Errorf("cannot send stream feature: %s", err) + } +} + +func bind(t *testing.T, c net.Conn, decoder *xml.Decoder) { + se, err := stanza.NextStart(decoder) + if err != nil { + t.Errorf("cannot read bind: %s", err) + return + } + + iq := &stanza.IQ{} + // Decode element into pointer storage + if err = decoder.DecodeElement(&iq, &se); err != nil { + t.Errorf("cannot decode bind iq: %s", err) + return + } + + // TODO Check all elements + switch iq.Payload.(type) { + case *stanza.Bind: + result := ` + + %s + +` + fmt.Fprintf(c, result, iq.Id, "test@localhost/test") // TODO use real JID + } +} + +func session(t *testing.T, c net.Conn, decoder *xml.Decoder) { + se, err := stanza.NextStart(decoder) + if err != nil { + t.Errorf("cannot read session: %s", err) + return + } + + iq := &stanza.IQ{} + // Decode element into pointer storage + if err = decoder.DecodeElement(&iq, &se); err != nil { + t.Errorf("cannot decode session iq: %s", err) + return + } + + switch iq.Payload.(type) { + case *stanza.StreamSession: + result := `` + fmt.Fprintf(c, result, iq.Id) + } +}