diff --git a/src/main/java/eu/siacs/conversations/crypto/sasl/ScramMechanism.java b/src/main/java/eu/siacs/conversations/crypto/sasl/ScramMechanism.java index dbc73c78f..5af911cc7 100644 --- a/src/main/java/eu/siacs/conversations/crypto/sasl/ScramMechanism.java +++ b/src/main/java/eu/siacs/conversations/crypto/sasl/ScramMechanism.java @@ -4,6 +4,10 @@ import android.annotation.TargetApi; import android.os.Build; import android.util.Base64; +import com.google.common.base.Objects; +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; + import org.bouncycastle.crypto.Digest; import org.bouncycastle.crypto.macs.HMac; import org.bouncycastle.crypto.params.KeyParameter; @@ -11,6 +15,7 @@ import org.bouncycastle.crypto.params.KeyParameter; import java.nio.charset.Charset; import java.security.InvalidKeyException; import java.security.SecureRandom; +import java.util.concurrent.ExecutionException; import eu.siacs.conversations.entities.Account; import eu.siacs.conversations.utils.CryptoHelper; @@ -27,17 +32,46 @@ abstract class ScramMechanism extends SaslMechanism { protected abstract Digest getDigest(); - private KeyPair getKeyPair(final String password, final String salt, final int iterations) { - try { + private static final Cache CACHE = CacheBuilder.newBuilder().maximumSize(10).build(); + + private static class CacheKey { + final String algorithm; + final String password; + final String salt; + final int iterations; + + private CacheKey(String algorithm, String password, String salt, int iterations) { + this.algorithm = algorithm; + this.password = password; + this.salt = salt; + this.iterations = iterations; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + CacheKey cacheKey = (CacheKey) o; + return iterations == cacheKey.iterations && + Objects.equal(algorithm, cacheKey.algorithm) && + Objects.equal(password, cacheKey.password) && + Objects.equal(salt, cacheKey.salt); + } + + @Override + public int hashCode() { + return Objects.hashCode(algorithm, password, salt, iterations); + } + } + + private KeyPair getKeyPair(final String password, final String salt, final int iterations) throws ExecutionException { + return CACHE.get(new CacheKey(getHMAC().getAlgorithmName(), password, salt, iterations), () -> { final byte[] saltedPassword, serverKey, clientKey; saltedPassword = hi(password.getBytes(), Base64.decode(salt, Base64.DEFAULT), iterations); serverKey = hmac(saltedPassword, SERVER_KEY_BYTES); clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES); - return new KeyPair(clientKey, serverKey); - } catch (final InvalidKeyException | NumberFormatException e) { - return null; - } + }); } private final String clientNonce; @@ -162,8 +196,10 @@ abstract class ScramMechanism extends SaslMechanism { final byte[] authMessage = (clientFirstMessageBare + ',' + new String(serverFirstMessage) + ',' + clientFinalMessageWithoutProof).getBytes(); - final KeyPair keys = getKeyPair(CryptoHelper.saslPrep(account.getPassword()), salt, iterationCount); - if (keys == null) { + final KeyPair keys; + try { + keys = getKeyPair(CryptoHelper.saslPrep(account.getPassword()), salt, iterationCount); + } catch (ExecutionException e) { throw new AuthenticationException("Invalid keys generated"); } final byte[] clientSignature;