go-xmpp/registry.go
2019-06-04 18:47:44 +02:00

106 lines
2.9 KiB
Go

package xmpp
import (
"encoding/xml"
"reflect"
"sync"
)
type MsgExtension interface{}
// The Registry for msg and IQ types is a global variable.
// TODO: Move to the client init process to remove the dependency on a global variable.
// That should make it possible to be able to share the decoder.
// TODO: Ensure that a client can add its own custom namespace to the registry (or overload existing ones).
type packetType uint8
const (
PKTPresence packetType = iota
PKTMessage
PKTIQ
)
var typeRegistry = newRegistry()
// We store different registries per packet type and namespace.
type registryKey struct {
packetType packetType
namespace string
}
type registryForNamespace map[string]reflect.Type
type registry struct {
// We store different registries per packet type and namespace.
msgTypes map[registryKey]registryForNamespace
// Handle concurrent access
msgTypesLock *sync.RWMutex
}
func newRegistry() *registry {
return &registry{
msgTypes: make(map[registryKey]registryForNamespace),
msgTypesLock: &sync.RWMutex{},
}
}
// MapExtension stores extension type for packet payload.
// The match is done per packetType (iq, message, or presence) and XML tag name.
// You can use the alias "*" as local XML name to be able to match all unknown tag name for that
// packet type and namespace.
func (r *registry) MapExtension(pktType packetType, name xml.Name, extension MsgExtension) {
key := registryKey{pktType, name.Space}
r.msgTypesLock.RLock()
store := r.msgTypes[key]
r.msgTypesLock.RUnlock()
r.msgTypesLock.Lock()
defer r.msgTypesLock.Unlock()
if store == nil {
store = make(map[string]reflect.Type)
}
store[name.Local] = reflect.TypeOf(extension)
r.msgTypes[key] = store
}
// GetExtensionType returns extension type for packet payload, based on packet type and tag name.
func (r *registry) GetExtensionType(pktType packetType, name xml.Name) reflect.Type {
key := registryKey{pktType, name.Space}
r.msgTypesLock.RLock()
defer r.msgTypesLock.RUnlock()
store := r.msgTypes[key]
result := store[name.Local]
if result == nil && name.Local != "*" {
return store["*"]
}
return result
}
// GetMsgExtension returns an instance of MsgExtension, by matching packet type and XML
// tag name against the registry.
func (r *registry) GetMsgExtension(name xml.Name) MsgExtension {
if extensionType := r.GetExtensionType(PKTMessage, name); extensionType != nil {
val := reflect.New(extensionType)
elt := val.Interface()
if msgExt, ok := elt.(MsgExtension); ok {
return msgExt
}
}
return nil
}
// GetIQExtension returns an instance of IQPayload, by matching packet type and XML
// tag name against the registry.
func (r *registry) GetIQExtension(name xml.Name) IQPayload {
if extensionType := r.GetExtensionType(PKTIQ, name); extensionType != nil {
val := reflect.New(extensionType)
elt := val.Interface()
if iqExt, ok := elt.(IQPayload); ok {
return iqExt
}
}
return nil
}