Compare commits

...

2 commits

Author SHA1 Message Date
Bohdan Horbeshko e55463fc98 Fix passing Component to StreamManager
0a4acd12c3, which fixes #160, introduced
a regression as it assumed only Client may implement StreamClient, and
passing a component triggers an unconditional "client is not
disconnected" error.
2021-12-18 10:55:35 -05:00
bodqhrohro 5f99e1cd06 Support partial JIDs in Bare/Full methods 2021-12-14 12:01:36 +01:00
3 changed files with 58 additions and 28 deletions

View file

@ -51,12 +51,22 @@ func NewJid(sjid string) (*Jid, error) {
} }
func (j *Jid) Full() string { func (j *Jid) Full() string {
if j.Resource == "" {
return j.Bare()
} else if j.Node == "" {
return j.Node + "/" + j.Resource
} else {
return j.Node + "@" + j.Domain + "/" + j.Resource return j.Node + "@" + j.Domain + "/" + j.Resource
} }
}
func (j *Jid) Bare() string { func (j *Jid) Bare() string {
if j.Node == "" {
return j.Domain
} else {
return j.Node + "@" + j.Domain return j.Node + "@" + j.Domain
} }
}
// ============================================================================ // ============================================================================
// Helpers, for parsing / validation // Helpers, for parsing / validation

View file

@ -61,26 +61,41 @@ func TestIncorrectJids(t *testing.T) {
} }
func TestFull(t *testing.T) { func TestFull(t *testing.T) {
jid := "test@domain.com/my resource" fullJids := []string{
parsedJid, err := NewJid(jid) "test@domain.com/my resource",
"test@domain.com",
"domain.com",
}
for _, sjid := range fullJids {
parsedJid, err := NewJid(sjid)
if err != nil { if err != nil {
t.Errorf("could not parse jid: %v", err) t.Errorf("could not parse jid: %v", err)
} }
fullJid := parsedJid.Full() fullJid := parsedJid.Full()
if fullJid != jid { if fullJid != sjid {
t.Errorf("incorrect full jid: %s", fullJid) t.Errorf("incorrect full jid: %s", fullJid)
} }
} }
}
func TestBare(t *testing.T) { func TestBare(t *testing.T) {
jid := "test@domain.com" tests := []struct {
fullJid := jid + "/my resource" jidstr string
parsedJid, err := NewJid(fullJid) expected string
}{
{jidstr: "test@domain.com", expected: "test@domain.com"},
{jidstr: "test@domain.com/resource", expected: "test@domain.com"},
{jidstr: "domain.com", expected: "domain.com"},
}
for _, tt := range tests {
parsedJid, err := NewJid(tt.jidstr)
if err != nil { if err != nil {
t.Errorf("could not parse jid: %v", err) t.Errorf("could not parse jid: %v", err)
} }
bareJid := parsedJid.Bare() bareJid := parsedJid.Bare()
if bareJid != jid { if bareJid != tt.expected {
t.Errorf("incorrect bare jid: %s", bareJid) t.Errorf("incorrect bare jid: %s", bareJid)
} }
} }
}

View file

@ -114,10 +114,16 @@ func (sm *StreamManager) Stop() {
func (sm *StreamManager) connect() error { func (sm *StreamManager) connect() error {
if sm.client != nil { if sm.client != nil {
if c, ok := sm.client.(*Client); ok { var scs *SyncConnState
if c.CurrentState.getState() == StateDisconnected { if client, ok := sm.client.(*Client); ok {
scs = &client.CurrentState
}
if component, ok := sm.client.(*Component); ok {
scs = &component.CurrentState
}
if scs != nil && scs.getState() == StateDisconnected {
sm.Metrics = initMetrics() sm.Metrics = initMetrics()
err := c.Connect() err := sm.client.Connect()
if err != nil { if err != nil {
return err return err
} }
@ -127,7 +133,6 @@ func (sm *StreamManager) connect() error {
return nil return nil
} }
} }
}
return errors.New("client is not disconnected") return errors.New("client is not disconnected")
} }