252 lines
9.2 KiB
Swift
252 lines
9.2 KiB
Swift
import Foundation
|
|
import Network
|
|
|
|
private let doPrint = false
|
|
|
|
final class StartTLSSocket: Socket {
|
|
let (events, eventsContinuation) = AsyncStream.makeStream(of: SocketEvent.self, bufferingPolicy: .unbounded)
|
|
|
|
private let queue: DispatchQueue
|
|
private let bridgeSocketPath: URL
|
|
private let host: String
|
|
private let port: Int
|
|
private let allowInsecure: Bool
|
|
private var rawConnection: NWConnection?
|
|
private var secConnection: NWConnection?
|
|
private var bridge: NWListener?
|
|
private var brgConnection: NWConnection?
|
|
|
|
init(id: String, host: String, port: Int, allowInsecure: Bool) {
|
|
self.host = host
|
|
self.port = port
|
|
self.allowInsecure = allowInsecure
|
|
queue = DispatchQueue(label: "another.xmpp.network.queue_\(id)")
|
|
|
|
bridgeSocketPath = URL(fileURLWithPath: "/tmp/tls_bridge_listener\(id).sock")
|
|
try? FileManager.default.removeItem(at: bridgeSocketPath)
|
|
|
|
// tcp options
|
|
let tcpOptions = NWProtocolTCP.Options()
|
|
tcpOptions.noDelay = true
|
|
tcpOptions.connectionTimeout = 5
|
|
tcpOptions.enableFastOpen = true
|
|
tcpOptions.disableAckStretching = true
|
|
let params = NWParameters(tls: nil, tcp: tcpOptions)
|
|
params.serviceClass = .responsiveData
|
|
|
|
rawConnection = NWConnection(host: .name(host, nil), port: .init(integerLiteral: UInt16(port)), using: .init(tls: nil, tcp: tcpOptions))
|
|
rawConnection?.stateUpdateHandler = { [weak self] state in
|
|
self?.logState("Raw connection \(state)")
|
|
switch state {
|
|
case .ready:
|
|
self?.eventsContinuation.yield(.state(.connected))
|
|
self?.rawRead()
|
|
|
|
case .waiting(let error), .failed(let error):
|
|
print(error.localizedDescription)
|
|
self?.eventsContinuation.yield(.state(.disconnected(error)))
|
|
|
|
default:
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
deinit {
|
|
eventsContinuation.finish()
|
|
rawConnection?.cancel()
|
|
brgConnection?.cancel()
|
|
secConnection?.cancel()
|
|
bridge?.cancel()
|
|
}
|
|
|
|
func connect() async {
|
|
rawConnection?.start(queue: queue)
|
|
}
|
|
|
|
func send(_ data: Data) async {
|
|
if secConnection != nil {
|
|
secWrite(data: data)
|
|
} else {
|
|
rawWrite(data: data)
|
|
}
|
|
}
|
|
|
|
func startTls() throws {
|
|
try initBridgeAndSecureConnection()
|
|
}
|
|
}
|
|
|
|
private extension StartTLSSocket {
|
|
func rawRead() {
|
|
rawConnection?.receive(minimumIncompleteLength: 1, maximumLength: 4096 * 2, completion: { [weak self] data, _, _, error in
|
|
if let err = error {
|
|
self?.eventsContinuation.yield(.state(.disconnected(err)))
|
|
return
|
|
}
|
|
guard let data else { return }
|
|
self?.log("raw", data: data, read: true)
|
|
if let brg = self?.brgConnection {
|
|
brg.send(content: data, completion: .contentProcessed { [weak self] error in
|
|
if let err = error {
|
|
self?.eventsContinuation.yield(.state(.disconnected(err)))
|
|
}
|
|
})
|
|
} else {
|
|
self?.eventsContinuation.yield(.dataReceived(data))
|
|
}
|
|
self?.rawRead()
|
|
})
|
|
}
|
|
|
|
func rawWrite(data: Data) {
|
|
log("raw", data: data, read: false)
|
|
rawConnection?.send(content: data, completion: .contentProcessed { [weak self] error in
|
|
if let err = error {
|
|
self?.eventsContinuation.yield(.state(.disconnected(err)))
|
|
}
|
|
})
|
|
}
|
|
|
|
func brgRead() {
|
|
brgConnection?.receive(minimumIncompleteLength: 1, maximumLength: 4096 * 2, completion: { [weak self] data, _, _, error in
|
|
if let err = error {
|
|
self?.eventsContinuation.yield(.state(.disconnected(err)))
|
|
return
|
|
}
|
|
guard let data else { return }
|
|
self?.rawWrite(data: data)
|
|
self?.brgRead()
|
|
})
|
|
}
|
|
|
|
func secRead() {
|
|
secConnection?.receive(minimumIncompleteLength: 1, maximumLength: 4096 * 2, completion: { [weak self] data, _, _, error in
|
|
if let err = error {
|
|
self?.eventsContinuation.yield(.state(.disconnected(err)))
|
|
return
|
|
}
|
|
guard let data else { return }
|
|
self?.log("sec", data: data, read: true)
|
|
self?.eventsContinuation.yield(.dataReceived(data))
|
|
self?.secRead()
|
|
})
|
|
}
|
|
|
|
func secWrite(data: Data) {
|
|
log("sec", data: data, read: false)
|
|
secConnection?.send(content: data, completion: .contentProcessed { [weak self] error in
|
|
if let err = error {
|
|
self?.eventsContinuation.yield(.state(.disconnected(err)))
|
|
}
|
|
})
|
|
}
|
|
|
|
func initBridgeAndSecureConnection() throws {
|
|
let params = NWParameters()
|
|
params.defaultProtocolStack.transportProtocol = NWProtocolTCP.Options()
|
|
params.requiredLocalEndpoint = NWEndpoint.unix(path: bridgeSocketPath.path)
|
|
params.allowLocalEndpointReuse = false
|
|
bridge = try NWListener(using: params)
|
|
bridge?.newConnectionLimit = 1
|
|
|
|
// make bridge and connection
|
|
bridge?.stateUpdateHandler = { [weak self] state in
|
|
guard let self else { return }
|
|
self.logState("Bridge \(state)")
|
|
switch state {
|
|
case .ready:
|
|
// tcp options for secure connection
|
|
let tcpOptions = NWProtocolTCP.Options()
|
|
tcpOptions.noDelay = true
|
|
tcpOptions.enableFastOpen = true
|
|
tcpOptions.disableAckStretching = true
|
|
|
|
// tls options for secure connection
|
|
let tlsOptions = NWProtocolTLS.Options()
|
|
sec_protocol_options_set_min_tls_protocol_version(tlsOptions.securityProtocolOptions, .TLSv12)
|
|
sec_protocol_options_set_max_tls_protocol_version(tlsOptions.securityProtocolOptions, .TLSv13)
|
|
// sec_protocol_options_set_peer_authentication_required(tlsOptions.securityProtocolOptions, false)
|
|
if let domain = self.host.cString(using: .utf8) {
|
|
sec_protocol_options_set_tls_server_name(tlsOptions.securityProtocolOptions, domain)
|
|
}
|
|
sec_protocol_options_set_verify_block(tlsOptions.securityProtocolOptions, { _, sec_trust, sec_protocol_verify_complete in
|
|
if self.allowInsecure {
|
|
sec_protocol_verify_complete(true)
|
|
} else {
|
|
let trust = sec_trust_copy_ref(sec_trust).takeRetainedValue()
|
|
var error: CFError?
|
|
if SecTrustEvaluateWithError(trust, &error) {
|
|
sec_protocol_verify_complete(true)
|
|
} else {
|
|
sec_protocol_verify_complete(false)
|
|
}
|
|
}
|
|
}, queue)
|
|
|
|
// secure connection
|
|
let params = NWParameters(tls: tlsOptions, tcp: tcpOptions)
|
|
params.serviceClass = .responsiveData
|
|
self.secConnection = NWConnection(to: .unix(path: self.bridgeSocketPath.path), using: params)
|
|
self.secConnection?.stateUpdateHandler = { [weak self] state in
|
|
self?.logState("Secure connection \(state)")
|
|
switch state {
|
|
case .ready:
|
|
self?.eventsContinuation.yield(.state(.startTlsReady))
|
|
self?.secRead()
|
|
|
|
case .waiting(let error), .failed(let error):
|
|
self?.eventsContinuation.yield(.state(.disconnected(error)))
|
|
|
|
default:
|
|
break
|
|
}
|
|
}
|
|
self.secConnection?.start(queue: queue)
|
|
|
|
case .waiting(let error), .failed(let error):
|
|
self.eventsContinuation.yield(.state(.disconnected(error)))
|
|
|
|
default:
|
|
break
|
|
}
|
|
}
|
|
|
|
// bridge handler
|
|
bridge?.newConnectionHandler = { [weak self] connection in
|
|
guard let self else { return }
|
|
self.brgConnection = connection
|
|
self.brgConnection?.stateUpdateHandler = { state in
|
|
self.logState("Bridge connection \(state)")
|
|
switch state {
|
|
case .ready:
|
|
self.brgRead()
|
|
|
|
case .waiting(let error), .failed(let error):
|
|
self.eventsContinuation.yield(.state(.disconnected(error)))
|
|
|
|
default:
|
|
break
|
|
}
|
|
}
|
|
connection.start(queue: self.queue)
|
|
}
|
|
bridge?.start(queue: queue)
|
|
}
|
|
|
|
func log(_ socket: String, data: Data, read: Bool) {
|
|
if socket == "raw", secConnection != nil { return }
|
|
if doPrint {
|
|
let direction = read ? "read-in" : "write-out"
|
|
let str = String(bytes: data, encoding: .ascii) ?? "?"
|
|
print("\nStartTLSSocket-\(socket) \(direction): \(str)\n")
|
|
}
|
|
}
|
|
|
|
func logState(_ str: String) {
|
|
if doPrint {
|
|
print("Connection state: \(str)")
|
|
}
|
|
}
|
|
}
|