// RFC - 6120: chapter 5 import Foundation import Network enum SocketType: Codable & Equatable { case startTls case directTls // case p2p // in future, maybe..., probably... } enum SocketState { case connected case disconnected(Error?) case startTlsReady } enum SocketEvent { case state(SocketState) case dataReceived(Data) } protocol Socket { var events: AsyncStream { get } func connect() async func send(_ data: Data) async } final class ConnectionModule: XmppModule { let id = "Connection module" private var socket: Socket? private var callback: ((Event) async -> Void)? init(_ fire: @escaping (Event) async -> Void) { callback = fire } func reduce(oldState: ClientState, with _: Event) -> ClientState { oldState } func process(state: ClientState, with event: Event) async -> Event? { switch event { case .tryConnect: guard state.srvRecordIndex < state.srvRecords.count else { return .allRecordsUnreachable } let record = state.srvRecords[state.srvRecordIndex] let conId = state.jid.uStr let type: SocketType if record.isSecure { socket = DirectTLSSocket(id: conId, host: record.target, port: record.port, allowInsecure: state.allowInsecure) type = .directTls } else { socket = StartTLSSocket(id: conId, host: record.target, port: record.port, allowInsecure: state.allowInsecure) type = .startTls } if let socket { Task { await socket.connect() } Task { for await msg in socket.events { switch msg { case .state(let conn): switch conn { case .connected: await callback?(.socketConnected(type)) case .disconnected(let error): if let error { await callback?(.socketError(error)) } else { await callback?(.socketDisconnected) } case .startTlsReady: await callback?(.startTlsDone) } case .dataReceived(let data): await callback?(.socketReceived(data)) } } } } return nil case .xmlOutbound(let xml): await socket?.send(xml.data) return nil // For StartTLS process case .xmlInbound(let element): // process start tls on first (almost always required) if element.name == "stream:features", element.nodes.map({ $0.name }).contains("starttls") { let req = XMLElement( name: "starttls", xmlns: "urn:ietf:params:xml:ns:xmpp-tls", attributes: [:], content: nil, nodes: [] ) return .xmlOutbound(req) // special case for starttls proceed } else if element.name == "proceed" && element.xmlns == "urn:ietf:params:xml:ns:xmpp-tls" { return .startTls } else { return nil } case .startTls: guard let socket = socket as? StartTLSSocket else { fatalError("why its not starttls socket?") } do { try socket.startTls() return nil } catch let err { return .startTlsFailed(err) } case .startTlsFailed: socket = nil return nil default: return nil } } }