Allow transports to define their own ping mechanism
This commit is contained in:
parent
d0f2b492ac
commit
36e153f981
|
@ -276,8 +276,8 @@ func keepalive(transport Transport, quit <-chan struct{}) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
if n, err := fmt.Fprintf(transport, "\n"); err != nil || n != 1 {
|
if err := transport.Ping(); err != nil {
|
||||||
// When keep alive fails, we force close the transportection. In all cases, the recv will also fail.
|
// When keepalive fails, we force close the transport. In all cases, the recv will also fail.
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
_ = transport.Close()
|
_ = transport.Close()
|
||||||
return
|
return
|
||||||
|
|
|
@ -25,6 +25,7 @@ type Transport interface {
|
||||||
|
|
||||||
IsSecure() bool
|
IsSecure() bool
|
||||||
|
|
||||||
|
Ping() error
|
||||||
Read(p []byte) (n int, err error)
|
Read(p []byte) (n int, err error)
|
||||||
Write(p []byte) (n int, err error)
|
Write(p []byte) (n int, err error)
|
||||||
Close() error
|
Close() error
|
||||||
|
|
|
@ -9,6 +9,8 @@ import (
|
||||||
"nhooyr.io/websocket"
|
"nhooyr.io/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const pingTimeout = time.Duration(5) * time.Second
|
||||||
|
|
||||||
type WebsocketTransport struct {
|
type WebsocketTransport struct {
|
||||||
Config TransportConfiguration
|
Config TransportConfiguration
|
||||||
wsConn *websocket.Conn
|
wsConn *websocket.Conn
|
||||||
|
@ -46,6 +48,14 @@ func (t WebsocketTransport) IsSecure() bool {
|
||||||
return strings.HasPrefix(t.Config.Address, "wss:")
|
return strings.HasPrefix(t.Config.Address, "wss:")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t WebsocketTransport) Ping() error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
|
||||||
|
defer cancel()
|
||||||
|
// Note that we do not use wsConn.Ping(), because not all websocket servers
|
||||||
|
// (ejabberd for example) implement ping frames
|
||||||
|
return t.wsConn.Write(ctx, websocket.MessageText, []byte(" "))
|
||||||
|
}
|
||||||
|
|
||||||
func (t WebsocketTransport) Read(p []byte) (n int, err error) {
|
func (t WebsocketTransport) Read(p []byte) (n int, err error) {
|
||||||
return t.netConn.Read(p)
|
return t.netConn.Read(p)
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package xmpp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -59,6 +60,17 @@ func (t *XMPPTransport) StartTLS(domain string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t XMPPTransport) Ping() error {
|
||||||
|
n, err := t.conn.Write([]byte("\n"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if n != 1 {
|
||||||
|
return errors.New("Could not write ping")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t XMPPTransport) Read(p []byte) (n int, err error) {
|
func (t XMPPTransport) Read(p []byte) (n int, err error) {
|
||||||
return t.conn.Read(p)
|
return t.conn.Read(p)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue