Allow transports to define their own ping mechanism

This commit is contained in:
Wichert Akkerman 2019-10-15 20:56:11 +02:00 committed by Mickaël Rémond
parent d0f2b492ac
commit 36e153f981
4 changed files with 25 additions and 2 deletions

View file

@ -276,8 +276,8 @@ func keepalive(transport Transport, quit <-chan struct{}) {
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
if n, err := fmt.Fprintf(transport, "\n"); err != nil || n != 1 { if err := transport.Ping(); err != nil {
// When keep alive fails, we force close the transportection. In all cases, the recv will also fail. // When keepalive fails, we force close the transport. In all cases, the recv will also fail.
ticker.Stop() ticker.Stop()
_ = transport.Close() _ = transport.Close()
return return

View file

@ -25,6 +25,7 @@ type Transport interface {
IsSecure() bool IsSecure() bool
Ping() error
Read(p []byte) (n int, err error) Read(p []byte) (n int, err error)
Write(p []byte) (n int, err error) Write(p []byte) (n int, err error)
Close() error Close() error

View file

@ -9,6 +9,8 @@ import (
"nhooyr.io/websocket" "nhooyr.io/websocket"
) )
const pingTimeout = time.Duration(5) * time.Second
type WebsocketTransport struct { type WebsocketTransport struct {
Config TransportConfiguration Config TransportConfiguration
wsConn *websocket.Conn wsConn *websocket.Conn
@ -46,6 +48,14 @@ func (t WebsocketTransport) IsSecure() bool {
return strings.HasPrefix(t.Config.Address, "wss:") return strings.HasPrefix(t.Config.Address, "wss:")
} }
func (t WebsocketTransport) Ping() error {
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
defer cancel()
// Note that we do not use wsConn.Ping(), because not all websocket servers
// (ejabberd for example) implement ping frames
return t.wsConn.Write(ctx, websocket.MessageText, []byte(" "))
}
func (t WebsocketTransport) Read(p []byte) (n int, err error) { func (t WebsocketTransport) Read(p []byte) (n int, err error) {
return t.netConn.Read(p) return t.netConn.Read(p)
} }

View file

@ -2,6 +2,7 @@ package xmpp
import ( import (
"crypto/tls" "crypto/tls"
"errors"
"net" "net"
"time" "time"
) )
@ -59,6 +60,17 @@ func (t *XMPPTransport) StartTLS(domain string) error {
return nil return nil
} }
func (t XMPPTransport) Ping() error {
n, err := t.conn.Write([]byte("\n"))
if err != nil {
return err
}
if n != 1 {
return errors.New("Could not write ping")
}
return nil
}
func (t XMPPTransport) Read(p []byte) (n int, err error) { func (t XMPPTransport) Read(p []byte) (n int, err error) {
return t.conn.Read(p) return t.conn.Read(p)
} }