diff --git a/client.go b/client.go index 32b2394..c8000bd 100644 --- a/client.go +++ b/client.go @@ -235,14 +235,14 @@ func (c *Client) Send(packet stanza.Packet) error { // // Handle the result here // }) // -func (c *Client) SendIQ(ctx context.Context, iq stanza.IQ, handler HandlerFunc) (*IqResultRoute, error) { +func (c *Client) SendIQ(ctx context.Context, iq stanza.IQ, handler IQResultHandlerFunc) (*IQResultRoute, error) { if iq.Attrs.Type != "set" && iq.Attrs.Type != "get" { return nil, ErrCanOnlySendGetOrSetIq } if err := c.Send(iq); err != nil { return nil, err } - return c.router.NewIqResultRoute(ctx, iq.Attrs.Id).HandlerFunc(handler), nil + return c.router.NewIQResultRoute(ctx, iq.Attrs.Id).HandlerFunc(handler), nil } // SendRaw sends an XMPP stanza as a string to the server. diff --git a/router.go b/router.go index df3b19f..05f4b09 100644 --- a/router.go +++ b/router.go @@ -28,14 +28,14 @@ type Router struct { // Routes to be matched, in order. routes []*Route - iqResultRoutes map[string]*IqResultRoute - iqResultRouteLock sync.RWMutex + IQResultRoutes map[string]*IQResultRoute + IQResultRouteLock sync.RWMutex } // NewRouter returns a new router instance. func NewRouter() *Router { return &Router{ - iqResultRoutes: make(map[string]*IqResultRoute), + IQResultRoutes: make(map[string]*IQResultRoute), } } @@ -44,15 +44,15 @@ func NewRouter() *Router { func (r *Router) route(s Sender, p stanza.Packet) { iq, isIq := p.(stanza.IQ) if isIq { - r.iqResultRouteLock.RLock() - route, ok := r.iqResultRoutes[iq.Id] - r.iqResultRouteLock.RUnlock() + r.IQResultRouteLock.RLock() + route, ok := r.IQResultRoutes[iq.Id] + r.IQResultRouteLock.RUnlock() if ok { - r.iqResultRouteLock.Lock() - delete(r.iqResultRoutes, iq.Id) - r.iqResultRouteLock.Unlock() + r.IQResultRouteLock.Lock() + delete(r.IQResultRoutes, iq.Id) + r.IQResultRouteLock.Unlock() close(route.matched) - route.handler.HandlePacket(s, p) + route.handler.HandleIQ(route.context, s, iq) return } } @@ -88,29 +88,31 @@ func (r *Router) NewRoute() *Route { return route } -// NewIqResultRoute register a route that will catch an IQ result stanza with +// NewIQResultRoute register a route that will catch an IQ result stanza with // the given Id. The route will only match ones, after which it will automatically // be unregistered -func (r *Router) NewIqResultRoute(ctx context.Context, id string) *IqResultRoute { - route := &IqResultRoute{ +func (r *Router) NewIQResultRoute(ctx context.Context, id string) *IQResultRoute { + route := &IQResultRoute{ context: ctx, matched: make(chan struct{}), } - r.iqResultRouteLock.Lock() - r.iqResultRoutes[id] = route - r.iqResultRouteLock.Unlock() + r.IQResultRouteLock.Lock() + r.IQResultRoutes[id] = route + r.IQResultRouteLock.Unlock() + go func() { select { case <-route.context.Done(): - r.iqResultRouteLock.Lock() - delete(r.iqResultRoutes, id) - r.iqResultRouteLock.Unlock() + r.IQResultRouteLock.Lock() + delete(r.IQResultRoutes, id) + r.IQResultRouteLock.Unlock() if route.timeoutHandler != nil { route.timeoutHandler(route.context.Err()) } case <-route.matched: } }() + return route } @@ -135,42 +137,56 @@ func (r *Router) HandleFunc(name string, f func(s Sender, p stanza.Packet)) *Rou return r.NewRoute().Packet(name).HandlerFunc(f) } -// HandleIqResult register a temporary route -func (r *Router) HandleIqResult(id string, handler Handler) *IqResultRoute { - return r.NewIqResultRoute(context.Background(), id).Handler(handler) -} - -func (r *Router) HandleFuncIqResult(id string, f func(s Sender, p stanza.Packet)) *IqResultRoute { - return r.NewIqResultRoute(context.Background(), id).HandlerFunc(f) -} - // ============================================================================ -// IqResultRoute + +// TimeoutHandlerFunc is a function type for handling IQ result timeouts. type TimeoutHandlerFunc func(err error) -type IqResultRoute struct { +// IQResultRoute is a temporary route to match IQ result stanzas +type IQResultRoute struct { context context.Context matched chan struct{} - handler Handler + handler IQResultHandler timeoutHandler TimeoutHandlerFunc } -func (r *IqResultRoute) Handler(handler Handler) *IqResultRoute { +// Handler adds an IQ handler to the route. +func (r *IQResultRoute) Handler(handler IQResultHandler) *IQResultRoute { r.handler = handler return r } -func (r *IqResultRoute) HandlerFunc(f HandlerFunc) *IqResultRoute { +// HandlerFunc updates the route to call a handler function when an IQ result is received. +func (r *IQResultRoute) HandlerFunc(f IQResultHandlerFunc) *IQResultRoute { return r.Handler(f) } -func (r *IqResultRoute) TimeoutHandlerFunc(f TimeoutHandlerFunc) *IqResultRoute { +// TimeoutHandlerFunc registers a function that will be called automatically when +// the IQ result route is cancelled (most likely due to a timeout on the context). +func (r *IQResultRoute) TimeoutHandlerFunc(f TimeoutHandlerFunc) *IQResultRoute { r.timeoutHandler = f return r } +// ============================================================================ +// IQ result handler + +// IQResultHandler is a utility interface for IQ result handlers +type IQResultHandler interface { + HandleIQ(ctx context.Context, s Sender, iq stanza.IQ) +} + +// IQResultHandlerFunc is an adapter to allow using functions as IQ result handlers. +type IQResultHandlerFunc func(ctx context.Context, s Sender, iq stanza.IQ) + +// HandleIQ is a proxy function to implement IQResultHandler using a function. +func (f IQResultHandlerFunc) HandleIQ(ctx context.Context, s Sender, iq stanza.IQ) { + f(ctx, s, iq) +} + // ============================================================================ // Route + type Handler interface { HandlePacket(s Sender, p stanza.Packet) } diff --git a/router_test.go b/router_test.go index 138999c..b63553d 100644 --- a/router_test.go +++ b/router_test.go @@ -13,18 +13,18 @@ import ( // ============================================================================ // Test route & matchers -func TestIqResultRoutes(t *testing.T) { +func TestIQResultRoutes(t *testing.T) { t.Parallel() router := NewRouter() - if router.iqResultRoutes == nil { + if router.IQResultRoutes == nil { t.Fatal("NewRouter does not initialize isResultRoutes") } // Check other IQ does not matcah conn := NewSenderMock() iq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeResult, Id: "4321"}) - router.NewIqResultRoute(context.Background(), "1234").HandlerFunc(func(s Sender, p stanza.Packet) { + router.NewIQResultRoute(context.Background(), "1234").HandlerFunc(func(ctx context.Context, s Sender, iq stanza.IQ) { _ = s.SendRaw(successFlag) }) if conn.String() == successFlag { @@ -51,7 +51,7 @@ func TestIqResultRoutes(t *testing.T) { conn = NewSenderMock() ctx, cancel := context.WithCancel(context.Background()) iq = stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeResult, Id: "1234"}) - router.NewIqResultRoute(ctx, "1234").HandlerFunc(func(s Sender, p stanza.Packet) { + router.NewIQResultRoute(ctx, "1234").HandlerFunc(func(ctx context.Context, s Sender, iq stanza.IQ) { _ = s.SendRaw(successFlag) }).TimeoutHandlerFunc(func(err error) { conn.SendRaw(cancelledFlag)