Add a go function to always read websockets
Websocket need to have a Reader running at all times in order to allow Ping to work (because a Reader is the only thing that will correctly handle control frames). To faciliate this a go function is introduced that will always read from the websocket until it is cancelled. Read data is passed to the transport via a channel.
This commit is contained in:
parent
92329b48e6
commit
ffadd331dd
|
@ -6,7 +6,6 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -14,6 +13,8 @@ import (
|
|||
"nhooyr.io/websocket"
|
||||
)
|
||||
|
||||
const maxPacketSize = 32768
|
||||
|
||||
const pingTimeout = time.Duration(5) * time.Second
|
||||
|
||||
var ServerDoesNotSupportXmppOverWebsocket = errors.New("The websocket server does not support the xmpp subprotocol")
|
||||
|
@ -22,17 +23,23 @@ type WebsocketTransport struct {
|
|||
Config TransportConfiguration
|
||||
decoder *xml.Decoder
|
||||
wsConn *websocket.Conn
|
||||
netConn net.Conn
|
||||
queue chan []byte
|
||||
logFile io.Writer
|
||||
|
||||
closeCtx context.Context
|
||||
closeFunc context.CancelFunc
|
||||
}
|
||||
|
||||
func (t *WebsocketTransport) Connect() (string, error) {
|
||||
ctx := context.Background()
|
||||
t.queue = make(chan []byte, 256)
|
||||
t.closeCtx, t.closeFunc = context.WithCancel(context.Background())
|
||||
|
||||
var ctx context.Context
|
||||
ctx = context.Background()
|
||||
if t.Config.ConnectTimeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, time.Duration(t.Config.ConnectTimeout)*time.Second)
|
||||
defer cancel()
|
||||
var cancelConnect context.CancelFunc
|
||||
ctx, cancelConnect = context.WithTimeout(t.closeCtx, time.Duration(t.Config.ConnectTimeout)*time.Second)
|
||||
defer cancelConnect()
|
||||
}
|
||||
|
||||
wsConn, response, err := websocket.Dial(ctx, t.Config.Address, &websocket.DialOptions{
|
||||
|
@ -42,28 +49,30 @@ func (t *WebsocketTransport) Connect() (string, error) {
|
|||
return "", NewConnError(err, true)
|
||||
}
|
||||
if response.Header.Get("Sec-WebSocket-Protocol") != "xmpp" {
|
||||
_ = wsConn.Close(websocket.StatusBadGateway, "Could not negotiate XMPP subprotocol")
|
||||
t.cleanup(websocket.StatusBadGateway)
|
||||
return "", NewConnError(ServerDoesNotSupportXmppOverWebsocket, true)
|
||||
}
|
||||
|
||||
wsConn.SetReadLimit(maxPacketSize)
|
||||
t.wsConn = wsConn
|
||||
t.netConn = websocket.NetConn(ctx, t.wsConn, websocket.MessageText)
|
||||
t.startReader()
|
||||
|
||||
handshake := fmt.Sprintf("<open xmlns=\"urn:ietf:params:xml:ns:xmpp-framing\" to=\"%s\" version=\"1.0\" />", t.Config.Domain)
|
||||
handshake := fmt.Sprintf(`<open xmlns="urn:ietf:params:xml:ns:xmpp-framing" to="%s" version="1.0" />`, t.Config.Domain)
|
||||
if _, err = t.Write([]byte(handshake)); err != nil {
|
||||
_ = wsConn.Close(websocket.StatusBadGateway, "XMPP handshake error")
|
||||
t.cleanup(websocket.StatusBadGateway)
|
||||
return "", NewConnError(err, false)
|
||||
}
|
||||
|
||||
handshakeResponse := make([]byte, 2048)
|
||||
if _, err = t.Read(handshakeResponse); err != nil {
|
||||
_ = wsConn.Close(websocket.StatusBadGateway, "XMPP handshake error")
|
||||
t.cleanup(websocket.StatusBadGateway)
|
||||
|
||||
return "", NewConnError(err, false)
|
||||
}
|
||||
|
||||
var openResponse = stanza.WebsocketOpen{}
|
||||
if err = xml.Unmarshal(handshakeResponse, &openResponse); err != nil {
|
||||
_ = wsConn.Close(websocket.StatusBadGateway, "XMPP handshake error")
|
||||
t.cleanup(websocket.StatusBadGateway)
|
||||
return "", NewConnError(err, false)
|
||||
}
|
||||
|
||||
|
@ -73,6 +82,32 @@ func (t *WebsocketTransport) Connect() (string, error) {
|
|||
return openResponse.Id, nil
|
||||
}
|
||||
|
||||
// startReader runs a go function that keeps reading from the websocket. This
|
||||
// is required to allow Ping() to work: Ping requires a Reader to be running
|
||||
// to process incoming control frames.
|
||||
func (t WebsocketTransport) startReader() {
|
||||
go func() {
|
||||
buffer := make([]byte, maxPacketSize)
|
||||
for {
|
||||
_, reader, err := t.wsConn.Reader(t.closeCtx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n, err := reader.Read(buffer)
|
||||
if err != nil && err != io.EOF {
|
||||
return
|
||||
}
|
||||
if n > 0 {
|
||||
// We need to make a copy, otherwise we will overwrite the slice content
|
||||
// on the next iteration of the for loop.
|
||||
tmp := make([]byte, len(buffer))
|
||||
copy(tmp, buffer)
|
||||
t.queue <- tmp
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (t WebsocketTransport) StartTLS() error {
|
||||
return TLSNotSupported
|
||||
}
|
||||
|
@ -90,31 +125,52 @@ func (t WebsocketTransport) IsSecure() bool {
|
|||
}
|
||||
|
||||
func (t WebsocketTransport) Ping() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
|
||||
ctx, cancel := context.WithTimeout(t.closeCtx, pingTimeout)
|
||||
defer cancel()
|
||||
return t.wsConn.Ping(ctx)
|
||||
}
|
||||
|
||||
func (t *WebsocketTransport) Read(p []byte) (n int, err error) {
|
||||
n, err = t.netConn.Read(p)
|
||||
if t.logFile != nil && n > 0 {
|
||||
_, _ = fmt.Fprintf(t.logFile, "RECV:\n%s\n\n", p)
|
||||
func (t *WebsocketTransport) Read(p []byte) (int, error) {
|
||||
select {
|
||||
case <-t.closeCtx.Done():
|
||||
return 0, t.closeCtx.Err()
|
||||
case data := <-t.queue:
|
||||
if t.logFile != nil && len(data) > 0 {
|
||||
_, _ = fmt.Fprintf(t.logFile, "RECV:\n%s\n\n", data)
|
||||
}
|
||||
copy(p, data)
|
||||
return len(data), nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (t WebsocketTransport) Write(p []byte) (n int, err error) {
|
||||
func (t WebsocketTransport) Write(p []byte) (int, error) {
|
||||
if t.logFile != nil {
|
||||
_, _ = fmt.Fprintf(t.logFile, "SEND:\n%s\n\n", p)
|
||||
}
|
||||
return t.netConn.Write(p)
|
||||
return len(p), t.wsConn.Write(t.closeCtx, websocket.MessageText, p)
|
||||
}
|
||||
|
||||
func (t WebsocketTransport) Close() error {
|
||||
t.Write([]byte("<close xmlns=\"urn:ietf:params:xml:ns:xmpp-framing\" />"))
|
||||
return t.netConn.Close()
|
||||
return t.wsConn.Close(websocket.StatusGoingAway, "Done")
|
||||
}
|
||||
|
||||
func (t *WebsocketTransport) LogTraffic(logFile io.Writer) {
|
||||
t.logFile = logFile
|
||||
}
|
||||
|
||||
func (t *WebsocketTransport) cleanup(code websocket.StatusCode) {
|
||||
if t.queue != nil {
|
||||
close(t.queue)
|
||||
t.queue = nil
|
||||
}
|
||||
if t.wsConn != nil {
|
||||
t.wsConn.Close(websocket.StatusGoingAway, "Done")
|
||||
t.wsConn = nil
|
||||
}
|
||||
if t.closeFunc != nil {
|
||||
t.closeFunc()
|
||||
t.closeFunc = nil
|
||||
t.closeCtx = nil
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue