From 4addeaa3564d7144bad3d8f8eee7b24dff770df3 Mon Sep 17 00:00:00 2001 From: Daniel Gultsch Date: Thu, 2 Mar 2023 18:44:27 +0100 Subject: [PATCH] use futures in DNS resolver --- .../conversations/android/dns/Resolver.java | 343 +++++------------- .../android/dns/ServiceRecord.java | 124 +++++++ .../android/xmpp/XmppConnection.java | 14 +- 3 files changed, 231 insertions(+), 250 deletions(-) create mode 100644 app/src/main/java/im/conversations/android/dns/ServiceRecord.java diff --git a/app/src/main/java/im/conversations/android/dns/Resolver.java b/app/src/main/java/im/conversations/android/dns/Resolver.java index 741338f3d..4db9d2d4c 100644 --- a/app/src/main/java/im/conversations/android/dns/Resolver.java +++ b/app/src/main/java/im/conversations/android/dns/Resolver.java @@ -2,9 +2,12 @@ package im.conversations.android.dns; import android.app.Application; import android.content.Context; -import androidx.annotation.NonNull; -import com.google.common.base.MoreObjects; -import com.google.common.base.Objects; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Ordering; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.MoreExecutors; import de.measite.minidns.AbstractDNSClient; import de.measite.minidns.DNSCache; import de.measite.minidns.DNSClient; @@ -24,14 +27,17 @@ import de.measite.minidns.record.CNAME; import de.measite.minidns.record.Data; import de.measite.minidns.record.InternetAddressRR; import de.measite.minidns.record.SRV; +import im.conversations.android.database.model.Connection; import java.io.IOException; import java.lang.reflect.Field; -import java.net.Inet4Address; import java.net.InetAddress; import java.net.UnknownHostException; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.stream.Collectors; import org.jxmpp.jid.DomainJid; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -45,6 +51,8 @@ public class Resolver { private static final String DIRECT_TLS_SERVICE = "_xmpps-client"; private static final String STARTTLS_SERVICE = "_xmpp-client"; + private static final Executor EXECUTOR = Executors.newFixedThreadPool(4); + private static Context SERVICE; public static void init(final Application application) { @@ -75,13 +83,15 @@ public class Resolver { } } - public static List fromHardCoded(final String hostname, final int port) { - final Result result = new Result(); - result.hostname = DNSName.from(hostname); - result.port = port; - result.directTls = useDirectTls(port); - result.authenticated = true; - return Collections.singletonList(result); + public static List fromHardCoded(final Connection connection) { + return Collections.singletonList( + new ServiceRecord( + null, + DNSName.from(connection.hostname), + connection.port, + connection.directTls, + 0, + true)); } public static void checkDomain(final DomainJid jid) { @@ -106,183 +116,132 @@ public class Resolver { } } - public static boolean useDirectTls(final int port) { - return port == 443 || port == 5223; - } - - public static List resolve(String domain) { - final List ipResults = fromIpAddress(domain); + public static List resolve(final String domain) { + final List ipResults = fromIpAddress(domain); if (ipResults.size() > 0) { return ipResults; } - final List results = new ArrayList<>(); - final List fallbackResults = new ArrayList<>(); - final Thread[] threads = new Thread[3]; - threads[0] = - new Thread( - () -> { - try { - final List list = resolveSrv(domain, true); - synchronized (results) { - results.addAll(list); - } - } catch (final Throwable throwable) { - LOGGER.debug("error resolving SRV record (direct TLS)", throwable); + final ListenableFuture> directTlsSrvRecords = + Futures.submitAsync(() -> resolveSrv(domain, true), EXECUTOR); + final ListenableFuture> startTlsSrvRecords = + Futures.submitAsync(() -> resolveSrv(domain, false), EXECUTOR); + final ListenableFuture> srvRecords = + Futures.transform( + Futures.allAsList(directTlsSrvRecords, startTlsSrvRecords), + input -> { + final var list = + input.stream() + .flatMap(List::stream) + .collect(Collectors.toList()); + if (list.isEmpty()) { + throw new IllegalStateException("No SRV records found"); } - }); - threads[1] = - new Thread( - () -> { - try { - final List list = resolveSrv(domain, false); - synchronized (results) { - results.addAll(list); - } - - } catch (Throwable throwable) { - LOGGER.debug( - "error resolving SRV record (direct STARTTLS)", throwable); - } - }); - threads[2] = - new Thread( - () -> { - List list = resolveNoSrvRecords(DNSName.from(domain), true); - synchronized (fallbackResults) { - fallbackResults.addAll(list); - } - }); - for (final Thread thread : threads) { - thread.start(); - } + return list; + }, + MoreExecutors.directExecutor()); + final ListenableFuture> fallback = + Futures.submit(() -> resolveNoSrvRecords(DNSName.from(domain), true), EXECUTOR); + final var resultFuture = + Futures.catchingAsync( + srvRecords, + Exception.class, + input -> fallback, + MoreExecutors.directExecutor()); try { - threads[0].join(); - threads[1].join(); - if (results.size() > 0) { - threads[2].interrupt(); - synchronized (results) { - Collections.sort(results); - LOGGER.info("{}", results); - return new ArrayList<>(results); - } - } else { - threads[2].join(); - synchronized (fallbackResults) { - Collections.sort(fallbackResults); - LOGGER.info("fallback {}", fallbackResults); - return new ArrayList<>(fallbackResults); - } - } - } catch (InterruptedException e) { - for (Thread thread : threads) { - thread.interrupt(); - } + return Ordering.natural().sortedCopy(resultFuture.get()); + } catch (final Exception e) { return Collections.emptyList(); } } - private static List fromIpAddress(String domain) { + private static List fromIpAddress(final String domain) { if (!IP.matches(domain)) { return Collections.emptyList(); } + final InetAddress ip; try { - Result result = new Result(); - result.ip = InetAddress.getByName(domain); - result.port = DEFAULT_PORT_XMPP; - return Collections.singletonList(result); - } catch (UnknownHostException e) { + ip = InetAddress.getByName(domain); + } catch (final UnknownHostException e) { return Collections.emptyList(); } + return Collections.singletonList(new ServiceRecord(ip, null, DEFAULT_PORT_XMPP, false, 0, false)); } - private static List resolveSrv(String domain, final boolean directTls) - throws IOException { + private static ListenableFuture> resolveSrv( + final String domain, final boolean directTls) throws IOException { DNSName dnsName = DNSName.from( (directTls ? DIRECT_TLS_SERVICE : STARTTLS_SERVICE) + "._tcp." + domain); - ResolverResult result = resolveWithFallback(dnsName, SRV.class); - final List results = new ArrayList<>(); - final List threads = new ArrayList<>(); - for (SRV record : result.getAnswersOrEmptySet()) { + final ResolverResult result = resolveWithFallback(dnsName, SRV.class); + final List>> results = new ArrayList<>(); + for (final SRV record : result.getAnswersOrEmptySet()) { if (record.name.length() == 0 && record.priority == 0) { continue; } - threads.add( - new Thread( + results.add( + Futures.submit( () -> { - final List ipv4s = + final List ipv4s = resolveIp( record, A.class, result.isAuthenticData(), directTls); - if (ipv4s.size() == 0) { - Result resolverResult = Result.fromRecord(record, directTls); - resolverResult.authenticated = result.isAuthenticData(); - ipv4s.add(resolverResult); + if (ipv4s.isEmpty()) { + return Collections.singletonList( + ServiceRecord.fromRecord( + record, directTls, result.isAuthenticData())); + } else { + return ipv4s; } - synchronized (results) { - results.addAll(ipv4s); - } - })); - threads.add( - new Thread( - () -> { - final List ipv6s = - resolveIp( - record, - AAAA.class, - result.isAuthenticData(), - directTls); - synchronized (results) { - results.addAll(ipv6s); - } - })); + }, + EXECUTOR)); + results.add( + Futures.submit( + () -> + resolveIp( + record, + AAAA.class, + result.isAuthenticData(), + directTls), + EXECUTOR)); } - for (Thread thread : threads) { - thread.start(); - } - for (Thread thread : threads) { - try { - thread.join(); - } catch (InterruptedException e) { - return Collections.emptyList(); - } - } - return results; + return Futures.transform( + Futures.allAsList(results), + input -> input.stream().flatMap(List::stream).collect(Collectors.toList()), + MoreExecutors.directExecutor()); } - private static List resolveIp( + private static List resolveIp( SRV srv, Class type, boolean authenticated, boolean directTls) { - List list = new ArrayList<>(); + final ImmutableList.Builder builder = new ImmutableList.Builder<>(); try { ResolverResult results = resolveWithFallback(srv.name, type, authenticated); for (D record : results.getAnswersOrEmptySet()) { - Result resolverResult = Result.fromRecord(srv, directTls); - resolverResult.authenticated = - results.isAuthenticData() - && authenticated; // TODO technically it doesn’t matter if the IP - // was authenticated - resolverResult.ip = record.getInetAddress(); - list.add(resolverResult); + builder.add( + ServiceRecord.fromRecord( + srv, + directTls, + results.isAuthenticData() && authenticated, + record.getInetAddress())); } } catch (final Throwable t) { LOGGER.info("error resolving {}", type.getSimpleName(), t); } - return list; + return builder.build(); } - private static List resolveNoSrvRecords(DNSName dnsName, boolean withCnames) { - List results = new ArrayList<>(); + private static List resolveNoSrvRecords(DNSName dnsName, boolean includeCName) { + List results = new ArrayList<>(); try { for (A a : resolveWithFallback(dnsName, A.class, false).getAnswersOrEmptySet()) { - results.add(Result.createDefault(dnsName, a.getInetAddress())); + results.add(ServiceRecord.createDefault(dnsName, a.getInetAddress())); } for (AAAA aaaa : resolveWithFallback(dnsName, AAAA.class, false).getAnswersOrEmptySet()) { - results.add(Result.createDefault(dnsName, aaaa.getInetAddress())); + results.add(ServiceRecord.createDefault(dnsName, aaaa.getInetAddress())); } - if (results.size() == 0 && withCnames) { + if (results.size() == 0 && includeCName) { for (CNAME cname : resolveWithFallback(dnsName, CNAME.class, false).getAnswersOrEmptySet()) { results.addAll(resolveNoSrvRecords(cname.name, false)); @@ -291,7 +250,7 @@ public class Resolver { } catch (Throwable throwable) { LOGGER.info("Error resolving fallback records", throwable); } - results.add(Result.createDefault(dnsName)); + results.add(ServiceRecord.createDefault(dnsName)); return results; } @@ -327,108 +286,4 @@ public class Resolver { return false; } - public static class Result implements Comparable { - private InetAddress ip; - private DNSName hostname; - private int port = DEFAULT_PORT_XMPP; - private boolean directTls = false; - private boolean authenticated = false; - private int priority; - - static Result fromRecord(SRV srv, boolean directTls) { - Result result = new Result(); - result.port = srv.port; - result.hostname = srv.name; - result.directTls = directTls; - result.priority = srv.priority; - return result; - } - - static Result createDefault(DNSName hostname, InetAddress ip) { - Result result = new Result(); - result.port = DEFAULT_PORT_XMPP; - result.hostname = hostname; - result.ip = ip; - return result; - } - - static Result createDefault(DNSName hostname) { - return createDefault(hostname, null); - } - - public InetAddress getIp() { - return ip; - } - - public int getPort() { - return port; - } - - public DNSName getHostname() { - return hostname; - } - - public boolean isDirectTls() { - return directTls; - } - - public boolean isAuthenticated() { - return authenticated; - } - - @Override - public int compareTo(@NonNull Result result) { - // TODO use comparison chain. get rid of IPv4 preference - if (result.priority == priority) { - if (directTls == result.directTls) { - if (ip == null && result.ip == null) { - return 0; - } else if (ip != null && result.ip != null) { - if (ip instanceof Inet4Address && result.ip instanceof Inet4Address) { - return 0; - } else { - return ip instanceof Inet4Address ? -1 : 1; - } - } else { - return ip != null ? -1 : 1; - } - } else { - return directTls ? -1 : 1; - } - } else { - return priority - result.priority; - } - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - Result result = (Result) o; - return port == result.port - && directTls == result.directTls - && authenticated == result.authenticated - && priority == result.priority - && Objects.equal(ip, result.ip) - && Objects.equal(hostname, result.hostname); - } - - @Override - public int hashCode() { - return Objects.hashCode(ip, hostname, port, directTls, authenticated, priority); - } - - @NonNull - @Override - public String toString() { - return MoreObjects.toStringHelper(this) - .add("ip", ip) - .add("hostname", hostname) - .add("port", port) - .add("directTls", directTls) - .add("authenticated", authenticated) - .add("priority", priority) - .toString(); - } - } } diff --git a/app/src/main/java/im/conversations/android/dns/ServiceRecord.java b/app/src/main/java/im/conversations/android/dns/ServiceRecord.java new file mode 100644 index 000000000..eb6a3dd8d --- /dev/null +++ b/app/src/main/java/im/conversations/android/dns/ServiceRecord.java @@ -0,0 +1,124 @@ +package im.conversations.android.dns; + +import androidx.annotation.NonNull; + +import com.google.common.base.MoreObjects; +import com.google.common.base.Objects; + +import java.net.InetAddress; + +import de.measite.minidns.DNSName; +import de.measite.minidns.record.SRV; + +public class ServiceRecord implements Comparable { + private final InetAddress ip; + private final DNSName hostname; + private final int port; + private final boolean directTls; + private final int priority; + private final boolean authenticated; + + public ServiceRecord( + InetAddress ip, + DNSName hostname, + int port, + boolean directTls, + int priority, + boolean authenticated) { + this.ip = ip; + this.hostname = hostname; + this.port = port; + this.directTls = directTls; + this.authenticated = authenticated; + this.priority = priority; + } + + public static ServiceRecord fromRecord( + final SRV srv, + final boolean directTls, + final boolean authenticated, + final InetAddress ip) { + return new ServiceRecord(ip, srv.name, srv.port, directTls, srv.priority, authenticated); + } + + public static ServiceRecord fromRecord( + final SRV srv, final boolean directTls, final boolean authenticated) { + return fromRecord(srv, directTls, authenticated, null); + } + + static ServiceRecord createDefault(final DNSName hostname, final InetAddress ip) { + return new ServiceRecord(ip, hostname, Resolver.DEFAULT_PORT_XMPP, false, 0, false); + } + + static ServiceRecord createDefault(final DNSName hostname) { + return createDefault(hostname, null); + } + + public InetAddress getIp() { + return ip; + } + + public int getPort() { + return port; + } + + public DNSName getHostname() { + return hostname; + } + + public boolean isDirectTls() { + return directTls; + } + + public boolean isAuthenticated() { + return authenticated; + } + + @Override + public int compareTo(@NonNull ServiceRecord result) { + if (result.priority == priority) { + if (directTls == result.directTls) { + if (ip == null && result.ip == null) { + return 0; + } else { + return ip != null ? -1 : 1; + } + } else { + return directTls ? -1 : 1; + } + } else { + return priority - result.priority; + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ServiceRecord result = (ServiceRecord) o; + return port == result.port + && directTls == result.directTls + && authenticated == result.authenticated + && priority == result.priority + && Objects.equal(ip, result.ip) + && Objects.equal(hostname, result.hostname); + } + + @Override + public int hashCode() { + return Objects.hashCode(ip, hostname, port, directTls, authenticated, priority); + } + + @NonNull + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("ip", ip) + .add("hostname", hostname) + .add("port", port) + .add("directTls", directTls) + .add("authenticated", authenticated) + .add("priority", priority) + .toString(); + } +} diff --git a/app/src/main/java/im/conversations/android/xmpp/XmppConnection.java b/app/src/main/java/im/conversations/android/xmpp/XmppConnection.java index d21c6c54e..11c94ca54 100644 --- a/app/src/main/java/im/conversations/android/xmpp/XmppConnection.java +++ b/app/src/main/java/im/conversations/android/xmpp/XmppConnection.java @@ -26,6 +26,7 @@ import im.conversations.android.database.model.Account; import im.conversations.android.database.model.Connection; import im.conversations.android.database.model.Credential; import im.conversations.android.dns.Resolver; +import im.conversations.android.dns.ServiceRecord; import im.conversations.android.socks.SocksSocketFactory; import im.conversations.android.tls.SSLSockets; import im.conversations.android.tls.XmppDomainVerifier; @@ -322,12 +323,13 @@ public class XmppConnection implements Runnable { } } else { final String domain = account.address.getDomain().toString(); - final List results; + final List results; if (connection != null) { - results = Resolver.fromHardCoded(connection.hostname, connection.port); + results = Resolver.fromHardCoded(connection); } else { results = Resolver.resolve(domain); } + LOGGER.info("{}", results); if (Thread.currentThread().isInterrupted()) { LOGGER.debug(account.address + ": Thread was interrupted"); return; @@ -336,7 +338,7 @@ public class XmppConnection implements Runnable { LOGGER.warn("Resolver results were empty"); return; } - final Resolver.Result storedBackupResult; + final ServiceRecord storedBackupResult; if (connection != null) { storedBackupResult = null; } else { @@ -351,9 +353,9 @@ public class XmppConnection implements Runnable { + storedBackupResult); } } - for (Iterator iterator = results.iterator(); - iterator.hasNext(); ) { - final Resolver.Result result = iterator.next(); + for (Iterator iterator = results.iterator(); + iterator.hasNext(); ) { + final ServiceRecord result = iterator.next(); if (Thread.currentThread().isInterrupted()) { LOGGER.debug(account.address + ": Thread was interrupted"); return;