134 lines
4 KiB
Swift
134 lines
4 KiB
Swift
// RFC - 6120: chapter 5
|
|
import Foundation
|
|
import Network
|
|
|
|
enum SocketType: Codable & Equatable {
|
|
case startTls
|
|
case directTls
|
|
// case p2p // in future, maybe..., probably...
|
|
}
|
|
|
|
enum SocketState {
|
|
case connected
|
|
case disconnected(Error?)
|
|
case startTlsReady
|
|
}
|
|
|
|
enum SocketEvent {
|
|
case state(SocketState)
|
|
case dataReceived(Data)
|
|
}
|
|
|
|
protocol Socket {
|
|
var events: AsyncStream<SocketEvent> { get }
|
|
|
|
func connect() async
|
|
func send(_ data: Data) async
|
|
}
|
|
|
|
final class ConnectionModule: XmppModule {
|
|
let id = "Connection module"
|
|
|
|
private var socket: Socket?
|
|
private var callback: ((Event) async -> Void)?
|
|
|
|
init(_ fire: @escaping (Event) async -> Void) {
|
|
callback = fire
|
|
}
|
|
|
|
func reduce(oldState: ClientState, with _: Event) -> ClientState {
|
|
oldState
|
|
}
|
|
|
|
func process(state: ClientState, with event: Event) async -> Event? {
|
|
switch event {
|
|
case .tryConnect:
|
|
guard state.srvRecordIndex < state.srvRecords.count else { return .allRecordsUnreachable }
|
|
|
|
let record = state.srvRecords[state.srvRecordIndex]
|
|
let conId = state.jid.uStr
|
|
let type: SocketType
|
|
if record.isSecure {
|
|
socket = DirectTLSSocket(id: conId, host: record.target, port: record.port, allowInsecure: state.allowInsecure)
|
|
type = .directTls
|
|
} else {
|
|
socket = StartTLSSocket(id: conId, host: record.target, port: record.port, allowInsecure: state.allowInsecure)
|
|
type = .startTls
|
|
}
|
|
|
|
if let socket {
|
|
Task {
|
|
await socket.connect()
|
|
}
|
|
Task {
|
|
for await msg in socket.events {
|
|
switch msg {
|
|
case .state(let conn):
|
|
switch conn {
|
|
case .connected:
|
|
await callback?(.socketConnected(type))
|
|
|
|
case .disconnected(let error):
|
|
if let error {
|
|
await callback?(.socketError(error))
|
|
} else {
|
|
await callback?(.socketDisconnected)
|
|
}
|
|
|
|
case .startTlsReady:
|
|
await callback?(.startTlsDone)
|
|
}
|
|
|
|
case .dataReceived(let data):
|
|
await callback?(.socketReceived(data))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
|
|
case .xmlOutbound(let xml):
|
|
await socket?.send(xml.data)
|
|
return nil
|
|
|
|
// For StartTLS process
|
|
case .xmlInbound(let element):
|
|
// process start tls on first (almost always required)
|
|
if element.name == "stream:features", element.nodes.map({ $0.name }).contains("starttls") {
|
|
let req = XMLElement(
|
|
name: "starttls",
|
|
xmlns: "urn:ietf:params:xml:ns:xmpp-tls",
|
|
attributes: [:],
|
|
content: nil,
|
|
nodes: []
|
|
)
|
|
return .xmlOutbound(req)
|
|
|
|
// special case for starttls proceed
|
|
} else if element.name == "proceed" && element.xmlns == "urn:ietf:params:xml:ns:xmpp-tls" {
|
|
return .startTls
|
|
} else {
|
|
return nil
|
|
}
|
|
|
|
case .startTls:
|
|
guard let socket = socket as? StartTLSSocket else {
|
|
fatalError("why its not starttls socket?")
|
|
}
|
|
do {
|
|
try socket.startTls()
|
|
return nil
|
|
} catch let err {
|
|
return .startTlsFailed(err)
|
|
}
|
|
|
|
case .startTlsFailed:
|
|
socket = nil
|
|
return nil
|
|
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
}
|