signal-protocol/omemo: fix null-pointer issues

Fixes #44 and #58
This commit is contained in:
Marvin W 2017-04-18 17:55:20 +02:00
parent f95b4f4e09
commit 7e388fb2bc
No known key found for this signature in database
GPG key ID: 072E9235DB996F2A
13 changed files with 205 additions and 144 deletions

View file

@ -34,7 +34,7 @@ public class AccountSettingWidget : Plugins.AccountSettingsWidget, Box {
if (row == null) {
fingerprint.set_markup("%s\n<span font='8'>%s</span>".printf(_("Own fingerprint"), _("Will be generated on first connect")));
} else {
uint8[] arr = Base64.decode(row[plugin.db.identity.identity_key_public_base64]);
uint8[] arr = Base64.decode(((!)row)[plugin.db.identity.identity_key_public_base64]);
arr = arr[1:arr.length];
string res = "";
foreach (uint8 i in arr) {

View file

@ -9,21 +9,22 @@ public class Bundle {
public Bundle(StanzaNode? node) {
this.node = node;
assert(Plugin.ensure_context());
}
public int32 signed_pre_key_id { owned get {
if (node == null) return -1;
string id = node.get_deep_attribute("signedPreKeyPublic", "signedPreKeyId");
string? id = ((!)node).get_deep_attribute("signedPreKeyPublic", "signedPreKeyId");
if (id == null) return -1;
return int.parse(id);
return int.parse((!)id);
}}
public ECPublicKey? signed_pre_key { owned get {
if (node == null) return null;
string? key = node.get_deep_string_content("signedPreKeyPublic");
string? key = ((!)node).get_deep_string_content("signedPreKeyPublic");
if (key == null) return null;
try {
return Plugin.context.decode_public_key(Base64.decode(key));
return Plugin.get_context().decode_public_key(Base64.decode((!)key));
} catch (Error e) {
return null;
}
@ -31,17 +32,17 @@ public class Bundle {
public uint8[]? signed_pre_key_signature { owned get {
if (node == null) return null;
string? sig = node.get_deep_string_content("signedPreKeySignature");
string? sig = ((!)node).get_deep_string_content("signedPreKeySignature");
if (sig == null) return null;
return Base64.decode(sig);
return Base64.decode((!)sig);
}}
public ECPublicKey? identity_key { owned get {
if (node == null) return null;
string? key = node.get_deep_string_content("identityKey");
string? key = ((!)node).get_deep_string_content("identityKey");
if (key == null) return null;
try {
return Plugin.context.decode_public_key(Base64.decode(key));
return Plugin.get_context().decode_public_key(Base64.decode((!)key));
} catch (Error e) {
return null;
}
@ -49,9 +50,9 @@ public class Bundle {
public ArrayList<PreKey> pre_keys { owned get {
ArrayList<PreKey> list = new ArrayList<PreKey>();
if (node == null || node.get_subnode("prekeys") == null) return list;
node.get_deep_subnodes("prekeys", "preKeyPublic")
.filter((node) => node.get_attribute("preKeyId") != null)
if (node == null || ((!)node).get_subnode("prekeys") == null) return list;
((!)node).get_deep_subnodes("prekeys", "preKeyPublic")
.filter((node) => ((!)node).get_attribute("preKeyId") != null)
.map<PreKey>(PreKey.create)
.foreach((key) => list.add(key));
return list;
@ -76,7 +77,7 @@ public class Bundle {
string? key = node.get_string_content();
if (key == null) return null;
try {
return Plugin.context.decode_public_key(Base64.decode(key));
return Plugin.get_context().decode_public_key(Base64.decode((!)key));
} catch (Error e) {
return null;
}

View file

@ -12,8 +12,8 @@ public class Database : Qlite.Database {
public Column<int> id = new Column.Integer("id") { primary_key = true, auto_increment = true };
public Column<int> account_id = new Column.Integer("account_id") { unique = true, not_null = true };
public Column<int> device_id = new Column.Integer("device_id") { not_null = true };
public Column<string> identity_key_private_base64 = new Column.Text("identity_key_private_base64") { not_null = true };
public Column<string> identity_key_public_base64 = new Column.Text("identity_key_public_base64") { not_null = true };
public Column<string> identity_key_private_base64 = new Column.NonNullText("identity_key_private_base64");
public Column<string> identity_key_public_base64 = new Column.NonNullText("identity_key_public_base64");
internal IdentityTable(Database db) {
base(db, "identity");
@ -24,7 +24,7 @@ public class Database : Qlite.Database {
public class SignedPreKeyTable : Table {
public Column<int> identity_id = new Column.Integer("identity_id") { not_null = true };
public Column<int> signed_pre_key_id = new Column.Integer("signed_pre_key_id") { not_null = true };
public Column<string> record_base64 = new Column.Text("record_base64") { not_null = true };
public Column<string> record_base64 = new Column.NonNullText("record_base64");
internal SignedPreKeyTable(Database db) {
base(db, "signed_pre_key");
@ -36,7 +36,7 @@ public class Database : Qlite.Database {
public class PreKeyTable : Table {
public Column<int> identity_id = new Column.Integer("identity_id") { not_null = true };
public Column<int> pre_key_id = new Column.Integer("pre_key_id") { not_null = true };
public Column<string> record_base64 = new Column.Text("record_base64") { not_null = true };
public Column<string> record_base64 = new Column.NonNullText("record_base64");
internal PreKeyTable(Database db) {
base(db, "pre_key");
@ -47,9 +47,9 @@ public class Database : Qlite.Database {
public class SessionTable : Table {
public Column<int> identity_id = new Column.Integer("identity_id") { not_null = true };
public Column<string> address_name = new Column.Text("name") { not_null = true };
public Column<string> address_name = new Column.NonNullText("name");
public Column<int> device_id = new Column.Integer("device_id") { not_null = true };
public Column<string> record_base64 = new Column.Text("record_base64") { not_null = true };
public Column<string> record_base64 = new Column.NonNullText("record_base64");
internal SessionTable(Database db) {
base(db, "session");

View file

@ -70,16 +70,22 @@ public class Manager : StreamInteractionModule, Object {
}
private void on_pre_message_received(Entities.Message message, Xmpp.Message.Stanza message_stanza, Conversation conversation) {
if (MessageFlag.get_flag(message_stanza) != null && MessageFlag.get_flag(message_stanza).decrypted) {
MessageFlag? flag = MessageFlag.get_flag(message_stanza);
if (flag != null && ((!)flag).decrypted) {
message.encryption = Encryption.OMEMO;
}
}
private void on_pre_message_send(Entities.Message message, Xmpp.Message.Stanza message_stanza, Conversation conversation) {
if (message.encryption == Encryption.OMEMO) {
StreamModule module = stream_interactor.get_stream(conversation.account).get_module(StreamModule.IDENTITY);
Core.XmppStream? stream = stream_interactor.get_stream(conversation.account);
if (stream == null) {
message.marked = Entities.Message.Marked.UNSENT;
return;
}
StreamModule module = ((!)stream).get_module(StreamModule.IDENTITY);
EncryptState enc_state = module.encrypt(message_stanza, conversation.account.bare_jid.to_string());
MessageState state = null;
MessageState state;
lock (message_states) {
if (message_states.has_key(message)) {
state = message_states.get(message);
@ -95,18 +101,18 @@ public class Manager : StreamInteractionModule, Object {
if (!state.will_send_now) {
if (message.marked == Entities.Message.Marked.WONTSEND) {
if (Plugin.DEBUG) print(@"OMEMO: message $(message.stanza_id) was not sent: $state\n");
if (Plugin.DEBUG) print(@"OMEMO: message was not sent: $state\n");
} else {
if (Plugin.DEBUG) print(@"OMEMO: message $(message.stanza_id) will be delayed: $state\n");
if (Plugin.DEBUG) print(@"OMEMO: message will be delayed: $state\n");
if (state.waiting_own_sessions > 0) {
module.start_sessions_with(stream_interactor.get_stream(conversation.account), conversation.account.bare_jid.to_string());
module.start_sessions_with((!)stream, conversation.account.bare_jid.to_string());
}
if (state.waiting_other_sessions > 0) {
module.start_sessions_with(stream_interactor.get_stream(conversation.account), message.counterpart.bare_jid.to_string());
if (state.waiting_other_sessions > 0 && message.counterpart != null) {
module.start_sessions_with((!)stream, ((!)message.counterpart).bare_jid.to_string());
}
if (state.waiting_other_devicelist) {
module.request_user_devicelist(stream_interactor.get_stream(conversation.account), message.counterpart.bare_jid.to_string());
if (state.waiting_other_devicelist && message.counterpart != null) {
module.request_user_devicelist((!)stream, ((!)message.counterpart).bare_jid.to_string());
}
}
}
@ -120,8 +126,7 @@ public class Manager : StreamInteractionModule, Object {
stream_interactor.module_manager.get_module(account, StreamModule.IDENTITY).session_start_failed.connect((jid, device_id) => on_session_started(account, jid, true));
}
private void on_stream_negotiated(Account account) {
Core.XmppStream stream = stream_interactor.get_stream(account);
private void on_stream_negotiated(Account account, Core.XmppStream stream) {
stream_interactor.module_manager.get_module(account, StreamModule.IDENTITY).request_user_devicelist(stream, account.bare_jid.to_string());
}
@ -134,7 +139,7 @@ public class Manager : StreamInteractionModule, Object {
MessageState state = message_states[msg];
if (account.bare_jid.to_string() == jid) {
state.waiting_own_sessions--;
} else if (msg.counterpart.bare_jid.to_string() == jid) {
} else if (msg.counterpart != null && ((!)msg.counterpart).bare_jid.to_string() == jid) {
state.waiting_other_sessions--;
}
if (state.should_retry_now()) {
@ -144,8 +149,10 @@ public class Manager : StreamInteractionModule, Object {
}
}
foreach (Entities.Message msg in send_now) {
Entities.Conversation conv = stream_interactor.get_module(ConversationManager.IDENTITY).get_conversation(msg.counterpart, account);
stream_interactor.get_module(MessageProcessor.IDENTITY).send_xmpp_message(msg, conv, true);
if (msg.counterpart == null) continue;
Entities.Conversation? conv = stream_interactor.get_module(ConversationManager.IDENTITY).get_conversation((!)msg.counterpart, account);
if (conv == null) continue;
stream_interactor.get_module(MessageProcessor.IDENTITY).send_xmpp_message(msg, (!)conv, true);
}
}
@ -158,7 +165,7 @@ public class Manager : StreamInteractionModule, Object {
MessageState state = message_states[msg];
if (account.bare_jid.to_string() == jid) {
state.waiting_own_devicelist = false;
} else if (msg.counterpart.bare_jid.to_string() == jid) {
} else if (msg.counterpart != null && ((!)msg.counterpart).bare_jid.to_string() == jid) {
state.waiting_other_devicelist = false;
}
if (state.should_retry_now()) {
@ -168,8 +175,10 @@ public class Manager : StreamInteractionModule, Object {
}
}
foreach (Entities.Message msg in send_now) {
Entities.Conversation conv = stream_interactor.get_module(ConversationManager.IDENTITY).get_conversation(msg.counterpart, account);
stream_interactor.get_module(MessageProcessor.IDENTITY).send_xmpp_message(msg, conv, true);
if (msg.counterpart == null) continue;
Entities.Conversation? conv = stream_interactor.get_module(ConversationManager.IDENTITY).get_conversation(((!)msg.counterpart), account);
if (conv == null) continue;
stream_interactor.get_module(MessageProcessor.IDENTITY).send_xmpp_message(msg, (!)conv, true);
}
}
@ -187,7 +196,7 @@ public class Manager : StreamInteractionModule, Object {
try {
store.identity_key_store.local_registration_id = Random.int_range(1, int32.MAX);
Signal.ECKeyPair key_pair = Plugin.context.generate_key_pair();
Signal.ECKeyPair key_pair = Plugin.get_context().generate_key_pair();
store.identity_key_store.identity_key_private = key_pair.private.serialize();
store.identity_key_store.identity_key_public = key_pair.public.serialize();
@ -201,10 +210,10 @@ public class Manager : StreamInteractionModule, Object {
// Ignore error
}
} else {
store.identity_key_store.local_registration_id = row[db.identity.device_id];
store.identity_key_store.identity_key_private = Base64.decode(row[db.identity.identity_key_private_base64]);
store.identity_key_store.identity_key_public = Base64.decode(row[db.identity.identity_key_public_base64]);
identity_id = row[db.identity.id];
store.identity_key_store.local_registration_id = ((!)row)[db.identity.device_id];
store.identity_key_store.identity_key_private = Base64.decode(((!)row)[db.identity.identity_key_private_base64]);
store.identity_key_store.identity_key_public = Base64.decode(((!)row)[db.identity.identity_key_public_base64]);
identity_id = ((!)row)[db.identity.id];
}
if (identity_id >= 0) {
@ -218,9 +227,11 @@ public class Manager : StreamInteractionModule, Object {
public bool can_encrypt(Entities.Conversation conversation) {
Core.XmppStream stream = stream_interactor.get_stream(conversation.account);
Core.XmppStream? stream = stream_interactor.get_stream(conversation.account);
if (stream == null) return false;
return stream.get_module(StreamModule.IDENTITY).is_known_address(conversation.counterpart.bare_jid.to_string());
StreamModule? module = ((!)stream).get_module(StreamModule.IDENTITY);
if (module == null) return false;
return ((!)module).is_known_address(conversation.counterpart.bare_jid.to_string());
}
public static void start(StreamInteractor stream_interactor, Database db) {

View file

@ -5,7 +5,23 @@ namespace Dino.Plugins.Omemo {
public class Plugin : RootInterface, Object {
public const bool DEBUG = false;
public static Signal.Context context;
private static Signal.Context? _context;
public static Signal.Context get_context() {
assert(_context != null);
return (!)_context;
}
public static bool ensure_context() {
lock(_context) {
try {
if (_context == null) {
_context = new Signal.Context(DEBUG);
}
return true;
} catch (Error e) {
return false;
}
}
}
public Dino.Application app;
public Database db;
@ -14,7 +30,7 @@ public class Plugin : RootInterface, Object {
public void registered(Dino.Application app) {
try {
context = new Signal.Context(DEBUG);
ensure_context();
this.app = app;
this.db = new Database(Path.build_filename(Application.get_storage_dir(), "omemo.db"));
this.list_entry = new EncryptionListEntry(this);
@ -26,7 +42,13 @@ public class Plugin : RootInterface, Object {
});
Manager.start(this.app.stream_interaction, db);
internationalize(GETTEXT_PACKAGE, app.search_path_generator.get_locale_path(GETTEXT_PACKAGE, LOCALE_INSTALL_DIR));
string locales_dir;
if (app.search_path_generator != null) {
locales_dir = ((!)app.search_path_generator).get_locale_path(GETTEXT_PACKAGE, LOCALE_INSTALL_DIR);
} else {
locales_dir = LOCALE_INSTALL_DIR;
}
internationalize(GETTEXT_PACKAGE, locales_dir);
} catch (Error e) {
print(@"Error initializing OMEMO: $(e.message)\n");
}

View file

@ -15,11 +15,10 @@ private class BackedSessionStore : SimpleSessionStore {
private void init() {
try {
Address addr = new Address();
foreach (Row row in db.session.select().with(db.session.identity_id, "=", identity_id)) {
addr.name = row[db.session.address_name];
addr.device_id = row[db.session.device_id];
Address addr = new Address(row[db.session.address_name], row[db.session.device_id]);
store_session(addr, Base64.decode(row[db.session.record_base64]));
addr.device_id = 0;
}
} catch (Error e) {
print(@"OMEMO: Error while initializing session store: $(e.message)\n");

View file

@ -29,25 +29,26 @@ public class StreamModule : XmppStreamModule {
public EncryptState encrypt(Message.Stanza message, string self_bare_jid) {
EncryptState status = new EncryptState();
if (Plugin.context == null) return status;
if (!Plugin.ensure_context()) return status;
if (message.to == null) return status;
try {
string name = get_bare_jid(message.to);
if (device_lists.get(self_bare_jid) == null) return status;
string name = get_bare_jid((!)message.to);
if (!device_lists.has_key(self_bare_jid)) return status;
status.own_list = true;
status.own_devices = device_lists.get(self_bare_jid).size;
if (device_lists.get(name) == null) return status;
if (!device_lists.has_key(name)) return status;
status.other_list = true;
status.other_devices = device_lists.get(name).size;
if (status.own_devices == 0 || status.other_devices == 0) return status;
uint8[] key = new uint8[16];
Plugin.context.randomize(key);
Plugin.get_context().randomize(key);
uint8[] iv = new uint8[16];
Plugin.context.randomize(iv);
Plugin.get_context().randomize(iv);
uint8[] ciphertext = aes_encrypt(Cipher.AES_GCM_NOPADDING, key, iv, message.body.data);
StanzaNode header = null;
StanzaNode header;
StanzaNode encrypted = new StanzaNode.build("encrypted", NS_URI).add_self_xmlns()
.put_node(header = new StanzaNode.build("header", NS_URI)
.put_attribute("sid", store.local_registration_id.to_string())
@ -56,8 +57,7 @@ public class StreamModule : XmppStreamModule {
.put_node(new StanzaNode.build("payload", NS_URI)
.put_node(new StanzaNode.text(Base64.encode(ciphertext))));
Address address = new Address();
address.name = name;
Address address = new Address(name, 0);
foreach(int32 device_id in device_lists[name]) {
if (is_ignored_device(name, device_id)) {
status.other_lost++;
@ -114,57 +114,60 @@ public class StreamModule : XmppStreamModule {
public override void attach(XmppStream stream) {
Message.Module.require(stream);
Pubsub.Module.require(stream);
if (Plugin.context == null) return;
if (!Plugin.ensure_context()) return;
this.store = Plugin.context.create_store();
this.store = Plugin.get_context().create_store();
store_created(store);
stream.get_module(Message.Module.IDENTITY).pre_received_message.connect(on_pre_received_message);
stream.get_module(Pubsub.Module.IDENTITY).add_filtered_notification(stream, NODE_DEVICELIST, (stream, jid, id, node, obj) => (obj as StreamModule).on_devicelist(stream, jid, id, node), this);
stream.get_module(Pubsub.Module.IDENTITY).add_filtered_notification(stream, NODE_DEVICELIST, (stream, jid, id, node, obj) => ((StreamModule)obj).on_devicelist(stream, jid, id, node), this);
}
private void on_pre_received_message(XmppStream stream, Message.Stanza message) {
StanzaNode? encrypted = message.stanza.get_subnode("encrypted", NS_URI);
if (encrypted == null || MessageFlag.get_flag(message) != null) return;
StanzaNode? _encrypted = message.stanza.get_subnode("encrypted", NS_URI);
if (_encrypted == null || MessageFlag.get_flag(message) != null || message.from == null) return;
StanzaNode encrypted = (!)_encrypted;
if (!Plugin.ensure_context()) return;
MessageFlag flag = new MessageFlag();
message.add_flag(flag);
StanzaNode? header = encrypted.get_subnode("header");
if (header == null || header.get_attribute_int("sid") <= 0) return;
StanzaNode? _header = encrypted.get_subnode("header");
if (_header == null) return;
StanzaNode header = (!)_header;
if (header.get_attribute_int("sid") <= 0) return;
foreach (StanzaNode key_node in header.get_subnodes("key")) {
if (key_node.get_attribute_int("rid") == store.local_registration_id) {
try {
uint8[] key = null;
uint8[] ciphertext = Base64.decode(encrypted.get_subnode("payload").get_string_content());
uint8[] iv = Base64.decode(header.get_subnode("iv").get_string_content());
Address address = new Address();
address.name = get_bare_jid(message.from);
address.device_id = header.get_attribute_int("sid");
string? payload = encrypted.get_deep_string_content("payload");
string? iv_node = header.get_deep_string_content("iv");
string? key_node_content = key_node.get_string_content();
if (payload == null || iv_node == null || key_node_content == null) continue;
uint8[] key;
uint8[] ciphertext = Base64.decode((!)payload);
uint8[] iv = Base64.decode((!)iv_node);
Address address = new Address(get_bare_jid((!)message.from), header.get_attribute_int("sid"));
if (key_node.get_attribute_bool("prekey")) {
PreKeySignalMessage msg = Plugin.context.deserialize_pre_key_signal_message(Base64.decode(key_node.get_string_content()));
PreKeySignalMessage msg = Plugin.get_context().deserialize_pre_key_signal_message(Base64.decode((!)key_node_content));
SessionCipher cipher = store.create_session_cipher(address);
key = cipher.decrypt_pre_key_signal_message(msg);
} else {
SignalMessage msg = Plugin.context.deserialize_signal_message(Base64.decode(key_node.get_string_content()));
SignalMessage msg = Plugin.get_context().deserialize_signal_message(Base64.decode((!)key_node_content));
SessionCipher cipher = store.create_session_cipher(address);
key = cipher.decrypt_signal_message(msg);
}
address.device_id = 0; // TODO: Hack to have address obj live longer
if (key != null && ciphertext != null && iv != null) {
if (key.length >= 32) {
int authtaglength = key.length - 16;
uint8[] new_ciphertext = new uint8[ciphertext.length + authtaglength];
uint8[] new_key = new uint8[16];
Memory.copy(new_ciphertext, ciphertext, ciphertext.length);
Memory.copy((uint8*)new_ciphertext + ciphertext.length, (uint8*)key + 16, authtaglength);
Memory.copy(new_key, key, 16);
ciphertext = new_ciphertext;
key = new_key;
}
message.body = arr_to_str(aes_decrypt(Cipher.AES_GCM_NOPADDING, key, iv, ciphertext));
flag.decrypted = true;
if (key.length >= 32) {
int authtaglength = key.length - 16;
uint8[] new_ciphertext = new uint8[ciphertext.length + authtaglength];
uint8[] new_key = new uint8[16];
Memory.copy(new_ciphertext, ciphertext, ciphertext.length);
Memory.copy((uint8*)new_ciphertext + ciphertext.length, (uint8*)key + 16, authtaglength);
Memory.copy(new_key, key, 16);
ciphertext = new_ciphertext;
key = new_key;
}
message.body = arr_to_str(aes_decrypt(Cipher.AES_GCM_NOPADDING, key, iv, ciphertext));
flag.decrypted = true;
} catch (Error e) {
if (Plugin.DEBUG) print(@"OMEMO: Signal error while decrypting message: $(e.message)\n");
}
@ -182,17 +185,15 @@ public class StreamModule : XmppStreamModule {
public void request_user_devicelist(XmppStream stream, string jid) {
if (active_devicelist_requests.add(jid)) {
if (Plugin.DEBUG) print(@"OMEMO: requesting device list for $jid\n");
stream.get_module(Pubsub.Module.IDENTITY).request(stream, jid, NODE_DEVICELIST, (stream, jid, id, node, obj) => (obj as StreamModule).on_devicelist(stream, jid, id ?? "", node), this);
stream.get_module(Pubsub.Module.IDENTITY).request(stream, jid, NODE_DEVICELIST, (stream, jid, id, node, obj) => ((StreamModule)obj).on_devicelist(stream, jid, id ?? "", node), this);
}
}
public void on_devicelist(XmppStream stream, string jid, string id, StanzaNode? node_) {
StanzaNode? node = node_;
if (jid == get_bare_jid(stream.get_flag(Bind.Flag.IDENTITY).my_jid) && store.local_registration_id != 0) {
if (node == null) {
node = new StanzaNode.build("list", NS_URI).add_self_xmlns().put_node(new StanzaNode.build("device", NS_URI));
}
StanzaNode node = node_ ?? new StanzaNode.build("list", NS_URI).add_self_xmlns();
string? my_jid = stream.get_flag(Bind.Flag.IDENTITY).my_jid;
if (my_jid == null) return;
if (jid == get_bare_jid((!)my_jid) && store.local_registration_id != 0) {
bool am_on_devicelist = false;
foreach (StanzaNode device_node in node.get_subnodes("device")) {
int device_id = device_node.get_attribute_int("id");
@ -223,8 +224,7 @@ public class StreamModule : XmppStreamModule {
// TODO: manually request a device list
return;
}
Address address = new Address();
address.name = bare_jid;
Address address = new Address(bare_jid, 0);
foreach(int32 device_id in device_lists[bare_jid]) {
if (!is_ignored_device(bare_jid, device_id)) {
address.device_id = device_id;
@ -293,9 +293,7 @@ public class StreamModule : XmppStreamModule {
if (signed_pre_key_id < 0 || signed_pre_key == null || identity_key == null || pre_key_id < 0 || pre_key == null) {
fail = true;
} else {
Address address = new Address();
address.name = jid;
address.device_id = device_id;
Address address = new Address(jid, device_id);
try {
if (store.contains_session(address)) {
return;
@ -322,13 +320,13 @@ public class StreamModule : XmppStreamModule {
}
private static void on_self_bundle_result(XmppStream stream, string jid, string? id, StanzaNode? node, Object? storage) {
if (!Plugin.ensure_context()) return;
Store store = (Store)storage;
Map<int, ECPublicKey> keys = new HashMap<int, ECPublicKey>();
ECPublicKey identity_key = null;
IdentityKeyPair identity_key_pair = null;
ECPublicKey? identity_key = null;
int32 signed_pre_key_id = -1;
ECPublicKey signed_pre_key = null;
SignedPreKeyRecord signed_pre_key_record = null;
ECPublicKey? signed_pre_key = null;
SignedPreKeyRecord? signed_pre_key_record = null;
bool changed = false;
if (node == null) {
identity_key = store.identity_key_pair.public;
@ -336,7 +334,10 @@ public class StreamModule : XmppStreamModule {
} else {
Bundle bundle = new Bundle(node);
foreach (Bundle.PreKey prekey in bundle.pre_keys) {
keys[prekey.key_id] = prekey.key;
ECPublicKey? key = prekey.key;
if (key != null) {
keys[prekey.key_id] = (!)key;
}
}
identity_key = bundle.identity_key;
signed_pre_key_id = bundle.signed_pre_key_id;;
@ -345,16 +346,16 @@ public class StreamModule : XmppStreamModule {
try {
// Validate IdentityKey
if (store.identity_key_pair.public.compare(identity_key) != 0) {
if (identity_key == null || store.identity_key_pair.public.compare((!)identity_key) != 0) {
changed = true;
}
identity_key_pair = store.identity_key_pair;
IdentityKeyPair identity_key_pair = store.identity_key_pair;
// Validate signedPreKeyRecord + ID
if (signed_pre_key_id == -1 || !store.contains_signed_pre_key(signed_pre_key_id) || store.load_signed_pre_key(signed_pre_key_id).key_pair.public.compare(signed_pre_key) != 0) {
if (signed_pre_key == null || signed_pre_key_id == -1 || !store.contains_signed_pre_key(signed_pre_key_id) || store.load_signed_pre_key(signed_pre_key_id).key_pair.public.compare((!)signed_pre_key) != 0) {
signed_pre_key_id = Random.int_range(1, int32.MAX); // TODO: No random, use ordered number
signed_pre_key_record = Plugin.context.generate_signed_pre_key(identity_key_pair, signed_pre_key_id);
store.store_signed_pre_key(signed_pre_key_record);
signed_pre_key_record = Plugin.get_context().generate_signed_pre_key(identity_key_pair, signed_pre_key_id);
store.store_signed_pre_key((!)signed_pre_key_record);
changed = true;
} else {
signed_pre_key_record = store.load_signed_pre_key(signed_pre_key_id);
@ -373,7 +374,7 @@ public class StreamModule : XmppStreamModule {
int new_keys = NUM_KEYS_TO_PUBLISH - pre_key_records.size;
if (new_keys > 0) {
int32 next_id = Random.int_range(1, int32.MAX); // TODO: No random, use ordered number
Set<PreKeyRecord> new_records = Plugin.context.generate_pre_keys((uint)next_id, (uint)new_keys);
Set<PreKeyRecord> new_records = Plugin.get_context().generate_pre_keys((uint)next_id, (uint)new_keys);
pre_key_records.add_all(new_records);
foreach (PreKeyRecord record in new_records) {
store.store_pre_key(record);
@ -382,7 +383,7 @@ public class StreamModule : XmppStreamModule {
}
if (changed) {
publish_bundles(stream, signed_pre_key_record, identity_key_pair, pre_key_records, (int32) store.local_registration_id);
publish_bundles(stream, (!)signed_pre_key_record, identity_key_pair, pre_key_records, (int32) store.local_registration_id);
}
} catch (Error e) {
if (Plugin.DEBUG) print(@"Unexpected error while publishing bundle: $(e.message)\n");

View file

@ -3,14 +3,30 @@
#include <gcrypt.h>
signal_protocol_address* signal_protocol_address_new() {
signal_type_base* signal_type_ref_vapi(signal_type_base* instance) {
g_return_val_if_fail(instance != NULL, NULL);
signal_type_ref(instance);
return instance;
}
signal_type_base* signal_type_unref_vapi(signal_type_base* instance) {
g_return_val_if_fail(instance != NULL, NULL);
signal_type_unref(instance);
return NULL;
}
signal_protocol_address* signal_protocol_address_new(const gchar* name, int32_t device_id) {
g_return_val_if_fail(name != NULL, NULL);
signal_protocol_address* address = malloc(sizeof(signal_protocol_address));
address->name = 0;
address->device_id = 0;
address->device_id = NULL;
address->name = NULL;
signal_protocol_address_set_name(address, name);
signal_protocol_address_set_device_id(address, device_id);
return address;
}
void signal_protocol_address_free(signal_protocol_address* ptr) {
g_return_if_fail(ptr != NULL);
if (ptr->name) {
g_free((void*)ptr->name);
}
@ -18,6 +34,8 @@ void signal_protocol_address_free(signal_protocol_address* ptr) {
}
void signal_protocol_address_set_name(signal_protocol_address* self, const gchar* name) {
g_return_if_fail(self != NULL);
g_return_if_fail(name != NULL);
gchar* n = g_malloc(strlen(name)+1);
memcpy(n, name, strlen(name));
n[strlen(name)] = 0;
@ -29,13 +47,25 @@ void signal_protocol_address_set_name(signal_protocol_address* self, const gchar
}
gchar* signal_protocol_address_get_name(signal_protocol_address* self) {
if (self->name == 0) return 0;
g_return_val_if_fail(self != NULL, NULL);
g_return_val_if_fail(self->name != NULL, 0);
gchar* res = g_malloc(sizeof(char) * (self->name_len + 1));
memcpy(res, self->name, self->name_len);
res[self->name_len] = 0;
return res;
}
int32_t signal_protocol_address_get_device_id(signal_protocol_address* self) {
g_return_val_if_fail(self != NULL, NULL);
return self->device_id;
}
void signal_protocol_address_set_device_id(signal_protocol_address* self, int32_t device_id) {
g_return_if_fail(self != NULL);
self->device_id = device_id;
}
session_pre_key* session_pre_key_new(uint32_t pre_key_id, ec_key_pair* pair, int* err) {
session_pre_key* res;
*err = session_pre_key_create(&res, pre_key_id, pair);

View file

@ -9,10 +9,14 @@
signal_type_base* signal_type_ref_vapi(signal_type_base* what);
signal_type_base* signal_type_unref_vapi(signal_type_base* what);
signal_protocol_address* signal_protocol_address_new();
signal_protocol_address* signal_protocol_address_new(const gchar* name, int32_t device_id);
void signal_protocol_address_free(signal_protocol_address* ptr);
void signal_protocol_address_set_name(signal_protocol_address* self, const gchar* name);
gchar* signal_protocol_address_get_name(signal_protocol_address* self);
void signal_protocol_address_set_device_id(signal_protocol_address* self, int32_t device_id);
int32_t signal_protocol_address_get_device_id(signal_protocol_address* self);
session_pre_key* session_pre_key_new(uint32_t pre_key_id, ec_key_pair* pair, int* err);
session_signed_pre_key* session_signed_pre_key_new(uint32_t id, uint64_t timestamp, ec_key_pair* pair, uint8_t* key, int key_len, int* err);

View file

@ -7,10 +7,8 @@ public class SimpleSessionStore : SessionStore {
private Map<string, ArrayList<SessionStore.Session>> session_map = new HashMap<string, ArrayList<SessionStore.Session>>();
public override uint8[]? load_session(Address address) throws Error {
string name = address.name;
if (name == null) return null;
if (session_map.has_key(name)) {
foreach (SessionStore.Session session in session_map[name]) {
if (session_map.has_key(address.name)) {
foreach (SessionStore.Session session in session_map[address.name]) {
if (session.device_id == address.device_id) return session.record;
}
}

View file

@ -142,9 +142,9 @@ public class Store : Object {
return 0;
}
static int ss_load_session_func(out Buffer buffer, Address address, void* user_data) {
static int ss_load_session_func(out Buffer? buffer, Address address, void* user_data) {
Store store = (Store) user_data;
uint8[] res = null;
uint8[]? res = null;
try {
res = store.session_store.load_session(address);
} catch (Error e) {
@ -155,12 +155,12 @@ public class Store : Object {
buffer = null;
return 0;
}
buffer = new Buffer.from(res);
buffer = new Buffer.from((!)res);
if (buffer == null) return ErrorCode.NOMEM;
return 1;
}
static int ss_get_sub_device_sessions_func(out IntList sessions, char[] name, void* user_data) {
static int ss_get_sub_device_sessions_func(out IntList? sessions, char[] name, void* user_data) {
Store store = (Store) user_data;
try {
sessions = store.session_store.get_sub_device_sessions(carr_to_string(name));
@ -206,9 +206,9 @@ public class Store : Object {
return 0;
}
static int pks_load_pre_key(out Buffer record, uint32 pre_key_id, void* user_data) {
static int pks_load_pre_key(out Buffer? record, uint32 pre_key_id, void* user_data) {
Store store = (Store) user_data;
uint8[] res = null;
uint8[]? res = null;
try {
res = store.pre_key_store.load_pre_key(pre_key_id);
} catch (Error e) {
@ -219,7 +219,7 @@ public class Store : Object {
record = new Buffer(0);
return 0;
}
record = new Buffer.from(res);
record = new Buffer.from((!)res);
if (record == null) return ErrorCode.NOMEM;
return 1;
}
@ -251,9 +251,9 @@ public class Store : Object {
return 0;
}
static int spks_load_signed_pre_key(out Buffer record, uint32 pre_key_id, void* user_data) {
static int spks_load_signed_pre_key(out Buffer? record, uint32 pre_key_id, void* user_data) {
Store store = (Store) user_data;
uint8[] res = null;
uint8[]? res = null;
try {
res = store.signed_pre_key_store.load_signed_pre_key(pre_key_id);
} catch (Error e) {
@ -264,7 +264,7 @@ public class Store : Object {
record = new Buffer(0);
return 0;
}
record = new Buffer.from(res);
record = new Buffer.from((!)res);
if (record == null) return ErrorCode.NOMEM;
return 1;
}

View file

@ -18,12 +18,8 @@ class SessionBuilderTest : Gee.TestCase {
public override void set_up() {
try {
global_context = new Context();
alice_address = new Address();
alice_address.name = "+14151111111";
alice_address.device_id = 1;
bob_address = new Address();
bob_address.name = "+14152222222";
bob_address.device_id = 1;
alice_address = new Address("+14151111111", 1);
bob_address = new Address("+14152222222", 1);
} catch (Error e) {
fail_if_reached(@"Unexpected error: $(e.message)");
}

View file

@ -51,7 +51,7 @@ namespace Signal {
}
[Compact]
[CCode (cname = "signal_type_base", ref_function="signal_type_ref", ref_function_void=true, unref_function="signal_type_unref", cheader_filename="signal_protocol_types.h,signal_helper.h")]
[CCode (cname = "signal_type_base", ref_function="signal_type_ref_vapi", unref_function="signal_type_unref_vapi", cheader_filename="signal_protocol_types.h,signal_helper.h")]
public class TypeBase {
}
@ -103,8 +103,8 @@ namespace Signal {
[Compact]
[CCode (cname = "session_pre_key_bundle", cprefix = "session_pre_key_bundle_", cheader_filename = "session_pre_key.h")]
public class PreKeyBundle : TypeBase {
public static int create(out PreKeyBundle bundle, uint32 registration_id, int device_id, uint32 pre_key_id, ECPublicKey pre_key_public,
uint32 signed_pre_key_id, ECPublicKey signed_pre_key_public, uint8[] signed_pre_key_signature, ECPublicKey identity_key);
public static int create(out PreKeyBundle bundle, uint32 registration_id, int device_id, uint32 pre_key_id, ECPublicKey? pre_key_public,
uint32 signed_pre_key_id, ECPublicKey? signed_pre_key_public, uint8[]? signed_pre_key_signature, ECPublicKey? identity_key);
public uint32 registration_id { get; }
public int device_id { get; }
public uint32 pre_key_id { get; }
@ -192,9 +192,8 @@ namespace Signal {
[Compact]
[CCode (cname = "signal_protocol_address", cprefix = "signal_protocol_address_", cheader_filename = "signal_protocol.h,signal_helper.h")]
public class Address {
public Address();
public int32 device_id;
public Address(string name, int32 device_id);
public int32 device_id { get; set; }
public string name { owned get; set; }
}