diff --git a/router.go b/router.go index bf6d5e6..df3b19f 100644 --- a/router.go +++ b/router.go @@ -4,6 +4,7 @@ import ( "context" "encoding/xml" "strings" + "sync" "gosrc.io/xmpp/stanza" ) @@ -27,7 +28,8 @@ type Router struct { // Routes to be matched, in order. routes []*Route - iqResultRoutes map[string]*IqResultRoute + iqResultRoutes map[string]*IqResultRoute + iqResultRouteLock sync.RWMutex } // NewRouter returns a new router instance. @@ -42,8 +44,16 @@ func NewRouter() *Router { func (r *Router) route(s Sender, p stanza.Packet) { iq, isIq := p.(stanza.IQ) if isIq { - if route, ok := r.iqResultRoutes[iq.Id]; ok { + r.iqResultRouteLock.RLock() + route, ok := r.iqResultRoutes[iq.Id] + r.iqResultRouteLock.RUnlock() + if ok { + r.iqResultRouteLock.Lock() + delete(r.iqResultRoutes, iq.Id) + r.iqResultRouteLock.Unlock() + close(route.matched) route.handler.HandlePacket(s, p) + return } } @@ -86,16 +96,20 @@ func (r *Router) NewIqResultRoute(ctx context.Context, id string) *IqResultRoute context: ctx, matched: make(chan struct{}), } + r.iqResultRouteLock.Lock() r.iqResultRoutes[id] = route + r.iqResultRouteLock.Unlock() go func() { select { case <-route.context.Done(): + r.iqResultRouteLock.Lock() + delete(r.iqResultRoutes, id) + r.iqResultRouteLock.Unlock() if route.timeoutHandler != nil { route.timeoutHandler(route.context.Err()) } case <-route.matched: } - delete(r.iqResultRoutes, id) }() return route }