import Foundation
import Network

private let doPrint = false

final class DirectTLSSocket: Socket {
    let (events, eventsContinuation) = AsyncStream.makeStream(of: SocketEvent.self, bufferingPolicy: .unbounded)

    private let queue: DispatchQueue
    private var nwConnection: NWConnection?

    init(id: String, host: String, port: Int, allowInsecure: Bool) {
        queue = DispatchQueue(label: "another.xmpp.network.queue_\(id)")
        // tcp options
        let tcpOptions = NWProtocolTCP.Options()
        tcpOptions.noDelay = true
        tcpOptions.connectionTimeout = 5
        tcpOptions.enableFastOpen = true
        tcpOptions.disableAckStretching = true
        let params = NWParameters(tls: nil, tcp: tcpOptions)
        params.serviceClass = .responsiveData

        // tls options
        let tlsOptions = NWProtocolTLS.Options()
        sec_protocol_options_set_min_tls_protocol_version(tlsOptions.securityProtocolOptions, .TLSv12)
        sec_protocol_options_set_max_tls_protocol_version(tlsOptions.securityProtocolOptions, .TLSv13)
        // sec_protocol_options_set_peer_authentication_required(tlsOptions.securityProtocolOptions, false)
        if let domain = host.cString(using: .utf8) {
            sec_protocol_options_set_tls_server_name(tlsOptions.securityProtocolOptions, domain)
        }
        sec_protocol_options_set_verify_block(tlsOptions.securityProtocolOptions, { _, sec_trust, sec_protocol_verify_complete in
            if allowInsecure {
                sec_protocol_verify_complete(true)
            } else {
                let trust = sec_trust_copy_ref(sec_trust).takeRetainedValue()
                var error: CFError?
                if SecTrustEvaluateWithError(trust, &error) {
                    sec_protocol_verify_complete(true)
                } else {
                    sec_protocol_verify_complete(false)
                }
            }
        }, queue)

        // nw connection
        nwConnection = NWConnection(host: .name(host, nil), port: .init(integerLiteral: UInt16(port)), using: .init(tls: tlsOptions, tcp: tcpOptions))
        // swiftlint:disable:next force_unwrapping
        nwConnection!.stateUpdateHandler = { [weak self] state in
            switch state {
            case .ready:
                self?.eventsContinuation.yield(.state(.connected))
                self?.read()

            case .waiting(let error), .failed(let error):
                print(error.localizedDescription)
                self?.eventsContinuation.yield(.state(.disconnected(error)))

            default:
                break
            }
        }
    }

    deinit {
        eventsContinuation.finish()
        nwConnection?.cancel()
    }

    func connect() async {
        nwConnection?.start(queue: queue)
    }

    func send(_ data: Data) async {
        log(data: data, read: false)
        nwConnection?.send(content: data, completion: .contentProcessed { [weak self] error in
            if let err = error {
                self?.eventsContinuation.yield(.state(.disconnected(err)))
            }
        })
    }

    private func read() {
        nwConnection?.receive(minimumIncompleteLength: 1, maximumLength: 4096 * 2, completion: { [weak self] data, _, _, error in
            if let err = error {
                self?.eventsContinuation.yield(.state(.disconnected(err)))
                return
            }
            guard let data else { return }
            self?.log(data: data, read: true)
            self?.eventsContinuation.yield(.dataReceived(data))
            self?.read()
        })
    }

    private func log(data: Data, read: Bool) {
        if doPrint {
            let direction = read ? "read-in" : "write-out"
            let str = String(bytes: data, encoding: .ascii) ?? "?"
            print("\nDirectTls \(direction): \(str)\n")
        }
    }
}