diff --git a/src/main/java/eu/siacs/conversations/crypto/sasl/SaslMechanism.java b/src/main/java/eu/siacs/conversations/crypto/sasl/SaslMechanism.java index f6024210a..b255b6f42 100644 --- a/src/main/java/eu/siacs/conversations/crypto/sasl/SaslMechanism.java +++ b/src/main/java/eu/siacs/conversations/crypto/sasl/SaslMechanism.java @@ -1,8 +1,12 @@ package eu.siacs.conversations.crypto.sasl; +import com.google.common.base.Strings; + import java.security.SecureRandom; import eu.siacs.conversations.entities.Account; +import eu.siacs.conversations.xml.Element; +import eu.siacs.conversations.xml.Namespace; import eu.siacs.conversations.xml.TagWriter; public abstract class SaslMechanism { @@ -68,6 +72,17 @@ public abstract class SaslMechanism { } public enum Version { - SASL, SASL_2 + SASL, SASL_2; + + public static Version of(final Element element) { + switch ( Strings.nullToEmpty(element.getNamespace())) { + case Namespace.SASL: + return SASL; + case Namespace.SASL_2: + return SASL_2; + default: + throw new IllegalArgumentException("Unrecognized SASL namespace"); + } + } } } diff --git a/src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java b/src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java index 2222da3e2..bc77246e8 100644 --- a/src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java +++ b/src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java @@ -469,63 +469,102 @@ public class XmppConnection implements Runnable { } else if (nextTag.isStart("proceed")) { switchOverToTls(); } else if (nextTag.isStart("success")) { - final String challenge = tagReader.readElement(nextTag).getContent(); + final Element success = tagReader.readElement(nextTag); + final SaslMechanism.Version version; + try { + version = SaslMechanism.Version.of(success); + } catch (final IllegalArgumentException e) { + throw new StateChangingException(Account.State.INCOMPATIBLE_SERVER); + } + final String challenge; + if (version == SaslMechanism.Version.SASL) { + challenge = success.getContent(); + } else if (version == SaslMechanism.Version.SASL_2) { + challenge = success.findChildContent("additional-data"); + } else { + throw new AssertionError("Missing implementation for " + version); + } try { saslMechanism.getResponse(challenge); } catch (final SaslMechanism.AuthenticationException e) { Log.e(Config.LOGTAG, String.valueOf(e)); throw new StateChangingException(Account.State.UNAUTHORIZED); } - Log.d(Config.LOGTAG, account.getJid().asBareJid().toString() + ": logged in"); - account.setKey(Account.PINNED_MECHANISM_KEY, - String.valueOf(saslMechanism.getPriority())); - tagReader.reset(); - sendStartStream(); - final Tag tag = tagReader.readTag(); - if (tag != null && tag.isStart("stream")) { - processStream(); - } else { - throw new StateChangingException(Account.State.STREAM_OPENING_ERROR); + Log.d( + Config.LOGTAG, + account.getJid().asBareJid().toString() + + ": logged in (using " + + version + + ")"); + account.setKey( + Account.PINNED_MECHANISM_KEY, String.valueOf(saslMechanism.getPriority())); + if (version == SaslMechanism.Version.SASL) { + tagReader.reset(); + sendStartStream(); + final Tag tag = tagReader.readTag(); + if (tag != null && tag.isStart("stream")) { + processStream(); + } else { + throw new StateChangingException(Account.State.STREAM_OPENING_ERROR); + } + break; } - break; } else if (nextTag.isStart("failure")) { final Element failure = tagReader.readElement(nextTag); - if (Namespace.SASL.equals(failure.getNamespace())) { - if (failure.hasChild("temporary-auth-failure")) { - throw new StateChangingException(Account.State.TEMPORARY_AUTH_FAILURE); - } else if (failure.hasChild("account-disabled")) { - final String text = failure.findChildContent("text"); - if ( Strings.isNullOrEmpty(text)) { - throw new StateChangingException(Account.State.UNAUTHORIZED); - } - final Matcher matcher = Patterns.AUTOLINK_WEB_URL.matcher(text); - if (matcher.find()) { - final HttpUrl url; - try { - url = HttpUrl.get(text.substring(matcher.start(), matcher.end())); - } catch (final IllegalArgumentException e) { - throw new StateChangingException(Account.State.UNAUTHORIZED); - } - if (url.isHttps()) { - this.redirectionUrl = url; - throw new StateChangingException(Account.State.PAYMENT_REQUIRED); - } - } - } - throw new StateChangingException(Account.State.UNAUTHORIZED); - } else if (Namespace.TLS.equals(failure.getNamespace())) { + if (Namespace.TLS.equals(failure.getNamespace())) { throw new StateChangingException(Account.State.TLS_ERROR); - } else { + } + final SaslMechanism.Version version; + try { + version = SaslMechanism.Version.of(failure); + } catch (final IllegalArgumentException e) { throw new StateChangingException(Account.State.INCOMPATIBLE_SERVER); } + Log.d(Config.LOGTAG, account.getJid().asBareJid() + ": login failure " + version); + if (failure.hasChild("temporary-auth-failure")) { + throw new StateChangingException(Account.State.TEMPORARY_AUTH_FAILURE); + } else if (failure.hasChild("account-disabled")) { + final String text = failure.findChildContent("text"); + if (Strings.isNullOrEmpty(text)) { + throw new StateChangingException(Account.State.UNAUTHORIZED); + } + final Matcher matcher = Patterns.AUTOLINK_WEB_URL.matcher(text); + if (matcher.find()) { + final HttpUrl url; + try { + url = HttpUrl.get(text.substring(matcher.start(), matcher.end())); + } catch (final IllegalArgumentException e) { + throw new StateChangingException(Account.State.UNAUTHORIZED); + } + if (url.isHttps()) { + this.redirectionUrl = url; + throw new StateChangingException(Account.State.PAYMENT_REQUIRED); + } + } + } + throw new StateChangingException(Account.State.UNAUTHORIZED); } else if (nextTag.isStart("challenge")) { - final String challenge = tagReader.readElement(nextTag).getContent(); - final Element response = new Element("response", Namespace.SASL); + final Element challenge = tagReader.readElement(nextTag); + final SaslMechanism.Version version; try { - response.setContent(saslMechanism.getResponse(challenge)); + version = SaslMechanism.Version.of(challenge); + } catch (final IllegalArgumentException e) { + throw new StateChangingException(Account.State.INCOMPATIBLE_SERVER); + } + final Element response; + if (version == SaslMechanism.Version.SASL) { + response = new Element("response", Namespace.SASL); + } else if (version == SaslMechanism.Version.SASL_2) { + response = new Element("response", Namespace.SASL_2); + } else { + throw new AssertionError("Missing implementation for " + version); + } + try { + response.setContent(saslMechanism.getResponse(challenge.getContent())); } catch (final SaslMechanism.AuthenticationException e) { // TODO: Send auth abort tag. Log.e(Config.LOGTAG, e.toString()); + throw new StateChangingException(Account.State.UNAUTHORIZED); } tagWriter.writeElement(response); } else if (nextTag.isStart("enabled")) { @@ -848,7 +887,6 @@ public class XmppConnection implements Runnable { private void processStreamFeatures(final Tag currentTag) throws IOException { this.streamFeatures = tagReader.readElement(currentTag); - Log.d(Config.LOGTAG, this.streamFeatures.toString()); final boolean isSecure = features.encryptionEnabled || Config.ALLOW_NON_TLS_CONNECTIONS || account.isOnion(); final boolean needsBinding = !isBound && !account.isOptionSet(Account.OPTION_REGISTER); @@ -907,7 +945,6 @@ public class XmppConnection implements Runnable { private void authenticate(final SaslMechanism.Version version) throws IOException { final List mechanisms = extractMechanisms(streamFeatures.findChild("mechanisms")); - final Element auth = new Element("auth", Namespace.SASL); if (mechanisms.contains(External.MECHANISM) && account.getPrivateKeyAlias() != null) { saslMechanism = new External(tagWriter, account, mXmppConnectionService.getRNG()); } else if (mechanisms.contains(ScramSha512.MECHANISM)) { @@ -923,25 +960,38 @@ public class XmppConnection implements Runnable { } else if (mechanisms.contains(Anonymous.MECHANISM)) { saslMechanism = new Anonymous(tagWriter, account, mXmppConnectionService.getRNG()); } - if (saslMechanism != null) { - final int pinnedMechanism = account.getKeyAsInt(Account.PINNED_MECHANISM_KEY, -1); - if (pinnedMechanism > saslMechanism.getPriority()) { - Log.e(Config.LOGTAG, "Auth failed. Authentication mechanism " + saslMechanism.getMechanism() + - " has lower priority (" + saslMechanism.getPriority() + - ") than pinned priority (" + pinnedMechanism + - "). Possible downgrade attack?"); - throw new StateChangingException(Account.State.DOWNGRADE_ATTACK); - } - Log.d(Config.LOGTAG, account.getJid().toString() + ": Authenticating with " + saslMechanism.getMechanism()); - auth.setAttribute("mechanism", saslMechanism.getMechanism()); - if (!saslMechanism.getClientFirstMessage().isEmpty()) { - auth.setContent(saslMechanism.getClientFirstMessage()); - } - tagWriter.writeElement(auth); - } else { + if (saslMechanism == null) { Log.d(Config.LOGTAG, account.getJid().asBareJid() + ": unable to find supported SASL mechanism in " + mechanisms); throw new StateChangingException(Account.State.INCOMPATIBLE_SERVER); } + final int pinnedMechanism = account.getKeyAsInt(Account.PINNED_MECHANISM_KEY, -1); + if (pinnedMechanism > saslMechanism.getPriority()) { + Log.e(Config.LOGTAG, "Auth failed. Authentication mechanism " + saslMechanism.getMechanism() + + " has lower priority (" + saslMechanism.getPriority() + + ") than pinned priority (" + pinnedMechanism + + "). Possible downgrade attack?"); + throw new StateChangingException(Account.State.DOWNGRADE_ATTACK); + } + final String firstMessage = saslMechanism.getClientFirstMessage(); + final Element authenticate; + if (version == SaslMechanism.Version.SASL) { + authenticate = new Element("auth", Namespace.SASL); + if (!Strings.isNullOrEmpty(firstMessage)) { + authenticate.setContent(firstMessage); + } + } else if (version == SaslMechanism.Version.SASL_2) { + authenticate = new Element("authenticate", Namespace.SASL_2); + if (!Strings.isNullOrEmpty(firstMessage)) { + authenticate.addChild("initial-response").setContent(firstMessage); + } + // TODO place to add extensions + } else { + throw new AssertionError("Missing implementation for " + version); + } + + Log.d(Config.LOGTAG, account.getJid().toString() + ": Authenticating with "+version+ "/" + saslMechanism.getMechanism()); + authenticate.setAttribute("mechanism", saslMechanism.getMechanism()); + tagWriter.writeElement(authenticate); } private List extractMechanisms(final Element stream) {