From 8f014d5525e234a7bd25f120b27480dd15044c82 Mon Sep 17 00:00:00 2001 From: Daniel Gultsch Date: Fri, 6 Oct 2023 13:28:55 +0200 Subject: [PATCH] implement Private DNS (DoT) due to limitations in the MiniDNS library this does not work when 'Validate hostname with DNSSEC' is enabled in the expert settings --- .../de/gultsch/minidns/AndroidDNSClient.java | 123 ++++++++++++ .../java/de/gultsch/minidns/DNSServer.java | 104 ++++++++++ .../java/de/gultsch/minidns/DNSSocket.java | 190 ++++++++++++++++++ .../de/gultsch/minidns/NetworkDataSource.java | 160 +++++++++++++++ .../java/de/gultsch/minidns/Transport.java | 23 +++ .../siacs/conversations/utils/Resolver.java | 5 +- 6 files changed, 604 insertions(+), 1 deletion(-) create mode 100644 src/main/java/de/gultsch/minidns/AndroidDNSClient.java create mode 100644 src/main/java/de/gultsch/minidns/DNSServer.java create mode 100644 src/main/java/de/gultsch/minidns/DNSSocket.java create mode 100644 src/main/java/de/gultsch/minidns/NetworkDataSource.java create mode 100644 src/main/java/de/gultsch/minidns/Transport.java diff --git a/src/main/java/de/gultsch/minidns/AndroidDNSClient.java b/src/main/java/de/gultsch/minidns/AndroidDNSClient.java new file mode 100644 index 000000000..194ad23bd --- /dev/null +++ b/src/main/java/de/gultsch/minidns/AndroidDNSClient.java @@ -0,0 +1,123 @@ +package de.gultsch.minidns; + +import android.content.Context; +import android.net.ConnectivityManager; +import android.net.LinkProperties; +import android.net.Network; +import android.os.Build; + +import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; + +import de.measite.minidns.AbstractDNSClient; +import de.measite.minidns.DNSMessage; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.List; + +public class AndroidDNSClient extends AbstractDNSClient { + private final Context context; + private final NetworkDataSource networkDataSource = new NetworkDataSource(); + private boolean askForDnssec = false; + + public AndroidDNSClient(final Context context) { + super(); + this.setDataSource(networkDataSource); + this.context = context; + } + + private static String getPrivateDnsServerName(final LinkProperties linkProperties) { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.P) { + return linkProperties.getPrivateDnsServerName(); + } else { + return null; + } + } + + private static boolean isPrivateDnsActive(final LinkProperties linkProperties) { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.P) { + return linkProperties.isPrivateDnsActive(); + } else { + return false; + } + } + + @Override + protected DNSMessage.Builder newQuestion(final DNSMessage.Builder message) { + message.setRecursionDesired(true); + message.getEdnsBuilder() + .setUdpPayloadSize(networkDataSource.getUdpPayloadSize()) + .setDnssecOk(askForDnssec); + return message; + } + + @Override + protected DNSMessage query(final DNSMessage.Builder queryBuilder) throws IOException { + final DNSMessage question = newQuestion(queryBuilder).build(); + for (final DNSServer dnsServer : getDNSServers()) { + final DNSMessage response = this.networkDataSource.query(question, dnsServer); + if (response == null) { + continue; + } + switch (response.responseCode) { + case NO_ERROR: + case NX_DOMAIN: + break; + default: + continue; + } + + return response; + } + return null; + } + + public boolean isAskForDnssec() { + return askForDnssec; + } + + public void setAskForDnssec(boolean askForDnssec) { + this.askForDnssec = askForDnssec; + } + + private List getDNSServers() { + final ImmutableList.Builder dnsServerBuilder = new ImmutableList.Builder<>(); + final ConnectivityManager connectivityManager = + (ConnectivityManager) context.getSystemService(Context.CONNECTIVITY_SERVICE); + final Network[] networks = getActiveNetworks(connectivityManager); + for (final Network network : networks) { + final LinkProperties linkProperties = connectivityManager.getLinkProperties(network); + if (linkProperties == null) { + continue; + } + final String privateDnsServerName = getPrivateDnsServerName(linkProperties); + if (Strings.isNullOrEmpty(privateDnsServerName)) { + final boolean isPrivateDns = isPrivateDnsActive(linkProperties); + for (final InetAddress dnsServer : linkProperties.getDnsServers()) { + if (isPrivateDns) { + dnsServerBuilder.add(new DNSServer(dnsServer, Transport.TLS)); + } else { + dnsServerBuilder.add(new DNSServer(dnsServer)); + } + } + } else { + dnsServerBuilder.add(new DNSServer(privateDnsServerName, Transport.TLS)); + } + } + return dnsServerBuilder.build(); + } + + private Network[] getActiveNetworks(final ConnectivityManager connectivityManager) { + if (connectivityManager == null) { + return new Network[0]; + } + if (android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.M) { + final Network activeNetwork = connectivityManager.getActiveNetwork(); + if (activeNetwork != null) { + return new Network[] {activeNetwork}; + } + } + return connectivityManager.getAllNetworks(); + } +} diff --git a/src/main/java/de/gultsch/minidns/DNSServer.java b/src/main/java/de/gultsch/minidns/DNSServer.java new file mode 100644 index 000000000..7486ec2c6 --- /dev/null +++ b/src/main/java/de/gultsch/minidns/DNSServer.java @@ -0,0 +1,104 @@ +package de.gultsch.minidns; + +import com.google.common.base.MoreObjects; +import com.google.common.base.Objects; +import com.google.common.base.Preconditions; +import com.google.common.collect.Iterables; + +import java.net.InetAddress; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import javax.annotation.Nonnull; + +public final class DNSServer { + + public final InetAddress inetAddress; + public final String hostname; + public final int port; + public final List transports; + + public DNSServer(InetAddress inetAddress, Integer port, Transport transport) { + this.inetAddress = inetAddress; + this.port = port == null ? 0 : port; + this.transports = Collections.singletonList(transport); + this.hostname = null; + } + + public DNSServer(final String hostname, final Integer port, final Transport transport) { + Preconditions.checkArgument( + Arrays.asList(Transport.HTTPS, Transport.TLS).contains(transport), + "hostname validation only works with TLS based transports"); + this.hostname = hostname; + this.port = port == null ? 0 : port; + this.transports = Collections.singletonList(transport); + this.inetAddress = null; + } + + public DNSServer(final String hostname, final Transport transport) { + this(hostname, Transport.DEFAULT_PORTS.get(transport), transport); + } + + public DNSServer(InetAddress inetAddress, Transport transport) { + this(inetAddress, Transport.DEFAULT_PORTS.get(transport), transport); + } + + public DNSServer(final InetAddress inetAddress) { + this(inetAddress, 53, Arrays.asList(Transport.UDP, Transport.TCP)); + } + + public DNSServer(final InetAddress inetAddress, int port, List transports) { + this(inetAddress, null, port, transports); + } + + private DNSServer( + final InetAddress inetAddress, + final String hostname, + final int port, + final List transports) { + this.inetAddress = inetAddress; + this.hostname = hostname; + this.port = port; + this.transports = transports; + } + + public Transport uniqueTransport() { + return Iterables.getOnlyElement(this.transports); + } + + public DNSServer asUniqueTransport(final Transport transport) { + Preconditions.checkArgument( + this.transports.contains(transport), + "This DNS server does not have transport ", + transport); + return new DNSServer(inetAddress, hostname, port, Collections.singletonList(transport)); + } + + @Override + @Nonnull + public String toString() { + return MoreObjects.toStringHelper(this) + .add("inetAddress", inetAddress) + .add("hostname", hostname) + .add("port", port) + .add("transports", transports) + .toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DNSServer dnsServer = (DNSServer) o; + return port == dnsServer.port + && Objects.equal(inetAddress, dnsServer.inetAddress) + && Objects.equal(hostname, dnsServer.hostname) + && Objects.equal(transports, dnsServer.transports); + } + + @Override + public int hashCode() { + return Objects.hashCode(inetAddress, hostname, port, transports); + } +} diff --git a/src/main/java/de/gultsch/minidns/DNSSocket.java b/src/main/java/de/gultsch/minidns/DNSSocket.java new file mode 100644 index 000000000..4b096e0b2 --- /dev/null +++ b/src/main/java/de/gultsch/minidns/DNSSocket.java @@ -0,0 +1,190 @@ +package de.gultsch.minidns; + +import android.util.Log; + +import com.google.common.base.Preconditions; +import com.google.common.base.Strings; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; + +import de.measite.minidns.DNSMessage; + +import eu.siacs.conversations.Config; + +import org.conscrypt.OkHostnameVerifier; + +import java.io.Closeable; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.EOFException; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.net.SocketAddress; +import java.security.cert.Certificate; +import java.security.cert.X509Certificate; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Semaphore; + +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; + +final class DNSSocket implements Closeable { + + private static final int CONNECT_TIMEOUT = 5_000; + + private final Semaphore semaphore = new Semaphore(1); + private final Map> inFlightQueries = new HashMap<>(); + private final Socket socket; + private final DataInputStream dataInputStream; + private final DataOutputStream dataOutputStream; + + private DNSSocket( + final Socket socket, + final DataInputStream dataInputStream, + final DataOutputStream dataOutputStream) { + this.socket = socket; + this.dataInputStream = dataInputStream; + this.dataOutputStream = dataOutputStream; + new Thread(this::readDNSMessages).start(); + } + + private void readDNSMessages() { + try { + while (socket.isConnected()) { + final DNSMessage response = readDNSMessage(); + final SettableFuture future; + synchronized (inFlightQueries) { + future = inFlightQueries.remove(response.id); + } + if (future != null) { + future.set(response); + } else { + Log.e(Config.LOGTAG, "no in flight query found for response id " + response.id); + } + } + evictInFlightQueries(new EOFException()); + } catch (final IOException e) { + evictInFlightQueries(e); + } + } + + private void evictInFlightQueries(final Exception e) { + synchronized (inFlightQueries) { + final Iterator>> iterator = + inFlightQueries.entrySet().iterator(); + while (iterator.hasNext()) { + final Map.Entry> entry = iterator.next(); + entry.getValue().setException(e); + iterator.remove(); + } + } + } + + private static DNSSocket of(final Socket socket) throws IOException { + final DataInputStream dataInputStream = new DataInputStream(socket.getInputStream()); + final DataOutputStream dataOutputStream = new DataOutputStream(socket.getOutputStream()); + return new DNSSocket(socket, dataInputStream, dataOutputStream); + } + + public static DNSSocket connect(final DNSServer dnsServer) throws IOException { + switch (dnsServer.uniqueTransport()) { + case TCP: + return connectTcpSocket(dnsServer); + case TLS: + return connectTlsSocket(dnsServer); + default: + throw new IllegalStateException("This is not a socket based transport"); + } + } + + private static DNSSocket connectTcpSocket(final DNSServer dnsServer) throws IOException { + Preconditions.checkArgument(dnsServer.uniqueTransport() == Transport.TCP); + final SocketAddress socketAddress = + new InetSocketAddress(dnsServer.inetAddress, dnsServer.port); + final Socket socket = new Socket(); + socket.connect(socketAddress, CONNECT_TIMEOUT); + return DNSSocket.of(socket); + } + + private static DNSSocket connectTlsSocket(final DNSServer dnsServer) throws IOException { + Preconditions.checkArgument(dnsServer.uniqueTransport() == Transport.TLS); + final SSLSocketFactory factory = (SSLSocketFactory) SSLSocketFactory.getDefault(); + final SSLSocket sslSocket; + if (Strings.isNullOrEmpty(dnsServer.hostname)) { + final SocketAddress socketAddress = + new InetSocketAddress(dnsServer.inetAddress, dnsServer.port); + sslSocket = (SSLSocket) factory.createSocket(dnsServer.inetAddress, dnsServer.port); + sslSocket.connect(socketAddress, 5_000); + } else { + sslSocket = (SSLSocket) factory.createSocket(dnsServer.hostname, dnsServer.port); + final SSLSession session = sslSocket.getSession(); + final Certificate[] peerCertificates = session.getPeerCertificates(); + if (peerCertificates.length == 0 || !(peerCertificates[0] instanceof X509Certificate)) { + throw new IOException("Peer did not provide X509 certificates"); + } + final X509Certificate certificate = (X509Certificate) peerCertificates[0]; + if (!OkHostnameVerifier.strictInstance().verify(dnsServer.hostname, certificate)) { + throw new SSLPeerUnverifiedException("Peer did not provide valid certificates"); + } + } + return DNSSocket.of(sslSocket); + } + + public DNSMessage query(final DNSMessage query) throws IOException, InterruptedException { + try { + return queryAsync(query).get(); + } catch (final ExecutionException e) { + final Throwable cause = e.getCause(); + if (cause instanceof IOException) { + throw (IOException) cause; + } else { + throw new IOException(e); + } + } + } + + public ListenableFuture queryAsync(final DNSMessage query) + throws InterruptedException, IOException { + final SettableFuture responseFuture = SettableFuture.create(); + synchronized (this.inFlightQueries) { + this.inFlightQueries.put(query.id, responseFuture); + } + this.semaphore.acquire(); + try { + query.writeTo(this.dataOutputStream); + this.dataOutputStream.flush(); + } finally { + this.semaphore.release(); + } + return responseFuture; + } + + private DNSMessage readDNSMessage() throws IOException { + final int length = this.dataInputStream.readUnsignedShort(); + byte[] data = new byte[length]; + int read = 0; + while (read < length) { + read += this.dataInputStream.read(data, read, length - read); + } + return new DNSMessage(data); + } + + @Override + public void close() throws IOException { + this.socket.close(); + } + + public void closeQuietly() { + try { + this.socket.close(); + } catch (final IOException ignored) { + + } + } +} diff --git a/src/main/java/de/gultsch/minidns/NetworkDataSource.java b/src/main/java/de/gultsch/minidns/NetworkDataSource.java new file mode 100644 index 000000000..1ba56f3c0 --- /dev/null +++ b/src/main/java/de/gultsch/minidns/NetworkDataSource.java @@ -0,0 +1,160 @@ +package de.gultsch.minidns; + +import android.util.Log; + +import androidx.annotation.NonNull; + +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import com.google.common.cache.RemovalListener; +import com.google.common.collect.ImmutableList; + +import de.measite.minidns.DNSMessage; +import de.measite.minidns.MiniDNSException; +import de.measite.minidns.source.DNSDataSource; +import de.measite.minidns.util.MultipleIoException; + +import eu.siacs.conversations.Config; + +import java.io.IOException; +import java.net.DatagramPacket; +import java.net.DatagramSocket; +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; + +public class NetworkDataSource extends DNSDataSource { + + private static final LoadingCache socketCache = + CacheBuilder.newBuilder() + .removalListener( + (RemovalListener) + notification -> { + final DNSServer dnsServer = notification.getKey(); + final DNSSocket dnsSocket = notification.getValue(); + if (dnsSocket == null) { + return; + } + Log.d(Config.LOGTAG, "closing connection to " + dnsServer); + dnsSocket.closeQuietly(); + }) + .expireAfterAccess(5, TimeUnit.MINUTES) + .build( + new CacheLoader() { + @Override + @NonNull + public DNSSocket load(@NonNull final DNSServer dnsServer) + throws Exception { + Log.d(Config.LOGTAG, "establishing connection to " + dnsServer); + return DNSSocket.connect(dnsServer); + } + }); + + private static List transportsForPort(final int port) { + final ImmutableList.Builder transportBuilder = new ImmutableList.Builder<>(); + for (final Map.Entry entry : Transport.DEFAULT_PORTS.entrySet()) { + if (entry.getValue().equals(port)) { + transportBuilder.add(entry.getKey()); + } + } + return transportBuilder.build(); + } + + @Override + public DNSMessage query(final DNSMessage message, final InetAddress address, final int port) + throws IOException { + final List transports = transportsForPort(port); + Log.w( + Config.LOGTAG, + "using legacy DataSource interface. guessing transports " + + transports + + " from port"); + if (transports.isEmpty()) { + throw new IOException(String.format("No transports found for port %d", port)); + } + return query(message, new DNSServer(address, port, transports)); + } + + public DNSMessage query(final DNSMessage message, final DNSServer dnsServer) + throws IOException { + Log.d(Config.LOGTAG, "using " + dnsServer); + final List ioExceptions = new ArrayList<>(); + for (final Transport transport : dnsServer.transports) { + try { + final DNSMessage response = + queryWithUniqueTransport(message, dnsServer.asUniqueTransport(transport)); + if (response != null && !response.truncated) { + return response; + } + } catch (final IOException e) { + ioExceptions.add(e); + } catch (final InterruptedException e) { + return null; + } + } + MultipleIoException.throwIfRequired(ioExceptions); + return null; + } + + private DNSMessage queryWithUniqueTransport(final DNSMessage message, final DNSServer dnsServer) + throws IOException, InterruptedException { + final Transport transport = dnsServer.uniqueTransport(); + switch (transport) { + case UDP: + return queryUdp(message, dnsServer.inetAddress, dnsServer.port); + case TCP: + case TLS: + return queryDnsSocket(message, dnsServer); + default: + throw new IOException( + String.format("Transport %s has not been implemented", transport)); + } + } + + protected DNSMessage queryUdp( + final DNSMessage message, final InetAddress address, final int port) + throws IOException { + final DatagramPacket request = message.asDatagram(address, port); + final byte[] buffer = new byte[udpPayloadSize]; + try (final DatagramSocket socket = new DatagramSocket()) { + socket.setSoTimeout(timeout); + socket.send(request); + final DatagramPacket response = new DatagramPacket(buffer, buffer.length); + socket.receive(response); + DNSMessage dnsMessage = new DNSMessage(response.getData()); + if (dnsMessage.id != message.id) { + throw new MiniDNSException.IdMismatch(message, dnsMessage); + } + return dnsMessage; + } + } + + protected DNSMessage queryDnsSocket(final DNSMessage message, final DNSServer dnsServer) + throws IOException, InterruptedException { + final DNSSocket cachedDnsSocket = socketCache.getIfPresent(dnsServer); + if (cachedDnsSocket != null) { + try { + return cachedDnsSocket.query(message); + } catch (final IOException e) { + Log.d( + Config.LOGTAG, + "IOException occurred at cached socket. invalidating and falling through to new socket creation"); + socketCache.invalidate(dnsServer); + } + } + try { + return socketCache.get(dnsServer).query(message); + } catch (final ExecutionException e) { + final Throwable cause = e.getCause(); + if (cause instanceof IOException) { + throw (IOException) cause; + } else { + throw new IOException(cause); + } + } + } +} diff --git a/src/main/java/de/gultsch/minidns/Transport.java b/src/main/java/de/gultsch/minidns/Transport.java new file mode 100644 index 000000000..3aabfacaa --- /dev/null +++ b/src/main/java/de/gultsch/minidns/Transport.java @@ -0,0 +1,23 @@ +package de.gultsch.minidns; + +import com.google.common.collect.ImmutableMap; + +import java.util.Map; + +public enum Transport { + UDP, + TCP, + TLS, + HTTPS; + + public static final Map DEFAULT_PORTS; + + static { + final ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); + builder.put(Transport.UDP, 53); + builder.put(Transport.TCP, 53); + builder.put(Transport.TLS, 853); + builder.put(Transport.HTTPS, 443); + DEFAULT_PORTS = builder.build(); + } +} diff --git a/src/main/java/eu/siacs/conversations/utils/Resolver.java b/src/main/java/eu/siacs/conversations/utils/Resolver.java index 463d6eb73..444deda4a 100644 --- a/src/main/java/eu/siacs/conversations/utils/Resolver.java +++ b/src/main/java/eu/siacs/conversations/utils/Resolver.java @@ -15,6 +15,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import de.gultsch.minidns.AndroidDNSClient; import de.measite.minidns.AbstractDNSClient; import de.measite.minidns.DNSCache; import de.measite.minidns.DNSClient; @@ -274,7 +275,9 @@ public class Resolver { private static ResolverResult resolveWithFallback(DNSName dnsName, Class type, boolean validateHostname) throws IOException { final Question question = new Question(dnsName, Record.TYPE.getType(type)); if (!validateHostname) { - return ResolverApi.INSTANCE.resolve(question); + final AndroidDNSClient androidDNSClient = new AndroidDNSClient(SERVICE); + final ResolverApi resolverApi = new ResolverApi(androidDNSClient); + return resolverApi.resolve(question); } try { return DnssecResolverApi.INSTANCE.resolveDnssecReliable(question);