0227596f90
Since a transport (and a streamlogger) does not implement io.ByteReader xml.Decoder wraps it using `bufio.NewReader(transport)` so it can easily read bytes one at a time. This has the unfortuante effect of resulting in a panic if we try to parse a stanza that is larger than the default buffer size of 4096 bytes. To fix this we wrap the transport using `bufio.NewReaderSize()` with a much larger buffer size.
179 lines
4.2 KiB
Go
179 lines
4.2 KiB
Go
package xmpp
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"encoding/xml"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
"time"
|
|
|
|
"gosrc.io/xmpp/stanza"
|
|
"nhooyr.io/websocket"
|
|
)
|
|
|
|
const maxPacketSize = 32768
|
|
|
|
const pingTimeout = time.Duration(5) * time.Second
|
|
|
|
var ServerDoesNotSupportXmppOverWebsocket = errors.New("The websocket server does not support the xmpp subprotocol")
|
|
|
|
type WebsocketTransport struct {
|
|
Config TransportConfiguration
|
|
decoder *xml.Decoder
|
|
wsConn *websocket.Conn
|
|
queue chan []byte
|
|
logFile io.Writer
|
|
|
|
closeCtx context.Context
|
|
closeFunc context.CancelFunc
|
|
}
|
|
|
|
func (t *WebsocketTransport) Connect() (string, error) {
|
|
t.queue = make(chan []byte, 256)
|
|
t.closeCtx, t.closeFunc = context.WithCancel(context.Background())
|
|
|
|
var ctx context.Context
|
|
ctx = context.Background()
|
|
if t.Config.ConnectTimeout > 0 {
|
|
var cancelConnect context.CancelFunc
|
|
ctx, cancelConnect = context.WithTimeout(t.closeCtx, time.Duration(t.Config.ConnectTimeout)*time.Second)
|
|
defer cancelConnect()
|
|
}
|
|
|
|
wsConn, response, err := websocket.Dial(ctx, t.Config.Address, &websocket.DialOptions{
|
|
Subprotocols: []string{"xmpp"},
|
|
})
|
|
if err != nil {
|
|
return "", NewConnError(err, true)
|
|
}
|
|
if response.Header.Get("Sec-WebSocket-Protocol") != "xmpp" {
|
|
t.cleanup(websocket.StatusBadGateway)
|
|
return "", NewConnError(ServerDoesNotSupportXmppOverWebsocket, true)
|
|
}
|
|
|
|
wsConn.SetReadLimit(maxPacketSize)
|
|
t.wsConn = wsConn
|
|
t.startReader()
|
|
|
|
t.decoder = xml.NewDecoder(bufio.NewReaderSize(t, maxPacketSize))
|
|
t.decoder.CharsetReader = t.Config.CharsetReader
|
|
|
|
return t.StartStream()
|
|
}
|
|
|
|
func (t WebsocketTransport) StartStream() (string, error) {
|
|
if _, err := fmt.Fprintf(t, `<open xmlns="urn:ietf:params:xml:ns:xmpp-framing" to="%s" version="1.0" />`, t.Config.Domain); err != nil {
|
|
t.cleanup(websocket.StatusBadGateway)
|
|
return "", NewConnError(err, true)
|
|
}
|
|
|
|
sessionID, err := stanza.InitStream(t.GetDecoder())
|
|
if err != nil {
|
|
t.Close()
|
|
return "", NewConnError(err, false)
|
|
}
|
|
return sessionID, nil
|
|
}
|
|
|
|
// startReader runs a go function that keeps reading from the websocket. This
|
|
// is required to allow Ping() to work: Ping requires a Reader to be running
|
|
// to process incoming control frames.
|
|
func (t WebsocketTransport) startReader() {
|
|
go func() {
|
|
buffer := make([]byte, maxPacketSize)
|
|
for {
|
|
_, reader, err := t.wsConn.Reader(t.closeCtx)
|
|
if err != nil {
|
|
return
|
|
}
|
|
n, err := reader.Read(buffer)
|
|
if err != nil && err != io.EOF {
|
|
return
|
|
}
|
|
if n > 0 {
|
|
// We need to make a copy, otherwise we will overwrite the slice content
|
|
// on the next iteration of the for loop.
|
|
tmp := make([]byte, n)
|
|
copy(tmp, buffer)
|
|
t.queue <- tmp
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (t WebsocketTransport) StartTLS() error {
|
|
return ErrTLSNotSupported
|
|
}
|
|
|
|
func (t WebsocketTransport) DoesStartTLS() bool {
|
|
return false
|
|
}
|
|
|
|
func (t WebsocketTransport) GetDomain() string {
|
|
return t.Config.Domain
|
|
}
|
|
|
|
func (t WebsocketTransport) GetDecoder() *xml.Decoder {
|
|
return t.decoder
|
|
}
|
|
|
|
func (t WebsocketTransport) IsSecure() bool {
|
|
return strings.HasPrefix(t.Config.Address, "wss:")
|
|
}
|
|
|
|
func (t WebsocketTransport) Ping() error {
|
|
ctx, cancel := context.WithTimeout(t.closeCtx, pingTimeout)
|
|
defer cancel()
|
|
return t.wsConn.Ping(ctx)
|
|
}
|
|
|
|
func (t *WebsocketTransport) Read(p []byte) (int, error) {
|
|
select {
|
|
case <-t.closeCtx.Done():
|
|
return 0, t.closeCtx.Err()
|
|
case data := <-t.queue:
|
|
if t.logFile != nil && len(data) > 0 {
|
|
_, _ = fmt.Fprintf(t.logFile, "RECV:\n%s\n\n", data)
|
|
}
|
|
copy(p, data)
|
|
return len(data), nil
|
|
}
|
|
}
|
|
|
|
func (t WebsocketTransport) Write(p []byte) (int, error) {
|
|
if t.logFile != nil {
|
|
_, _ = fmt.Fprintf(t.logFile, "SEND:\n%s\n\n", p)
|
|
}
|
|
return len(p), t.wsConn.Write(t.closeCtx, websocket.MessageText, p)
|
|
}
|
|
|
|
func (t WebsocketTransport) Close() error {
|
|
t.Write([]byte("<close xmlns=\"urn:ietf:params:xml:ns:xmpp-framing\" />"))
|
|
return t.cleanup(websocket.StatusGoingAway)
|
|
}
|
|
|
|
func (t *WebsocketTransport) LogTraffic(logFile io.Writer) {
|
|
t.logFile = logFile
|
|
}
|
|
|
|
func (t *WebsocketTransport) cleanup(code websocket.StatusCode) error {
|
|
var err error
|
|
if t.queue != nil {
|
|
close(t.queue)
|
|
t.queue = nil
|
|
}
|
|
if t.wsConn != nil {
|
|
err = t.wsConn.Close(websocket.StatusGoingAway, "Done")
|
|
t.wsConn = nil
|
|
}
|
|
if t.closeFunc != nil {
|
|
t.closeFunc()
|
|
t.closeFunc = nil
|
|
t.closeCtx = nil
|
|
}
|
|
return err
|
|
}
|