store full sasl mechanism (not just priority)

This commit is contained in:
Daniel Gultsch 2022-09-15 12:22:05 +02:00
parent 82efb6f1db
commit 495f79921d
8 changed files with 117 additions and 43 deletions

View file

@ -100,7 +100,7 @@ public class MagicCreateActivity extends XmppActivity implements TextWatcher {
account.setOption(Account.OPTION_MAGIC_CREATE, true); account.setOption(Account.OPTION_MAGIC_CREATE, true);
account.setOption(Account.OPTION_FIXED_USERNAME, fixedUsername); account.setOption(Account.OPTION_FIXED_USERNAME, fixedUsername);
if (this.preAuth != null) { if (this.preAuth != null) {
account.setKey(Account.PRE_AUTH_REGISTRATION_TOKEN, this.preAuth); account.setKey(Account.KEY_PRE_AUTH_REGISTRATION_TOKEN, this.preAuth);
} }
xmppConnectionService.createAccount(account); xmppConnectionService.createAccount(account);
} }

View file

@ -3,6 +3,7 @@ package eu.siacs.conversations.crypto.sasl;
import android.util.Log; import android.util.Log;
import com.google.common.base.CaseFormat; import com.google.common.base.CaseFormat;
import com.google.common.base.Strings;
import java.util.Collection; import java.util.Collection;
@ -27,6 +28,17 @@ public enum ChannelBinding {
} }
} }
public static ChannelBinding get(final String name) {
if (Strings.isNullOrEmpty(name)) {
return NONE;
}
try {
return valueOf(name);
} catch (final IllegalArgumentException e) {
return NONE;
}
}
public static ChannelBinding best(final Collection<ChannelBinding> bindings) { public static ChannelBinding best(final Collection<ChannelBinding> bindings) {
if (bindings.contains(TLS_EXPORTER)) { if (bindings.contains(TLS_EXPORTER)) {
return TLS_EXPORTER; return TLS_EXPORTER;

View file

@ -3,6 +3,7 @@ package eu.siacs.conversations.crypto.sasl;
import com.google.common.base.Strings; import com.google.common.base.Strings;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocket;
@ -129,5 +130,10 @@ public abstract class SaslMechanism {
return null; return null;
} }
} }
public SaslMechanism of(final String mechanism, final ChannelBinding channelBinding) {
return of(Collections.singleton(mechanism), Collections.singleton(channelBinding));
}
} }
} }

View file

@ -16,7 +16,7 @@ import javax.net.ssl.SSLSocket;
import eu.siacs.conversations.entities.Account; import eu.siacs.conversations.entities.Account;
abstract class ScramPlusMechanism extends ScramMechanism { public abstract class ScramPlusMechanism extends ScramMechanism {
private static final String EXPORTER_LABEL = "EXPORTER-Channel-Binding"; private static final String EXPORTER_LABEL = "EXPORTER-Channel-Binding";
@ -51,8 +51,7 @@ abstract class ScramPlusMechanism extends ScramMechanism {
} }
return unique; return unique;
} else if (this.channelBinding == ChannelBinding.TLS_SERVER_END_POINT) { } else if (this.channelBinding == ChannelBinding.TLS_SERVER_END_POINT) {
final byte[] endPoint = getServerEndPointChannelBinding(sslSocket.getSession()); return getServerEndPointChannelBinding(sslSocket.getSession());
return endPoint;
} else { } else {
throw new AuthenticationException( throw new AuthenticationException(
String.format("%s is not a valid channel binding", channelBinding)); String.format("%s is not a valid channel binding", channelBinding));
@ -103,4 +102,8 @@ abstract class ScramPlusMechanism extends ScramMechanism {
messageDigest.update(encodedCertificate); messageDigest.update(encodedCertificate);
return messageDigest.digest(); return messageDigest.digest();
} }
public ChannelBinding getChannelBinding() {
return this.channelBinding;
}
} }

View file

@ -25,6 +25,9 @@ import eu.siacs.conversations.R;
import eu.siacs.conversations.crypto.PgpDecryptionService; import eu.siacs.conversations.crypto.PgpDecryptionService;
import eu.siacs.conversations.crypto.axolotl.AxolotlService; import eu.siacs.conversations.crypto.axolotl.AxolotlService;
import eu.siacs.conversations.crypto.axolotl.XmppAxolotlSession; import eu.siacs.conversations.crypto.axolotl.XmppAxolotlSession;
import eu.siacs.conversations.crypto.sasl.ChannelBinding;
import eu.siacs.conversations.crypto.sasl.SaslMechanism;
import eu.siacs.conversations.crypto.sasl.ScramPlusMechanism;
import eu.siacs.conversations.services.AvatarService; import eu.siacs.conversations.services.AvatarService;
import eu.siacs.conversations.services.XmppConnectionService; import eu.siacs.conversations.services.XmppConnectionService;
import eu.siacs.conversations.utils.UIHelper; import eu.siacs.conversations.utils.UIHelper;
@ -50,9 +53,9 @@ public class Account extends AbstractEntity implements AvatarService.Avatarable
public static final String STATUS = "status"; public static final String STATUS = "status";
public static final String STATUS_MESSAGE = "status_message"; public static final String STATUS_MESSAGE = "status_message";
public static final String RESOURCE = "resource"; public static final String RESOURCE = "resource";
public static final String PINNED_MECHANISM = "pinned_mechanism";
public static final String PINNED_CHANNEL_BINDING = "pinned_channel_binding";
public static final String PINNED_MECHANISM_KEY = "pinned_mechanism";
public static final String PRE_AUTH_REGISTRATION_TOKEN = "pre_auth_registration";
public static final int OPTION_USETLS = 0; public static final int OPTION_USETLS = 0;
public static final int OPTION_DISABLED = 1; public static final int OPTION_DISABLED = 1;
@ -64,8 +67,13 @@ public class Account extends AbstractEntity implements AvatarService.Avatarable
public static final int OPTION_HTTP_UPLOAD_AVAILABLE = 7; public static final int OPTION_HTTP_UPLOAD_AVAILABLE = 7;
public static final int OPTION_UNVERIFIED = 8; public static final int OPTION_UNVERIFIED = 8;
public static final int OPTION_FIXED_USERNAME = 9; public static final int OPTION_FIXED_USERNAME = 9;
private static final String KEY_PGP_SIGNATURE = "pgp_signature"; private static final String KEY_PGP_SIGNATURE = "pgp_signature";
private static final String KEY_PGP_ID = "pgp_id"; private static final String KEY_PGP_ID = "pgp_id";
private static final String KEY_PINNED_MECHANISM = "pinned_mechanism";
public static final String KEY_PRE_AUTH_REGISTRATION_TOKEN = "pre_auth_registration";
protected final JSONObject keys; protected final JSONObject keys;
private final Roster roster = new Roster(this); private final Roster roster = new Roster(this);
private final Collection<Jid> blocklist = new CopyOnWriteArraySet<>(); private final Collection<Jid> blocklist = new CopyOnWriteArraySet<>();
@ -90,18 +98,20 @@ public class Account extends AbstractEntity implements AvatarService.Avatarable
private XmppConnection xmppConnection = null; private XmppConnection xmppConnection = null;
private long mEndGracePeriod = 0L; private long mEndGracePeriod = 0L;
private final Map<Jid, Bookmark> bookmarks = new HashMap<>(); private final Map<Jid, Bookmark> bookmarks = new HashMap<>();
private Presence.Status presenceStatus = Presence.Status.ONLINE; private Presence.Status presenceStatus;
private String presenceStatusMessage = null; private String presenceStatusMessage;
private String pinnedMechanism;
private String pinnedChannelBinding;
public Account(final Jid jid, final String password) { public Account(final Jid jid, final String password) {
this(java.util.UUID.randomUUID().toString(), jid, this(java.util.UUID.randomUUID().toString(), jid,
password, 0, null, "", null, null, null, 5222, Presence.Status.ONLINE, null); password, 0, null, "", null, null, null, 5222, Presence.Status.ONLINE, null, null, null);
} }
private Account(final String uuid, final Jid jid, private Account(final String uuid, final Jid jid,
final String password, final int options, final String rosterVersion, final String keys, final String password, final int options, final String rosterVersion, final String keys,
final String avatar, String displayName, String hostname, int port, final String avatar, String displayName, String hostname, int port,
final Presence.Status status, String statusMessage) { final Presence.Status status, String statusMessage, final String pinnedMechanism, final String pinnedChannelBinding) {
this.uuid = uuid; this.uuid = uuid;
this.jid = jid; this.jid = jid;
this.password = password; this.password = password;
@ -120,19 +130,21 @@ public class Account extends AbstractEntity implements AvatarService.Avatarable
this.port = port; this.port = port;
this.presenceStatus = status; this.presenceStatus = status;
this.presenceStatusMessage = statusMessage; this.presenceStatusMessage = statusMessage;
this.pinnedMechanism = pinnedMechanism;
this.pinnedChannelBinding = pinnedChannelBinding;
} }
public static Account fromCursor(final Cursor cursor) { public static Account fromCursor(final Cursor cursor) {
final Jid jid; final Jid jid;
try { try {
String resource = cursor.getString(cursor.getColumnIndexOrThrow(RESOURCE)); final String resource = cursor.getString(cursor.getColumnIndexOrThrow(RESOURCE));
jid = Jid.of( jid = Jid.of(
cursor.getString(cursor.getColumnIndexOrThrow(USERNAME)), cursor.getString(cursor.getColumnIndexOrThrow(USERNAME)),
cursor.getString(cursor.getColumnIndexOrThrow(SERVER)), cursor.getString(cursor.getColumnIndexOrThrow(SERVER)),
resource == null || resource.trim().isEmpty() ? null : resource); resource == null || resource.trim().isEmpty() ? null : resource);
} catch (final IllegalArgumentException ignored) { } catch (final IllegalArgumentException e) {
Log.d(Config.LOGTAG, cursor.getString(cursor.getColumnIndexOrThrow(USERNAME)) + "@" + cursor.getString(cursor.getColumnIndexOrThrow(SERVER))); Log.d(Config.LOGTAG, cursor.getString(cursor.getColumnIndexOrThrow(USERNAME)) + "@" + cursor.getString(cursor.getColumnIndexOrThrow(SERVER)));
throw new AssertionError(ignored); throw new AssertionError(e);
} }
return new Account(cursor.getString(cursor.getColumnIndexOrThrow(UUID)), return new Account(cursor.getString(cursor.getColumnIndexOrThrow(UUID)),
jid, jid,
@ -145,7 +157,9 @@ public class Account extends AbstractEntity implements AvatarService.Avatarable
cursor.getString(cursor.getColumnIndexOrThrow(HOSTNAME)), cursor.getString(cursor.getColumnIndexOrThrow(HOSTNAME)),
cursor.getInt(cursor.getColumnIndexOrThrow(PORT)), cursor.getInt(cursor.getColumnIndexOrThrow(PORT)),
Presence.Status.fromShowString(cursor.getString(cursor.getColumnIndexOrThrow(STATUS))), Presence.Status.fromShowString(cursor.getString(cursor.getColumnIndexOrThrow(STATUS))),
cursor.getString(cursor.getColumnIndexOrThrow(STATUS_MESSAGE))); cursor.getString(cursor.getColumnIndexOrThrow(STATUS_MESSAGE)),
cursor.getString(cursor.getColumnIndexOrThrow(PINNED_MECHANISM)),
cursor.getString(cursor.getColumnIndexOrThrow(PINNED_CHANNEL_BINDING)));
} }
public boolean httpUploadAvailable(long size) { public boolean httpUploadAvailable(long size) {
@ -289,6 +303,38 @@ public class Account extends AbstractEntity implements AvatarService.Avatarable
} }
} }
public void setPinnedMechanism(final SaslMechanism mechanism) {
this.pinnedMechanism = mechanism.getMechanism();
if (mechanism instanceof ScramPlusMechanism) {
this.pinnedChannelBinding = ((ScramPlusMechanism) mechanism).getChannelBinding().toString();
}
}
public void resetPinnedMechanism() {
this.pinnedMechanism = null;
this.pinnedChannelBinding = null;
setKey(Account.KEY_PINNED_MECHANISM, String.valueOf(-1));
}
public int getPinnedMechanismPriority() {
final int fallback = getKeyAsInt(KEY_PINNED_MECHANISM, -1);
if (Strings.isNullOrEmpty(this.pinnedMechanism)) {
return fallback;
}
final SaslMechanism saslMechanism = getPinnedMechanism();
if (saslMechanism == null) {
return fallback;
} else {
return saslMechanism.getPriority();
}
}
public SaslMechanism getPinnedMechanism() {
final String mechanism = Strings.nullToEmpty(this.pinnedMechanism);
final ChannelBinding channelBinding = ChannelBinding.get(this.pinnedChannelBinding);
return new SaslMechanism.Factory(this).of(mechanism, channelBinding);
}
public State getTrueStatus() { public State getTrueStatus() {
return this.status; return this.status;
} }
@ -361,8 +407,8 @@ public class Account extends AbstractEntity implements AvatarService.Avatarable
} }
} }
public boolean setPrivateKeyAlias(String alias) { public void setPrivateKeyAlias(final String alias) {
return setKey("private_key_alias", alias); setKey("private_key_alias", alias);
} }
public String getPrivateKeyAlias() { public String getPrivateKeyAlias() {
@ -388,6 +434,8 @@ public class Account extends AbstractEntity implements AvatarService.Avatarable
values.put(STATUS, presenceStatus.toShowString()); values.put(STATUS, presenceStatus.toShowString());
values.put(STATUS_MESSAGE, presenceStatusMessage); values.put(STATUS_MESSAGE, presenceStatusMessage);
values.put(RESOURCE, jid.getResource()); values.put(RESOURCE, jid.getResource());
values.put(PINNED_MECHANISM, pinnedMechanism);
values.put(PINNED_CHANNEL_BINDING, pinnedChannelBinding);
return values; return values;
} }

View file

@ -64,7 +64,7 @@ import eu.siacs.conversations.xmpp.mam.MamReference;
public class DatabaseBackend extends SQLiteOpenHelper { public class DatabaseBackend extends SQLiteOpenHelper {
private static final String DATABASE_NAME = "history"; private static final String DATABASE_NAME = "history";
private static final int DATABASE_VERSION = 49; private static final int DATABASE_VERSION = 50;
private static boolean requiresMessageIndexRebuild = false; private static boolean requiresMessageIndexRebuild = false;
private static DatabaseBackend instance = null; private static DatabaseBackend instance = null;
@ -230,6 +230,8 @@ public class DatabaseBackend extends SQLiteOpenHelper {
+ Account.KEYS + " TEXT, " + Account.KEYS + " TEXT, "
+ Account.HOSTNAME + " TEXT, " + Account.HOSTNAME + " TEXT, "
+ Account.RESOURCE + " TEXT," + Account.RESOURCE + " TEXT,"
+ Account.PINNED_MECHANISM + " TEXT,"
+ Account.PINNED_CHANNEL_BINDING + " TEXT,"
+ Account.PORT + " NUMBER DEFAULT 5222)"); + Account.PORT + " NUMBER DEFAULT 5222)");
db.execSQL("create table " + Conversation.TABLENAME + " (" db.execSQL("create table " + Conversation.TABLENAME + " ("
+ Conversation.UUID + " TEXT PRIMARY KEY, " + Conversation.NAME + Conversation.UUID + " TEXT PRIMARY KEY, " + Conversation.NAME
@ -589,6 +591,11 @@ public class DatabaseBackend extends SQLiteOpenHelper {
db.endTransaction(); db.endTransaction();
requiresMessageIndexRebuild = true; requiresMessageIndexRebuild = true;
} }
if (oldVersion < 50 && newVersion >= 50) {
db.execSQL("ALTER TABLE " + Account.TABLENAME + " ADD COLUMN " + Account.PINNED_MECHANISM + " TEXT");
db.execSQL("ALTER TABLE " + Account.TABLENAME + " ADD COLUMN " + Account.PINNED_CHANNEL_BINDING + " TEXT");
}
} }
private void canonicalizeJids(SQLiteDatabase db) { private void canonicalizeJids(SQLiteDatabase db) {
@ -938,20 +945,19 @@ public class DatabaseBackend extends SQLiteOpenHelper {
contactJid.asBareJid().toString() + "/%", contactJid.asBareJid().toString() + "/%",
contactJid.asBareJid().toString() contactJid.asBareJid().toString()
}; };
Cursor cursor = db.query(Conversation.TABLENAME, null, try(final Cursor cursor = db.query(Conversation.TABLENAME, null,
Conversation.ACCOUNT + "=? AND (" + Conversation.CONTACTJID Conversation.ACCOUNT + "=? AND (" + Conversation.CONTACTJID
+ " like ? OR " + Conversation.CONTACTJID + "=?)", selectionArgs, null, null, null); + " like ? OR " + Conversation.CONTACTJID + "=?)", selectionArgs, null, null, null)) {
if (cursor.getCount() == 0) { if (cursor.getCount() == 0) {
cursor.close(); return null;
return null; }
cursor.moveToFirst();
final Conversation conversation = Conversation.fromCursor(cursor);
if (conversation.getJid() instanceof InvalidJid) {
return null;
}
return conversation;
} }
cursor.moveToFirst();
Conversation conversation = Conversation.fromCursor(cursor);
cursor.close();
if (conversation.getJid() instanceof InvalidJid) {
return null;
}
return conversation;
} }
public void updateConversation(final Conversation conversation) { public void updateConversation(final Conversation conversation) {
@ -1024,14 +1030,14 @@ public class DatabaseBackend extends SQLiteOpenHelper {
} }
public void readRoster(Roster roster) { public void readRoster(Roster roster) {
SQLiteDatabase db = this.getReadableDatabase(); final SQLiteDatabase db = this.getReadableDatabase();
Cursor cursor; final String[] args = {roster.getAccount().getUuid()};
String[] args = {roster.getAccount().getUuid()}; try (final Cursor cursor =
cursor = db.query(Contact.TABLENAME, null, Contact.ACCOUNT + "=?", args, null, null, null); db.query(Contact.TABLENAME, null, Contact.ACCOUNT + "=?", args, null, null, null)) {
while (cursor.moveToNext()) { while (cursor.moveToNext()) {
roster.initContact(Contact.fromCursor(cursor)); roster.initContact(Contact.fromCursor(cursor));
}
} }
cursor.close();
} }
public void writeRoster(final Roster roster) { public void writeRoster(final Roster roster) {

View file

@ -181,7 +181,7 @@ public class EditAccountActivity extends OmemoActivity implements OnAccountUpdat
} }
if (inNeedOfSaslAccept()) { if (inNeedOfSaslAccept()) {
mAccount.setKey(Account.PINNED_MECHANISM_KEY, String.valueOf(-1)); mAccount.resetPinnedMechanism();
if (!xmppConnectionService.updateAccount(mAccount)) { if (!xmppConnectionService.updateAccount(mAccount)) {
Toast.makeText(EditAccountActivity.this, R.string.unable_to_update_account, Toast.LENGTH_SHORT).show(); Toast.makeText(EditAccountActivity.this, R.string.unable_to_update_account, Toast.LENGTH_SHORT).show();
} }
@ -421,7 +421,7 @@ public class EditAccountActivity extends OmemoActivity implements OnAccountUpdat
} else { } else {
preset = jid.getDomain(); preset = jid.getDomain();
} }
final Intent intent = SignupUtils.getTokenRegistrationIntent(this, preset, mAccount.getKey(Account.PRE_AUTH_REGISTRATION_TOKEN)); final Intent intent = SignupUtils.getTokenRegistrationIntent(this, preset, mAccount.getKey(Account.KEY_PRE_AUTH_REGISTRATION_TOKEN));
StartConversationActivity.addInviteUri(intent, getIntent()); StartConversationActivity.addInviteUri(intent, getIntent());
startActivity(intent); startActivity(intent);
return; return;
@ -892,7 +892,7 @@ public class EditAccountActivity extends OmemoActivity implements OnAccountUpdat
} }
private boolean inNeedOfSaslAccept() { private boolean inNeedOfSaslAccept() {
return mAccount != null && mAccount.getLastErrorStatus() == Account.State.DOWNGRADE_ATTACK && mAccount.getKeyAsInt(Account.PINNED_MECHANISM_KEY, -1) >= 0 && !accountInfoEdited(); return mAccount != null && mAccount.getLastErrorStatus() == Account.State.DOWNGRADE_ATTACK && mAccount.getPinnedMechanismPriority() >= 0 && !accountInfoEdited();
} }
private void shareBarcode() { private void shareBarcode() {

View file

@ -692,8 +692,7 @@ public class XmppConnection implements Runnable {
Log.d( Log.d(
Config.LOGTAG, Config.LOGTAG,
account.getJid().asBareJid().toString() + ": logged in (using " + version + ")"); account.getJid().asBareJid().toString() + ": logged in (using " + version + ")");
// TODO store mechanism name account.setPinnedMechanism(saslMechanism);
account.setKey(Account.PINNED_MECHANISM_KEY, String.valueOf(saslMechanism.getPriority()));
if (version == SaslMechanism.Version.SASL_2) { if (version == SaslMechanism.Version.SASL_2) {
final String authorizationIdentifier = final String authorizationIdentifier =
success.findChildContent("authorization-identifier"); success.findChildContent("authorization-identifier");
@ -1264,7 +1263,7 @@ public class XmppConnection implements Runnable {
+ mechanisms); + mechanisms);
throw new StateChangingException(Account.State.INCOMPATIBLE_SERVER); throw new StateChangingException(Account.State.INCOMPATIBLE_SERVER);
} }
final int pinnedMechanism = account.getKeyAsInt(Account.PINNED_MECHANISM_KEY, -1); final int pinnedMechanism = account.getPinnedMechanismPriority();
if (pinnedMechanism > saslMechanism.getPriority()) { if (pinnedMechanism > saslMechanism.getPriority()) {
Log.e( Log.e(
Config.LOGTAG, Config.LOGTAG,
@ -1345,7 +1344,7 @@ public class XmppConnection implements Runnable {
} }
private void register() { private void register() {
final String preAuth = account.getKey(Account.PRE_AUTH_REGISTRATION_TOKEN); final String preAuth = account.getKey(Account.KEY_PRE_AUTH_REGISTRATION_TOKEN);
if (preAuth != null && features.invite()) { if (preAuth != null && features.invite()) {
final IqPacket preAuthRequest = new IqPacket(IqPacket.TYPE.SET); final IqPacket preAuthRequest = new IqPacket(IqPacket.TYPE.SET);
preAuthRequest.addChild("preauth", Namespace.PARS).setAttribute("token", preAuth); preAuthRequest.addChild("preauth", Namespace.PARS).setAttribute("token", preAuth);