From 7a4364be955379cf140e19f4f8de235b13131c23 Mon Sep 17 00:00:00 2001 From: Mickael Remond Date: Tue, 4 Jun 2019 17:04:25 +0200 Subject: [PATCH] Refactor / clean up registry --- auth.go | 4 +- cmd/xmpp-check/xmpp-check.go | 1 - cmd/xmpp_jukebox/xmpp_jukebox.go | 13 +++-- control_test.go | 4 +- iot/control.go => iot_control.go | 14 ++--- iq.go | 47 ++++++---------- message.go | 3 +- message_test.go | 34 ------------ msg_receipts.go | 20 ++++--- msg_receipts_test.go | 42 ++++++++++++++ pep/user_tune.go => pep.go | 25 ++------- registry.go | 94 +++++++++++++++++++++++--------- registry_test.go | 47 ++++++++++++++++ 13 files changed, 205 insertions(+), 143 deletions(-) rename iot/control.go => iot_control.go (74%) create mode 100644 msg_receipts_test.go rename pep/user_tune.go => pep.go (72%) create mode 100644 registry_test.go diff --git a/auth.go b/auth.go index de38cac..36ce19e 100644 --- a/auth.go +++ b/auth.go @@ -105,14 +105,12 @@ type auth struct { } type BindBind struct { + IQPayload XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-bind bind"` Resource string `xml:"resource,omitempty"` Jid string `xml:"jid,omitempty"` } -func (*BindBind) IsIQPayload() { -} - // Session is obsolete in RFC 6121. // Added for compliance with RFC 3121. // Remove when ejabberd purely conforms to RFC 6121. diff --git a/cmd/xmpp-check/xmpp-check.go b/cmd/xmpp-check/xmpp-check.go index 6e7ceab..77af605 100644 --- a/cmd/xmpp-check/xmpp-check.go +++ b/cmd/xmpp-check/xmpp-check.go @@ -29,7 +29,6 @@ func main() { func runCheck(address, domain string) { client, err := xmpp.NewChecker(address, domain) - // client, err := xmpp.NewChecker("mickael.m.in-app.io:5222", "mickael.m.in-app.io") if err != nil { log.Fatal("Error: ", err) diff --git a/cmd/xmpp_jukebox/xmpp_jukebox.go b/cmd/xmpp_jukebox/xmpp_jukebox.go index bd3e610..a60e164 100644 --- a/cmd/xmpp_jukebox/xmpp_jukebox.go +++ b/cmd/xmpp_jukebox/xmpp_jukebox.go @@ -12,8 +12,6 @@ import ( "github.com/processone/mpg123" "github.com/processone/soundcloud" "gosrc.io/xmpp" - "gosrc.io/xmpp/iot" - "gosrc.io/xmpp/pep" ) // Get the actual song Stream URL from SoundCloud website song URL and play it with mpg123 player. @@ -65,7 +63,7 @@ func processMessage(client *xmpp.Client, p *mpg123.Player, packet *xmpp.Message) func processIq(client *xmpp.Client, p *mpg123.Player, packet *xmpp.IQ) { switch payload := packet.Payload[0].(type) { // We support IOT Control IQ - case *iot.ControlSet: + case *xmpp.ControlSet: var url string for _, element := range payload.Fields { if element.XMLName.Local == "string" && element.Name == "url" { @@ -75,7 +73,7 @@ func processIq(client *xmpp.Client, p *mpg123.Player, packet *xmpp.IQ) { } playSCURL(p, url) - setResponse := new(iot.ControlSetResponse) + setResponse := new(xmpp.ControlSetResponse) reply := xmpp.IQ{PacketAttrs: xmpp.PacketAttrs{To: packet.From, Type: "result", Id: packet.Id}, Payload: []xmpp.IQPayload{setResponse}} _ = client.Send(reply) // TODO add Soundclound artist / title retrieval @@ -86,8 +84,11 @@ func processIq(client *xmpp.Client, p *mpg123.Player, packet *xmpp.IQ) { } func sendUserTune(client *xmpp.Client, artist string, title string) { - tune := pep.Tune{Artist: artist, Title: title} - _ = client.SendRaw(tune.XMPPFormat()) + tune := xmpp.Tune{Artist: artist, Title: title} + iq := xmpp.NewIQ("set", "", "", "usertune-1", "en") + payload := xmpp.PubSub{Publish: xmpp.Publish{Node: "http://jabber.org/protocol/tune", Item: xmpp.Item{Tune: tune}}} + iq.AddPayload(&payload) + _ = client.Send(iq) } func playSCURL(p *mpg123.Player, rawURL string) { diff --git a/control_test.go b/control_test.go index f52f663..2a2a027 100644 --- a/control_test.go +++ b/control_test.go @@ -3,8 +3,6 @@ package xmpp // import "gosrc.io/xmpp" import ( "encoding/xml" "testing" - - "gosrc.io/xmpp/iot" ) func TestControlSet(t *testing.T) { @@ -22,7 +20,7 @@ func TestControlSet(t *testing.T) { t.Errorf("Unmarshal(%s) returned error", data) } - if cs, ok := parsedIQ.Payload[0].(*iot.ControlSet); !ok { + if cs, ok := parsedIQ.Payload[0].(*ControlSet); !ok { t.Errorf("Paylod is not an iot control set: %v", cs) } } diff --git a/iot/control.go b/iot_control.go similarity index 74% rename from iot/control.go rename to iot_control.go index 580f26a..795e5e3 100644 --- a/iot/control.go +++ b/iot_control.go @@ -1,15 +1,15 @@ -package iot // import "gosrc.io/xmpp/iot" +package xmpp // import "gosrc.io/xmpp/iot" -import "encoding/xml" +import ( + "encoding/xml" +) type ControlSet struct { + IQPayload XMLName xml.Name `xml:"urn:xmpp:iot:control set"` Fields []ControlField `xml:",any"` } -func (*ControlSet) IsIQPayload() { -} - type ControlGetForm struct { XMLName xml.Name `xml:"urn:xmpp:iot:control getForm"` } @@ -21,8 +21,6 @@ type ControlField struct { } type ControlSetResponse struct { + IQPayload XMLName xml.Name `xml:"urn:xmpp:iot:control setResponse"` } - -func (*ControlSetResponse) IsIQPayload() { -} diff --git a/iq.go b/iq.go index 00ec811..bfc13e3 100644 --- a/iq.go +++ b/iq.go @@ -2,10 +2,7 @@ package xmpp // import "gosrc.io/xmpp" import ( "encoding/xml" - "reflect" "strconv" - - "gosrc.io/xmpp/iot" ) /* @@ -19,6 +16,7 @@ TODO support ability to put Raw payload inside IQ // presence or iq stanza. // It is intended to be added in the payload of the erroneous stanza. type Err struct { + IQPayload XMLName xml.Name `xml:"error"` Code int `xml:"code,attr,omitempty"` Type string `xml:"type,attr,omitempty"` @@ -26,8 +24,6 @@ type Err struct { Text string `xml:"urn:ietf:params:xml:ns:xmpp-stanzas text,omitempty"` } -func (*Err) IsIQPayload() {} - // UnmarshalXML implements custom parsing for IQs func (x *Err) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { x.XMLName = start.Name @@ -208,22 +204,16 @@ func (iq *IQ) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { case xml.StartElement: level++ if level <= 1 { - var elt interface{} - payloadType := tt.Name.Space + " " + tt.Name.Local - if payloadType := iqTypeRegistry[payloadType]; payloadType != nil { - val := reflect.New(payloadType) - elt = val.Interface() - } else { - // TODO: Fix me. We do nothing of that element here. - elt = new(Node) - } - - if iqPl, ok := elt.(IQPayload); ok { - err = d.DecodeElement(elt, &tt) + if iqExt := typeRegistry.GetIQExtension(tt.Name); iqExt != nil { + // Decode payload extension + err = d.DecodeElement(iqExt, &tt) if err != nil { return err } - iq.Payload = append(iq.Payload, iqPl) + iq.Payload = append(iq.Payload, iqExt) + } else { + // TODO: Fix me. We do nothing of that element here. + // elt = new(Node) } } @@ -239,13 +229,12 @@ func (iq *IQ) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { // ============================================================================ // Generic IQ Payload -type IQPayload interface { - IsIQPayload() -} +type IQPayload interface{} // Node is a generic structure to represent XML data. It is used to parse // unreferenced or custom stanza payload. type Node struct { + IQPayload XMLName xml.Name Attrs []xml.Attr `xml:"-"` Content string `xml:",innerxml"` @@ -284,8 +273,6 @@ func (n Node) MarshalXML(e *xml.Encoder, start xml.StartElement) (err error) { return e.EncodeToken(xml.EndElement{Name: start.Name}) } -func (*Node) IsIQPayload() {} - // ============================================================================ // Disco @@ -295,14 +282,13 @@ const ( ) type DiscoInfo struct { + IQPayload XMLName xml.Name `xml:"http://jabber.org/protocol/disco#info query"` Node string `xml:"node,attr,omitempty"` Identity Identity `xml:"identity"` Features []Feature `xml:"feature"` } -func (*DiscoInfo) IsIQPayload() {} - type Identity struct { XMLName xml.Name `xml:"identity,omitempty"` Name string `xml:"name,attr,omitempty"` @@ -318,13 +304,12 @@ type Feature struct { // ============================================================================ type DiscoItems struct { + IQPayload XMLName xml.Name `xml:"http://jabber.org/protocol/disco#items query"` Node string `xml:"node,attr,omitempty"` Items []DiscoItem `xml:"item"` } -func (*DiscoItems) IsIQPayload() {} - type DiscoItem struct { XMLName xml.Name `xml:"item"` Name string `xml:"name,attr,omitempty"` @@ -333,8 +318,8 @@ type DiscoItem struct { } func init() { - iqTypeRegistry["http://jabber.org/protocol/disco#info query"] = reflect.TypeOf(DiscoInfo{}) - iqTypeRegistry["http://jabber.org/protocol/disco#items query"] = reflect.TypeOf(DiscoItems{}) - iqTypeRegistry["urn:ietf:params:xml:ns:xmpp-bind bind"] = reflect.TypeOf(BindBind{}) - iqTypeRegistry["urn:xmpp:iot:control set"] = reflect.TypeOf(iot.ControlSet{}) + typeRegistry.MapExtension(PKTIQ, xml.Name{"http://jabber.org/protocol/disco#info", "query"}, DiscoInfo{}) + typeRegistry.MapExtension(PKTIQ, xml.Name{"http://jabber.org/protocol/disco#items", "query"}, DiscoItems{}) + typeRegistry.MapExtension(PKTIQ, xml.Name{"urn:ietf:params:xml:ns:xmpp-bind", "bind"}, BindBind{}) + typeRegistry.MapExtension(PKTIQ, xml.Name{"urn:xmpp:iot:control", "set"}, ControlSet{}) } diff --git a/message.go b/message.go index 039df80..5b99348 100644 --- a/message.go +++ b/message.go @@ -86,8 +86,7 @@ func (msg *Message) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { switch tt := t.(type) { case xml.StartElement: - elementType := tt.Name.Space - if msgExt := typeRegistry.getmsgType(elementType); msgExt != nil { + if msgExt := typeRegistry.GetMsgExtension(tt.Name); msgExt != nil { // Decode message extension err = d.DecodeElement(msgExt, &tt) if err != nil { diff --git a/message_test.go b/message_test.go index 39738ae..fb30029 100644 --- a/message_test.go +++ b/message_test.go @@ -47,37 +47,3 @@ func TestDecodeError(t *testing.T) { t.Errorf("incorrect error type: %s", parsedMessage.Error.Type) } } - -func TestDecodeXEP0184(t *testing.T) { - str := ` - My lord, dispatch; read o'er these articles. - -` - parsedMessage := xmpp.Message{} - if err := xml.Unmarshal([]byte(str), &parsedMessage); err != nil { - t.Errorf("message receipt unmarshall error: %v", err) - return - } - - if parsedMessage.Body != "My lord, dispatch; read o'er these articles." { - t.Errorf("Unexpected body: '%s'", parsedMessage.Body) - } - - if len(parsedMessage.Extensions) < 1 { - t.Errorf("no extension found on parsed message") - return - } - - switch ext := parsedMessage.Extensions[0].(type) { - case *xmpp.Receipt: - if ext.XMLName.Local != "request" { - t.Errorf("unexpected extension: %s:%s", ext.XMLName.Space, ext.XMLName.Local) - } - default: - t.Errorf("could not find receipt extension") - } - -} diff --git a/msg_receipts.go b/msg_receipts.go index 5d317b3..a61740b 100644 --- a/msg_receipts.go +++ b/msg_receipts.go @@ -7,17 +7,19 @@ Support for: - XEP-0184 - Message Delivery Receipts: https://xmpp.org/extensions/xep-0184.html */ -const ( - NSReceipts = "urn:xmpp:receipts" -) - -// XEP-0184 message receipt markers -type Receipt struct { +// Used on outgoing message, to tell the recipient that you are requesting a message receipt / ack. +type ReceiptRequest struct { MsgExtension - XMLName xml.Name - Id string + XMLName xml.Name `xml:"urn:xmpp:receipts request"` +} + +type ReceiptReceived struct { + MsgExtension + XMLName xml.Name `xml:"urn:xmpp:receipts received"` + ID string } func init() { - typeRegistry.RegisterMsgExt(NSReceipts, Receipt{}) + typeRegistry.MapExtension(PKTMessage, xml.Name{"urn:xmpp:receipts", "request"}, ReceiptRequest{}) + typeRegistry.MapExtension(PKTMessage, xml.Name{"urn:xmpp:receipts", "received"}, ReceiptReceived{}) } diff --git a/msg_receipts_test.go b/msg_receipts_test.go new file mode 100644 index 0000000..0db0c64 --- /dev/null +++ b/msg_receipts_test.go @@ -0,0 +1,42 @@ +package xmpp_test + +import ( + "encoding/xml" + "testing" + + "gosrc.io/xmpp" +) + +func TestDecodeRequest(t *testing.T) { + str := ` + My lord, dispatch; read o'er these articles. + +` + parsedMessage := xmpp.Message{} + if err := xml.Unmarshal([]byte(str), &parsedMessage); err != nil { + t.Errorf("message receipt unmarshall error: %v", err) + return + } + + if parsedMessage.Body != "My lord, dispatch; read o'er these articles." { + t.Errorf("Unexpected body: '%s'", parsedMessage.Body) + } + + if len(parsedMessage.Extensions) < 1 { + t.Errorf("no extension found on parsed message") + return + } + + switch ext := parsedMessage.Extensions[0].(type) { + case *xmpp.ReceiptRequest: + if ext.XMLName.Local != "request" { + t.Errorf("unexpected extension: %s:%s", ext.XMLName.Space, ext.XMLName.Local) + } + default: + t.Errorf("could not find receipts extension") + } + +} diff --git a/pep/user_tune.go b/pep.go similarity index 72% rename from pep/user_tune.go rename to pep.go index c1fe5ad..1870ca8 100644 --- a/pep/user_tune.go +++ b/pep.go @@ -1,29 +1,21 @@ -package pep // import "gosrc.io/xmpp/pep" +package xmpp // import "gosrc.io/xmpp/pep" import ( "encoding/xml" - - "gosrc.io/xmpp" ) -type iq struct { - XMLName xml.Name `xml:"jabber:client iq"` - C pubSub // c for "contains" - xmpp.PacketAttrs // Rename h for "header" ? -} - -type pubSub struct { +type PubSub struct { XMLName xml.Name `xml:"http://jabber.org/protocol/pubsub pubsub"` - Publish publish + Publish Publish } -type publish struct { +type Publish struct { XMLName xml.Name `xml:"publish"` Node string `xml:"node,attr"` - Item item + Item Item } -type item struct { +type Item struct { XMLName xml.Name `xml:"item"` Tune Tune } @@ -67,11 +59,6 @@ type Tune struct { } */ -func (t *Tune) XMPPFormat() (s string) { - packet, _ := xml.Marshal(iq{PacketAttrs: xmpp.PacketAttrs{Id: "tunes", Type: "set"}, C: pubSub{Publish: publish{Node: "http://jabber.org/protocol/tune", Item: item{Tune: *t}}}}) - return string(packet) -} - /* func (*Tune) XMPPFormat() string { return fmt.Sprintf( diff --git a/registry.go b/registry.go index 6f62199..49e6178 100644 --- a/registry.go +++ b/registry.go @@ -1,6 +1,7 @@ package xmpp import ( + "encoding/xml" "reflect" "sync" ) @@ -11,43 +12,76 @@ type MsgExtension interface{} // TODO: Move to the client init process to remove the dependency on a global variable. // That should make it possible to be able to share the decoder. // TODO: Ensure that a client can add its own custom namespace to the registry (or overload existing ones). + +type packetType uint8 + +const ( + PKTPresence packetType = iota + PKTMessage + PKTIQ +) + var typeRegistry = newRegistry() -type namespace = string - -type registry struct { - // Key is namespace of message extension - msgTypes map[namespace]reflect.Type - msgTypesLock *sync.RWMutex - - iqTypes map[namespace]reflect.Type +// We store different registries per packet type and namespace. +type registryKey struct { + packetType packetType + namespace string } -func newRegistry() registry { - return registry{ - msgTypes: make(map[namespace]reflect.Type), +type registryForNamespace map[string]reflect.Type + +type registry struct { + // We store different registries per packet type and namespace. + msgTypes map[registryKey]registryForNamespace + // Handle concurrent access + msgTypesLock *sync.RWMutex +} + +func newRegistry() *registry { + return ®istry{ + msgTypes: make(map[registryKey]registryForNamespace), msgTypesLock: &sync.RWMutex{}, - iqTypes: make(map[namespace]reflect.Type), } } -// Mutexes are not needed when adding a Message or IQ extension in init function. -// However, forcing the use of the mutex protect the data structure against unexpected use -// of the registry by developers using the library. -func (r registry) RegisterMsgExt(namespace string, extension MsgExtension) { +// MapExtension stores extension type for packet payload. +// The match is done per packetType (iq, message, or presence) and XML tag name. +// You can use the alias "*" as local XML name to be able to match all unknown tag name for that +// packet type and namespace. +func (r *registry) MapExtension(pktType packetType, name xml.Name, extension MsgExtension) { + key := registryKey{pktType, name.Space} + r.msgTypesLock.RLock() + store := r.msgTypes[key] + r.msgTypesLock.RUnlock() + r.msgTypesLock.Lock() defer r.msgTypesLock.Unlock() - r.msgTypes[namespace] = reflect.TypeOf(extension) + if store == nil { + store = make(map[string]reflect.Type) + } + store[name.Local] = reflect.TypeOf(extension) + r.msgTypes[key] = store } -func (r registry) getMsgExtType(namespace string) reflect.Type { +// GetExtensionType returns extension type for packet payload, based on packet type and tag name. +func (r *registry) GetExtensionType(pktType packetType, name xml.Name) reflect.Type { + key := registryKey{pktType, name.Space} + r.msgTypesLock.RLock() defer r.msgTypesLock.RUnlock() - return r.msgTypes[namespace] + store := r.msgTypes[key] + result := store[name.Local] + if result == nil && name.Local != "*" { + return store["*"] + } + return result } -func (r registry) getmsgType(namespace string) MsgExtension { - if extensionType := r.getMsgExtType(namespace); extensionType != nil { +// GetMsgExtension returns an instance of MsgExtension, by matching packet type and XML +// tag name against the registry. +func (r *registry) GetMsgExtension(name xml.Name) MsgExtension { + if extensionType := r.GetExtensionType(PKTMessage, name); extensionType != nil { val := reflect.New(extensionType) elt := val.Interface() if msgExt, ok := elt.(MsgExtension); ok { @@ -57,9 +91,15 @@ func (r registry) getmsgType(namespace string) MsgExtension { return nil } -// Registry to support message extensions -//var msgTypeRegistry = make(map[string]reflect.Type) - -// Registry to instantiate the right IQ payload element -// Key is namespace and key of the payload -var iqTypeRegistry = make(map[string]reflect.Type) +// GetIQExtension returns an instance of IQPayload, by matching packet type and XML +// tag name against the registry. +func (r *registry) GetIQExtension(name xml.Name) IQPayload { + if extensionType := r.GetExtensionType(PKTIQ, name); extensionType != nil { + val := reflect.New(extensionType) + elt := val.Interface() + if iqExt, ok := elt.(IQPayload); ok { + return iqExt + } + } + return nil +} diff --git a/registry_test.go b/registry_test.go new file mode 100644 index 0000000..4b3ba81 --- /dev/null +++ b/registry_test.go @@ -0,0 +1,47 @@ +package xmpp // import "gosrc.io/xmpp" + +import ( + "encoding/xml" + "reflect" + "testing" +) + +func TestRegistry_RegisterMsgExt(t *testing.T) { + // Setup registry + typeRegistry := newRegistry() + + // Register an element + name := xml.Name{Space: "urn:xmpp:receipts", Local: "request"} + typeRegistry.MapExtension(PKTMessage, name, ReceiptRequest{}) + + // Match that element + receipt := typeRegistry.GetMsgExtension(name) + if receipt == nil { + t.Error("cannot read element type from registry") + return + } + + switch r := receipt.(type) { + case *ReceiptRequest: + default: + t.Errorf("Registry did not return expected type ReceiptRequest: %v", reflect.TypeOf(r)) + } +} + +func BenchmarkRegistryGet(b *testing.B) { + // Setup registry + typeRegistry := newRegistry() + + // Register an element + name := xml.Name{Space: "urn:xmpp:receipts", Local: "request"} + typeRegistry.MapExtension(PKTMessage, name, ReceiptRequest{}) + + for i := 0; i < b.N; i++ { + // Match that element + receipt := typeRegistry.GetExtensionType(PKTMessage, name) + if receipt == nil { + b.Error("cannot read element type from registry") + return + } + } +}