diff --git a/_examples/xmpp_component/xmpp_component.go b/_examples/xmpp_component/xmpp_component.go index e36b287..0452888 100644 --- a/_examples/xmpp_component/xmpp_component.go +++ b/_examples/xmpp_component/xmpp_component.go @@ -58,7 +58,7 @@ func handleMessage(_ xmpp.Sender, p stanza.Packet) { func discoInfo(c xmpp.Sender, p stanza.Packet, opts xmpp.ComponentOptions) { // Type conversion & sanity checks iq, ok := p.(stanza.IQ) - if !ok || iq.Type != "get" { + if !ok || iq.Type != stanza.IQTypeGet { return } @@ -73,7 +73,7 @@ func discoInfo(c xmpp.Sender, p stanza.Packet, opts xmpp.ComponentOptions) { func discoItems(c xmpp.Sender, p stanza.Packet) { // Type conversion & sanity checks iq, ok := p.(stanza.IQ) - if !ok || iq.Type != "get" { + if !ok || iq.Type != stanza.IQTypeGet { return } diff --git a/_examples/xmpp_jukebox/xmpp_jukebox.go b/_examples/xmpp_jukebox/xmpp_jukebox.go index 10e5dfc..91f453c 100644 --- a/_examples/xmpp_jukebox/xmpp_jukebox.go +++ b/_examples/xmpp_jukebox/xmpp_jukebox.go @@ -106,7 +106,7 @@ func handleIQ(s xmpp.Sender, p stanza.Packet, player *mpg123.Player) { func sendUserTune(s xmpp.Sender, artist string, title string) { tune := stanza.Tune{Artist: artist, Title: title} - iq := stanza.NewIQ(stanza.Attrs{Type: "set", Id: "usertune-1", Lang: "en"}) + iq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeSet, Id: "usertune-1", Lang: "en"}) payload := stanza.PubSub{Publish: &stanza.Publish{Node: "http://jabber.org/protocol/tune", Item: stanza.Item{Tune: &tune}}} iq.Payload = &payload _ = s.Send(iq) diff --git a/auth.go b/auth.go index 726e15a..b8d20b9 100644 --- a/auth.go +++ b/auth.go @@ -60,7 +60,10 @@ func authPlain(socket io.ReadWriter, decoder *xml.Decoder, mech string, user str raw := "\x00" + user + "\x00" + secret enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw))) base64.StdEncoding.Encode(enc, []byte(raw)) - fmt.Fprintf(socket, "%s", stanza.NSSASL, mech, enc) + _, err := fmt.Fprintf(socket, "%s", stanza.NSSASL, mech, enc) + if err != nil { + return err + } // Next message should be either success or failure. val, err := stanza.NextPacket(decoder) diff --git a/cert_checker.go b/cert_checker.go index fcee7b1..30a265a 100644 --- a/cert_checker.go +++ b/cert_checker.go @@ -79,7 +79,10 @@ func (c *ServerCheck) Check() error { } if _, ok := f.DoesStartTLS(); ok { - fmt.Fprintf(tcpconn, "") + _, err = fmt.Fprintf(tcpconn, "") + if err != nil { + return err + } var k stanza.TLSProceed if err = decoder.DecodeElement(&k, nil); err != nil { diff --git a/client.go b/client.go index b7111f2..14537db 100644 --- a/client.go +++ b/client.go @@ -50,7 +50,7 @@ type SMState struct { // EventHandler is use to pass events about state of the connection to // client implementation. -type EventHandler func(Event) +type EventHandler func(Event) error type EventManager struct { // Store current state @@ -188,13 +188,16 @@ func (c *Client) Resume(state SMState) error { go keepalive(c.transport, keepaliveQuit) // Start the receiver go routine state = c.Session.SMState - go c.recv(state, keepaliveQuit) + // Leaving this channel here for later. Not used atm. We should return this instead of an error because right + // now the returned error is lost in limbo. + errChan := make(chan error) + go c.recv(state, keepaliveQuit, errChan) // We're connected and can now receive and send messages. //fmt.Fprintf(client.conn, "%s%s", "chat", "Online") // TODO: Do we always want to send initial presence automatically ? // Do we need an option to avoid that or do we rely on client to send the presence itself ? - fmt.Fprintf(c.transport, "") + _, err = fmt.Fprintf(c.transport, "") return err } @@ -235,7 +238,7 @@ func (c *Client) Send(packet stanza.Packet) error { // result := <- client.SendIQ(ctx, iq) // func (c *Client) SendIQ(ctx context.Context, iq stanza.IQ) (chan stanza.IQ, error) { - if iq.Attrs.Type != "set" && iq.Attrs.Type != "get" { + if iq.Attrs.Type != stanza.IQTypeSet && iq.Attrs.Type != stanza.IQTypeGet { return nil, ErrCanOnlySendGetOrSetIq } if err := c.Send(iq); err != nil { @@ -267,13 +270,14 @@ func (c *Client) sendWithWriter(writer io.Writer, packet []byte) error { // Go routines // Loop: Receive data from server -func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) (err error) { +func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}, errChan chan<- error) { for { val, err := stanza.NextPacket(c.transport.GetDecoder()) if err != nil { + errChan <- err close(keepaliveQuit) c.disconnected(state) - return err + return } // Handle stream errors @@ -282,18 +286,22 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) (err error) c.router.route(c, val) close(keepaliveQuit) c.streamError(packet.Error.Local, packet.Text) - return errors.New("stream error: " + packet.Error.Local) + errChan <- errors.New("stream error: " + packet.Error.Local) + return // Process Stream management nonzas case stanza.SMRequest: answer := stanza.SMAnswer{XMLName: xml.Name{ Space: stanza.NSStreamManagement, Local: "a", }, H: state.Inbound} - c.Send(answer) + err = c.Send(answer) + if err != nil { + errChan <- err + return + } default: state.Inbound++ } - // Do normal route processing in a go-routine so we can immediately // start receiving other stanzas. This also allows route handlers to // send and receive more stanzas. diff --git a/component.go b/component.go index cb468db..471f1db 100644 --- a/component.go +++ b/component.go @@ -72,11 +72,13 @@ func (c *Component) Resume(sm SMState) error { c.transport, err = NewComponentTransport(c.ComponentOptions.TransportConfiguration) if err != nil { c.updateState(StatePermanentError) + return NewConnError(err, true) } if streamId, err = c.transport.Connect(); err != nil { c.updateState(StatePermanentError) + return NewConnError(err, true) } c.updateState(StateConnected) @@ -84,6 +86,7 @@ func (c *Component) Resume(sm SMState) error { // Authentication if _, err := fmt.Fprintf(c.transport, "%s", c.handshake(streamId)); err != nil { c.updateState(StateStreamError) + return NewConnError(errors.New("cannot send handshake "+err.Error()), false) } @@ -101,12 +104,16 @@ func (c *Component) Resume(sm SMState) error { case stanza.Handshake: // Start the receiver go routine c.updateState(StateSessionEstablished) - go c.recv() - return nil + // Leaving this channel here for later. Not used atm. We should return this instead of an error because right + // now the returned error is lost in limbo. + errChan := make(chan error) + go c.recv(errChan) // Sends to errChan + return err // Should be empty at this point default: c.updateState(StatePermanentError) return NewConnError(errors.New("expecting handshake result, got "+v.Name()), true) } + return err } func (c *Component) Disconnect() { @@ -121,20 +128,22 @@ func (c *Component) SetHandler(handler EventHandler) { } // Receiver Go routine receiver -func (c *Component) recv() (err error) { +func (c *Component) recv(errChan chan<- error) { + for { val, err := stanza.NextPacket(c.transport.GetDecoder()) if err != nil { c.updateState(StateDisconnected) - return err + errChan <- err + return } - // Handle stream errors switch p := val.(type) { case stanza.StreamError: c.router.route(c, val) c.streamError(p.Error.Local, p.Text) - return errors.New("stream error: " + p.Error.Local) + errChan <- errors.New("stream error: " + p.Error.Local) + return } c.router.route(c, val) } @@ -168,7 +177,7 @@ func (c *Component) Send(packet stanza.Packet) error { // result := <- client.SendIQ(ctx, iq) // func (c *Component) SendIQ(ctx context.Context, iq stanza.IQ) (chan stanza.IQ, error) { - if iq.Attrs.Type != "set" && iq.Attrs.Type != "get" { + if iq.Attrs.Type != stanza.IQTypeSet && iq.Attrs.Type != stanza.IQTypeGet { return nil, ErrCanOnlySendGetOrSetIq } if err := c.Send(iq); err != nil { diff --git a/component_test.go b/component_test.go index 8938769..4e115f0 100644 --- a/component_test.go +++ b/component_test.go @@ -1,12 +1,34 @@ package xmpp import ( + "context" + "encoding/xml" + "errors" "fmt" + "gosrc.io/xmpp/stanza" + "net" + "strings" "testing" + "time" ) -const testComponentDomain = "localhost" -const testComponentPort = "15222" +// Tests are ran in parallel, so each test creating a server must use a different port so we do not get any +// conflict. Using iota for this should do the trick. +const ( + testComponentDomain = "localhost" + defaultServerName = "testServer" + defaultStreamID = "91bd0bba-012f-4d92-bb17-5fc41e6fe545" + defaultComponentName = "Test Component" + + // Default port is not standard XMPP port to avoid interfering + // with local running XMPP server + testHandshakePort = iota + 15222 + testDecoderPort + testSendIqPort + testSendRawPort + testDisconnectPort + testSManDisconnectPort +) func TestHandshake(t *testing.T) { opts := ComponentOptions{ @@ -24,25 +46,43 @@ func TestHandshake(t *testing.T) { } } +// Tests connection process with a handshake exchange +// Tests multiple session IDs. All connections should generate a unique stream ID func TestGenerateHandshake(t *testing.T) { - // TODO -} + // Using this array with a channel to make a queue of values to test + // These are stream IDs that will be used to test the connection process, mixing them with the "secret" to generate + // some handshake value + var uuidsArray = [5]string{ + "cc9b3249-9582-4780-825f-4311b42f9b0e", + "bba8be3c-d98e-4e26-b9bb-9ed34578a503", + "dae72822-80e8-496b-b763-ab685f53a188", + "a45d6c06-de49-4bb0-935b-1a2201b71028", + "7dc6924f-0eca-4237-9898-18654b8d891e", + } -// Test that NewStreamManager can accept a Component. -// -// This validates that Component conforms to StreamClient interface. -func TestStreamManager(t *testing.T) { - NewStreamManager(&Component{}, nil) -} + // Channel to pass stream IDs as a queue + var uchan = make(chan string, len(uuidsArray)) + // Populate test channel + for _, elt := range uuidsArray { + uchan <- elt + } -// Tests that the decoder is properly initialized when connecting a component to a server. -// The decoder is expected to be built after a valid connection -// Based on the xmpp_component example. -func TestDecoder(t *testing.T) { - testComponentAddess := fmt.Sprintf("%s:%s", testComponentDomain, testComponentPort) + // Performs a Component connection with a handshake. It expects to have an ID sent its way through the "uchan" + // channel of this file. Otherwise it will hang for ever. + h := func(t *testing.T, c net.Conn) { + decoder := xml.NewDecoder(c) + checkOpenStreamHandshakeID(t, c, decoder, <-uchan) + readHandshakeComponent(t, decoder) + fmt.Fprintln(c, "") // That's all the server needs to return (see xep-0114) + return + } + + // Init mock server + testComponentAddess := fmt.Sprintf("%s:%d", testComponentDomain, testHandshakePort) mock := ServerMock{} - mock.Start(t, testComponentAddess, handlerConnectSuccess) + mock.Start(t, testComponentAddess, h) + // Init component opts := ComponentOptions{ TransportConfiguration: TransportConfiguration{ Address: testComponentAddess, @@ -63,12 +103,352 @@ func TestDecoder(t *testing.T) { if err != nil { t.Errorf("%+v", err) } - _, err = c.transport.Connect() - if err != nil { - t.Errorf("%+v", err) + + // Try connecting, and storing the resulting streamID in a map. + m := make(map[string]bool) + for _, _ = range uuidsArray { + streamId, _ := c.transport.Connect() + m[c.handshake(streamId)] = true } + if len(uuidsArray) != len(m) { + t.Errorf("Handshake does not produce a unique id. Expected: %d unique ids, got: %d", len(uuidsArray), len(m)) + } +} + +// Test that NewStreamManager can accept a Component. +// +// This validates that Component conforms to StreamClient interface. +func TestStreamManager(t *testing.T) { + NewStreamManager(&Component{}, nil) +} + +// Tests that the decoder is properly initialized when connecting a component to a server. +// The decoder is expected to be built after a valid connection +// Based on the xmpp_component example. +func TestDecoder(t *testing.T) { + c, _ := mockConnection(t, testDecoderPort, handlerForComponentHandshakeDefaultID) if c.transport.GetDecoder() == nil { t.Errorf("Failed to initialize decoder. Decoder is nil.") } - +} + +// Tests sending an IQ to the server, and getting the response +func TestSendIq(t *testing.T) { + //Connecting to a mock server, initialized with given port and handler function + c, m := mockConnection(t, testSendIqPort, handlerForComponentIQSend) + + ctx, _ := context.WithTimeout(context.Background(), 30*time.Second) + iqReq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "test1@localhost/mremond-mbp", To: defaultServerName, Id: defaultStreamID, Lang: "en"}) + disco := iqReq.DiscoInfo() + iqReq.Payload = disco + + var res chan stanza.IQ + res, _ = c.SendIQ(ctx, iqReq) + + select { + case <-res: + case <-time.After(100 * time.Millisecond): + t.Errorf("Failed to receive response, to sent IQ, from mock server") + } + + m.Stop() +} + +// Tests sending raw xml to the mock server. +// TODO : check the server response client side ? +// Right now, the server response is not checked and an err is passed in a channel if the test is supposed to err. +// In this test, we use IQs +func TestSendRaw(t *testing.T) { + // Error channel for the handler + errChan := make(chan error) + // Handler for the mock server + h := func(t *testing.T, c net.Conn) { + // Completes the connection by exchanging handshakes + handlerForComponentHandshakeDefaultID(t, c) + receiveRawIq(t, c, errChan) + return + } + + type testCase struct { + req string + shouldErr bool + } + testRequests := make(map[string]testCase) + // Sending a correct IQ of type get. Not supposed to err + testRequests["Correct IQ"] = testCase{ + req: ``, + shouldErr: false, + } + // Sending an IQ with a missing ID. Should err + testRequests["IQ with missing ID"] = testCase{ + req: ``, + shouldErr: true, + } + + // Tests for all the IQs + for name, tcase := range testRequests { + t.Run(name, func(st *testing.T) { + //Connecting to a mock server, initialized with given port and handler function + c, m := mockConnection(t, testSendRawPort, h) + + // Sending raw xml from test case + err := c.SendRaw(tcase.req) + if err != nil { + t.Errorf("Error sending Raw string") + } + // Just wait a little so the message has time to arrive + select { + case <-time.After(100 * time.Millisecond): + case err = <-errChan: + if err == nil && tcase.shouldErr { + t.Errorf("Failed to get closing stream err") + } + } + c.transport.Close() + m.Stop() + }) + } +} + +// Tests the Disconnect method for Components +func TestDisconnect(t *testing.T) { + c, m := mockConnection(t, testDisconnectPort, handlerForComponentHandshakeDefaultID) + err := c.transport.Ping() + if err != nil { + t.Errorf("Could not ping but not disconnected yet") + } + c.Disconnect() + err = c.transport.Ping() + if err == nil { + t.Errorf("Did not disconnect properly") + } + m.Stop() +} + +// Tests that a streamManager successfully disconnects when a handshake fails between the component and the server. +func TestStreamManagerDisconnect(t *testing.T) { + // Init mock server + testComponentAddress := fmt.Sprintf("%s:%d", testComponentDomain, testSManDisconnectPort) + mock := ServerMock{} + // Handler fails the handshake, which is currently the only option to disconnect completely when using a streamManager + // a failed handshake being a permanent error, except for a "conflict" + mock.Start(t, testComponentAddress, handlerComponentFailedHandshakeDefaultID) + + //================================== + // Create Component to connect to it + c := makeBasicComponent(defaultComponentName, testComponentAddress, t) + + //======================================== + // Connect the new Component to the server + cm := NewStreamManager(c, nil) + errChan := make(chan error) + runSMan := func(errChan chan error) { + errChan <- cm.Run() + } + + go runSMan(errChan) + select { + case <-errChan: + case <-time.After(100 * time.Millisecond): + t.Errorf("The component and server seem to still be connected while they should not.") + } + mock.Stop() +} + +//============================================================================= +// Basic XMPP Server Mock Handlers. +// Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant. +// Used in the mock server as a Handler +func handlerForComponentHandshakeDefaultID(t *testing.T, c net.Conn) { + decoder := xml.NewDecoder(c) + checkOpenStreamHandshakeDefaultID(t, c, decoder) + readHandshakeComponent(t, decoder) + fmt.Fprintln(c, "") // That's all the server needs to return (see xep-0114) + return +} + +// Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant. +// This handler is supposed to fail by sending a "message" stanza instead of a stanza to finalize the handshake. +func handlerComponentFailedHandshakeDefaultID(t *testing.T, c net.Conn) { + decoder := xml.NewDecoder(c) + checkOpenStreamHandshakeDefaultID(t, c, decoder) + readHandshakeComponent(t, decoder) + + // Send a message, instead of a "" tag, to fail the handshake process dans disconnect the client. + me := stanza.Message{ + Attrs: stanza.Attrs{Type: stanza.MessageTypeChat, From: defaultServerName, To: defaultComponentName, Lang: "en"}, + Body: "Fail my handshake.", + } + s, _ := xml.Marshal(me) + fmt.Fprintln(c, string(s)) + + return +} + +// Reads from the connection with the Component. Expects a handshake request, and returns the tag. +func readHandshakeComponent(t *testing.T, decoder *xml.Decoder) { + se, err := stanza.NextStart(decoder) + if err != nil { + t.Errorf("cannot read auth: %s", err) + return + } + nv := &stanza.Handshake{} + // Decode element into pointer storage + if err = decoder.DecodeElement(nv, &se); err != nil { + t.Errorf("cannot decode handshake: %s", err) + return + } + if len(strings.TrimSpace(nv.Value)) == 0 { + t.Errorf("did not receive handshake ID") + } +} + +func checkOpenStreamHandshakeDefaultID(t *testing.T, c net.Conn, decoder *xml.Decoder) { + checkOpenStreamHandshakeID(t, c, decoder, defaultStreamID) +} + +// Used for ID and handshake related tests +func checkOpenStreamHandshakeID(t *testing.T, c net.Conn, decoder *xml.Decoder, streamID string) { + c.SetDeadline(time.Now().Add(defaultTimeout)) + defer c.SetDeadline(time.Time{}) + + for { // TODO clean up. That for loop is not elegant and I prefer bounded recursion. + token, err := decoder.Token() + if err != nil { + t.Errorf("cannot read next token: %s", err) + } + + switch elem := token.(type) { + // Wait for first startElement + case xml.StartElement: + if elem.Name.Space != stanza.NSStream || elem.Name.Local != "stream" { + err = errors.New("xmpp: expected but got <" + elem.Name.Local + "> in " + elem.Name.Space) + return + } + if _, err := fmt.Fprintf(c, serverStreamOpen, "localhost", streamID, stanza.NSComponent, stanza.NSStream); err != nil { + t.Errorf("cannot write server stream open: %s", err) + } + return + } + } +} + +//============================================================================= +// Sends IQ response to Component request. +// No parsing of the request here. We just check that it's valid, and send the default response. +func handlerForComponentIQSend(t *testing.T, c net.Conn) { + // Completes the connection by exchanging handshakes + handlerForComponentHandshakeDefaultID(t, c) + + // Decoder to parse the request + decoder := xml.NewDecoder(c) + + iqReq, err := receiveIq(t, c, decoder) + if err != nil { + t.Errorf("Error receiving the IQ stanza : %v", err) + } else if !iqReq.IsValid() { + t.Errorf("server received an IQ stanza : %v", iqReq) + } + + // Crafting response + iqResp := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeResult, From: iqReq.To, To: iqReq.From, Id: iqReq.Id, Lang: "en"}) + disco := iqResp.DiscoInfo() + disco.AddFeatures("vcard-temp", + `http://jabber.org/protocol/address`) + + disco.AddIdentity("Multicast", "service", "multicast") + iqResp.Payload = disco + + // Sending response to the Component + mResp, err := xml.Marshal(iqResp) + _, err = fmt.Fprintln(c, string(mResp)) + if err != nil { + t.Errorf("Could not send response stanza : %s", err) + } + return +} + +// Reads next request coming from the Component. Expecting it to be an IQ request +func receiveIq(t *testing.T, c net.Conn, decoder *xml.Decoder) (stanza.IQ, error) { + c.SetDeadline(time.Now().Add(defaultTimeout)) + defer c.SetDeadline(time.Time{}) + var iqStz stanza.IQ + err := decoder.Decode(&iqStz) + if err != nil { + t.Errorf("cannot read the received IQ stanza: %s", err) + } + if !iqStz.IsValid() { + t.Errorf("received IQ stanza is invalid : %s", err) + } + return iqStz, nil +} + +func receiveRawIq(t *testing.T, c net.Conn, errChan chan error) { + c.SetDeadline(time.Now().Add(defaultTimeout)) + defer c.SetDeadline(time.Time{}) + decoder := xml.NewDecoder(c) + var iq stanza.IQ + err := decoder.Decode(&iq) + if err != nil || !iq.IsValid() { + s := stanza.StreamError{ + XMLName: xml.Name{Local: "stream:error"}, + Error: xml.Name{Local: "xml-not-well-formed"}, + Text: `XML was not well-formed`, + } + raw, _ := xml.Marshal(s) + fmt.Fprintln(c, string(raw)) + fmt.Fprintln(c, ``) // TODO : check this client side + errChan <- fmt.Errorf("invalid xml") + return + } + errChan <- nil + return +} + +//=============================== +// Init mock server and connection +// Creating a mock server and connecting a Component to it. Initialized with given port and handler function +// The Component and mock are both returned +func mockConnection(t *testing.T, port int, handler func(t *testing.T, c net.Conn)) (*Component, *ServerMock) { + // Init mock server + testComponentAddress := fmt.Sprintf("%s:%d", testComponentDomain, port) + mock := ServerMock{} + mock.Start(t, testComponentAddress, handler) + + //================================== + // Create Component to connect to it + c := makeBasicComponent(defaultComponentName, testComponentAddress, t) + + //======================================== + // Connect the new Component to the server + err := c.Connect() + if err != nil { + t.Errorf("%+v", err) + } + + return c, &mock +} + +func makeBasicComponent(name string, mockServerAddr string, t *testing.T) *Component { + opts := ComponentOptions{ + TransportConfiguration: TransportConfiguration{ + Address: mockServerAddr, + Domain: "localhost", + }, + Domain: testComponentDomain, + Secret: "mypass", + Name: name, + Category: "gateway", + Type: "service", + } + router := NewRouter() + c, err := NewComponent(opts, router) + if err != nil { + t.Errorf("%+v", err) + } + c.transport, err = NewComponentTransport(c.ComponentOptions.TransportConfiguration) + if err != nil { + t.Errorf("%+v", err) + } + return c } diff --git a/network.go b/network.go index 75a0a60..8b03f3f 100644 --- a/network.go +++ b/network.go @@ -23,7 +23,7 @@ func ensurePort(addr string, port int) string { // This is IPV4 without port return addr + ":" + strconv.Itoa(port) case 1: - // This is IPV$ with port + // This is IPV6 with port return addr default: // This is IPV6 without port, as you need to use bracket with port in IPV6 diff --git a/network_test.go b/network_test.go index 116ecef..470f150 100644 --- a/network_test.go +++ b/network_test.go @@ -1,12 +1,10 @@ package xmpp import ( + "strings" "testing" ) -type params struct { -} - func TestParseAddr(t *testing.T) { tests := []struct { name string @@ -33,3 +31,36 @@ func TestParseAddr(t *testing.T) { }) } } + +func TestEnsurePort(t *testing.T) { + testAddresses := []string{ + "1ca3:6c07:ee3a:89ca:e065:9a70:71d:daad", + "1ca3:6c07:ee3a:89ca:e065:9a70:71d:daad:5252", + "[::1]", + "127.0.0.1:5555", + "127.0.0.1", + "[::1]:5555", + } + + for _, oldAddr := range testAddresses { + t.Run(oldAddr, func(st *testing.T) { + newAddr := ensurePort(oldAddr, 5222) + + if len(newAddr) < len(oldAddr) { + st.Errorf("incorrect Result: transformed address is shorter than input : %v (old) > %v (new)", newAddr, oldAddr) + } + // If IPv6, the new address needs brackets to specify a port, like so : [2001:db8:85a3:0:0:8a2e:370:7334]:5222 + if strings.Count(newAddr, "[") < strings.Count(oldAddr, "[") || + strings.Count(newAddr, "]") < strings.Count(oldAddr, "]") { + + st.Errorf("incorrect Result. Transformed address seems to not have correct brakets : %v => %v", oldAddr, newAddr) + } + + // Check if we messed up the colons, or didn't properly add a port + if strings.Count(newAddr, ":") < strings.Count(oldAddr, ":") { + st.Errorf("incorrect Result: transformed address doesn't seem to have a port %v (=> %v, no port ?)", oldAddr, newAddr) + } + }) + } + +} diff --git a/router_test.go b/router_test.go index b3d253e..2b5cf82 100644 --- a/router_test.go +++ b/router_test.go @@ -146,7 +146,7 @@ func TestTypeMatcher(t *testing.T) { // We do not match on other types conn = NewSenderMock() - iqVersion := stanza.NewIQ(stanza.Attrs{Type: "get", From: "service.localhost", To: "test@localhost", Id: "1"}) + iqVersion := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "service.localhost", To: "test@localhost", Id: "1"}) iqVersion.Payload = &stanza.DiscoInfo{ XMLName: xml.Name{ Space: "jabber:iq:version", @@ -163,27 +163,27 @@ func TestCompositeMatcher(t *testing.T) { router := NewRouter() router.NewRoute(). IQNamespaces("jabber:iq:version"). - StanzaType("get"). + StanzaType(string(stanza.IQTypeGet)). HandlerFunc(func(s Sender, p stanza.Packet) { _ = s.SendRaw(successFlag) }) // Data set - getVersionIq := stanza.NewIQ(stanza.Attrs{Type: "get", From: "service.localhost", To: "test@localhost", Id: "1"}) + getVersionIq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "service.localhost", To: "test@localhost", Id: "1"}) getVersionIq.Payload = &stanza.Version{ XMLName: xml.Name{ Space: "jabber:iq:version", Local: "query", }} - setVersionIq := stanza.NewIQ(stanza.Attrs{Type: "set", From: "service.localhost", To: "test@localhost", Id: "1"}) + setVersionIq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeSet, From: "service.localhost", To: "test@localhost", Id: "1"}) setVersionIq.Payload = &stanza.Version{ XMLName: xml.Name{ Space: "jabber:iq:version", Local: "query", }} - GetDiscoIq := stanza.NewIQ(stanza.Attrs{Type: "get", From: "service.localhost", To: "test@localhost", Id: "1"}) + GetDiscoIq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "service.localhost", To: "test@localhost", Id: "1"}) GetDiscoIq.Payload = &stanza.DiscoInfo{ XMLName: xml.Name{ Space: "http://jabber.org/protocol/disco#info", @@ -238,7 +238,7 @@ func TestCatchallMatcher(t *testing.T) { } conn = NewSenderMock() - iqVersion := stanza.NewIQ(stanza.Attrs{Type: "get", From: "service.localhost", To: "test@localhost", Id: "1"}) + iqVersion := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "service.localhost", To: "test@localhost", Id: "1"}) iqVersion.Payload = &stanza.DiscoInfo{ XMLName: xml.Name{ Space: "jabber:iq:version", diff --git a/stanza/component.go b/stanza/component.go index 33ced33..32a36b0 100644 --- a/stanza/component.go +++ b/stanza/component.go @@ -12,7 +12,7 @@ import ( type Handshake struct { XMLName xml.Name `xml:"jabber:component:accept handshake"` // TODO Add handshake value with test for proper serialization - // Value string `xml:",innerxml"` + Value string `xml:",innerxml"` } func (Handshake) Name() string { diff --git a/stanza/error.go b/stanza/error.go index bcc947f..0f416e4 100644 --- a/stanza/error.go +++ b/stanza/error.go @@ -54,7 +54,7 @@ func (x *Err) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { textName := xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-stanzas", Local: "text"} if elt.XMLName == textName { - x.Text = string(elt.Content) + x.Text = elt.Content } else if elt.XMLName.Space == "urn:ietf:params:xml:ns:xmpp-stanzas" { x.Reason = elt.XMLName.Local } @@ -94,16 +94,32 @@ func (x Err) MarshalXML(e *xml.Encoder, start xml.StartElement) (err error) { // Reason if x.Reason != "" { reason := xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-stanzas", Local: x.Reason} - e.EncodeToken(xml.StartElement{Name: reason}) - e.EncodeToken(xml.EndElement{Name: reason}) + err = e.EncodeToken(xml.StartElement{Name: reason}) + if err != nil { + return err + } + err = e.EncodeToken(xml.EndElement{Name: reason}) + if err != nil { + return err + } + } // Text if x.Text != "" { text := xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-stanzas", Local: "text"} - e.EncodeToken(xml.StartElement{Name: text}) - e.EncodeToken(xml.CharData(x.Text)) - e.EncodeToken(xml.EndElement{Name: text}) + err = e.EncodeToken(xml.StartElement{Name: text}) + if err != nil { + return err + } + err = e.EncodeToken(xml.CharData(x.Text)) + if err != nil { + return err + } + err = e.EncodeToken(xml.EndElement{Name: text}) + if err != nil { + return err + } } return e.EncodeToken(xml.EndElement{Name: start.Name}) diff --git a/stanza/iq.go b/stanza/iq.go index 923cf28..499c261 100644 --- a/stanza/iq.go +++ b/stanza/iq.go @@ -2,6 +2,7 @@ package stanza import ( "encoding/xml" + "strings" "github.com/google/uuid" ) @@ -23,7 +24,7 @@ type IQ struct { // Info/Query // child element, which specifies the semantics of the particular // request." Payload IQPayload `xml:",omitempty"` - Error Err `xml:"error,omitempty"` + Error *Err `xml:"error,omitempty"` // Any is used to decode unknown payload as a generic structure Any *Node `xml:",any"` } @@ -52,7 +53,7 @@ func (iq IQ) MakeError(xerror Err) IQ { iq.Type = "error" iq.From = to iq.To = from - iq.Error = xerror + iq.Error = &xerror return iq } @@ -106,7 +107,7 @@ func (iq *IQ) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { if err != nil { return err } - iq.Error = xmppError + iq.Error = &xmppError continue } if iqExt := TypeRegistry.GetIQExtension(tt.Name); iqExt != nil { @@ -132,3 +133,39 @@ func (iq *IQ) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { } } } + +// Following RFC-3920 for IQs +func (iq *IQ) IsValid() bool { + // ID is required + if len(strings.TrimSpace(iq.Id)) == 0 { + return false + } + + // Type is required + if iq.Type.IsEmpty() { + return false + } + + // Type get and set must contain one and only one child element that specifies the semantics + if iq.Type == IQTypeGet || iq.Type == IQTypeSet { + if iq.Payload == nil && iq.Any == nil { + return false + } + } + + // A result must include zero or one child element + if iq.Type == IQTypeResult { + if iq.Payload != nil && iq.Any != nil { + return false + } + } + + //Error type must contain an "error" child element + if iq.Type == IQTypeError { + if iq.Error == nil { + return false + } + } + + return true +} diff --git a/stanza/iq_test.go b/stanza/iq_test.go index 54a8fc5..3223566 100644 --- a/stanza/iq_test.go +++ b/stanza/iq_test.go @@ -187,3 +187,38 @@ func TestUnknownPayload(t *testing.T) { t.Errorf("could not extract namespace: '%s'", parsedIQ.Any.XMLName.Space) } } + +func TestIsValid(t *testing.T) { + type testCase struct { + iq string + shouldErr bool + } + testIQs := make(map[string]testCase) + testIQs["Valid IQ"] = testCase{ + ` + + `, + false, + } + testIQs["Invalid IQ"] = testCase{ + ` + + `, + true, + } + + for name, tcase := range testIQs { + t.Run(name, func(st *testing.T) { + parsedIQ := stanza.IQ{} + err := xml.Unmarshal([]byte(tcase.iq), &parsedIQ) + if err != nil { + t.Errorf("Unmarshal error: %#v (%s)", err, tcase.iq) + return + } + if !parsedIQ.IsValid() && !tcase.shouldErr { + t.Errorf("failed iq validation for : %s", tcase.iq) + } + }) + } + +} diff --git a/stanza/node.go b/stanza/node.go index 6afa7bc..308729c 100644 --- a/stanza/node.go +++ b/stanza/node.go @@ -46,9 +46,18 @@ func (n Node) MarshalXML(e *xml.Encoder, start xml.StartElement) (err error) { start.Name = n.XMLName err = e.EncodeToken(start) - e.EncodeElement(n.Nodes, xml.StartElement{Name: n.XMLName}) + if err != nil { + return err + } + err = e.EncodeElement(n.Nodes, xml.StartElement{Name: n.XMLName}) + if err != nil { + return err + } if n.Content != "" { - e.EncodeToken(xml.CharData(n.Content)) + err = e.EncodeToken(xml.CharData(n.Content)) + if err != nil { + return err + } } return e.EncodeToken(xml.EndElement{Name: start.Name}) } diff --git a/stanza/packet_enum.go b/stanza/packet_enum.go index 103966a..84dd476 100644 --- a/stanza/packet_enum.go +++ b/stanza/packet_enum.go @@ -1,5 +1,7 @@ package stanza +import "strings" + type StanzaType string // RFC 6120: part of A.5 Client Namespace and A.6 Server Namespace @@ -23,3 +25,7 @@ const ( PresenceTypeUnsubscribe StanzaType = "unsubscribe" PresenceTypeUnsubscribed StanzaType = "unsubscribed" ) + +func (s StanzaType) IsEmpty() bool { + return len(strings.TrimSpace(string(s))) == 0 +} diff --git a/stanza/sasl_auth.go b/stanza/sasl_auth.go index d04174f..29648ee 100644 --- a/stanza/sasl_auth.go +++ b/stanza/sasl_auth.go @@ -107,6 +107,6 @@ func (s *StreamSession) IsOptional() bool { // Registry init func init() { - TypeRegistry.MapExtension(PKTIQ, xml.Name{"urn:ietf:params:xml:ns:xmpp-bind", "bind"}, Bind{}) - TypeRegistry.MapExtension(PKTIQ, xml.Name{"urn:ietf:params:xml:ns:xmpp-session", "session"}, StreamSession{}) + TypeRegistry.MapExtension(PKTIQ, xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-bind", Local: "bind"}, Bind{}) + TypeRegistry.MapExtension(PKTIQ, xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-session", Local: "session"}, StreamSession{}) } diff --git a/stanza/stream.go b/stanza/stream.go index 290abfe..203cc83 100644 --- a/stanza/stream.go +++ b/stanza/stream.go @@ -8,7 +8,7 @@ import "encoding/xml" type Stream struct { XMLName xml.Name `xml:"http://etherx.jabber.org/streams stream"` From string `xml:"from,attr"` - To string `xml:"to,attr"` + To string `xml:"to,attr"` Id string `xml:"id,attr"` Version string `xml:"version,attr"` } diff --git a/stanza/stream_features.go b/stanza/stream_features.go index 11cd96b..14358f0 100644 --- a/stanza/stream_features.go +++ b/stanza/stream_features.go @@ -15,7 +15,7 @@ type StreamFeatures struct { // Server capabilities hash Caps Caps // Stream features - StartTLS tlsStartTLS + StartTLS TlsStartTLS Mechanisms saslMechanisms Bind Bind StreamManagement streamManagement @@ -60,13 +60,13 @@ type Caps struct { // StartTLS feature // Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-5.4 -type tlsStartTLS struct { +type TlsStartTLS struct { XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"` Required bool } // UnmarshalXML implements custom parsing startTLS required flag -func (stls *tlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { +func (stls *TlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { stls.XMLName = start.Name // Check subelements to extract required field as boolean @@ -98,7 +98,7 @@ func (stls *tlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) er } } -func (sf *StreamFeatures) DoesStartTLS() (feature tlsStartTLS, isSupported bool) { +func (sf *StreamFeatures) DoesStartTLS() (feature TlsStartTLS, isSupported bool) { if sf.StartTLS.XMLName.Space+" "+sf.StartTLS.XMLName.Local == nsTLS+" starttls" { return sf.StartTLS, true } diff --git a/stream_manager.go b/stream_manager.go index bf7fba8..aebd8a4 100644 --- a/stream_manager.go +++ b/stream_manager.go @@ -74,7 +74,7 @@ func (sm *StreamManager) Run() error { return errors.New("missing stream client") } - handler := func(e Event) { + handler := func(e Event) error { switch e.State { case StateConnected: sm.Metrics.setConnectTime() @@ -82,17 +82,18 @@ func (sm *StreamManager) Run() error { sm.Metrics.setLoginTime() case StateDisconnected: // Reconnect on disconnection - sm.resume(e.SMState) + return sm.resume(e.SMState) case StateStreamError: sm.client.Disconnect() // Only try reconnecting if we have not been kicked by another session to avoid connection loop. // TODO: Make this conflict exception a permanent error if e.StreamError != "conflict" { - sm.connect() + return sm.connect() } case StatePermanentError: // Do not attempt to reconnect } + return nil } sm.client.SetHandler(handler) diff --git a/test.sh b/test.sh index 9730026..725dcaf 100755 --- a/test.sh +++ b/test.sh @@ -5,7 +5,7 @@ export GO111MODULE=on echo "" > coverage.txt for d in $(go list ./... | grep -v vendor); do - go test -race -coverprofile=profile.out -covermode=atomic ${d} + go test -race -coverprofile=profile.out -covermode=atomic "${d}" if [ -f profile.out ]; then cat profile.out >> coverage.txt rm profile.out