another.im-ios/AnotherIM/xmpp/modules/dns/SRVResolverModule.swift
2024-12-16 08:19:58 +01:00

186 lines
6.4 KiB
Swift

// RFC - 6120: chapter 3
import Combine
import dnssd
import Foundation
// MARK: Public
enum SRVResolverError: Error {
case srvReferenceError
case srvSocketError
case srvTimeout
case srvUnableToComplete
case srvProcessError
}
final class SRVResolverModule: XmppModule {
let id = "SRV resolver module"
func reduce(oldState: ClientState, with event: Event) -> ClientState {
var newState = oldState
switch event {
case .domainResolved(let records):
newState.srvRecords = records
newState.srvRecordIndex = -1 // will be increased on each connection attempt
default:
break
}
return newState
}
func process(state: ClientState, with event: Event) async -> Event? {
switch event {
case .resolveDomain:
let domain = state.jid.domainPart
do {
let records = try await SRVResolver.resolve(domain: domain)
return .domainResolved(records)
} catch let err {
// swiftlint:disable:next force_cast
return .domainResolvingError(err as! SRVResolverError)
}
default:
return nil
}
}
}
// MARK: Private
private enum SRVResolver {
static func resolve(domain: String) async throws -> [SRVRecord] {
// request for non-tls
async let req1 = try withCheckedThrowingContinuation { continuation in
SRVRequest(target: "_xmpp-client._tcp." + domain) { result in
switch result {
case .success(let records):
continuation.resume(returning: records)
case .failure(let error):
continuation.resume(throwing: error)
}
}
.runQuery()
}
// request for tls
async let req2 = try withCheckedThrowingContinuation { continuation in
SRVRequest(target: "_xmpps-client._tcp." + domain) { result in
switch result {
case .success(let records):
continuation.resume(returning: records)
case .failure(let error):
continuation.resume(throwing: error)
}
}
.runQuery()
}
// sort by priority
let (records1, records2) = try (await req1, await req2)
var result = (records1 + records2).sorted(by: { $0.priority < $1.priority })
// for fallback according to RFC 6120 section 3.2.2
if result.isEmpty {
result.append(.init(fallbackTarget: domain))
}
//
return result
}
}
private typealias SRVRequestCompletion = (Result<[SRVRecord], SRVResolverError>) -> Void
private class SRVRequest {
private let queue = DispatchQueue(label: "srv.resolving")
private var dispatchSourceRead: DispatchSourceRead?
private var timeoutTimer: DispatchSourceTimer?
private var serviceRef: DNSServiceRef?
private var socket: dnssd_sock_t = -1
private let timeout = TimeInterval(5)
private let target: String
var records = [SRVRecord]()
var completion: SRVRequestCompletion
init(target: String, completion: @escaping SRVRequestCompletion) {
self.target = target
self.completion = completion
}
func runQuery() {
let result = DNSServiceQueryRecord(
&serviceRef,
kDNSServiceFlagsReturnIntermediates,
UInt32(kDNSServiceInterfaceIndexAny),
target.cString(using: .utf8),
UInt16(kDNSServiceType_SRV),
UInt16(kDNSServiceClass_IN), { _, flags, _, _, _, _, _, rdLen, rdata, _, context in
guard let context = context else {
return
}
let request: SRVRequest = Mem.bridge(context)
if
let data = rdata?.assumingMemoryBound(to: UInt8.self),
let record = SRVRecord(data: Data(bytes: data, count: Int(rdLen)))
// swiftlint:disable:next opening_brace
{
request.records.append(record)
}
if flags & kDNSServiceFlagsMoreComing == 0 {
request.timeoutTimer?.cancel()
request.dispatchSourceRead?.cancel()
request.completion(.success(request.records))
}
},
Mem.bridge(self)
)
switch result {
case DNSServiceErrorType(kDNSServiceErr_NoError):
guard let sdRef = serviceRef else {
timeoutTimer?.cancel()
dispatchSourceRead?.cancel()
completion(.failure(.srvReferenceError))
return
}
socket = DNSServiceRefSockFD(serviceRef)
guard socket != -1 else {
timeoutTimer?.cancel()
dispatchSourceRead?.cancel()
completion(.failure(.srvSocketError))
return
}
dispatchSourceRead = DispatchSource.makeReadSource(fileDescriptor: socket, queue: queue)
dispatchSourceRead?.setEventHandler { [weak self] in
let res = DNSServiceProcessResult(sdRef)
if res != kDNSServiceErr_NoError {
self?.timeoutTimer?.cancel()
self?.dispatchSourceRead?.cancel()
self?.completion(.failure(.srvProcessError))
}
}
dispatchSourceRead?.setCancelHandler {
DNSServiceRefDeallocate(self.serviceRef)
}
dispatchSourceRead?.resume()
timeoutTimer = DispatchSource.makeTimerSource(flags: [], queue: queue)
timeoutTimer?.setEventHandler { [weak self] in
self?.timeoutTimer?.cancel()
self?.dispatchSourceRead?.cancel()
self?.completion(.failure(.srvTimeout))
}
let deadline = DispatchTime(uptimeNanoseconds: DispatchTime.now().uptimeNanoseconds + UInt64(timeout * Double(NSEC_PER_SEC)))
timeoutTimer?.schedule(deadline: deadline, repeating: .infinity, leeway: DispatchTimeInterval.never)
timeoutTimer?.resume()
default:
timeoutTimer?.cancel()
dispatchSourceRead?.cancel()
completion(.failure(.srvUnableToComplete))
}
}
}