diff --git a/_examples/xmpp_component/xmpp_component.go b/_examples/xmpp_component/xmpp_component.go
index e36b287..0452888 100644
--- a/_examples/xmpp_component/xmpp_component.go
+++ b/_examples/xmpp_component/xmpp_component.go
@@ -58,7 +58,7 @@ func handleMessage(_ xmpp.Sender, p stanza.Packet) {
func discoInfo(c xmpp.Sender, p stanza.Packet, opts xmpp.ComponentOptions) {
// Type conversion & sanity checks
iq, ok := p.(stanza.IQ)
- if !ok || iq.Type != "get" {
+ if !ok || iq.Type != stanza.IQTypeGet {
return
}
@@ -73,7 +73,7 @@ func discoInfo(c xmpp.Sender, p stanza.Packet, opts xmpp.ComponentOptions) {
func discoItems(c xmpp.Sender, p stanza.Packet) {
// Type conversion & sanity checks
iq, ok := p.(stanza.IQ)
- if !ok || iq.Type != "get" {
+ if !ok || iq.Type != stanza.IQTypeGet {
return
}
diff --git a/_examples/xmpp_jukebox/xmpp_jukebox.go b/_examples/xmpp_jukebox/xmpp_jukebox.go
index 10e5dfc..91f453c 100644
--- a/_examples/xmpp_jukebox/xmpp_jukebox.go
+++ b/_examples/xmpp_jukebox/xmpp_jukebox.go
@@ -106,7 +106,7 @@ func handleIQ(s xmpp.Sender, p stanza.Packet, player *mpg123.Player) {
func sendUserTune(s xmpp.Sender, artist string, title string) {
tune := stanza.Tune{Artist: artist, Title: title}
- iq := stanza.NewIQ(stanza.Attrs{Type: "set", Id: "usertune-1", Lang: "en"})
+ iq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeSet, Id: "usertune-1", Lang: "en"})
payload := stanza.PubSub{Publish: &stanza.Publish{Node: "http://jabber.org/protocol/tune", Item: stanza.Item{Tune: &tune}}}
iq.Payload = &payload
_ = s.Send(iq)
diff --git a/auth.go b/auth.go
index 726e15a..b8d20b9 100644
--- a/auth.go
+++ b/auth.go
@@ -60,7 +60,10 @@ func authPlain(socket io.ReadWriter, decoder *xml.Decoder, mech string, user str
raw := "\x00" + user + "\x00" + secret
enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw)))
base64.StdEncoding.Encode(enc, []byte(raw))
- fmt.Fprintf(socket, "%s", stanza.NSSASL, mech, enc)
+ _, err := fmt.Fprintf(socket, "%s", stanza.NSSASL, mech, enc)
+ if err != nil {
+ return err
+ }
// Next message should be either success or failure.
val, err := stanza.NextPacket(decoder)
diff --git a/cert_checker.go b/cert_checker.go
index fcee7b1..30a265a 100644
--- a/cert_checker.go
+++ b/cert_checker.go
@@ -79,7 +79,10 @@ func (c *ServerCheck) Check() error {
}
if _, ok := f.DoesStartTLS(); ok {
- fmt.Fprintf(tcpconn, "")
+ _, err = fmt.Fprintf(tcpconn, "")
+ if err != nil {
+ return err
+ }
var k stanza.TLSProceed
if err = decoder.DecodeElement(&k, nil); err != nil {
diff --git a/client.go b/client.go
index b7111f2..14537db 100644
--- a/client.go
+++ b/client.go
@@ -50,7 +50,7 @@ type SMState struct {
// EventHandler is use to pass events about state of the connection to
// client implementation.
-type EventHandler func(Event)
+type EventHandler func(Event) error
type EventManager struct {
// Store current state
@@ -188,13 +188,16 @@ func (c *Client) Resume(state SMState) error {
go keepalive(c.transport, keepaliveQuit)
// Start the receiver go routine
state = c.Session.SMState
- go c.recv(state, keepaliveQuit)
+ // Leaving this channel here for later. Not used atm. We should return this instead of an error because right
+ // now the returned error is lost in limbo.
+ errChan := make(chan error)
+ go c.recv(state, keepaliveQuit, errChan)
// We're connected and can now receive and send messages.
//fmt.Fprintf(client.conn, "%s%s", "chat", "Online")
// TODO: Do we always want to send initial presence automatically ?
// Do we need an option to avoid that or do we rely on client to send the presence itself ?
- fmt.Fprintf(c.transport, "")
+ _, err = fmt.Fprintf(c.transport, "")
return err
}
@@ -235,7 +238,7 @@ func (c *Client) Send(packet stanza.Packet) error {
// result := <- client.SendIQ(ctx, iq)
//
func (c *Client) SendIQ(ctx context.Context, iq stanza.IQ) (chan stanza.IQ, error) {
- if iq.Attrs.Type != "set" && iq.Attrs.Type != "get" {
+ if iq.Attrs.Type != stanza.IQTypeSet && iq.Attrs.Type != stanza.IQTypeGet {
return nil, ErrCanOnlySendGetOrSetIq
}
if err := c.Send(iq); err != nil {
@@ -267,13 +270,14 @@ func (c *Client) sendWithWriter(writer io.Writer, packet []byte) error {
// Go routines
// Loop: Receive data from server
-func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) (err error) {
+func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}, errChan chan<- error) {
for {
val, err := stanza.NextPacket(c.transport.GetDecoder())
if err != nil {
+ errChan <- err
close(keepaliveQuit)
c.disconnected(state)
- return err
+ return
}
// Handle stream errors
@@ -282,18 +286,22 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) (err error)
c.router.route(c, val)
close(keepaliveQuit)
c.streamError(packet.Error.Local, packet.Text)
- return errors.New("stream error: " + packet.Error.Local)
+ errChan <- errors.New("stream error: " + packet.Error.Local)
+ return
// Process Stream management nonzas
case stanza.SMRequest:
answer := stanza.SMAnswer{XMLName: xml.Name{
Space: stanza.NSStreamManagement,
Local: "a",
}, H: state.Inbound}
- c.Send(answer)
+ err = c.Send(answer)
+ if err != nil {
+ errChan <- err
+ return
+ }
default:
state.Inbound++
}
-
// Do normal route processing in a go-routine so we can immediately
// start receiving other stanzas. This also allows route handlers to
// send and receive more stanzas.
diff --git a/component.go b/component.go
index cb468db..471f1db 100644
--- a/component.go
+++ b/component.go
@@ -72,11 +72,13 @@ func (c *Component) Resume(sm SMState) error {
c.transport, err = NewComponentTransport(c.ComponentOptions.TransportConfiguration)
if err != nil {
c.updateState(StatePermanentError)
+
return NewConnError(err, true)
}
if streamId, err = c.transport.Connect(); err != nil {
c.updateState(StatePermanentError)
+
return NewConnError(err, true)
}
c.updateState(StateConnected)
@@ -84,6 +86,7 @@ func (c *Component) Resume(sm SMState) error {
// Authentication
if _, err := fmt.Fprintf(c.transport, "%s", c.handshake(streamId)); err != nil {
c.updateState(StateStreamError)
+
return NewConnError(errors.New("cannot send handshake "+err.Error()), false)
}
@@ -101,12 +104,16 @@ func (c *Component) Resume(sm SMState) error {
case stanza.Handshake:
// Start the receiver go routine
c.updateState(StateSessionEstablished)
- go c.recv()
- return nil
+ // Leaving this channel here for later. Not used atm. We should return this instead of an error because right
+ // now the returned error is lost in limbo.
+ errChan := make(chan error)
+ go c.recv(errChan) // Sends to errChan
+ return err // Should be empty at this point
default:
c.updateState(StatePermanentError)
return NewConnError(errors.New("expecting handshake result, got "+v.Name()), true)
}
+ return err
}
func (c *Component) Disconnect() {
@@ -121,20 +128,22 @@ func (c *Component) SetHandler(handler EventHandler) {
}
// Receiver Go routine receiver
-func (c *Component) recv() (err error) {
+func (c *Component) recv(errChan chan<- error) {
+
for {
val, err := stanza.NextPacket(c.transport.GetDecoder())
if err != nil {
c.updateState(StateDisconnected)
- return err
+ errChan <- err
+ return
}
-
// Handle stream errors
switch p := val.(type) {
case stanza.StreamError:
c.router.route(c, val)
c.streamError(p.Error.Local, p.Text)
- return errors.New("stream error: " + p.Error.Local)
+ errChan <- errors.New("stream error: " + p.Error.Local)
+ return
}
c.router.route(c, val)
}
@@ -168,7 +177,7 @@ func (c *Component) Send(packet stanza.Packet) error {
// result := <- client.SendIQ(ctx, iq)
//
func (c *Component) SendIQ(ctx context.Context, iq stanza.IQ) (chan stanza.IQ, error) {
- if iq.Attrs.Type != "set" && iq.Attrs.Type != "get" {
+ if iq.Attrs.Type != stanza.IQTypeSet && iq.Attrs.Type != stanza.IQTypeGet {
return nil, ErrCanOnlySendGetOrSetIq
}
if err := c.Send(iq); err != nil {
diff --git a/component_test.go b/component_test.go
index 8938769..4e115f0 100644
--- a/component_test.go
+++ b/component_test.go
@@ -1,12 +1,34 @@
package xmpp
import (
+ "context"
+ "encoding/xml"
+ "errors"
"fmt"
+ "gosrc.io/xmpp/stanza"
+ "net"
+ "strings"
"testing"
+ "time"
)
-const testComponentDomain = "localhost"
-const testComponentPort = "15222"
+// Tests are ran in parallel, so each test creating a server must use a different port so we do not get any
+// conflict. Using iota for this should do the trick.
+const (
+ testComponentDomain = "localhost"
+ defaultServerName = "testServer"
+ defaultStreamID = "91bd0bba-012f-4d92-bb17-5fc41e6fe545"
+ defaultComponentName = "Test Component"
+
+ // Default port is not standard XMPP port to avoid interfering
+ // with local running XMPP server
+ testHandshakePort = iota + 15222
+ testDecoderPort
+ testSendIqPort
+ testSendRawPort
+ testDisconnectPort
+ testSManDisconnectPort
+)
func TestHandshake(t *testing.T) {
opts := ComponentOptions{
@@ -24,25 +46,43 @@ func TestHandshake(t *testing.T) {
}
}
+// Tests connection process with a handshake exchange
+// Tests multiple session IDs. All connections should generate a unique stream ID
func TestGenerateHandshake(t *testing.T) {
- // TODO
-}
+ // Using this array with a channel to make a queue of values to test
+ // These are stream IDs that will be used to test the connection process, mixing them with the "secret" to generate
+ // some handshake value
+ var uuidsArray = [5]string{
+ "cc9b3249-9582-4780-825f-4311b42f9b0e",
+ "bba8be3c-d98e-4e26-b9bb-9ed34578a503",
+ "dae72822-80e8-496b-b763-ab685f53a188",
+ "a45d6c06-de49-4bb0-935b-1a2201b71028",
+ "7dc6924f-0eca-4237-9898-18654b8d891e",
+ }
-// Test that NewStreamManager can accept a Component.
-//
-// This validates that Component conforms to StreamClient interface.
-func TestStreamManager(t *testing.T) {
- NewStreamManager(&Component{}, nil)
-}
+ // Channel to pass stream IDs as a queue
+ var uchan = make(chan string, len(uuidsArray))
+ // Populate test channel
+ for _, elt := range uuidsArray {
+ uchan <- elt
+ }
-// Tests that the decoder is properly initialized when connecting a component to a server.
-// The decoder is expected to be built after a valid connection
-// Based on the xmpp_component example.
-func TestDecoder(t *testing.T) {
- testComponentAddess := fmt.Sprintf("%s:%s", testComponentDomain, testComponentPort)
+ // Performs a Component connection with a handshake. It expects to have an ID sent its way through the "uchan"
+ // channel of this file. Otherwise it will hang for ever.
+ h := func(t *testing.T, c net.Conn) {
+ decoder := xml.NewDecoder(c)
+ checkOpenStreamHandshakeID(t, c, decoder, <-uchan)
+ readHandshakeComponent(t, decoder)
+ fmt.Fprintln(c, "") // That's all the server needs to return (see xep-0114)
+ return
+ }
+
+ // Init mock server
+ testComponentAddess := fmt.Sprintf("%s:%d", testComponentDomain, testHandshakePort)
mock := ServerMock{}
- mock.Start(t, testComponentAddess, handlerConnectSuccess)
+ mock.Start(t, testComponentAddess, h)
+ // Init component
opts := ComponentOptions{
TransportConfiguration: TransportConfiguration{
Address: testComponentAddess,
@@ -63,12 +103,352 @@ func TestDecoder(t *testing.T) {
if err != nil {
t.Errorf("%+v", err)
}
- _, err = c.transport.Connect()
- if err != nil {
- t.Errorf("%+v", err)
+
+ // Try connecting, and storing the resulting streamID in a map.
+ m := make(map[string]bool)
+ for _, _ = range uuidsArray {
+ streamId, _ := c.transport.Connect()
+ m[c.handshake(streamId)] = true
}
+ if len(uuidsArray) != len(m) {
+ t.Errorf("Handshake does not produce a unique id. Expected: %d unique ids, got: %d", len(uuidsArray), len(m))
+ }
+}
+
+// Test that NewStreamManager can accept a Component.
+//
+// This validates that Component conforms to StreamClient interface.
+func TestStreamManager(t *testing.T) {
+ NewStreamManager(&Component{}, nil)
+}
+
+// Tests that the decoder is properly initialized when connecting a component to a server.
+// The decoder is expected to be built after a valid connection
+// Based on the xmpp_component example.
+func TestDecoder(t *testing.T) {
+ c, _ := mockConnection(t, testDecoderPort, handlerForComponentHandshakeDefaultID)
if c.transport.GetDecoder() == nil {
t.Errorf("Failed to initialize decoder. Decoder is nil.")
}
-
+}
+
+// Tests sending an IQ to the server, and getting the response
+func TestSendIq(t *testing.T) {
+ //Connecting to a mock server, initialized with given port and handler function
+ c, m := mockConnection(t, testSendIqPort, handlerForComponentIQSend)
+
+ ctx, _ := context.WithTimeout(context.Background(), 30*time.Second)
+ iqReq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "test1@localhost/mremond-mbp", To: defaultServerName, Id: defaultStreamID, Lang: "en"})
+ disco := iqReq.DiscoInfo()
+ iqReq.Payload = disco
+
+ var res chan stanza.IQ
+ res, _ = c.SendIQ(ctx, iqReq)
+
+ select {
+ case <-res:
+ case <-time.After(100 * time.Millisecond):
+ t.Errorf("Failed to receive response, to sent IQ, from mock server")
+ }
+
+ m.Stop()
+}
+
+// Tests sending raw xml to the mock server.
+// TODO : check the server response client side ?
+// Right now, the server response is not checked and an err is passed in a channel if the test is supposed to err.
+// In this test, we use IQs
+func TestSendRaw(t *testing.T) {
+ // Error channel for the handler
+ errChan := make(chan error)
+ // Handler for the mock server
+ h := func(t *testing.T, c net.Conn) {
+ // Completes the connection by exchanging handshakes
+ handlerForComponentHandshakeDefaultID(t, c)
+ receiveRawIq(t, c, errChan)
+ return
+ }
+
+ type testCase struct {
+ req string
+ shouldErr bool
+ }
+ testRequests := make(map[string]testCase)
+ // Sending a correct IQ of type get. Not supposed to err
+ testRequests["Correct IQ"] = testCase{
+ req: ``,
+ shouldErr: false,
+ }
+ // Sending an IQ with a missing ID. Should err
+ testRequests["IQ with missing ID"] = testCase{
+ req: ``,
+ shouldErr: true,
+ }
+
+ // Tests for all the IQs
+ for name, tcase := range testRequests {
+ t.Run(name, func(st *testing.T) {
+ //Connecting to a mock server, initialized with given port and handler function
+ c, m := mockConnection(t, testSendRawPort, h)
+
+ // Sending raw xml from test case
+ err := c.SendRaw(tcase.req)
+ if err != nil {
+ t.Errorf("Error sending Raw string")
+ }
+ // Just wait a little so the message has time to arrive
+ select {
+ case <-time.After(100 * time.Millisecond):
+ case err = <-errChan:
+ if err == nil && tcase.shouldErr {
+ t.Errorf("Failed to get closing stream err")
+ }
+ }
+ c.transport.Close()
+ m.Stop()
+ })
+ }
+}
+
+// Tests the Disconnect method for Components
+func TestDisconnect(t *testing.T) {
+ c, m := mockConnection(t, testDisconnectPort, handlerForComponentHandshakeDefaultID)
+ err := c.transport.Ping()
+ if err != nil {
+ t.Errorf("Could not ping but not disconnected yet")
+ }
+ c.Disconnect()
+ err = c.transport.Ping()
+ if err == nil {
+ t.Errorf("Did not disconnect properly")
+ }
+ m.Stop()
+}
+
+// Tests that a streamManager successfully disconnects when a handshake fails between the component and the server.
+func TestStreamManagerDisconnect(t *testing.T) {
+ // Init mock server
+ testComponentAddress := fmt.Sprintf("%s:%d", testComponentDomain, testSManDisconnectPort)
+ mock := ServerMock{}
+ // Handler fails the handshake, which is currently the only option to disconnect completely when using a streamManager
+ // a failed handshake being a permanent error, except for a "conflict"
+ mock.Start(t, testComponentAddress, handlerComponentFailedHandshakeDefaultID)
+
+ //==================================
+ // Create Component to connect to it
+ c := makeBasicComponent(defaultComponentName, testComponentAddress, t)
+
+ //========================================
+ // Connect the new Component to the server
+ cm := NewStreamManager(c, nil)
+ errChan := make(chan error)
+ runSMan := func(errChan chan error) {
+ errChan <- cm.Run()
+ }
+
+ go runSMan(errChan)
+ select {
+ case <-errChan:
+ case <-time.After(100 * time.Millisecond):
+ t.Errorf("The component and server seem to still be connected while they should not.")
+ }
+ mock.Stop()
+}
+
+//=============================================================================
+// Basic XMPP Server Mock Handlers.
+// Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant.
+// Used in the mock server as a Handler
+func handlerForComponentHandshakeDefaultID(t *testing.T, c net.Conn) {
+ decoder := xml.NewDecoder(c)
+ checkOpenStreamHandshakeDefaultID(t, c, decoder)
+ readHandshakeComponent(t, decoder)
+ fmt.Fprintln(c, "") // That's all the server needs to return (see xep-0114)
+ return
+}
+
+// Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant.
+// This handler is supposed to fail by sending a "message" stanza instead of a stanza to finalize the handshake.
+func handlerComponentFailedHandshakeDefaultID(t *testing.T, c net.Conn) {
+ decoder := xml.NewDecoder(c)
+ checkOpenStreamHandshakeDefaultID(t, c, decoder)
+ readHandshakeComponent(t, decoder)
+
+ // Send a message, instead of a "" tag, to fail the handshake process dans disconnect the client.
+ me := stanza.Message{
+ Attrs: stanza.Attrs{Type: stanza.MessageTypeChat, From: defaultServerName, To: defaultComponentName, Lang: "en"},
+ Body: "Fail my handshake.",
+ }
+ s, _ := xml.Marshal(me)
+ fmt.Fprintln(c, string(s))
+
+ return
+}
+
+// Reads from the connection with the Component. Expects a handshake request, and returns the tag.
+func readHandshakeComponent(t *testing.T, decoder *xml.Decoder) {
+ se, err := stanza.NextStart(decoder)
+ if err != nil {
+ t.Errorf("cannot read auth: %s", err)
+ return
+ }
+ nv := &stanza.Handshake{}
+ // Decode element into pointer storage
+ if err = decoder.DecodeElement(nv, &se); err != nil {
+ t.Errorf("cannot decode handshake: %s", err)
+ return
+ }
+ if len(strings.TrimSpace(nv.Value)) == 0 {
+ t.Errorf("did not receive handshake ID")
+ }
+}
+
+func checkOpenStreamHandshakeDefaultID(t *testing.T, c net.Conn, decoder *xml.Decoder) {
+ checkOpenStreamHandshakeID(t, c, decoder, defaultStreamID)
+}
+
+// Used for ID and handshake related tests
+func checkOpenStreamHandshakeID(t *testing.T, c net.Conn, decoder *xml.Decoder, streamID string) {
+ c.SetDeadline(time.Now().Add(defaultTimeout))
+ defer c.SetDeadline(time.Time{})
+
+ for { // TODO clean up. That for loop is not elegant and I prefer bounded recursion.
+ token, err := decoder.Token()
+ if err != nil {
+ t.Errorf("cannot read next token: %s", err)
+ }
+
+ switch elem := token.(type) {
+ // Wait for first startElement
+ case xml.StartElement:
+ if elem.Name.Space != stanza.NSStream || elem.Name.Local != "stream" {
+ err = errors.New("xmpp: expected but got <" + elem.Name.Local + "> in " + elem.Name.Space)
+ return
+ }
+ if _, err := fmt.Fprintf(c, serverStreamOpen, "localhost", streamID, stanza.NSComponent, stanza.NSStream); err != nil {
+ t.Errorf("cannot write server stream open: %s", err)
+ }
+ return
+ }
+ }
+}
+
+//=============================================================================
+// Sends IQ response to Component request.
+// No parsing of the request here. We just check that it's valid, and send the default response.
+func handlerForComponentIQSend(t *testing.T, c net.Conn) {
+ // Completes the connection by exchanging handshakes
+ handlerForComponentHandshakeDefaultID(t, c)
+
+ // Decoder to parse the request
+ decoder := xml.NewDecoder(c)
+
+ iqReq, err := receiveIq(t, c, decoder)
+ if err != nil {
+ t.Errorf("Error receiving the IQ stanza : %v", err)
+ } else if !iqReq.IsValid() {
+ t.Errorf("server received an IQ stanza : %v", iqReq)
+ }
+
+ // Crafting response
+ iqResp := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeResult, From: iqReq.To, To: iqReq.From, Id: iqReq.Id, Lang: "en"})
+ disco := iqResp.DiscoInfo()
+ disco.AddFeatures("vcard-temp",
+ `http://jabber.org/protocol/address`)
+
+ disco.AddIdentity("Multicast", "service", "multicast")
+ iqResp.Payload = disco
+
+ // Sending response to the Component
+ mResp, err := xml.Marshal(iqResp)
+ _, err = fmt.Fprintln(c, string(mResp))
+ if err != nil {
+ t.Errorf("Could not send response stanza : %s", err)
+ }
+ return
+}
+
+// Reads next request coming from the Component. Expecting it to be an IQ request
+func receiveIq(t *testing.T, c net.Conn, decoder *xml.Decoder) (stanza.IQ, error) {
+ c.SetDeadline(time.Now().Add(defaultTimeout))
+ defer c.SetDeadline(time.Time{})
+ var iqStz stanza.IQ
+ err := decoder.Decode(&iqStz)
+ if err != nil {
+ t.Errorf("cannot read the received IQ stanza: %s", err)
+ }
+ if !iqStz.IsValid() {
+ t.Errorf("received IQ stanza is invalid : %s", err)
+ }
+ return iqStz, nil
+}
+
+func receiveRawIq(t *testing.T, c net.Conn, errChan chan error) {
+ c.SetDeadline(time.Now().Add(defaultTimeout))
+ defer c.SetDeadline(time.Time{})
+ decoder := xml.NewDecoder(c)
+ var iq stanza.IQ
+ err := decoder.Decode(&iq)
+ if err != nil || !iq.IsValid() {
+ s := stanza.StreamError{
+ XMLName: xml.Name{Local: "stream:error"},
+ Error: xml.Name{Local: "xml-not-well-formed"},
+ Text: `XML was not well-formed`,
+ }
+ raw, _ := xml.Marshal(s)
+ fmt.Fprintln(c, string(raw))
+ fmt.Fprintln(c, ``) // TODO : check this client side
+ errChan <- fmt.Errorf("invalid xml")
+ return
+ }
+ errChan <- nil
+ return
+}
+
+//===============================
+// Init mock server and connection
+// Creating a mock server and connecting a Component to it. Initialized with given port and handler function
+// The Component and mock are both returned
+func mockConnection(t *testing.T, port int, handler func(t *testing.T, c net.Conn)) (*Component, *ServerMock) {
+ // Init mock server
+ testComponentAddress := fmt.Sprintf("%s:%d", testComponentDomain, port)
+ mock := ServerMock{}
+ mock.Start(t, testComponentAddress, handler)
+
+ //==================================
+ // Create Component to connect to it
+ c := makeBasicComponent(defaultComponentName, testComponentAddress, t)
+
+ //========================================
+ // Connect the new Component to the server
+ err := c.Connect()
+ if err != nil {
+ t.Errorf("%+v", err)
+ }
+
+ return c, &mock
+}
+
+func makeBasicComponent(name string, mockServerAddr string, t *testing.T) *Component {
+ opts := ComponentOptions{
+ TransportConfiguration: TransportConfiguration{
+ Address: mockServerAddr,
+ Domain: "localhost",
+ },
+ Domain: testComponentDomain,
+ Secret: "mypass",
+ Name: name,
+ Category: "gateway",
+ Type: "service",
+ }
+ router := NewRouter()
+ c, err := NewComponent(opts, router)
+ if err != nil {
+ t.Errorf("%+v", err)
+ }
+ c.transport, err = NewComponentTransport(c.ComponentOptions.TransportConfiguration)
+ if err != nil {
+ t.Errorf("%+v", err)
+ }
+ return c
}
diff --git a/network.go b/network.go
index 75a0a60..8b03f3f 100644
--- a/network.go
+++ b/network.go
@@ -23,7 +23,7 @@ func ensurePort(addr string, port int) string {
// This is IPV4 without port
return addr + ":" + strconv.Itoa(port)
case 1:
- // This is IPV$ with port
+ // This is IPV6 with port
return addr
default:
// This is IPV6 without port, as you need to use bracket with port in IPV6
diff --git a/network_test.go b/network_test.go
index 116ecef..470f150 100644
--- a/network_test.go
+++ b/network_test.go
@@ -1,12 +1,10 @@
package xmpp
import (
+ "strings"
"testing"
)
-type params struct {
-}
-
func TestParseAddr(t *testing.T) {
tests := []struct {
name string
@@ -33,3 +31,36 @@ func TestParseAddr(t *testing.T) {
})
}
}
+
+func TestEnsurePort(t *testing.T) {
+ testAddresses := []string{
+ "1ca3:6c07:ee3a:89ca:e065:9a70:71d:daad",
+ "1ca3:6c07:ee3a:89ca:e065:9a70:71d:daad:5252",
+ "[::1]",
+ "127.0.0.1:5555",
+ "127.0.0.1",
+ "[::1]:5555",
+ }
+
+ for _, oldAddr := range testAddresses {
+ t.Run(oldAddr, func(st *testing.T) {
+ newAddr := ensurePort(oldAddr, 5222)
+
+ if len(newAddr) < len(oldAddr) {
+ st.Errorf("incorrect Result: transformed address is shorter than input : %v (old) > %v (new)", newAddr, oldAddr)
+ }
+ // If IPv6, the new address needs brackets to specify a port, like so : [2001:db8:85a3:0:0:8a2e:370:7334]:5222
+ if strings.Count(newAddr, "[") < strings.Count(oldAddr, "[") ||
+ strings.Count(newAddr, "]") < strings.Count(oldAddr, "]") {
+
+ st.Errorf("incorrect Result. Transformed address seems to not have correct brakets : %v => %v", oldAddr, newAddr)
+ }
+
+ // Check if we messed up the colons, or didn't properly add a port
+ if strings.Count(newAddr, ":") < strings.Count(oldAddr, ":") {
+ st.Errorf("incorrect Result: transformed address doesn't seem to have a port %v (=> %v, no port ?)", oldAddr, newAddr)
+ }
+ })
+ }
+
+}
diff --git a/router_test.go b/router_test.go
index b3d253e..2b5cf82 100644
--- a/router_test.go
+++ b/router_test.go
@@ -146,7 +146,7 @@ func TestTypeMatcher(t *testing.T) {
// We do not match on other types
conn = NewSenderMock()
- iqVersion := stanza.NewIQ(stanza.Attrs{Type: "get", From: "service.localhost", To: "test@localhost", Id: "1"})
+ iqVersion := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "service.localhost", To: "test@localhost", Id: "1"})
iqVersion.Payload = &stanza.DiscoInfo{
XMLName: xml.Name{
Space: "jabber:iq:version",
@@ -163,27 +163,27 @@ func TestCompositeMatcher(t *testing.T) {
router := NewRouter()
router.NewRoute().
IQNamespaces("jabber:iq:version").
- StanzaType("get").
+ StanzaType(string(stanza.IQTypeGet)).
HandlerFunc(func(s Sender, p stanza.Packet) {
_ = s.SendRaw(successFlag)
})
// Data set
- getVersionIq := stanza.NewIQ(stanza.Attrs{Type: "get", From: "service.localhost", To: "test@localhost", Id: "1"})
+ getVersionIq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "service.localhost", To: "test@localhost", Id: "1"})
getVersionIq.Payload = &stanza.Version{
XMLName: xml.Name{
Space: "jabber:iq:version",
Local: "query",
}}
- setVersionIq := stanza.NewIQ(stanza.Attrs{Type: "set", From: "service.localhost", To: "test@localhost", Id: "1"})
+ setVersionIq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeSet, From: "service.localhost", To: "test@localhost", Id: "1"})
setVersionIq.Payload = &stanza.Version{
XMLName: xml.Name{
Space: "jabber:iq:version",
Local: "query",
}}
- GetDiscoIq := stanza.NewIQ(stanza.Attrs{Type: "get", From: "service.localhost", To: "test@localhost", Id: "1"})
+ GetDiscoIq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "service.localhost", To: "test@localhost", Id: "1"})
GetDiscoIq.Payload = &stanza.DiscoInfo{
XMLName: xml.Name{
Space: "http://jabber.org/protocol/disco#info",
@@ -238,7 +238,7 @@ func TestCatchallMatcher(t *testing.T) {
}
conn = NewSenderMock()
- iqVersion := stanza.NewIQ(stanza.Attrs{Type: "get", From: "service.localhost", To: "test@localhost", Id: "1"})
+ iqVersion := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "service.localhost", To: "test@localhost", Id: "1"})
iqVersion.Payload = &stanza.DiscoInfo{
XMLName: xml.Name{
Space: "jabber:iq:version",
diff --git a/stanza/component.go b/stanza/component.go
index 33ced33..32a36b0 100644
--- a/stanza/component.go
+++ b/stanza/component.go
@@ -12,7 +12,7 @@ import (
type Handshake struct {
XMLName xml.Name `xml:"jabber:component:accept handshake"`
// TODO Add handshake value with test for proper serialization
- // Value string `xml:",innerxml"`
+ Value string `xml:",innerxml"`
}
func (Handshake) Name() string {
diff --git a/stanza/error.go b/stanza/error.go
index bcc947f..0f416e4 100644
--- a/stanza/error.go
+++ b/stanza/error.go
@@ -54,7 +54,7 @@ func (x *Err) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
textName := xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-stanzas", Local: "text"}
if elt.XMLName == textName {
- x.Text = string(elt.Content)
+ x.Text = elt.Content
} else if elt.XMLName.Space == "urn:ietf:params:xml:ns:xmpp-stanzas" {
x.Reason = elt.XMLName.Local
}
@@ -94,16 +94,32 @@ func (x Err) MarshalXML(e *xml.Encoder, start xml.StartElement) (err error) {
// Reason
if x.Reason != "" {
reason := xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-stanzas", Local: x.Reason}
- e.EncodeToken(xml.StartElement{Name: reason})
- e.EncodeToken(xml.EndElement{Name: reason})
+ err = e.EncodeToken(xml.StartElement{Name: reason})
+ if err != nil {
+ return err
+ }
+ err = e.EncodeToken(xml.EndElement{Name: reason})
+ if err != nil {
+ return err
+ }
+
}
// Text
if x.Text != "" {
text := xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-stanzas", Local: "text"}
- e.EncodeToken(xml.StartElement{Name: text})
- e.EncodeToken(xml.CharData(x.Text))
- e.EncodeToken(xml.EndElement{Name: text})
+ err = e.EncodeToken(xml.StartElement{Name: text})
+ if err != nil {
+ return err
+ }
+ err = e.EncodeToken(xml.CharData(x.Text))
+ if err != nil {
+ return err
+ }
+ err = e.EncodeToken(xml.EndElement{Name: text})
+ if err != nil {
+ return err
+ }
}
return e.EncodeToken(xml.EndElement{Name: start.Name})
diff --git a/stanza/iq.go b/stanza/iq.go
index 923cf28..499c261 100644
--- a/stanza/iq.go
+++ b/stanza/iq.go
@@ -2,6 +2,7 @@ package stanza
import (
"encoding/xml"
+ "strings"
"github.com/google/uuid"
)
@@ -23,7 +24,7 @@ type IQ struct { // Info/Query
// child element, which specifies the semantics of the particular
// request."
Payload IQPayload `xml:",omitempty"`
- Error Err `xml:"error,omitempty"`
+ Error *Err `xml:"error,omitempty"`
// Any is used to decode unknown payload as a generic structure
Any *Node `xml:",any"`
}
@@ -52,7 +53,7 @@ func (iq IQ) MakeError(xerror Err) IQ {
iq.Type = "error"
iq.From = to
iq.To = from
- iq.Error = xerror
+ iq.Error = &xerror
return iq
}
@@ -106,7 +107,7 @@ func (iq *IQ) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
if err != nil {
return err
}
- iq.Error = xmppError
+ iq.Error = &xmppError
continue
}
if iqExt := TypeRegistry.GetIQExtension(tt.Name); iqExt != nil {
@@ -132,3 +133,39 @@ func (iq *IQ) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
}
}
}
+
+// Following RFC-3920 for IQs
+func (iq *IQ) IsValid() bool {
+ // ID is required
+ if len(strings.TrimSpace(iq.Id)) == 0 {
+ return false
+ }
+
+ // Type is required
+ if iq.Type.IsEmpty() {
+ return false
+ }
+
+ // Type get and set must contain one and only one child element that specifies the semantics
+ if iq.Type == IQTypeGet || iq.Type == IQTypeSet {
+ if iq.Payload == nil && iq.Any == nil {
+ return false
+ }
+ }
+
+ // A result must include zero or one child element
+ if iq.Type == IQTypeResult {
+ if iq.Payload != nil && iq.Any != nil {
+ return false
+ }
+ }
+
+ //Error type must contain an "error" child element
+ if iq.Type == IQTypeError {
+ if iq.Error == nil {
+ return false
+ }
+ }
+
+ return true
+}
diff --git a/stanza/iq_test.go b/stanza/iq_test.go
index 54a8fc5..3223566 100644
--- a/stanza/iq_test.go
+++ b/stanza/iq_test.go
@@ -187,3 +187,38 @@ func TestUnknownPayload(t *testing.T) {
t.Errorf("could not extract namespace: '%s'", parsedIQ.Any.XMLName.Space)
}
}
+
+func TestIsValid(t *testing.T) {
+ type testCase struct {
+ iq string
+ shouldErr bool
+ }
+ testIQs := make(map[string]testCase)
+ testIQs["Valid IQ"] = testCase{
+ `
+
+ `,
+ false,
+ }
+ testIQs["Invalid IQ"] = testCase{
+ `
+
+ `,
+ true,
+ }
+
+ for name, tcase := range testIQs {
+ t.Run(name, func(st *testing.T) {
+ parsedIQ := stanza.IQ{}
+ err := xml.Unmarshal([]byte(tcase.iq), &parsedIQ)
+ if err != nil {
+ t.Errorf("Unmarshal error: %#v (%s)", err, tcase.iq)
+ return
+ }
+ if !parsedIQ.IsValid() && !tcase.shouldErr {
+ t.Errorf("failed iq validation for : %s", tcase.iq)
+ }
+ })
+ }
+
+}
diff --git a/stanza/node.go b/stanza/node.go
index 6afa7bc..308729c 100644
--- a/stanza/node.go
+++ b/stanza/node.go
@@ -46,9 +46,18 @@ func (n Node) MarshalXML(e *xml.Encoder, start xml.StartElement) (err error) {
start.Name = n.XMLName
err = e.EncodeToken(start)
- e.EncodeElement(n.Nodes, xml.StartElement{Name: n.XMLName})
+ if err != nil {
+ return err
+ }
+ err = e.EncodeElement(n.Nodes, xml.StartElement{Name: n.XMLName})
+ if err != nil {
+ return err
+ }
if n.Content != "" {
- e.EncodeToken(xml.CharData(n.Content))
+ err = e.EncodeToken(xml.CharData(n.Content))
+ if err != nil {
+ return err
+ }
}
return e.EncodeToken(xml.EndElement{Name: start.Name})
}
diff --git a/stanza/packet_enum.go b/stanza/packet_enum.go
index 103966a..84dd476 100644
--- a/stanza/packet_enum.go
+++ b/stanza/packet_enum.go
@@ -1,5 +1,7 @@
package stanza
+import "strings"
+
type StanzaType string
// RFC 6120: part of A.5 Client Namespace and A.6 Server Namespace
@@ -23,3 +25,7 @@ const (
PresenceTypeUnsubscribe StanzaType = "unsubscribe"
PresenceTypeUnsubscribed StanzaType = "unsubscribed"
)
+
+func (s StanzaType) IsEmpty() bool {
+ return len(strings.TrimSpace(string(s))) == 0
+}
diff --git a/stanza/sasl_auth.go b/stanza/sasl_auth.go
index d04174f..29648ee 100644
--- a/stanza/sasl_auth.go
+++ b/stanza/sasl_auth.go
@@ -107,6 +107,6 @@ func (s *StreamSession) IsOptional() bool {
// Registry init
func init() {
- TypeRegistry.MapExtension(PKTIQ, xml.Name{"urn:ietf:params:xml:ns:xmpp-bind", "bind"}, Bind{})
- TypeRegistry.MapExtension(PKTIQ, xml.Name{"urn:ietf:params:xml:ns:xmpp-session", "session"}, StreamSession{})
+ TypeRegistry.MapExtension(PKTIQ, xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-bind", Local: "bind"}, Bind{})
+ TypeRegistry.MapExtension(PKTIQ, xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-session", Local: "session"}, StreamSession{})
}
diff --git a/stanza/stream.go b/stanza/stream.go
index 290abfe..203cc83 100644
--- a/stanza/stream.go
+++ b/stanza/stream.go
@@ -8,7 +8,7 @@ import "encoding/xml"
type Stream struct {
XMLName xml.Name `xml:"http://etherx.jabber.org/streams stream"`
From string `xml:"from,attr"`
- To string `xml:"to,attr"`
+ To string `xml:"to,attr"`
Id string `xml:"id,attr"`
Version string `xml:"version,attr"`
}
diff --git a/stanza/stream_features.go b/stanza/stream_features.go
index 11cd96b..14358f0 100644
--- a/stanza/stream_features.go
+++ b/stanza/stream_features.go
@@ -15,7 +15,7 @@ type StreamFeatures struct {
// Server capabilities hash
Caps Caps
// Stream features
- StartTLS tlsStartTLS
+ StartTLS TlsStartTLS
Mechanisms saslMechanisms
Bind Bind
StreamManagement streamManagement
@@ -60,13 +60,13 @@ type Caps struct {
// StartTLS feature
// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-5.4
-type tlsStartTLS struct {
+type TlsStartTLS struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"`
Required bool
}
// UnmarshalXML implements custom parsing startTLS required flag
-func (stls *tlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
+func (stls *TlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
stls.XMLName = start.Name
// Check subelements to extract required field as boolean
@@ -98,7 +98,7 @@ func (stls *tlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) er
}
}
-func (sf *StreamFeatures) DoesStartTLS() (feature tlsStartTLS, isSupported bool) {
+func (sf *StreamFeatures) DoesStartTLS() (feature TlsStartTLS, isSupported bool) {
if sf.StartTLS.XMLName.Space+" "+sf.StartTLS.XMLName.Local == nsTLS+" starttls" {
return sf.StartTLS, true
}
diff --git a/stream_manager.go b/stream_manager.go
index bf7fba8..aebd8a4 100644
--- a/stream_manager.go
+++ b/stream_manager.go
@@ -74,7 +74,7 @@ func (sm *StreamManager) Run() error {
return errors.New("missing stream client")
}
- handler := func(e Event) {
+ handler := func(e Event) error {
switch e.State {
case StateConnected:
sm.Metrics.setConnectTime()
@@ -82,17 +82,18 @@ func (sm *StreamManager) Run() error {
sm.Metrics.setLoginTime()
case StateDisconnected:
// Reconnect on disconnection
- sm.resume(e.SMState)
+ return sm.resume(e.SMState)
case StateStreamError:
sm.client.Disconnect()
// Only try reconnecting if we have not been kicked by another session to avoid connection loop.
// TODO: Make this conflict exception a permanent error
if e.StreamError != "conflict" {
- sm.connect()
+ return sm.connect()
}
case StatePermanentError:
// Do not attempt to reconnect
}
+ return nil
}
sm.client.SetHandler(handler)
diff --git a/test.sh b/test.sh
index 9730026..725dcaf 100755
--- a/test.sh
+++ b/test.sh
@@ -5,7 +5,7 @@ export GO111MODULE=on
echo "" > coverage.txt
for d in $(go list ./... | grep -v vendor); do
- go test -race -coverprofile=profile.out -covermode=atomic ${d}
+ go test -race -coverprofile=profile.out -covermode=atomic "${d}"
if [ -f profile.out ]; then
cat profile.out >> coverage.txt
rm profile.out