Merge pull request #135 from remicorniere/Error_Handling

Added callback to process errors after connection.
This commit is contained in:
remicorniere 2019-12-12 14:51:00 +00:00 committed by GitHub
commit 27130d7292
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 865 additions and 386 deletions

View file

@ -0,0 +1,95 @@
package main
/*
xmpp_chat_client is a demo client that connect on an XMPP server to chat with other members
Note that this example sends to a very specific user. User logic is not implemented here.
*/
import (
. "bufio"
"fmt"
"os"
"gosrc.io/xmpp"
"gosrc.io/xmpp/stanza"
)
const (
currentUserAddress = "localhost:5222"
currentUserJid = "testuser@localhost"
currentUserPass = "testpass"
correspondantJid = "testuser2@localhost"
)
func main() {
config := xmpp.Config{
TransportConfiguration: xmpp.TransportConfiguration{
Address: currentUserAddress,
},
Jid: currentUserJid,
Credential: xmpp.Password(currentUserPass),
Insecure: true}
var client *xmpp.Client
var err error
router := xmpp.NewRouter()
router.HandleFunc("message", handleMessage)
if client, err = xmpp.NewClient(config, router, errorHandler); err != nil {
fmt.Println("Error new client")
}
// Connecting client and handling messages
// To use a stream manager, just write something like this instead :
//cm := xmpp.NewStreamManager(client, startMessaging)
//log.Fatal(cm.Run()) //=> this will lock the calling goroutine
if err = client.Connect(); err != nil {
fmt.Printf("XMPP connection failed: %s", err)
return
}
startMessaging(client)
}
func startMessaging(client xmpp.Sender) {
reader := NewReader(os.Stdin)
textChan := make(chan string)
var text string
for {
fmt.Print("Enter text: ")
go readInput(reader, textChan)
select {
case <-killChan:
return
case text = <-textChan:
reply := stanza.Message{Attrs: stanza.Attrs{To: correspondantJid}, Body: text}
err := client.Send(reply)
if err != nil {
fmt.Printf("There was a problem sending the message : %v", reply)
return
}
}
}
}
func readInput(reader *Reader, textChan chan string) {
text, _ := reader.ReadString('\n')
textChan <- text
}
var killChan = make(chan struct{})
// If an error occurs, this is used
func errorHandler(err error) {
fmt.Printf("%v", err)
killChan <- struct{}{}
}
func handleMessage(s xmpp.Sender, p stanza.Packet) {
msg, ok := p.(stanza.Message)
if !ok {
_, _ = fmt.Fprintf(os.Stdout, "Ignoring packet: %T\n", p)
return
}
_, _ = fmt.Fprintf(os.Stdout, "Body = %s - from = %s\n", msg.Body, msg.From)
}

View file

@ -35,7 +35,7 @@ func main() {
IQNamespaces("jabber:iq:version"). IQNamespaces("jabber:iq:version").
HandlerFunc(handleVersion) HandlerFunc(handleVersion)
component, err := xmpp.NewComponent(opts, router) component, err := xmpp.NewComponent(opts, router, handleError)
if err != nil { if err != nil {
log.Fatalf("%+v", err) log.Fatalf("%+v", err)
} }
@ -47,6 +47,10 @@ func main() {
log.Fatal(cm.Run()) log.Fatal(cm.Run())
} }
func handleError(err error) {
fmt.Println(err.Error())
}
func handleMessage(_ xmpp.Sender, p stanza.Packet) { func handleMessage(_ xmpp.Sender, p stanza.Packet) {
msg, ok := p.(stanza.Message) msg, ok := p.(stanza.Message)
if !ok { if !ok {

View file

@ -53,7 +53,7 @@ func main() {
handleIQ(s, p, player) handleIQ(s, p, player)
}) })
client, err := xmpp.NewClient(config, router) client, err := xmpp.NewClient(config, router, errorHandler)
if err != nil { if err != nil {
log.Fatalf("%+v", err) log.Fatalf("%+v", err)
} }
@ -61,6 +61,9 @@ func main() {
cm := xmpp.NewStreamManager(client, nil) cm := xmpp.NewStreamManager(client, nil)
log.Fatal(cm.Run()) log.Fatal(cm.Run())
} }
func errorHandler(err error) {
fmt.Println(err.Error())
}
func handleMessage(s xmpp.Sender, p stanza.Packet, player *mpg123.Player) { func handleMessage(s xmpp.Sender, p stanza.Packet, player *mpg123.Player) {
msg, ok := p.(stanza.Message) msg, ok := p.(stanza.Message)

View file

@ -28,7 +28,7 @@ func main() {
router := xmpp.NewRouter() router := xmpp.NewRouter()
router.HandleFunc("message", handleMessage) router.HandleFunc("message", handleMessage)
client, err := xmpp.NewClient(config, router) client, err := xmpp.NewClient(config, router, errorHandler)
if err != nil { if err != nil {
log.Fatalf("%+v", err) log.Fatalf("%+v", err)
} }
@ -39,6 +39,10 @@ func main() {
log.Fatal(cm.Run()) log.Fatal(cm.Run())
} }
func errorHandler(err error) {
fmt.Println(err.Error())
}
func handleMessage(s xmpp.Sender, p stanza.Packet) { func handleMessage(s xmpp.Sender, p stanza.Packet) {
msg, ok := p.(stanza.Message) msg, ok := p.(stanza.Message)
if !ok { if !ok {

View file

@ -26,7 +26,7 @@ func main() {
router := xmpp.NewRouter() router := xmpp.NewRouter()
router.HandleFunc("message", handleMessage) router.HandleFunc("message", handleMessage)
client, err := xmpp.NewClient(config, router) client, err := xmpp.NewClient(config, router, errorHandler)
if err != nil { if err != nil {
log.Fatalf("%+v", err) log.Fatalf("%+v", err)
} }
@ -37,6 +37,10 @@ func main() {
log.Fatal(cm.Run()) log.Fatal(cm.Run())
} }
func errorHandler(err error) {
fmt.Println(err.Error())
}
func handleMessage(s xmpp.Sender, p stanza.Packet) { func handleMessage(s xmpp.Sender, p stanza.Packet) {
msg, ok := p.(stanza.Message) msg, ok := p.(stanza.Message)
if !ok { if !ok {

View file

@ -4,7 +4,6 @@ import (
"context" "context"
"encoding/xml" "encoding/xml"
"errors" "errors"
"fmt"
"io" "io"
"net" "net"
"time" "time"
@ -21,6 +20,7 @@ type ConnState = uint8
// This is a the list of events happening on the connection that the // This is a the list of events happening on the connection that the
// client can be notified about. // client can be notified about.
const ( const (
InitialPresence = "<presence/>"
StateDisconnected ConnState = iota StateDisconnected ConnState = iota
StateConnected StateConnected
StateSessionEstablished StateSessionEstablished
@ -98,6 +98,8 @@ type Client struct {
router *Router router *Router
// Track and broadcast connection state // Track and broadcast connection state
EventManager EventManager
// Handle errors from client execution
ErrorHandler func(error)
} }
/* /*
@ -107,7 +109,7 @@ Setting up the client / Checking the parameters
// NewClient generates a new XMPP client, based on Config passed as parameters. // NewClient generates a new XMPP client, based on Config passed as parameters.
// If host is not specified, the DNS SRV should be used to find the host from the domainpart of the JID. // If host is not specified, the DNS SRV should be used to find the host from the domainpart of the JID.
// Default the port to 5222. // Default the port to 5222.
func NewClient(config Config, r *Router) (c *Client, err error) { func NewClient(config Config, r *Router, errorHandler func(error)) (c *Client, err error) {
if config.KeepaliveInterval == 0 { if config.KeepaliveInterval == 0 {
config.KeepaliveInterval = time.Second * 30 config.KeepaliveInterval = time.Second * 30
} }
@ -143,6 +145,7 @@ func NewClient(config Config, r *Router) (c *Client, err error) {
c = new(Client) c = new(Client)
c.config = config c.config = config
c.router = r c.router = r
c.ErrorHandler = errorHandler
if c.config.ConnectTimeout == 0 { if c.config.ConnectTimeout == 0 {
c.config.ConnectTimeout = 15 // 15 second as default c.config.ConnectTimeout = 15 // 15 second as default
@ -191,16 +194,13 @@ func (c *Client) Resume(state SMState) error {
go keepalive(c.transport, c.config.KeepaliveInterval, keepaliveQuit) go keepalive(c.transport, c.config.KeepaliveInterval, keepaliveQuit)
// Start the receiver go routine // Start the receiver go routine
state = c.Session.SMState state = c.Session.SMState
// Leaving this channel here for later. Not used atm. We should return this instead of an error because right go c.recv(state, keepaliveQuit)
// now the returned error is lost in limbo.
errChan := make(chan error)
go c.recv(state, keepaliveQuit, errChan)
// We're connected and can now receive and send messages. // We're connected and can now receive and send messages.
//fmt.Fprintf(client.conn, "<presence xml:lang='en'><show>%s</show><status>%s</status></presence>", "chat", "Online") //fmt.Fprintf(client.conn, "<presence xml:lang='en'><show>%s</show><status>%s</status></presence>", "chat", "Online")
// TODO: Do we always want to send initial presence automatically ? // TODO: Do we always want to send initial presence automatically ?
// Do we need an option to avoid that or do we rely on client to send the presence itself ? // Do we need an option to avoid that or do we rely on client to send the presence itself ?
_, err = fmt.Fprintf(c.transport, "<presence/>") err = c.sendWithWriter(c.transport, []byte(InitialPresence))
return err return err
} }
@ -273,11 +273,11 @@ func (c *Client) sendWithWriter(writer io.Writer, packet []byte) error {
// Go routines // Go routines
// Loop: Receive data from server // Loop: Receive data from server
func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}, errChan chan<- error) { func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) {
for { for {
val, err := stanza.NextPacket(c.transport.GetDecoder()) val, err := stanza.NextPacket(c.transport.GetDecoder())
if err != nil { if err != nil {
errChan <- err c.ErrorHandler(err)
close(keepaliveQuit) close(keepaliveQuit)
c.disconnected(state) c.disconnected(state)
return return
@ -289,7 +289,7 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}, errChan chan
c.router.route(c, val) c.router.route(c, val)
close(keepaliveQuit) close(keepaliveQuit)
c.streamError(packet.Error.Local, packet.Text) c.streamError(packet.Error.Local, packet.Text)
errChan <- errors.New("stream error: " + packet.Error.Local) c.ErrorHandler(errors.New("stream error: " + packet.Error.Local))
return return
// Process Stream management nonzas // Process Stream management nonzas
case stanza.SMRequest: case stanza.SMRequest:
@ -299,7 +299,7 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}, errChan chan
}, H: state.Inbound} }, H: state.Inbound}
err = c.Send(answer) err = c.Send(answer)
if err != nil { if err != nil {
errChan <- err c.ErrorHandler(err)
return return
} }
default: default:

View file

@ -1,10 +1,10 @@
package xmpp package xmpp
import ( import (
"context"
"encoding/xml" "encoding/xml"
"errors" "errors"
"fmt" "fmt"
"net"
"testing" "testing"
"time" "time"
@ -15,8 +15,7 @@ const (
// Default port is not standard XMPP port to avoid interfering // Default port is not standard XMPP port to avoid interfering
// with local running XMPP server // with local running XMPP server
testXMPPAddress = "localhost:15222" testXMPPAddress = "localhost:15222"
testClientDomain = "localhost"
defaultTimeout = 2 * time.Second
) )
func TestEventManager(t *testing.T) { func TestEventManager(t *testing.T) {
@ -40,7 +39,7 @@ func TestEventManager(t *testing.T) {
func TestClient_Connect(t *testing.T) { func TestClient_Connect(t *testing.T) {
// Setup Mock server // Setup Mock server
mock := ServerMock{} mock := ServerMock{}
mock.Start(t, testXMPPAddress, handlerConnectSuccess) mock.Start(t, testXMPPAddress, handlerClientConnectSuccess)
// Test / Check result // Test / Check result
config := Config{ config := Config{
@ -54,7 +53,7 @@ func TestClient_Connect(t *testing.T) {
var client *Client var client *Client
var err error var err error
router := NewRouter() router := NewRouter()
if client, err = NewClient(config, router); err != nil { if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil {
t.Errorf("connect create XMPP client: %s", err) t.Errorf("connect create XMPP client: %s", err)
} }
@ -82,7 +81,7 @@ func TestClient_NoInsecure(t *testing.T) {
var client *Client var client *Client
var err error var err error
router := NewRouter() router := NewRouter()
if client, err = NewClient(config, router); err != nil { if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil {
t.Errorf("cannot create XMPP client: %s", err) t.Errorf("cannot create XMPP client: %s", err)
} }
@ -112,7 +111,7 @@ func TestClient_FeaturesTracking(t *testing.T) {
var client *Client var client *Client
var err error var err error
router := NewRouter() router := NewRouter()
if client, err = NewClient(config, router); err != nil { if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil {
t.Errorf("cannot create XMPP client: %s", err) t.Errorf("cannot create XMPP client: %s", err)
} }
@ -127,7 +126,7 @@ func TestClient_FeaturesTracking(t *testing.T) {
func TestClient_RFC3921Session(t *testing.T) { func TestClient_RFC3921Session(t *testing.T) {
// Setup Mock server // Setup Mock server
mock := ServerMock{} mock := ServerMock{}
mock.Start(t, testXMPPAddress, handlerConnectWithSession) mock.Start(t, testXMPPAddress, handlerClientConnectWithSession)
// Test / Check result // Test / Check result
config := Config{ config := Config{
@ -142,7 +141,7 @@ func TestClient_RFC3921Session(t *testing.T) {
var client *Client var client *Client
var err error var err error
router := NewRouter() router := NewRouter()
if client, err = NewClient(config, router); err != nil { if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil {
t.Errorf("connect create XMPP client: %s", err) t.Errorf("connect create XMPP client: %s", err)
} }
@ -153,54 +152,256 @@ func TestClient_RFC3921Session(t *testing.T) {
mock.Stop() mock.Stop()
} }
// Testing sending an IQ to the mock server and reading its response.
func TestClient_SendIQ(t *testing.T) {
done := make(chan struct{})
// Handler for Mock server
h := func(t *testing.T, sc *ServerConn) {
handlerClientConnectSuccess(t, sc)
discardPresence(t, sc)
respondToIQ(t, sc)
done <- struct{}{}
}
client, mock := mockClientConnection(t, h, testClientIqPort)
ctx, _ := context.WithTimeout(context.Background(), 30*time.Second)
iqReq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "test1@localhost/mremond-mbp", To: defaultServerName, Id: defaultStreamID, Lang: "en"})
disco := iqReq.DiscoInfo()
iqReq.Payload = disco
// Handle a possible error
errChan := make(chan error)
errorHandler := func(err error) {
errChan <- err
}
client.ErrorHandler = errorHandler
res, err := client.SendIQ(ctx, iqReq)
if err != nil {
t.Errorf(err.Error())
}
select {
case <-res: // If the server responds with an IQ, we pass the test
case err := <-errChan: // If the server sends an error, or there is a connection error
t.Fatal(err.Error())
case <-time.After(defaultChannelTimeout): // If we timeout
t.Fatal("Failed to receive response, to sent IQ, from mock server")
}
select {
case <-done:
mock.Stop()
case <-time.After(defaultChannelTimeout):
t.Fatal("The mock server failed to finish its job !")
}
}
func TestClient_SendIQFail(t *testing.T) {
done := make(chan struct{})
// Handler for Mock server
h := func(t *testing.T, sc *ServerConn) {
handlerClientConnectSuccess(t, sc)
discardPresence(t, sc)
respondToIQ(t, sc)
done <- struct{}{}
}
client, mock := mockClientConnection(t, h, testClientIqFailPort)
//==================
// Create an IQ to send
ctx, _ := context.WithTimeout(context.Background(), 30*time.Second)
iqReq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "test1@localhost/mremond-mbp", To: defaultServerName, Id: defaultStreamID, Lang: "en"})
disco := iqReq.DiscoInfo()
iqReq.Payload = disco
// Removing the id to make the stanza invalid. The IQ constructor makes a random one if none is specified
// so we need to overwrite it.
iqReq.Id = ""
// Handle a possible error
errChan := make(chan error)
errorHandler := func(err error) {
errChan <- err
}
client.ErrorHandler = errorHandler
res, _ := client.SendIQ(ctx, iqReq)
// Test
select {
case <-res: // If the server responds with an IQ
t.Errorf("Server should not respond with an IQ since the request is expected to be invalid !")
case <-errChan: // If the server sends an error, the test passes
case <-time.After(defaultChannelTimeout): // If we timeout
t.Errorf("Failed to receive response, to sent IQ, from mock server")
}
select {
case <-done:
mock.Stop()
case <-time.After(defaultChannelTimeout):
t.Errorf("The mock server failed to finish its job !")
}
}
func TestClient_SendRaw(t *testing.T) {
done := make(chan struct{})
// Handler for Mock server
h := func(t *testing.T, sc *ServerConn) {
handlerClientConnectSuccess(t, sc)
discardPresence(t, sc)
respondToIQ(t, sc)
done <- struct{}{}
}
type testCase struct {
req string
shouldErr bool
port int
}
testRequests := make(map[string]testCase)
// Sending a correct IQ of type get. Not supposed to err
testRequests["Correct IQ"] = testCase{
req: `<iq type="get" id="91bd0bba-012f-4d92-bb17-5fc41e6fe545" from="test1@localhost/mremond-mbp" to="testServer" lang="en"><query xmlns="http://jabber.org/protocol/disco#info"></query></iq>`,
shouldErr: false,
port: testClientRawPort + 100,
}
// Sending an IQ with a missing ID. Should err
testRequests["IQ with missing ID"] = testCase{
req: `<iq type="get" from="test1@localhost/mremond-mbp" to="testServer" lang="en"><query xmlns="http://jabber.org/protocol/disco#info"></query></iq>`,
shouldErr: true,
port: testClientRawPort,
}
// A handler for the client.
// In the failing test, the server returns a stream error, which triggers this handler, client side.
errChan := make(chan error)
errHandler := func(err error) {
errChan <- err
}
// Tests for all the IQs
for name, tcase := range testRequests {
t.Run(name, func(st *testing.T) {
//Connecting to a mock server, initialized with given port and handler function
c, m := mockClientConnection(t, h, tcase.port)
c.ErrorHandler = errHandler
// Sending raw xml from test case
err := c.SendRaw(tcase.req)
if err != nil {
t.Errorf("Error sending Raw string")
}
// Just wait a little so the message has time to arrive
select {
// We don't use the default "long" timeout here because waiting it out means passing the test.
case <-time.After(100 * time.Millisecond):
case err = <-errChan:
if err == nil && tcase.shouldErr {
t.Errorf("Failed to get closing stream err")
} else if err != nil && !tcase.shouldErr {
t.Errorf("This test is not supposed to err !")
}
}
c.transport.Close()
select {
case <-done:
m.Stop()
case <-time.After(defaultChannelTimeout):
t.Errorf("The mock server failed to finish its job !")
}
})
}
}
func TestClient_Disconnect(t *testing.T) {
c, m := mockClientConnection(t, handlerClientConnectSuccess, testClientBasePort)
err := c.transport.Ping()
if err != nil {
t.Errorf("Could not ping but not disconnected yet")
}
c.Disconnect()
err = c.transport.Ping()
if err == nil {
t.Errorf("Did not disconnect properly")
}
m.Stop()
}
func TestClient_DisconnectStreamManager(t *testing.T) {
// Init mock server
// Setup Mock server
mock := ServerMock{}
mock.Start(t, testXMPPAddress, handlerAbortTLS)
// Test / Check result
config := Config{
TransportConfiguration: TransportConfiguration{
Address: testXMPPAddress,
},
Jid: "test@localhost",
Credential: Password("test"),
}
var client *Client
var err error
router := NewRouter()
if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil {
t.Errorf("cannot create XMPP client: %s", err)
}
sman := NewStreamManager(client, nil)
errChan := make(chan error)
runSMan := func(errChan chan error) {
errChan <- sman.Run()
}
go runSMan(errChan)
select {
case <-errChan:
case <-time.After(defaultChannelTimeout):
// When insecure is not allowed:
t.Errorf("should fail as insecure connection is not allowed and server does not support TLS")
}
mock.Stop()
}
//============================================================================= //=============================================================================
// Basic XMPP Server Mock Handlers. // Basic XMPP Server Mock Handlers.
const serverStreamOpen = "<?xml version='1.0'?><stream:stream to='%s' id='%s' xmlns='%s' xmlns:stream='%s' version='1.0'>"
// Test connection with a basic straightforward workflow // Test connection with a basic straightforward workflow
func handlerConnectSuccess(t *testing.T, c net.Conn) { func handlerClientConnectSuccess(t *testing.T, sc *ServerConn) {
decoder := xml.NewDecoder(c) checkClientOpenStream(t, sc)
checkOpenStream(t, c, decoder) sendStreamFeatures(t, sc) // Send initial features
readAuth(t, sc.decoder)
fmt.Fprintln(sc.connection, "<success xmlns=\"urn:ietf:params:xml:ns:xmpp-sasl\"/>")
sendStreamFeatures(t, c, decoder) // Send initial features checkClientOpenStream(t, sc) // Reset stream
readAuth(t, decoder) sendBindFeature(t, sc) // Send post auth features
fmt.Fprintln(c, "<success xmlns=\"urn:ietf:params:xml:ns:xmpp-sasl\"/>") bind(t, sc)
checkOpenStream(t, c, decoder) // Reset stream
sendBindFeature(t, c, decoder) // Send post auth features
bind(t, c, decoder)
} }
// We expect client will abort on TLS // We expect client will abort on TLS
func handlerAbortTLS(t *testing.T, c net.Conn) { func handlerAbortTLS(t *testing.T, sc *ServerConn) {
decoder := xml.NewDecoder(c) checkClientOpenStream(t, sc)
checkOpenStream(t, c, decoder) sendStreamFeatures(t, sc) // Send initial features
sendStreamFeatures(t, c, decoder) // Send initial features
} }
// Test connection with mandatory session (RFC-3921) // Test connection with mandatory session (RFC-3921)
func handlerConnectWithSession(t *testing.T, c net.Conn) { func handlerClientConnectWithSession(t *testing.T, sc *ServerConn) {
decoder := xml.NewDecoder(c) checkClientOpenStream(t, sc)
checkOpenStream(t, c, decoder)
sendStreamFeatures(t, c, decoder) // Send initial features sendStreamFeatures(t, sc) // Send initial features
readAuth(t, decoder) readAuth(t, sc.decoder)
fmt.Fprintln(c, "<success xmlns=\"urn:ietf:params:xml:ns:xmpp-sasl\"/>") fmt.Fprintln(sc.connection, "<success xmlns=\"urn:ietf:params:xml:ns:xmpp-sasl\"/>")
checkOpenStream(t, c, decoder) // Reset stream checkClientOpenStream(t, sc) // Reset stream
sendRFC3921Feature(t, c, decoder) // Send post auth features sendRFC3921Feature(t, sc) // Send post auth features
bind(t, c, decoder) bind(t, sc)
session(t, c, decoder) session(t, sc)
} }
func checkOpenStream(t *testing.T, c net.Conn, decoder *xml.Decoder) { func checkClientOpenStream(t *testing.T, sc *ServerConn) {
c.SetDeadline(time.Now().Add(defaultTimeout)) sc.connection.SetDeadline(time.Now().Add(defaultTimeout))
defer c.SetDeadline(time.Time{}) defer sc.connection.SetDeadline(time.Time{})
for { // TODO clean up. That for loop is not elegant and I prefer bounded recursion. for { // TODO clean up. That for loop is not elegant and I prefer bounded recursion.
var token xml.Token var token xml.Token
token, err := decoder.Token() token, err := sc.decoder.Token()
if err != nil { if err != nil {
t.Errorf("cannot read next token: %s", err) t.Errorf("cannot read next token: %s", err)
} }
@ -212,7 +413,7 @@ func checkOpenStream(t *testing.T, c net.Conn, decoder *xml.Decoder) {
err = errors.New("xmpp: expected <stream> but got <" + elem.Name.Local + "> in " + elem.Name.Space) err = errors.New("xmpp: expected <stream> but got <" + elem.Name.Local + "> in " + elem.Name.Space)
return return
} }
if _, err := fmt.Fprintf(c, serverStreamOpen, "localhost", "streamid1", stanza.NSClient, stanza.NSStream); err != nil { if _, err := fmt.Fprintf(sc.connection, serverStreamOpen, "localhost", "streamid1", stanza.NSClient, stanza.NSStream); err != nil {
t.Errorf("cannot write server stream open: %s", err) t.Errorf("cannot write server stream open: %s", err)
} }
return return
@ -220,105 +421,35 @@ func checkOpenStream(t *testing.T, c net.Conn, decoder *xml.Decoder) {
} }
} }
func sendStreamFeatures(t *testing.T, c net.Conn, _ *xml.Decoder) { func mockClientConnection(t *testing.T, serverHandler func(*testing.T, *ServerConn), port int) (*Client, *ServerMock) {
// This is a basic server, supporting only 1 stream feature: SASL Plain Auth mock := &ServerMock{}
features := `<stream:features> testServerAddress := fmt.Sprintf("%s:%d", testClientDomain, port)
<mechanisms xmlns="urn:ietf:params:xml:ns:xmpp-sasl">
<mechanism>PLAIN</mechanism> mock.Start(t, testServerAddress, serverHandler)
</mechanisms>
</stream:features>` config := Config{
if _, err := fmt.Fprintln(c, features); err != nil { TransportConfiguration: TransportConfiguration{
t.Errorf("cannot send stream feature: %s", err) Address: testServerAddress,
} },
Jid: "test@localhost",
Credential: Password("test"),
Insecure: true}
var client *Client
var err error
router := NewRouter()
if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil {
t.Errorf("connect create XMPP client: %s", err)
} }
// TODO return err in case of error reading the auth params if err = client.Connect(); err != nil {
func readAuth(t *testing.T, decoder *xml.Decoder) string { t.Errorf("XMPP connection failed: %s", err)
se, err := stanza.NextStart(decoder)
if err != nil {
t.Errorf("cannot read auth: %s", err)
return ""
} }
var nv interface{} return client, mock
nv = &stanza.SASLAuth{}
// Decode element into pointer storage
if err = decoder.DecodeElement(nv, &se); err != nil {
t.Errorf("cannot decode auth: %s", err)
return ""
} }
switch v := nv.(type) { // This really should not be used as is.
case *stanza.SASLAuth: // It's just meant to be a placeholder when error handling is not needed at this level
return v.Value func clientDefaultErrorHandler(err error) {
}
return ""
}
func sendBindFeature(t *testing.T, c net.Conn, _ *xml.Decoder) {
// This is a basic server, supporting only 1 stream feature after auth: resource binding
features := `<stream:features>
<bind xmlns='urn:ietf:params:xml:ns:xmpp-bind'/>
</stream:features>`
if _, err := fmt.Fprintln(c, features); err != nil {
t.Errorf("cannot send stream feature: %s", err)
}
}
func sendRFC3921Feature(t *testing.T, c net.Conn, _ *xml.Decoder) {
// This is a basic server, supporting only 2 features after auth: resource & session binding
features := `<stream:features>
<bind xmlns='urn:ietf:params:xml:ns:xmpp-bind'/>
<session xmlns='urn:ietf:params:xml:ns:xmpp-session'/>
</stream:features>`
if _, err := fmt.Fprintln(c, features); err != nil {
t.Errorf("cannot send stream feature: %s", err)
}
}
func bind(t *testing.T, c net.Conn, decoder *xml.Decoder) {
se, err := stanza.NextStart(decoder)
if err != nil {
t.Errorf("cannot read bind: %s", err)
return
}
iq := &stanza.IQ{}
// Decode element into pointer storage
if err = decoder.DecodeElement(&iq, &se); err != nil {
t.Errorf("cannot decode bind iq: %s", err)
return
}
// TODO Check all elements
switch iq.Payload.(type) {
case *stanza.Bind:
result := `<iq id='%s' type='result'>
<bind xmlns='urn:ietf:params:xml:ns:xmpp-bind'>
<jid>%s</jid>
</bind>
</iq>`
fmt.Fprintf(c, result, iq.Id, "test@localhost/test") // TODO use real JID
}
}
func session(t *testing.T, c net.Conn, decoder *xml.Decoder) {
se, err := stanza.NextStart(decoder)
if err != nil {
t.Errorf("cannot read session: %s", err)
return
}
iq := &stanza.IQ{}
// Decode element into pointer storage
if err = decoder.DecodeElement(&iq, &se); err != nil {
t.Errorf("cannot decode session iq: %s", err)
return
}
switch iq.Payload.(type) {
case *stanza.StreamSession:
result := `<iq id='%s' type='result'/>`
fmt.Fprintf(c, result, iq.Id)
}
} }

View file

@ -49,10 +49,11 @@ type Component struct {
// read / write // read / write
socketProxy io.ReadWriter // TODO socketProxy io.ReadWriter // TODO
ErrorHandler func(error)
} }
func NewComponent(opts ComponentOptions, r *Router) (*Component, error) { func NewComponent(opts ComponentOptions, r *Router, errorHandler func(error)) (*Component, error) {
c := Component{ComponentOptions: opts, router: r} c := Component{ComponentOptions: opts, router: r, ErrorHandler: errorHandler}
return &c, nil return &c, nil
} }
@ -84,7 +85,7 @@ func (c *Component) Resume(sm SMState) error {
c.updateState(StateConnected) c.updateState(StateConnected)
// Authentication // Authentication
if _, err := fmt.Fprintf(c.transport, "<handshake>%s</handshake>", c.handshake(streamId)); err != nil { if err := c.sendWithWriter(c.transport, []byte(fmt.Sprintf("<handshake>%s</handshake>", c.handshake(streamId)))); err != nil {
c.updateState(StateStreamError) c.updateState(StateStreamError)
return NewConnError(errors.New("cannot send handshake "+err.Error()), false) return NewConnError(errors.New("cannot send handshake "+err.Error()), false)
@ -104,10 +105,7 @@ func (c *Component) Resume(sm SMState) error {
case stanza.Handshake: case stanza.Handshake:
// Start the receiver go routine // Start the receiver go routine
c.updateState(StateSessionEstablished) c.updateState(StateSessionEstablished)
// Leaving this channel here for later. Not used atm. We should return this instead of an error because right go c.recv()
// now the returned error is lost in limbo.
errChan := make(chan error)
go c.recv(errChan) // Sends to errChan
return err // Should be empty at this point return err // Should be empty at this point
default: default:
c.updateState(StatePermanentError) c.updateState(StatePermanentError)
@ -128,13 +126,13 @@ func (c *Component) SetHandler(handler EventHandler) {
} }
// Receiver Go routine receiver // Receiver Go routine receiver
func (c *Component) recv(errChan chan<- error) { func (c *Component) recv() {
for { for {
val, err := stanza.NextPacket(c.transport.GetDecoder()) val, err := stanza.NextPacket(c.transport.GetDecoder())
if err != nil { if err != nil {
c.updateState(StateDisconnected) c.updateState(StateDisconnected)
errChan <- err c.ErrorHandler(err)
return return
} }
// Handle stream errors // Handle stream errors
@ -142,7 +140,7 @@ func (c *Component) recv(errChan chan<- error) {
case stanza.StreamError: case stanza.StreamError:
c.router.route(c, val) c.router.route(c, val)
c.streamError(p.Error.Local, p.Text) c.streamError(p.Error.Local, p.Text)
errChan <- errors.New("stream error: " + p.Error.Local) c.ErrorHandler(errors.New("stream error: " + p.Error.Local))
return return
} }
c.router.route(c, val) c.router.route(c, val)
@ -161,12 +159,18 @@ func (c *Component) Send(packet stanza.Packet) error {
return errors.New("cannot marshal packet " + err.Error()) return errors.New("cannot marshal packet " + err.Error())
} }
if _, err := fmt.Fprintf(transport, string(data)); err != nil { if err := c.sendWithWriter(transport, data); err != nil {
return errors.New("cannot send packet " + err.Error()) return errors.New("cannot send packet " + err.Error())
} }
return nil return nil
} }
func (c *Component) sendWithWriter(writer io.Writer, packet []byte) error {
var err error
_, err = writer.Write(packet)
return err
}
// SendIQ sends an IQ set or get stanza to the server. If a result is received // SendIQ sends an IQ set or get stanza to the server. If a result is received
// the provided handler function will automatically be called. // the provided handler function will automatically be called.
// //
@ -197,7 +201,7 @@ func (c *Component) SendRaw(packet string) error {
} }
var err error var err error
_, err = fmt.Fprintf(transport, packet) err = c.sendWithWriter(transport, []byte(packet))
return err return err
} }

View file

@ -5,8 +5,8 @@ import (
"encoding/xml" "encoding/xml"
"errors" "errors"
"fmt" "fmt"
"github.com/google/uuid"
"gosrc.io/xmpp/stanza" "gosrc.io/xmpp/stanza"
"net"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -15,19 +15,7 @@ import (
// Tests are ran in parallel, so each test creating a server must use a different port so we do not get any // Tests are ran in parallel, so each test creating a server must use a different port so we do not get any
// conflict. Using iota for this should do the trick. // conflict. Using iota for this should do the trick.
const ( const (
testComponentDomain = "localhost" defaultChannelTimeout = 5 * time.Second
defaultServerName = "testServer"
defaultStreamID = "91bd0bba-012f-4d92-bb17-5fc41e6fe545"
defaultComponentName = "Test Component"
// Default port is not standard XMPP port to avoid interfering
// with local running XMPP server
testHandshakePort = iota + 15222
testDecoderPort
testSendIqPort
testSendRawPort
testDisconnectPort
testSManDisconnectPort
) )
func TestHandshake(t *testing.T) { func TestHandshake(t *testing.T) {
@ -47,17 +35,15 @@ func TestHandshake(t *testing.T) {
} }
// Tests connection process with a handshake exchange // Tests connection process with a handshake exchange
// Tests multiple session IDs. All connections should generate a unique stream ID // Tests multiple session IDs. All serverConnections should generate a unique stream ID
func TestGenerateHandshake(t *testing.T) { func TestGenerateHandshakeId(t *testing.T) {
// Using this array with a channel to make a queue of values to test // Using this array with a channel to make a queue of values to test
// These are stream IDs that will be used to test the connection process, mixing them with the "secret" to generate // These are stream IDs that will be used to test the connection process, mixing them with the "secret" to generate
// some handshake value // some handshake value
var uuidsArray = [5]string{ var uuidsArray = [5]string{}
"cc9b3249-9582-4780-825f-4311b42f9b0e", for i := 1; i < len(uuidsArray); i++ {
"bba8be3c-d98e-4e26-b9bb-9ed34578a503", id, _ := uuid.NewRandom()
"dae72822-80e8-496b-b763-ab685f53a188", uuidsArray[i] = id.String()
"a45d6c06-de49-4bb0-935b-1a2201b71028",
"7dc6924f-0eca-4237-9898-18654b8d891e",
} }
// Channel to pass stream IDs as a queue // Channel to pass stream IDs as a queue
@ -69,11 +55,11 @@ func TestGenerateHandshake(t *testing.T) {
// Performs a Component connection with a handshake. It expects to have an ID sent its way through the "uchan" // Performs a Component connection with a handshake. It expects to have an ID sent its way through the "uchan"
// channel of this file. Otherwise it will hang for ever. // channel of this file. Otherwise it will hang for ever.
h := func(t *testing.T, c net.Conn) { h := func(t *testing.T, sc *ServerConn) {
decoder := xml.NewDecoder(c)
checkOpenStreamHandshakeID(t, c, decoder, <-uchan) checkOpenStreamHandshakeID(t, sc, <-uchan)
readHandshakeComponent(t, decoder) readHandshakeComponent(t, sc.decoder)
fmt.Fprintln(c, "<handshake/>") // That's all the server needs to return (see xep-0114) fmt.Fprintln(sc.connection, "<handshake/>") // That's all the server needs to return (see xep-0114)
return return
} }
@ -95,7 +81,7 @@ func TestGenerateHandshake(t *testing.T) {
Type: "service", Type: "service",
} }
router := NewRouter() router := NewRouter()
c, err := NewComponent(opts, router) c, err := NewComponent(opts, router, componentDefaultErrorHandler)
if err != nil { if err != nil {
t.Errorf("%+v", err) t.Errorf("%+v", err)
} }
@ -126,7 +112,7 @@ func TestStreamManager(t *testing.T) {
// The decoder is expected to be built after a valid connection // The decoder is expected to be built after a valid connection
// Based on the xmpp_component example. // Based on the xmpp_component example.
func TestDecoder(t *testing.T) { func TestDecoder(t *testing.T) {
c, _ := mockConnection(t, testDecoderPort, handlerForComponentHandshakeDefaultID) c, _ := mockComponentConnection(t, testDecoderPort, handlerForComponentHandshakeDefaultID)
if c.transport.GetDecoder() == nil { if c.transport.GetDecoder() == nil {
t.Errorf("Failed to initialize decoder. Decoder is nil.") t.Errorf("Failed to initialize decoder. Decoder is nil.")
} }
@ -134,63 +120,137 @@ func TestDecoder(t *testing.T) {
// Tests sending an IQ to the server, and getting the response // Tests sending an IQ to the server, and getting the response
func TestSendIq(t *testing.T) { func TestSendIq(t *testing.T) {
done := make(chan struct{})
h := func(t *testing.T, sc *ServerConn) {
handlerForComponentIQSend(t, sc)
done <- struct{}{}
}
//Connecting to a mock server, initialized with given port and handler function //Connecting to a mock server, initialized with given port and handler function
c, m := mockConnection(t, testSendIqPort, handlerForComponentIQSend) c, m := mockComponentConnection(t, testSendIqPort, h)
ctx, _ := context.WithTimeout(context.Background(), 30*time.Second) ctx, _ := context.WithTimeout(context.Background(), 30*time.Second)
iqReq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "test1@localhost/mremond-mbp", To: defaultServerName, Id: defaultStreamID, Lang: "en"}) iqReq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "test1@localhost/mremond-mbp", To: defaultServerName, Id: defaultStreamID, Lang: "en"})
disco := iqReq.DiscoInfo() disco := iqReq.DiscoInfo()
iqReq.Payload = disco iqReq.Payload = disco
// Handle a possible error
errChan := make(chan error)
errorHandler := func(err error) {
errChan <- err
}
c.ErrorHandler = errorHandler
var res chan stanza.IQ var res chan stanza.IQ
res, _ = c.SendIQ(ctx, iqReq) res, _ = c.SendIQ(ctx, iqReq)
select { select {
case <-res: case <-res:
case <-time.After(100 * time.Millisecond): case err := <-errChan:
t.Errorf(err.Error())
case <-time.After(defaultChannelTimeout):
t.Errorf("Failed to receive response, to sent IQ, from mock server") t.Errorf("Failed to receive response, to sent IQ, from mock server")
} }
select {
case <-done:
m.Stop() m.Stop()
case <-time.After(defaultChannelTimeout):
t.Errorf("The mock server failed to finish its job !")
}
}
// Checking that error handling is done properly client side when an invalid IQ is sent and the server responds in kind.
func TestSendIqFail(t *testing.T) {
done := make(chan struct{})
h := func(t *testing.T, sc *ServerConn) {
handlerForComponentIQSend(t, sc)
done <- struct{}{}
}
//Connecting to a mock server, initialized with given port and handler function
c, m := mockComponentConnection(t, testSendIqFailPort, h)
ctx, _ := context.WithTimeout(context.Background(), 30*time.Second)
iqReq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "test1@localhost/mremond-mbp", To: defaultServerName, Id: defaultStreamID, Lang: "en"})
// Removing the id to make the stanza invalid. The IQ constructor makes a random one if none is specified
// so we need to overwrite it.
iqReq.Id = ""
disco := iqReq.DiscoInfo()
iqReq.Payload = disco
errChan := make(chan error)
errorHandler := func(err error) {
errChan <- err
}
c.ErrorHandler = errorHandler
var res chan stanza.IQ
res, _ = c.SendIQ(ctx, iqReq)
select {
case r := <-res: // Do we get an IQ response from the server ?
t.Errorf("We should not be getting an IQ response here : this should fail !")
fmt.Println(r)
case <-errChan: // Do we get a stream error from the server ?
// If we get an error from the server, the test passes.
case <-time.After(defaultChannelTimeout): // Timeout ?
t.Errorf("Failed to receive response, to sent IQ, from mock server")
}
select {
case <-done:
m.Stop()
case <-time.After(defaultChannelTimeout):
t.Errorf("The mock server failed to finish its job !")
}
} }
// Tests sending raw xml to the mock server. // Tests sending raw xml to the mock server.
// TODO : check the server response client side ?
// Right now, the server response is not checked and an err is passed in a channel if the test is supposed to err. // Right now, the server response is not checked and an err is passed in a channel if the test is supposed to err.
// In this test, we use IQs // In this test, we use IQs
func TestSendRaw(t *testing.T) { func TestSendRaw(t *testing.T) {
// Error channel for the handler done := make(chan struct{})
errChan := make(chan error)
// Handler for the mock server // Handler for the mock server
h := func(t *testing.T, c net.Conn) { h := func(t *testing.T, sc *ServerConn) {
// Completes the connection by exchanging handshakes // Completes the connection by exchanging handshakes
handlerForComponentHandshakeDefaultID(t, c) handlerForComponentHandshakeDefaultID(t, sc)
receiveRawIq(t, c, errChan) respondToIQ(t, sc)
return done <- struct{}{}
} }
type testCase struct { type testCase struct {
req string req string
shouldErr bool shouldErr bool
port int
} }
testRequests := make(map[string]testCase) testRequests := make(map[string]testCase)
// Sending a correct IQ of type get. Not supposed to err // Sending a correct IQ of type get. Not supposed to err
testRequests["Correct IQ"] = testCase{ testRequests["Correct IQ"] = testCase{
req: `<iq type="get" id="91bd0bba-012f-4d92-bb17-5fc41e6fe545" from="test1@localhost/mremond-mbp" to="testServer" lang="en"><query xmlns="http://jabber.org/protocol/disco#info"></query></iq>`, req: `<iq type="get" id="91bd0bba-012f-4d92-bb17-5fc41e6fe545" from="test1@localhost/mremond-mbp" to="testServer" lang="en"><query xmlns="http://jabber.org/protocol/disco#info"></query></iq>`,
shouldErr: false, shouldErr: false,
port: testSendRawPort + 100,
} }
// Sending an IQ with a missing ID. Should err // Sending an IQ with a missing ID. Should err
testRequests["IQ with missing ID"] = testCase{ testRequests["IQ with missing ID"] = testCase{
req: `<iq type="get" from="test1@localhost/mremond-mbp" to="testServer" lang="en"><query xmlns="http://jabber.org/protocol/disco#info"></query></iq>`, req: `<iq type="get" from="test1@localhost/mremond-mbp" to="testServer" lang="en"><query xmlns="http://jabber.org/protocol/disco#info"></query></iq>`,
shouldErr: true, shouldErr: true,
port: testSendRawPort + 200,
}
// A handler for the component.
// In the failing test, the server returns a stream error, which triggers this handler, component side.
errChan := make(chan error)
errHandler := func(err error) {
errChan <- err
} }
// Tests for all the IQs // Tests for all the IQs
for name, tcase := range testRequests { for name, tcase := range testRequests {
t.Run(name, func(st *testing.T) { t.Run(name, func(st *testing.T) {
//Connecting to a mock server, initialized with given port and handler function //Connecting to a mock server, initialized with given port and handler function
c, m := mockConnection(t, testSendRawPort, h) c, m := mockComponentConnection(t, tcase.port, h)
c.ErrorHandler = errHandler
// Sending raw xml from test case // Sending raw xml from test case
err := c.SendRaw(tcase.req) err := c.SendRaw(tcase.req)
if err != nil { if err != nil {
@ -198,21 +258,29 @@ func TestSendRaw(t *testing.T) {
} }
// Just wait a little so the message has time to arrive // Just wait a little so the message has time to arrive
select { select {
case <-time.After(100 * time.Millisecond): // We don't use the default "long" timeout here because waiting it out means passing the test.
case <-time.After(200 * time.Millisecond):
case err = <-errChan: case err = <-errChan:
if err == nil && tcase.shouldErr { if err == nil && tcase.shouldErr {
t.Errorf("Failed to get closing stream err") t.Errorf("Failed to get closing stream err")
} else if err != nil && !tcase.shouldErr {
t.Errorf("This test is not supposed to err ! => %s", err.Error())
} }
} }
c.transport.Close() c.transport.Close()
select {
case <-done:
m.Stop() m.Stop()
case <-time.After(defaultChannelTimeout):
t.Errorf("The mock server failed to finish its job !")
}
}) })
} }
} }
// Tests the Disconnect method for Components // Tests the Disconnect method for Components
func TestDisconnect(t *testing.T) { func TestDisconnect(t *testing.T) {
c, m := mockConnection(t, testDisconnectPort, handlerForComponentHandshakeDefaultID) c, m := mockComponentConnection(t, testDisconnectPort, handlerForComponentHandshakeDefaultID)
err := c.transport.Ping() err := c.transport.Ping()
if err != nil { if err != nil {
t.Errorf("Could not ping but not disconnected yet") t.Errorf("Could not ping but not disconnected yet")
@ -257,22 +325,106 @@ func TestStreamManagerDisconnect(t *testing.T) {
//============================================================================= //=============================================================================
// Basic XMPP Server Mock Handlers. // Basic XMPP Server Mock Handlers.
// Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant.
// Used in the mock server as a Handler //===============================
func handlerForComponentHandshakeDefaultID(t *testing.T, c net.Conn) { // Init mock server and connection
decoder := xml.NewDecoder(c) // Creating a mock server and connecting a Component to it. Initialized with given port and handler function
checkOpenStreamHandshakeDefaultID(t, c, decoder) // The Component and mock are both returned
readHandshakeComponent(t, decoder) func mockComponentConnection(t *testing.T, port int, handler func(t *testing.T, sc *ServerConn)) (*Component, *ServerMock) {
fmt.Fprintln(c, "<handshake/>") // That's all the server needs to return (see xep-0114) // Init mock server
testComponentAddress := fmt.Sprintf("%s:%d", testComponentDomain, port)
mock := &ServerMock{}
mock.Start(t, testComponentAddress, handler)
//==================================
// Create Component to connect to it
c := makeBasicComponent(defaultComponentName, testComponentAddress, t)
//========================================
// Connect the new Component to the server
err := c.Connect()
if err != nil {
t.Errorf("%+v", err)
}
// Now that the Component is connected, let's set the xml.Decoder for the server
return c, mock
}
func makeBasicComponent(name string, mockServerAddr string, t *testing.T) *Component {
opts := ComponentOptions{
TransportConfiguration: TransportConfiguration{
Address: mockServerAddr,
Domain: "localhost",
},
Domain: testComponentDomain,
Secret: "mypass",
Name: name,
Category: "gateway",
Type: "service",
}
router := NewRouter()
c, err := NewComponent(opts, router, componentDefaultErrorHandler)
if err != nil {
t.Errorf("%+v", err)
}
c.transport, err = NewComponentTransport(c.ComponentOptions.TransportConfiguration)
if err != nil {
t.Errorf("%+v", err)
}
return c
}
// This really should not be used as is.
// It's just meant to be a placeholder when error handling is not needed at this level
func componentDefaultErrorHandler(err error) {
}
// Sends IQ response to Component request.
// No parsing of the request here. We just check that it's valid, and send the default response.
func handlerForComponentIQSend(t *testing.T, sc *ServerConn) {
// Completes the connection by exchanging handshakes
handlerForComponentHandshakeDefaultID(t, sc)
respondToIQ(t, sc)
}
// Used for ID and handshake related tests
func checkOpenStreamHandshakeID(t *testing.T, sc *ServerConn, streamID string) {
sc.connection.SetDeadline(time.Now().Add(defaultTimeout))
defer sc.connection.SetDeadline(time.Time{})
for { // TODO clean up. That for loop is not elegant and I prefer bounded recursion.
token, err := sc.decoder.Token()
if err != nil {
t.Errorf("cannot read next token: %s", err)
}
switch elem := token.(type) {
// Wait for first startElement
case xml.StartElement:
if elem.Name.Space != stanza.NSStream || elem.Name.Local != "stream" {
err = errors.New("xmpp: expected <stream> but got <" + elem.Name.Local + "> in " + elem.Name.Space)
return return
} }
if _, err := fmt.Fprintf(sc.connection, serverStreamOpen, "localhost", streamID, stanza.NSComponent, stanza.NSStream); err != nil {
t.Errorf("cannot write server stream open: %s", err)
}
return
}
}
}
func checkOpenStreamHandshakeDefaultID(t *testing.T, sc *ServerConn) {
checkOpenStreamHandshakeID(t, sc, defaultStreamID)
}
// Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant. // Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant.
// This handler is supposed to fail by sending a "message" stanza instead of a <handshake/> stanza to finalize the handshake. // This handler is supposed to fail by sending a "message" stanza instead of a <handshake/> stanza to finalize the handshake.
func handlerComponentFailedHandshakeDefaultID(t *testing.T, c net.Conn) { func handlerComponentFailedHandshakeDefaultID(t *testing.T, sc *ServerConn) {
decoder := xml.NewDecoder(c) checkOpenStreamHandshakeDefaultID(t, sc)
checkOpenStreamHandshakeDefaultID(t, c, decoder) readHandshakeComponent(t, sc.decoder)
readHandshakeComponent(t, decoder)
// Send a message, instead of a "<handshake/>" tag, to fail the handshake process dans disconnect the client. // Send a message, instead of a "<handshake/>" tag, to fail the handshake process dans disconnect the client.
me := stanza.Message{ me := stanza.Message{
@ -280,7 +432,7 @@ func handlerComponentFailedHandshakeDefaultID(t *testing.T, c net.Conn) {
Body: "Fail my handshake.", Body: "Fail my handshake.",
} }
s, _ := xml.Marshal(me) s, _ := xml.Marshal(me)
fmt.Fprintln(c, string(s)) fmt.Fprintln(sc.connection, string(s))
return return
} }
@ -303,152 +455,11 @@ func readHandshakeComponent(t *testing.T, decoder *xml.Decoder) {
} }
} }
func checkOpenStreamHandshakeDefaultID(t *testing.T, c net.Conn, decoder *xml.Decoder) { // Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant.
checkOpenStreamHandshakeID(t, c, decoder, defaultStreamID) // Used in the mock server as a Handler
} func handlerForComponentHandshakeDefaultID(t *testing.T, sc *ServerConn) {
checkOpenStreamHandshakeDefaultID(t, sc)
// Used for ID and handshake related tests readHandshakeComponent(t, sc.decoder)
func checkOpenStreamHandshakeID(t *testing.T, c net.Conn, decoder *xml.Decoder, streamID string) { fmt.Fprintln(sc.connection, "<handshake/>") // That's all the server needs to return (see xep-0114)
c.SetDeadline(time.Now().Add(defaultTimeout))
defer c.SetDeadline(time.Time{})
for { // TODO clean up. That for loop is not elegant and I prefer bounded recursion.
token, err := decoder.Token()
if err != nil {
t.Errorf("cannot read next token: %s", err)
}
switch elem := token.(type) {
// Wait for first startElement
case xml.StartElement:
if elem.Name.Space != stanza.NSStream || elem.Name.Local != "stream" {
err = errors.New("xmpp: expected <stream> but got <" + elem.Name.Local + "> in " + elem.Name.Space)
return return
} }
if _, err := fmt.Fprintf(c, serverStreamOpen, "localhost", streamID, stanza.NSComponent, stanza.NSStream); err != nil {
t.Errorf("cannot write server stream open: %s", err)
}
return
}
}
}
//=============================================================================
// Sends IQ response to Component request.
// No parsing of the request here. We just check that it's valid, and send the default response.
func handlerForComponentIQSend(t *testing.T, c net.Conn) {
// Completes the connection by exchanging handshakes
handlerForComponentHandshakeDefaultID(t, c)
// Decoder to parse the request
decoder := xml.NewDecoder(c)
iqReq, err := receiveIq(t, c, decoder)
if err != nil {
t.Errorf("Error receiving the IQ stanza : %v", err)
} else if !iqReq.IsValid() {
t.Errorf("server received an IQ stanza : %v", iqReq)
}
// Crafting response
iqResp := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeResult, From: iqReq.To, To: iqReq.From, Id: iqReq.Id, Lang: "en"})
disco := iqResp.DiscoInfo()
disco.AddFeatures("vcard-temp",
`http://jabber.org/protocol/address`)
disco.AddIdentity("Multicast", "service", "multicast")
iqResp.Payload = disco
// Sending response to the Component
mResp, err := xml.Marshal(iqResp)
_, err = fmt.Fprintln(c, string(mResp))
if err != nil {
t.Errorf("Could not send response stanza : %s", err)
}
return
}
// Reads next request coming from the Component. Expecting it to be an IQ request
func receiveIq(t *testing.T, c net.Conn, decoder *xml.Decoder) (stanza.IQ, error) {
c.SetDeadline(time.Now().Add(defaultTimeout))
defer c.SetDeadline(time.Time{})
var iqStz stanza.IQ
err := decoder.Decode(&iqStz)
if err != nil {
t.Errorf("cannot read the received IQ stanza: %s", err)
}
if !iqStz.IsValid() {
t.Errorf("received IQ stanza is invalid : %s", err)
}
return iqStz, nil
}
func receiveRawIq(t *testing.T, c net.Conn, errChan chan error) {
c.SetDeadline(time.Now().Add(defaultTimeout))
defer c.SetDeadline(time.Time{})
decoder := xml.NewDecoder(c)
var iq stanza.IQ
err := decoder.Decode(&iq)
if err != nil || !iq.IsValid() {
s := stanza.StreamError{
XMLName: xml.Name{Local: "stream:error"},
Error: xml.Name{Local: "xml-not-well-formed"},
Text: `XML was not well-formed`,
}
raw, _ := xml.Marshal(s)
fmt.Fprintln(c, string(raw))
fmt.Fprintln(c, `</stream:stream>`) // TODO : check this client side
errChan <- fmt.Errorf("invalid xml")
return
}
errChan <- nil
return
}
//===============================
// Init mock server and connection
// Creating a mock server and connecting a Component to it. Initialized with given port and handler function
// The Component and mock are both returned
func mockConnection(t *testing.T, port int, handler func(t *testing.T, c net.Conn)) (*Component, *ServerMock) {
// Init mock server
testComponentAddress := fmt.Sprintf("%s:%d", testComponentDomain, port)
mock := ServerMock{}
mock.Start(t, testComponentAddress, handler)
//==================================
// Create Component to connect to it
c := makeBasicComponent(defaultComponentName, testComponentAddress, t)
//========================================
// Connect the new Component to the server
err := c.Connect()
if err != nil {
t.Errorf("%+v", err)
}
return c, &mock
}
func makeBasicComponent(name string, mockServerAddr string, t *testing.T) *Component {
opts := ComponentOptions{
TransportConfiguration: TransportConfiguration{
Address: mockServerAddr,
Domain: "localhost",
},
Domain: testComponentDomain,
Secret: "mypass",
Name: name,
Category: "gateway",
Type: "service",
}
router := NewRouter()
c, err := NewComponent(opts, router)
if err != nil {
t.Errorf("%+v", err)
}
c.transport, err = NewComponentTransport(c.ComponentOptions.TransportConfiguration)
if err != nil {
t.Errorf("%+v", err)
}
return c
}

2
doc.go
View file

@ -29,7 +29,7 @@ Components
XMPP components can typically be used to extends the features of an XMPP XMPP components can typically be used to extends the features of an XMPP
server, in a portable way, using component protocol over persistent TCP server, in a portable way, using component protocol over persistent TCP
connections. serverConnections.
Component protocol is defined in XEP-114 (https://xmpp.org/extensions/xep-0114.html). Component protocol is defined in XEP-114 (https://xmpp.org/extensions/xep-0114.html).

View file

@ -119,7 +119,7 @@ func (s *Session) startTlsIfSupported(o Config) {
return return
} }
// If we do not allow cleartext connections, make it explicit that server do not support starttls // If we do not allow cleartext serverConnections, make it explicit that server do not support starttls
if !o.Insecure { if !o.Insecure {
s.err = errors.New("XMPP server does not advertise support for starttls") s.err = errors.New("XMPP server does not advertise support for starttls")
} }

View file

@ -1,17 +1,47 @@
package xmpp package xmpp
import ( import (
"encoding/xml"
"fmt"
"gosrc.io/xmpp/stanza"
"net" "net"
"testing" "testing"
"time"
) )
//============================================================================= //=============================================================================
// TCP Server Mock // TCP Server Mock
const (
defaultTimeout = 2 * time.Second
testComponentDomain = "localhost"
defaultServerName = "testServer"
defaultStreamID = "91bd0bba-012f-4d92-bb17-5fc41e6fe545"
defaultComponentName = "Test Component"
serverStreamOpen = "<?xml version='1.0'?><stream:stream to='%s' id='%s' xmlns='%s' xmlns:stream='%s' version='1.0'>"
// Default port is not standard XMPP port to avoid interfering
// with local running XMPP server
// Component tests
testHandshakePort = iota + 15222
testDecoderPort
testSendIqPort
testSendIqFailPort
testSendRawPort
testDisconnectPort
testSManDisconnectPort
// Client tests
testClientBasePort
testClientRawPort
testClientIqPort
testClientIqFailPort
)
// ClientHandler is passed by the test client to provide custom behaviour to // ClientHandler is passed by the test client to provide custom behaviour to
// the TCP server mock. This allows customizing the server behaviour to allow // the TCP server mock. This allows customizing the server behaviour to allow
// testing clients under various scenarii. // testing clients under various scenarii.
type ClientHandler func(t *testing.T, conn net.Conn) type ClientHandler func(t *testing.T, serverConn *ServerConn)
// ServerMock is a simple TCP server that can be use to mock basic server // ServerMock is a simple TCP server that can be use to mock basic server
// behaviour to test clients. // behaviour to test clients.
@ -19,10 +49,15 @@ type ServerMock struct {
t *testing.T t *testing.T
handler ClientHandler handler ClientHandler
listener net.Listener listener net.Listener
connections []net.Conn serverConnections []*ServerConn
done chan struct{} done chan struct{}
} }
type ServerConn struct {
connection net.Conn
decoder *xml.Decoder
}
// Start launches the mock TCP server, listening to an actual address / port. // Start launches the mock TCP server, listening to an actual address / port.
func (mock *ServerMock) Start(t *testing.T, addr string, handler ClientHandler) { func (mock *ServerMock) Start(t *testing.T, addr string, handler ClientHandler) {
mock.t = t mock.t = t
@ -38,9 +73,9 @@ func (mock *ServerMock) Stop() {
if mock.listener != nil { if mock.listener != nil {
mock.listener.Close() mock.listener.Close()
} }
// Close all existing connections // Close all existing serverConnections
for _, c := range mock.connections { for _, c := range mock.serverConnections {
c.Close() c.connection.Close()
} }
} }
@ -60,13 +95,14 @@ func (mock *ServerMock) init(addr string) error {
return nil return nil
} }
// loop accepts connections and creates a go routine per connection. // loop accepts serverConnections and creates a go routine per connection.
// The go routine is running the client handler, that is used to provide the // The go routine is running the client handler, that is used to provide the
// real TCP server behaviour. // real TCP server behaviour.
func (mock *ServerMock) loop() { func (mock *ServerMock) loop() {
listener := mock.listener listener := mock.listener
for { for {
conn, err := listener.Accept() conn, err := listener.Accept()
serverConn := &ServerConn{conn, xml.NewDecoder(conn)}
if err != nil { if err != nil {
select { select {
case <-mock.done: case <-mock.done:
@ -76,8 +112,195 @@ func (mock *ServerMock) loop() {
} }
return return
} }
mock.connections = append(mock.connections, conn) mock.serverConnections = append(mock.serverConnections, serverConn)
// TODO Create and pass a context to cancel the handler if they are still around = avoid possible leak on complex handlers // TODO Create and pass a context to cancel the handler if they are still around = avoid possible leak on complex handlers
go mock.handler(mock.t, conn) go mock.handler(mock.t, serverConn)
}
}
//======================================================================================================================
// A few functions commonly used for tests. Trying to avoid duplicates in client and component test files.
//======================================================================================================================
func respondToIQ(t *testing.T, sc *ServerConn) {
// Decoder to parse the request
iqReq, err := receiveIq(sc)
if err != nil {
t.Fatalf("failed to receive IQ : %s", err.Error())
}
if !iqReq.IsValid() {
mockIQError(sc.connection)
return
}
// Crafting response
iqResp := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeResult, From: iqReq.To, To: iqReq.From, Id: iqReq.Id, Lang: "en"})
disco := iqResp.DiscoInfo()
disco.AddFeatures("vcard-temp",
`http://jabber.org/protocol/address`)
disco.AddIdentity("Multicast", "service", "multicast")
iqResp.Payload = disco
// Sending response to the Component
mResp, err := xml.Marshal(iqResp)
_, err = fmt.Fprintln(sc.connection, string(mResp))
if err != nil {
t.Errorf("Could not send response stanza : %s", err)
}
return
}
// When a presence stanza is automatically sent (right now it's the case in the client), we may want to discard it
// and test further stanzas.
func discardPresence(t *testing.T, sc *ServerConn) {
sc.connection.SetDeadline(time.Now().Add(defaultTimeout))
defer sc.connection.SetDeadline(time.Time{})
var presenceStz stanza.Presence
recvBuf := make([]byte, len(InitialPresence))
_, err := sc.connection.Read(recvBuf[:]) // recv data
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
t.Errorf("read timeout: %s", err)
} else {
t.Errorf("read error: %s", err)
}
}
xml.Unmarshal(recvBuf, &presenceStz)
if err != nil {
t.Errorf("Expected presence but this happened : %s", err.Error())
}
}
// Reads next request coming from the Component. Expecting it to be an IQ request
func receiveIq(sc *ServerConn) (*stanza.IQ, error) {
sc.connection.SetDeadline(time.Now().Add(defaultTimeout))
defer sc.connection.SetDeadline(time.Time{})
var iqStz stanza.IQ
err := sc.decoder.Decode(&iqStz)
if err != nil {
return nil, err
}
return &iqStz, nil
}
// Should be used in server handlers when an IQ sent by a client or component is invalid.
// This responds as expected from a "real" server, aside from the error message.
func mockIQError(c net.Conn) {
s := stanza.StreamError{
XMLName: xml.Name{Local: "stream:error"},
Error: xml.Name{Local: "xml-not-well-formed"},
Text: `XML was not well-formed`,
}
raw, _ := xml.Marshal(s)
fmt.Fprintln(c, string(raw))
fmt.Fprintln(c, `</stream:stream>`)
}
func sendStreamFeatures(t *testing.T, sc *ServerConn) {
// This is a basic server, supporting only 1 stream feature: SASL Plain Auth
features := `<stream:features>
<mechanisms xmlns="urn:ietf:params:xml:ns:xmpp-sasl">
<mechanism>PLAIN</mechanism>
</mechanisms>
</stream:features>`
if _, err := fmt.Fprintln(sc.connection, features); err != nil {
t.Errorf("cannot send stream feature: %s", err)
}
}
// TODO return err in case of error reading the auth params
func readAuth(t *testing.T, decoder *xml.Decoder) string {
se, err := stanza.NextStart(decoder)
if err != nil {
t.Errorf("cannot read auth: %s", err)
return ""
}
var nv interface{}
nv = &stanza.SASLAuth{}
// Decode element into pointer storage
if err = decoder.DecodeElement(nv, &se); err != nil {
t.Errorf("cannot decode auth: %s", err)
return ""
}
switch v := nv.(type) {
case *stanza.SASLAuth:
return v.Value
}
return ""
}
func sendBindFeature(t *testing.T, sc *ServerConn) {
// This is a basic server, supporting only 1 stream feature after auth: resource binding
features := `<stream:features>
<bind xmlns='urn:ietf:params:xml:ns:xmpp-bind'/>
</stream:features>`
if _, err := fmt.Fprintln(sc.connection, features); err != nil {
t.Errorf("cannot send stream feature: %s", err)
}
}
func sendRFC3921Feature(t *testing.T, sc *ServerConn) {
// This is a basic server, supporting only 2 features after auth: resource & session binding
features := `<stream:features>
<bind xmlns='urn:ietf:params:xml:ns:xmpp-bind'/>
<session xmlns='urn:ietf:params:xml:ns:xmpp-session'/>
</stream:features>`
if _, err := fmt.Fprintln(sc.connection, features); err != nil {
t.Errorf("cannot send stream feature: %s", err)
}
}
func bind(t *testing.T, sc *ServerConn) {
se, err := stanza.NextStart(sc.decoder)
if err != nil {
t.Errorf("cannot read bind: %s", err)
return
}
iq := &stanza.IQ{}
// Decode element into pointer storage
if err = sc.decoder.DecodeElement(&iq, &se); err != nil {
t.Errorf("cannot decode bind iq: %s", err)
return
}
// TODO Check all elements
switch iq.Payload.(type) {
case *stanza.Bind:
result := `<iq id='%s' type='result'>
<bind xmlns='urn:ietf:params:xml:ns:xmpp-bind'>
<jid>%s</jid>
</bind>
</iq>`
fmt.Fprintf(sc.connection, result, iq.Id, "test@localhost/test") // TODO use real JID
}
}
func session(t *testing.T, sc *ServerConn) {
se, err := stanza.NextStart(sc.decoder)
if err != nil {
t.Errorf("cannot read session: %s", err)
return
}
iq := &stanza.IQ{}
// Decode element into pointer storage
if err = sc.decoder.DecodeElement(&iq, &se); err != nil {
t.Errorf("cannot decode session iq: %s", err)
return
}
switch iq.Payload.(type) {
case *stanza.StreamSession:
result := `<iq id='%s' type='result'/>`
fmt.Fprintf(sc.connection, result, iq.Id)
} }
} }