package xmpp import ( "context" "encoding/xml" "errors" "fmt" "io" "strings" "time" "gosrc.io/xmpp/stanza" "nhooyr.io/websocket" ) const maxPacketSize = 32768 const pingTimeout = time.Duration(5) * time.Second var ServerDoesNotSupportXmppOverWebsocket = errors.New("The websocket server does not support the xmpp subprotocol") type WebsocketTransport struct { Config TransportConfiguration decoder *xml.Decoder wsConn *websocket.Conn queue chan []byte logFile io.Writer closeCtx context.Context closeFunc context.CancelFunc } func (t *WebsocketTransport) Connect() (string, error) { t.queue = make(chan []byte, 256) t.closeCtx, t.closeFunc = context.WithCancel(context.Background()) var ctx context.Context ctx = context.Background() if t.Config.ConnectTimeout > 0 { var cancelConnect context.CancelFunc ctx, cancelConnect = context.WithTimeout(t.closeCtx, time.Duration(t.Config.ConnectTimeout)*time.Second) defer cancelConnect() } wsConn, response, err := websocket.Dial(ctx, t.Config.Address, &websocket.DialOptions{ Subprotocols: []string{"xmpp"}, }) if err != nil { return "", NewConnError(err, true) } if response.Header.Get("Sec-WebSocket-Protocol") != "xmpp" { t.cleanup(websocket.StatusBadGateway) return "", NewConnError(ServerDoesNotSupportXmppOverWebsocket, true) } wsConn.SetReadLimit(maxPacketSize) t.wsConn = wsConn t.startReader() t.decoder = xml.NewDecoder(t) t.decoder.CharsetReader = t.Config.CharsetReader return t.StartStream() } func (t WebsocketTransport) StartStream() (string, error) { if _, err := fmt.Fprintf(t, ``, t.Config.Domain); err != nil { t.cleanup(websocket.StatusBadGateway) return "", NewConnError(err, true) } sessionID, err := stanza.InitStream(t.GetDecoder()) if err != nil { t.Close() return "", NewConnError(err, false) } return sessionID, nil } // startReader runs a go function that keeps reading from the websocket. This // is required to allow Ping() to work: Ping requires a Reader to be running // to process incoming control frames. func (t WebsocketTransport) startReader() { go func() { buffer := make([]byte, maxPacketSize) for { _, reader, err := t.wsConn.Reader(t.closeCtx) if err != nil { return } n, err := reader.Read(buffer) if err != nil && err != io.EOF { return } if n > 0 { // We need to make a copy, otherwise we will overwrite the slice content // on the next iteration of the for loop. tmp := make([]byte, n) copy(tmp, buffer) t.queue <- tmp } } }() } func (t WebsocketTransport) StartTLS() error { return ErrTLSNotSupported } func (t WebsocketTransport) DoesStartTLS() bool { return false } func (t WebsocketTransport) GetDomain() string { return t.Config.Domain } func (t WebsocketTransport) GetDecoder() *xml.Decoder { return t.decoder } func (t WebsocketTransport) IsSecure() bool { return strings.HasPrefix(t.Config.Address, "wss:") } func (t WebsocketTransport) Ping() error { ctx, cancel := context.WithTimeout(t.closeCtx, pingTimeout) defer cancel() return t.wsConn.Ping(ctx) } func (t *WebsocketTransport) Read(p []byte) (int, error) { select { case <-t.closeCtx.Done(): return 0, t.closeCtx.Err() case data := <-t.queue: if t.logFile != nil && len(data) > 0 { _, _ = fmt.Fprintf(t.logFile, "RECV:\n%s\n\n", data) } copy(p, data) return len(data), nil } } func (t WebsocketTransport) Write(p []byte) (int, error) { if t.logFile != nil { _, _ = fmt.Fprintf(t.logFile, "SEND:\n%s\n\n", p) } return len(p), t.wsConn.Write(t.closeCtx, websocket.MessageText, p) } func (t WebsocketTransport) Close() error { t.Write([]byte("")) return t.cleanup(websocket.StatusGoingAway) } func (t *WebsocketTransport) LogTraffic(logFile io.Writer) { t.logFile = logFile } func (t *WebsocketTransport) cleanup(code websocket.StatusCode) error { var err error if t.queue != nil { close(t.queue) t.queue = nil } if t.wsConn != nil { err = t.wsConn.Close(websocket.StatusGoingAway, "Done") t.wsConn = nil } if t.closeFunc != nil { t.closeFunc() t.closeFunc = nil t.closeCtx = nil } return err }