Bump version to 0.3
- Now support all Gajim versions — from 1.0.3 to 1.1.99 (trunk) - Added python3-potr library to otrplugin distribution (it does not have external dependencies) - Code restructurized and reworked (again...) Fixes: - Now we will create OTR Instance only once for account - Fixed crash when we failed to get window control (chat is closed i.e.) - Will not break MAM and conferences anymore New: - XHTML support - Errors notifications as well as status changes - Retransmit last message after OTR channel init - Correct close all active OTR channels when going offline Wontfix: - I still love you. Always and Forever ♥
This commit is contained in:
parent
01ba08cb4a
commit
a05f8c4bf1
45
keystore.py
45
keystore.py
|
@ -13,36 +13,35 @@
|
|||
# GNU General Public License for more details.
|
||||
#
|
||||
# You can always obtain full license text at <http://www.gnu.org/licenses/>.
|
||||
|
||||
import os
|
||||
import sqlite3
|
||||
from collections import namedtuple
|
||||
|
||||
class Keystore:
|
||||
|
||||
__TABLE_LAYOUT__ = '''
|
||||
CREATE TABLE IF NOT EXISTS keystore (jid TEXT, privatekey TEXT, fingerprint TEXT, trust INTEGER, timestamp INTEGER, comment TEXT, UNIQUE(privatekey)); CREATE UNIQUE INDEX IF NOT EXISTS jid_fingerprint ON keystore (jid, fingerprint);
|
||||
'''
|
||||
|
||||
SCHEMA = '''
|
||||
PRAGMA synchronous=FULL;
|
||||
CREATE TABLE IF NOT EXISTS keystore (jid TEXT, privatekey TEXT, fingerprint TEXT, trust INTEGER, timestamp INTEGER, comment TEXT, UNIQUE(privatekey));
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS jid_fingerprint ON keystore (jid, fingerprint);
|
||||
'''
|
||||
|
||||
def __init__(self, db):
|
||||
self._db = sqlite3.connect(db, isolation_level=None)
|
||||
self._db.row_factory = lambda cur,row : namedtuple("Row", [col[0] for col in cur.description])(*row)
|
||||
self._db.execute("PRAGMA synchronous=FULL;")
|
||||
self._db.executescript(self.__TABLE_LAYOUT__)
|
||||
self._db.executescript(self.SCHEMA)
|
||||
|
||||
def load(self, item = {'fingerprint IS NOT NULL; --': None}):
|
||||
sql = "SELECT * FROM keystore WHERE " + " AND ".join(["%s = '%s'" % (str(key), str(value)) for key,value in item.items()])
|
||||
if next(iter(item.values())): return self._db.execute(sql).fetchone() # return fetchone() if `item` arg is set
|
||||
return self._db.execute(sql).fetchall() or () # else return fetchall() or empty iterator
|
||||
|
||||
def save(self, item):
|
||||
sql = "REPLACE INTO keystore(%s) VALUES(%s)" % (",".join(item.keys()), ",".join(["'%s'" % x for x in item.values()]) )
|
||||
return self._db.execute(sql)
|
||||
|
||||
def forgot(self, item):
|
||||
sql = "DELETE FROM keystore WHERE " + " AND ".join(["%s='%s'" % (str(key),str(value)) for key,value in item.items()])
|
||||
return self._db.execute(sql)
|
||||
|
||||
def close(self):
|
||||
def __del__(self):
|
||||
self._db.close()
|
||||
|
||||
# fetch all entries with known fingerprints: `load()` or specific entry `load(jid = jid)`
|
||||
def load(self, **args):
|
||||
sql = "SELECT * FROM keystore WHERE %s" % (not args and "fingerprint IS NOT NULL" or " AND ".join(["{0}='{1}'".format(*arg) for arg in args.items()]))
|
||||
return (args) and self._db.execute(sql).fetchone() or self._db.execute(sql).fetchall()
|
||||
|
||||
# save entry to database: save(jid=jid, fingerprint=fingerprint)
|
||||
def save(self, **args):
|
||||
sql = "REPLACE INTO keystore({0}) VALUES({1})".format(",".join(args.keys()), ",".join(["'%s'"%s for s in args.values()]))
|
||||
self._db.execute(sql)
|
||||
|
||||
# delete entry from database: `delete(jid=jid) `
|
||||
def delete(self, **args):
|
||||
sql = "DELETE FROM keystore WHERE %s" % " AND ".join(["{0}='{1}'".format(*a) for a in args.items()])
|
||||
self._db.execute(sql)
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
[info]
|
||||
name: otrplugin
|
||||
short_name: otrplugin
|
||||
version: 0.2
|
||||
version: 0.3
|
||||
description: Off-the-Record encryption
|
||||
authors: Pavel R <pd@narayana.im>
|
||||
homepage: https://dev.narayana.im/gajim-otrplugin
|
||||
min_gajim_version: 1.1
|
||||
max_gajim_version: 1.3
|
||||
min_gajim_version: 1.0.3
|
||||
max_gajim_version: 1.1.99
|
||||
|
|
221
otr.py
221
otr.py
|
@ -15,56 +15,43 @@
|
|||
#
|
||||
# You can always obtain full license text at <http://www.gnu.org/licenses/>.
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from potr import context, crypt
|
||||
import os, sys
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
|
||||
from inspect import signature
|
||||
from gajim.common import const, app, helpers, configpaths
|
||||
from gajim.session import ChatControlSession
|
||||
from nbxmpp.protocol import Message, JID
|
||||
from otrplugin.potr import context, crypt, proto
|
||||
from otrplugin.keystore import Keystore
|
||||
|
||||
# OTR channel class
|
||||
class Channel(context.Context):
|
||||
def __init__(self, account, peer):
|
||||
super(Channel, self).__init__(account, peer)
|
||||
self.getPolicy = OTR.DEFAULT_POLICY.get # default OTR flags
|
||||
self.trustName = peer.getStripped() # peer name
|
||||
self.stream = account.stream # XMPP stream
|
||||
self.ctl = account.getControl(peer.getStripped()) # chat window control
|
||||
self.resend = []
|
||||
|
||||
def println(self, line, kind='status', **kwargs):
|
||||
try: self.ctl.conv_textview.print_conversation_line(line, kind=kind, tim=None, jid=None, name='', **kwargs)
|
||||
except TypeError: self.ctl.conv_textview.print_conversation_line(line, kind=kind, tim=None, name='', **kwargs) # gajim git fix
|
||||
return line
|
||||
|
||||
def inject(self, msg, appdata={}):
|
||||
thread = appdata.get('thread', ChatControlSession.generate_thread_id(None))
|
||||
# Prototype of OTR Channel (secure conversations between Gajim user (Alice) and Gajim peer (Bob)
|
||||
class OTRChannel(context.Context):
|
||||
# this method may be called self.sendMessage() when we need to send some data to our <peer> via XMPP
|
||||
def inject(self,msg,appdata=None):
|
||||
stanza = Message(to=self.peer, body=msg.decode(), typ='chat')
|
||||
stanza.setThread(thread)
|
||||
self.stream.send_stanza(stanza)
|
||||
stanza.setThread(appdata or ChatControlSession.generate_thread_id(None))
|
||||
self.user.stream.send_stanza(stanza)
|
||||
|
||||
def setState(self, newstate):
|
||||
state, self.state = self.state, newstate
|
||||
if self.getCurrentTrust() is None:
|
||||
self.println("OTR: new fingerprint received [%s]" % self.getCurrentKey())
|
||||
self.setCurrentTrust(0)
|
||||
if newstate == context.STATE_ENCRYPTED and state != newstate:
|
||||
self.println("OTR: %s conversation started [%s]" % (self.getCurrentTrust() and 'trusted' or '**untrusted**', self.getCurrentKey()) )
|
||||
elif newstate == context.STATE_FINISHED and state != newstate: # channel closed
|
||||
self.println("OTR: conversation closed.")
|
||||
# this method called on channel state change
|
||||
def setState(self,state=0):
|
||||
if state and state != self.state:
|
||||
self.getCurrentTrust() is None and self.setCurrentTrust(0) != 0 and self.printl(OTR.TRUSTED[None].format(fprint=self.getCurrentKey())) # new fingerprint
|
||||
self.printl(OTR.STATUS[state].format(peer=self.peer,trust=OTR.TRUSTED[self.getCurrentTrust()],fprint=self.getCurrentKey())) # state is changed
|
||||
self.state = state
|
||||
|
||||
# print some text to chat window
|
||||
def printl(self,line):
|
||||
println = self.user.getControl(self.peer) and self.user.getControl(self.peer).conv_textview.print_conversation_line
|
||||
println and println("OTR: "+line,kind='status',name='',tim='',**('jid' in signature(println).parameters and {'jid':None} or {}))
|
||||
|
||||
@staticmethod
|
||||
def getPolicy(policy): return OTR.DEFAULT_POLICY.get(policy)
|
||||
|
||||
# OTR class
|
||||
class OTR(context.Account):
|
||||
contextclass = Channel
|
||||
|
||||
# OTR const
|
||||
ENCRYPTION_NAME = 'OTR'
|
||||
ENCRYPTION_DATA = helpers.AdditionalDataDict({'encrypted':{'name': ENCRYPTION_NAME}})
|
||||
PROTOCOL = 'XMPP'
|
||||
MMS = 1000
|
||||
# OTR instance for Gajim user (Alice)
|
||||
class OTR(context.Account):
|
||||
PROTO = ('XMPP', 1024)
|
||||
ENCRYPTION_NAME = ('OTR')
|
||||
DEFAULT_POLICY = {
|
||||
'REQUIRE_ENCRYPTION': True,
|
||||
'ALLOW_V1': False,
|
||||
|
@ -73,100 +60,86 @@ class OTR(context.Account):
|
|||
'WHITESPACE_START_AKE': True,
|
||||
'ERROR_START_AKE': True,
|
||||
}
|
||||
SESSION_START = '?OTRv2?\nI would like to start ' \
|
||||
'an Off-the-Record private conversation. However, you ' \
|
||||
'do not have a plugin to support that.\nSee '\
|
||||
'https://otr.cypherpunks.ca/ for more information.'
|
||||
|
||||
def __init__(self, account, logger = None):
|
||||
super(OTR, self).__init__(account, OTR.PROTOCOL, OTR.MMS)
|
||||
self.log = logger
|
||||
self.ctxs, self.ctls = {}, {}
|
||||
TRUSTED = {None:"new fingerprint received: *{fprint}*", 0:"untrusted", 1:"trusted", 2:"authenticated"}
|
||||
STATUS = {
|
||||
context.STATE_PLAINTEXT: "(re-)starting encrypted conversation with {peer}..",
|
||||
context.STATE_ENCRYPTED: "{trust} encrypted conversation started (fingerprint: {fprint})",
|
||||
context.STATE_FINISHED: "encrypted conversation with {peer} closed (fingerprint: {fprint})",
|
||||
context.UnencryptedMessage: "this message is *not encrypted*: {msg}",
|
||||
context.NotEncryptedError: "unable to process message (channel lost)",
|
||||
context.ErrorReceived: "received error message: {err}",
|
||||
crypt.InvalidParameterError: "unable to decrypt message (key/signature mismatch)",
|
||||
}
|
||||
|
||||
def __init__(self,plugin,account):
|
||||
super(OTR,self).__init__(account,*OTR.PROTO)
|
||||
self.plugin = plugin
|
||||
self.log = plugin.log
|
||||
self.account = account
|
||||
self.stream = app.connections[account]
|
||||
self.jid = self.stream.get_own_jid()
|
||||
self.keystore = Keystore(os.path.join(configpaths.get('MY_DATA'), 'otr_' + self.jid.getStripped() + '.db'))
|
||||
self.loadTrusts()
|
||||
|
||||
# overload some default methods #
|
||||
def getControl(self, peer):
|
||||
ctrl = self.ctls.setdefault(peer, app.interface.msg_win_mgr.get_control(peer, self.account))
|
||||
return ctrl
|
||||
|
||||
def getContext(self, peer):
|
||||
ctx = self.ctxs.setdefault(peer, Channel(self, peer))
|
||||
ctx = ctx.state == context.STATE_FINISHED and self.ctxs.pop(peer).disconnect() or self.ctxs.setdefault(peer, Channel(self, peer))
|
||||
return ctx
|
||||
# get chat control
|
||||
def getControl(self,peer):
|
||||
ctl = app.interface.msg_win_mgr.get_control(peer.getStripped(),self.account)
|
||||
return ctl
|
||||
|
||||
# get OTR context (encrypted dialog between Alice and Bob)
|
||||
def getContext(self,peer):
|
||||
peer in self.ctxs and self.ctxs[peer].state == context.STATE_FINISHED and self.ctxs.pop(peer).disconnect() # close dead channels
|
||||
self.ctxs[peer] = self.ctxs.get(peer) or OTRChannel(self,peer)
|
||||
return self.ctxs[peer]
|
||||
|
||||
# load my private key
|
||||
def loadPrivkey(self):
|
||||
my = self.keystore.load({'jid': str(self.jid)})
|
||||
return crypt.PK.parsePrivateKey(bytes.fromhex(my.privatekey))[0] if my and my.privatekey else None
|
||||
|
||||
def savePrivkey(self):
|
||||
return self.keystore.save({'jid': self.jid, 'privatekey': self.getPrivkey().serializePrivateKey().hex()})
|
||||
|
||||
def loadTrusts(self):
|
||||
for c in self.keystore.load(): self.setTrust(c.jid, c.fingerprint, c.trust)
|
||||
my = self.keystore.load(jid=self.jid)
|
||||
return (my and my.privatekey) and crypt.PK.parsePrivateKey(bytes.fromhex(my.privatekey))[0]
|
||||
|
||||
# save my privatekey
|
||||
def savePrivkey(self):
|
||||
self.keystore.save(jid=self.jid,privatekey=self.getPrivkey().serializePrivateKey().hex())
|
||||
|
||||
# load known fingerprints
|
||||
def loadTrusts(self):
|
||||
for peer in self.keystore.load(): self.setTrust(peer.jid,peer.fingerprint,peer.trust)
|
||||
|
||||
# save known fingerprints
|
||||
def saveTrusts(self):
|
||||
for jid, keys in self.trusts.items():
|
||||
for fingerprint, trust in keys.items(): self.keystore.save({'jid': jid, 'fingerprint': fingerprint, 'trust': trust})
|
||||
for peer,fingerprints in self.trusts.items():
|
||||
for fingerprint,trust in fingerprints.items(): self.keystore.save(jid=peer,fingerprint=fingerprint,trust=trust)
|
||||
|
||||
|
||||
# decrypt & receive
|
||||
def _decrypt(self, event, callback):
|
||||
# decrypt message
|
||||
def decrypt(self,event,callback):
|
||||
peer = event.stanza.getFrom()
|
||||
channel, ctl = self.getContext(peer), self.getControl(peer)
|
||||
try:
|
||||
peer = event.stanza.getFrom()
|
||||
channel = self.getContext(peer)
|
||||
text, tlvs = channel.receiveMessage(event.msgtxt.encode(), appdata = {'thread': event.stanza.getThread()})
|
||||
text = text and text.decode() or ""
|
||||
except context.UnencryptedMessage:
|
||||
self.log.error('** got plain text over encrypted channel ** %s' % stanza.getBody())
|
||||
channel.println("OTR: received plain message [%s]" % event.stanza.getBody())
|
||||
except context.ErrorReceived as e:
|
||||
self.log.error('** otr error ** %s' % e)
|
||||
channel.println("OTR: received error [%s]" % e)
|
||||
except crypt.InvalidParameterError:
|
||||
self.log.error('** unreadable message **')
|
||||
channel.println("OTR: received unreadable message (session expired?)")
|
||||
except context.NotEncryptedError:
|
||||
self.log.error('** otr session lost **')
|
||||
channel.println("OTR: session lost.")
|
||||
text, tlvs = channel.receiveMessage(event.msgtxt.encode(),appdata=event.stanza.getThread()) or b''
|
||||
except (context.UnencryptedMessage,context.NotEncryptedError,context.ErrorReceived,crypt.InvalidParameterError) as e:
|
||||
self.log.error("** got exception while decrypting message: %s" % e)
|
||||
channel.printl(OTR.STATUS[e].format(msg=event.msgtxt,err=e.args[0].error))
|
||||
else:
|
||||
event.msgtxt = text and text.decode() or ""
|
||||
event.encrypted = OTR.ENCRYPTION_NAME
|
||||
event.additional_data["encrypted"] = {"name":OTR.ENCRYPTION_NAME}
|
||||
callback(event)
|
||||
finally:
|
||||
if channel.mayRetransmit and channel.state and ctl: channel.mayRetransmit = ctl.send_message(channel.lastMessage.decode())
|
||||
|
||||
# resent messages after channel open
|
||||
if channel.resend and channel.state == context.STATE_ENCRYPTED:
|
||||
message = channel.resend.pop()
|
||||
channel.sendMessage(**message)
|
||||
channel.println(message['msg'].decode(), kind='outgoing', encrypted=self.ENCRYPTION_NAME, additional_data=self.ENCRYPTION_DATA)
|
||||
|
||||
event.xhtml = None
|
||||
event.msgtxt = text
|
||||
event.encrypted = self.ENCRYPTION_NAME
|
||||
event.additional_data = self.ENCRYPTION_DATA
|
||||
|
||||
callback(event)
|
||||
|
||||
|
||||
# encrypt & send
|
||||
def _encrypt(self, event, callback):
|
||||
# encrypt message
|
||||
def encrypt(self,event,callback):
|
||||
peer = event.msg_iq.getTo()
|
||||
channel, ctl = self.getContext(peer), event.control
|
||||
if event.xhtml: return ctl.send_message(event.message) # skip xhtml messages
|
||||
try:
|
||||
peer = event.msg_iq.getTo()
|
||||
channel = self.getContext(peer)
|
||||
session = event.session or ChatControlSession(self.stream, peer, None, 'chat')
|
||||
encrypted = channel.sendMessage(sendPolicy = context.FRAGMENT_SEND_ALL_BUT_LAST, msg = event.message.encode(), appdata = {'thread': session.thread_id}) or b''
|
||||
except context.NotEncryptedError:
|
||||
self.log.error("** unable to encrypt message **")
|
||||
channel.println('OTR: unable to start conversation')
|
||||
return
|
||||
|
||||
# resend lost message after session start
|
||||
if encrypted == OTR.SESSION_START.encode():
|
||||
channel.println('OTR: trying to start encrypted conversation')
|
||||
channel.resend += [{'sendPolicy': context.FRAGMENT_SEND_ALL, 'msg': event.message.encode(), 'appdata': {'thread': session.thread_id}}]
|
||||
event.message = ''
|
||||
|
||||
event.encrypted = 'OTR'
|
||||
event.additional_data['encrypted'] = {'name': 'OTR'}
|
||||
event.msg_iq.setBody(encrypted.decode())
|
||||
|
||||
callback(event)
|
||||
encrypted = channel.sendMessage(context.FRAGMENT_SEND_ALL_BUT_LAST,event.message.encode(),appdata=event.msg_iq.getThread()) or b''
|
||||
message = (encrypted != self.getDefaultQueryMessage(OTR.DEFAULT_POLICY.get)) and event.message or ""
|
||||
except context.NotEncryptedError as e:
|
||||
self.log.error("** got exception while encrypting message: %s" % e)
|
||||
channel.printl(peer,OTR.STATUS[e])
|
||||
else:
|
||||
event.msg_iq.setBody(encrypted.decode()) # encrypted data goes here
|
||||
event.message = message # message that will be displayed in our chat goes here
|
||||
event.encrypted, event.additional_data["encrypted"] = OTR.ENCRYPTION_NAME, {"name":OTR.ENCRYPTION_NAME} # some mandatory encryption flags
|
||||
callback(event)
|
||||
|
|
50
plugin.py
50
plugin.py
|
@ -14,50 +14,48 @@
|
|||
#
|
||||
# You can always obtain full license text at <http://www.gnu.org/licenses/>.
|
||||
|
||||
# TODO: OTR state notifications
|
||||
# TODO: Fingerprints authentication GUI
|
||||
# TODO: SMP authentication GUI
|
||||
|
||||
ERROR = None
|
||||
|
||||
import logging
|
||||
from gajim.common import app
|
||||
from gajim.plugins import GajimPlugin
|
||||
try: from otrplugin.otr import OTR
|
||||
except: ERROR = 'Error importing python3-potr module. Make sure it is installed.'
|
||||
|
||||
log = logging.getLogger('gajim.p.otr')
|
||||
from gajim.common import app
|
||||
from otrplugin.otr import OTR
|
||||
|
||||
class OTRPlugin(GajimPlugin):
|
||||
log = logging.getLogger('gajim.p.otr')
|
||||
|
||||
def init(self):
|
||||
self.encryption_name = OTR.ENCRYPTION_NAME
|
||||
self.encryption_name = 'OTR'
|
||||
self.description = 'Provides Off-the-Record encryption'
|
||||
self.activatable = not ERROR
|
||||
self.available_text = ERROR
|
||||
self.sessions = {}
|
||||
self.session = lambda account: self.sessions.setdefault(account, OTR(account, logger = log))
|
||||
self.instances = {}
|
||||
self.get_instance = lambda acct: self.instances.get(acct) or self.instances.setdefault(acct,OTR(self,acct))
|
||||
self.events_handlers = {
|
||||
'before-change-show': (10, self._on_status_change),
|
||||
}
|
||||
self.gui_extension_points = {
|
||||
'encrypt' + self.encryption_name: (self._encrypt_message, None),
|
||||
'decrypt': (self._decrypt_message, None),
|
||||
}
|
||||
|
||||
|
||||
@staticmethod
|
||||
def activate_encryption(ctl):
|
||||
return True
|
||||
|
||||
|
||||
@staticmethod
|
||||
def encrypt_file(file, account, callback):
|
||||
callback(file)
|
||||
|
||||
def _encrypt_message(self, con, event, callback):
|
||||
if not event.message or event.type_ != 'chat': return # drop empty or non-chat messages
|
||||
log.debug('encrypting message: %s' % event)
|
||||
otr = self.session(event.account)
|
||||
otr._encrypt(event, callback)
|
||||
def _encrypt_message(self,con,event,callback):
|
||||
if not event.message or event.type_ != 'chat': return # drop empty and non-chat messages
|
||||
self.get_instance(event.conn.name).encrypt(event,callback)
|
||||
|
||||
def _decrypt_message(self, con, event, callback):
|
||||
if event.name == 'mam-message-received': event.msgtxt = '' # drop mam messages because we cannot decrypt it post-factum
|
||||
if not event.msgtxt or not event.msgtxt.startswith("?OTR"): return # drop messages without OTR tag
|
||||
log.debug('received otr message: %s' % event)
|
||||
otr = self.session(event.account)
|
||||
otr._decrypt(event, callback)
|
||||
def _decrypt_message(self,con,event,callback):
|
||||
if (event.encrypted) or (event.name[0:2] == 'gc') or not (event.msgtxt or '').startswith('?OTR'): return # skip non-OTR messages..
|
||||
if (event.name[0:3] == 'mam'): return setattr(event,'msgtxt','') # skip MAM messages (we can not decrypt OTR out of session)..
|
||||
if (app.config.get_per('encryption','%s-%s'%(event.conn.name,event.jid),'encryption')!=self.encryption_name): return # skip all when encryption not set to OTR
|
||||
self.get_instance(event.conn.name).decrypt(event,callback)
|
||||
|
||||
def _on_status_change(self,event):
|
||||
if event.show == 'offline':
|
||||
for ctx in self.get_instance(event.conn.name).ctxs.values(): ctx.state and ctx.disconnect()
|
||||
|
|
27
potr/__init__.py
Normal file
27
potr/__init__.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2011-2012 Kjell Braden <afflux@pentabarf.de>
|
||||
#
|
||||
# This file is part of the python-potr library.
|
||||
#
|
||||
# python-potr is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License as published by
|
||||
# the Free Software Foundation; either version 3 of the License, or
|
||||
# any later version.
|
||||
#
|
||||
# python-potr is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
# some python3 compatibilty
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from potr import context
|
||||
from potr import proto
|
||||
from potr.utils import human_hash
|
||||
|
||||
''' version is: (major, minor, patch, sub) with sub being one of 'alpha',
|
||||
'beta', 'final' '''
|
||||
VERSION = (1, 0, 3, 'alpha')
|
21
potr/compatcrypto/__init__.py
Normal file
21
potr/compatcrypto/__init__.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
# Copyright 2012 Kjell Braden <afflux@pentabarf.de>
|
||||
#
|
||||
# This file is part of the python-potr library.
|
||||
#
|
||||
# python-potr is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License as published by
|
||||
# the Free Software Foundation; either version 3 of the License, or
|
||||
# any later version.
|
||||
#
|
||||
# python-potr is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
|
||||
from potr.compatcrypto.common import *
|
||||
|
||||
from potr.compatcrypto.pycrypto import *
|
BIN
potr/compatcrypto/__pycache__/__init__.cpython-37.pyc
Normal file
BIN
potr/compatcrypto/__pycache__/__init__.cpython-37.pyc
Normal file
Binary file not shown.
BIN
potr/compatcrypto/__pycache__/common.cpython-37.pyc
Normal file
BIN
potr/compatcrypto/__pycache__/common.cpython-37.pyc
Normal file
Binary file not shown.
BIN
potr/compatcrypto/__pycache__/pycrypto.cpython-37.pyc
Normal file
BIN
potr/compatcrypto/__pycache__/pycrypto.cpython-37.pyc
Normal file
Binary file not shown.
108
potr/compatcrypto/common.py
Normal file
108
potr/compatcrypto/common.py
Normal file
|
@ -0,0 +1,108 @@
|
|||
# Copyright 2012 Kjell Braden <afflux@pentabarf.de>
|
||||
#
|
||||
# This file is part of the python-potr library.
|
||||
#
|
||||
# python-potr is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License as published by
|
||||
# the Free Software Foundation; either version 3 of the License, or
|
||||
# any later version.
|
||||
#
|
||||
# python-potr is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
# some python3 compatibilty
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import logging
|
||||
import struct
|
||||
|
||||
from potr.utils import human_hash, bytes_to_long, unpack, pack_mpi
|
||||
|
||||
DEFAULT_KEYTYPE = 0x0000
|
||||
pkTypes = {}
|
||||
def registerkeytype(cls):
|
||||
if cls.keyType is None:
|
||||
raise TypeError('registered key class needs a type value')
|
||||
pkTypes[cls.keyType] = cls
|
||||
return cls
|
||||
|
||||
def generateDefaultKey():
|
||||
return pkTypes[DEFAULT_KEYTYPE].generate()
|
||||
|
||||
class PK(object):
|
||||
keyType = None
|
||||
|
||||
@classmethod
|
||||
def generate(cls):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def parsePayload(cls, data, private=False):
|
||||
raise NotImplementedError
|
||||
|
||||
def sign(self, data):
|
||||
raise NotImplementedError
|
||||
def verify(self, data):
|
||||
raise NotImplementedError
|
||||
def fingerprint(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def serializePublicKey(self):
|
||||
return struct.pack(b'!H', self.keyType) \
|
||||
+ self.getSerializedPublicPayload()
|
||||
|
||||
def getSerializedPublicPayload(self):
|
||||
buf = b''
|
||||
for x in self.getPublicPayload():
|
||||
buf += pack_mpi(x)
|
||||
return buf
|
||||
|
||||
def getPublicPayload(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def serializePrivateKey(self):
|
||||
return struct.pack(b'!H', self.keyType) \
|
||||
+ self.getSerializedPrivatePayload()
|
||||
|
||||
def getSerializedPrivatePayload(self):
|
||||
buf = b''
|
||||
for x in self.getPrivatePayload():
|
||||
buf += pack_mpi(x)
|
||||
return buf
|
||||
|
||||
def getPrivatePayload(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def cfingerprint(self):
|
||||
return '{0:040x}'.format(bytes_to_long(self.fingerprint()))
|
||||
|
||||
@classmethod
|
||||
def parsePrivateKey(cls, data):
|
||||
implCls, data = cls.getImplementation(data)
|
||||
logging.debug('Got privkey of type %r', implCls)
|
||||
return implCls.parsePayload(data, private=True)
|
||||
|
||||
@classmethod
|
||||
def parsePublicKey(cls, data):
|
||||
implCls, data = cls.getImplementation(data)
|
||||
logging.debug('Got pubkey of type %r', implCls)
|
||||
return implCls.parsePayload(data)
|
||||
|
||||
def __str__(self):
|
||||
return human_hash(self.cfingerprint())
|
||||
def __repr__(self):
|
||||
return '<{cls}(fpr=\'{fpr}\')>'.format(
|
||||
cls=self.__class__.__name__, fpr=str(self))
|
||||
|
||||
@staticmethod
|
||||
def getImplementation(data):
|
||||
typeid, data = unpack(b'!H', data)
|
||||
cls = pkTypes.get(typeid, None)
|
||||
if cls is None:
|
||||
raise NotImplementedError('unknown typeid %r' % typeid)
|
||||
return cls, data
|
149
potr/compatcrypto/pycrypto.py
Normal file
149
potr/compatcrypto/pycrypto.py
Normal file
|
@ -0,0 +1,149 @@
|
|||
# Copyright 2012 Kjell Braden <afflux@pentabarf.de>
|
||||
#
|
||||
# This file is part of the python-potr library.
|
||||
#
|
||||
# python-potr is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License as published by
|
||||
# the Free Software Foundation; either version 3 of the License, or
|
||||
# any later version.
|
||||
#
|
||||
# python-potr is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
try:
|
||||
import Crypto
|
||||
except ImportError:
|
||||
import crypto as Crypto
|
||||
|
||||
from Crypto import Cipher
|
||||
from Crypto.Hash import SHA256 as _SHA256
|
||||
from Crypto.Hash import SHA as _SHA1
|
||||
from Crypto.Hash import HMAC as _HMAC
|
||||
from Crypto.PublicKey import DSA
|
||||
import Crypto.Random.random
|
||||
from numbers import Number
|
||||
|
||||
from potr.compatcrypto import common
|
||||
from potr.utils import read_mpi, bytes_to_long, long_to_bytes
|
||||
|
||||
def SHA256(data):
|
||||
return _SHA256.new(data).digest()
|
||||
|
||||
def SHA1(data):
|
||||
return _SHA1.new(data).digest()
|
||||
|
||||
def SHA1HMAC(key, data):
|
||||
return _HMAC.new(key, msg=data, digestmod=_SHA1).digest()
|
||||
|
||||
def SHA256HMAC(key, data):
|
||||
return _HMAC.new(key, msg=data, digestmod=_SHA256).digest()
|
||||
|
||||
def AESCTR(key, counter=0):
|
||||
if isinstance(counter, Number):
|
||||
counter = Counter(counter)
|
||||
if not isinstance(counter, Counter):
|
||||
raise TypeError
|
||||
return Cipher.AES.new(key, Cipher.AES.MODE_CTR, counter=counter)
|
||||
|
||||
class Counter(object):
|
||||
def __init__(self, prefix):
|
||||
self.prefix = prefix
|
||||
self.val = 0
|
||||
|
||||
def inc(self):
|
||||
self.prefix += 1
|
||||
self.val = 0
|
||||
|
||||
def __setattr__(self, attr, val):
|
||||
if attr == 'prefix':
|
||||
self.val = 0
|
||||
super(Counter, self).__setattr__(attr, val)
|
||||
|
||||
def __repr__(self):
|
||||
return '<Counter(p={p!r},v={v!r})>'.format(p=self.prefix, v=self.val)
|
||||
|
||||
def byteprefix(self):
|
||||
return long_to_bytes(self.prefix, 8)
|
||||
|
||||
def __call__(self):
|
||||
bytesuffix = long_to_bytes(self.val, 8)
|
||||
self.val += 1
|
||||
return self.byteprefix() + bytesuffix
|
||||
|
||||
@common.registerkeytype
|
||||
class DSAKey(common.PK):
|
||||
keyType = 0x0000
|
||||
|
||||
def __init__(self, key=None, private=False):
|
||||
self.priv = self.pub = None
|
||||
|
||||
if not isinstance(key, tuple):
|
||||
raise TypeError('4/5-tuple required for key')
|
||||
|
||||
if len(key) == 5 and private:
|
||||
self.priv = DSA.construct(key)
|
||||
self.pub = self.priv.publickey()
|
||||
elif len(key) == 4 and not private:
|
||||
self.pub = DSA.construct(key)
|
||||
else:
|
||||
raise TypeError('wrong number of arguments for ' \
|
||||
'private={0!r}: got {1} '
|
||||
.format(private, len(key)))
|
||||
|
||||
def getPublicPayload(self):
|
||||
return (self.pub.p, self.pub.q, self.pub.g, self.pub.y)
|
||||
|
||||
def getPrivatePayload(self):
|
||||
return (self.priv.p, self.priv.q, self.priv.g, self.priv.y, self.priv.x)
|
||||
|
||||
def fingerprint(self):
|
||||
return SHA1(self.getSerializedPublicPayload())
|
||||
|
||||
def sign(self, data):
|
||||
# 2 <= K <= q
|
||||
K = randrange(2, self.priv.q)
|
||||
r, s = self.priv.sign(data, K)
|
||||
return long_to_bytes(r, 20) + long_to_bytes(s, 20)
|
||||
|
||||
def verify(self, data, sig):
|
||||
r, s = bytes_to_long(sig[:20]), bytes_to_long(sig[20:])
|
||||
return self.pub.verify(data, (r, s))
|
||||
|
||||
def __hash__(self):
|
||||
return bytes_to_long(self.fingerprint())
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, type(self)):
|
||||
return False
|
||||
return self.fingerprint() == other.fingerprint()
|
||||
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
||||
@classmethod
|
||||
def generate(cls):
|
||||
privkey = DSA.generate(1024)
|
||||
return cls((privkey.key.y, privkey.key.g, privkey.key.p, privkey.key.q,
|
||||
privkey.key.x), private=True)
|
||||
|
||||
@classmethod
|
||||
def parsePayload(cls, data, private=False):
|
||||
p, data = read_mpi(data)
|
||||
q, data = read_mpi(data)
|
||||
g, data = read_mpi(data)
|
||||
y, data = read_mpi(data)
|
||||
if private:
|
||||
x, data = read_mpi(data)
|
||||
return cls((y, g, p, q, x), private=True), data
|
||||
return cls((y, g, p, q), private=False), data
|
||||
|
||||
def getrandbits(k):
|
||||
return Crypto.Random.random.getrandbits(k)
|
||||
|
||||
def randrange(start, stop):
|
||||
return Crypto.Random.random.randrange(start, stop)
|
574
potr/context.py
Normal file
574
potr/context.py
Normal file
|
@ -0,0 +1,574 @@
|
|||
# Copyright 2011-2012 Kjell Braden <afflux@pentabarf.de>
|
||||
#
|
||||
# This file is part of the python-potr library.
|
||||
#
|
||||
# python-potr is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License as published by
|
||||
# the Free Software Foundation; either version 3 of the License, or
|
||||
# any later version.
|
||||
#
|
||||
# python-potr is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
# some python3 compatibilty
|
||||
from __future__ import unicode_literals
|
||||
|
||||
try:
|
||||
type(basestring)
|
||||
except NameError:
|
||||
# all strings are unicode in python3k
|
||||
basestring = str
|
||||
unicode = str
|
||||
|
||||
# callable is not available in python 3.0 and 3.1
|
||||
try:
|
||||
type(callable)
|
||||
except NameError:
|
||||
from collections import Callable
|
||||
def callable(x):
|
||||
return isinstance(x, Callable)
|
||||
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import struct
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from potr import crypt
|
||||
from potr import proto
|
||||
from potr import compatcrypto
|
||||
|
||||
from time import time
|
||||
|
||||
EXC_UNREADABLE_MESSAGE = 1
|
||||
EXC_FINISHED = 2
|
||||
|
||||
HEARTBEAT_INTERVAL = 60
|
||||
STATE_PLAINTEXT = 0
|
||||
STATE_ENCRYPTED = 1
|
||||
STATE_FINISHED = 2
|
||||
FRAGMENT_SEND_ALL = 0
|
||||
FRAGMENT_SEND_ALL_BUT_FIRST = 1
|
||||
FRAGMENT_SEND_ALL_BUT_LAST = 2
|
||||
|
||||
OFFER_NOTSENT = 0
|
||||
OFFER_SENT = 1
|
||||
OFFER_REJECTED = 2
|
||||
OFFER_ACCEPTED = 3
|
||||
|
||||
class Context(object):
|
||||
def __init__(self, account, peername):
|
||||
self.user = account
|
||||
self.peer = peername
|
||||
self.policy = {}
|
||||
self.crypto = crypt.CryptEngine(self)
|
||||
self.tagOffer = OFFER_NOTSENT
|
||||
self.mayRetransmit = 0
|
||||
self.lastSend = 0
|
||||
self.lastMessage = None
|
||||
self.state = STATE_PLAINTEXT
|
||||
self.trustName = self.peer
|
||||
|
||||
self.fragmentInfo = None
|
||||
self.fragment = None
|
||||
self.discardFragment()
|
||||
|
||||
def getPolicy(self, key):
|
||||
raise NotImplementedError
|
||||
|
||||
def inject(self, msg, appdata=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def policyOtrEnabled(self):
|
||||
return self.getPolicy('ALLOW_V2') or self.getPolicy('ALLOW_V1')
|
||||
|
||||
def discardFragment(self):
|
||||
self.fragmentInfo = (0, 0)
|
||||
self.fragment = []
|
||||
|
||||
def fragmentAccumulate(self, message):
|
||||
'''Accumulate a fragmented message. Returns None if the fragment is
|
||||
to be ignored, returns a string if the message is ready for further
|
||||
processing'''
|
||||
|
||||
params = message.split(b',')
|
||||
if len(params) < 5 or not params[1].isdigit() or not params[2].isdigit():
|
||||
logger.warning('invalid formed fragmented message: %r', params)
|
||||
self.discardFragment()
|
||||
return message
|
||||
|
||||
|
||||
K, N = self.fragmentInfo
|
||||
try:
|
||||
k = int(params[1])
|
||||
n = int(params[2])
|
||||
except ValueError:
|
||||
logger.warning('invalid formed fragmented message: %r', params)
|
||||
self.discardFragment()
|
||||
return message
|
||||
|
||||
fragData = params[3]
|
||||
|
||||
logger.debug(params)
|
||||
|
||||
if n >= k == 1:
|
||||
# first fragment
|
||||
self.discardFragment()
|
||||
self.fragmentInfo = (k, n)
|
||||
self.fragment.append(fragData)
|
||||
elif N == n >= k > 1 and k == K+1:
|
||||
# accumulate
|
||||
self.fragmentInfo = (k, n)
|
||||
self.fragment.append(fragData)
|
||||
else:
|
||||
# bad, discard
|
||||
self.discardFragment()
|
||||
logger.warning('invalid fragmented message: %r', params)
|
||||
return message
|
||||
|
||||
if n == k > 0:
|
||||
assembled = b''.join(self.fragment)
|
||||
self.discardFragment()
|
||||
return assembled
|
||||
|
||||
return None
|
||||
|
||||
def removeFingerprint(self, fingerprint):
|
||||
self.user.removeFingerprint(self.trustName, fingerprint)
|
||||
|
||||
def setTrust(self, fingerprint, trustLevel):
|
||||
''' sets the trust level for the given fingerprint.
|
||||
trust is usually:
|
||||
- the empty string for known but untrusted keys
|
||||
- 'verified' for manually verified keys
|
||||
- 'smp' for smp-style verified keys '''
|
||||
self.user.setTrust(self.trustName, fingerprint, trustLevel)
|
||||
|
||||
def getTrust(self, fingerprint, default=None):
|
||||
return self.user.getTrust(self.trustName, fingerprint, default)
|
||||
|
||||
def setCurrentTrust(self, trustLevel):
|
||||
self.setTrust(self.crypto.theirPubkey.cfingerprint(), trustLevel)
|
||||
|
||||
def getCurrentKey(self):
|
||||
return self.crypto.theirPubkey
|
||||
|
||||
def getCurrentTrust(self):
|
||||
''' returns a 2-tuple: first element is the current fingerprint,
|
||||
second is:
|
||||
- None if the key is unknown yet
|
||||
- a non-empty string if the key is trusted
|
||||
- an empty string if the key is untrusted '''
|
||||
if self.crypto.theirPubkey is None:
|
||||
return None
|
||||
return self.getTrust(self.crypto.theirPubkey.cfingerprint(), None)
|
||||
|
||||
def receiveMessage(self, messageData, appdata=None):
|
||||
IGN = None, []
|
||||
|
||||
if not self.policyOtrEnabled():
|
||||
raise NotOTRMessage(messageData)
|
||||
|
||||
message = self.parse(messageData)
|
||||
|
||||
if message is None:
|
||||
# nothing to see. move along.
|
||||
return IGN
|
||||
|
||||
logger.debug(repr(message))
|
||||
|
||||
if self.getPolicy('SEND_TAG'):
|
||||
if isinstance(message, basestring):
|
||||
# received a plaintext message without tag
|
||||
# we should not tag anymore
|
||||
self.tagOffer = OFFER_REJECTED
|
||||
else:
|
||||
# got something OTR-ish, cool!
|
||||
self.tagOffer = OFFER_ACCEPTED
|
||||
|
||||
if isinstance(message, proto.Query):
|
||||
self.handleQuery(message, appdata=appdata)
|
||||
|
||||
if isinstance(message, proto.TaggedPlaintext):
|
||||
# it's actually a plaintext message
|
||||
if self.state != STATE_PLAINTEXT or \
|
||||
self.getPolicy('REQUIRE_ENCRYPTION'):
|
||||
# but we don't want plaintexts
|
||||
raise UnencryptedMessage(message.msg)
|
||||
|
||||
raise NotOTRMessage(message.msg)
|
||||
|
||||
return IGN
|
||||
|
||||
if isinstance(message, proto.AKEMessage):
|
||||
self.crypto.handleAKE(message, appdata=appdata)
|
||||
return IGN
|
||||
|
||||
if isinstance(message, proto.DataMessage):
|
||||
ignore = message.flags & proto.MSGFLAGS_IGNORE_UNREADABLE
|
||||
|
||||
if self.state != STATE_ENCRYPTED:
|
||||
self.sendInternal(proto.Error(
|
||||
'You sent encrypted data, but I wasn\'t expecting it.'
|
||||
.encode('utf-8')), appdata=appdata)
|
||||
if ignore:
|
||||
return IGN
|
||||
raise NotEncryptedError(EXC_UNREADABLE_MESSAGE)
|
||||
|
||||
try:
|
||||
plaintext, tlvs = self.crypto.handleDataMessage(message)
|
||||
self.processTLVs(tlvs, appdata=appdata)
|
||||
if plaintext and self.lastSend < time() - HEARTBEAT_INTERVAL:
|
||||
self.sendInternal(b'', appdata=appdata)
|
||||
return plaintext or None, tlvs
|
||||
except crypt.InvalidParameterError:
|
||||
if ignore:
|
||||
return IGN
|
||||
logger.exception('decryption failed')
|
||||
raise
|
||||
if isinstance(message, basestring):
|
||||
if self.state != STATE_PLAINTEXT or \
|
||||
self.getPolicy('REQUIRE_ENCRYPTION'):
|
||||
raise UnencryptedMessage(message)
|
||||
|
||||
if isinstance(message, proto.Error):
|
||||
raise ErrorReceived(message)
|
||||
|
||||
raise NotOTRMessage(messageData)
|
||||
|
||||
def sendInternal(self, msg, tlvs=[], appdata=None):
|
||||
self.sendMessage(FRAGMENT_SEND_ALL, msg, tlvs=tlvs, appdata=appdata,
|
||||
flags=proto.MSGFLAGS_IGNORE_UNREADABLE)
|
||||
|
||||
def sendMessage(self, sendPolicy, msg, flags=0, tlvs=[], appdata=None):
|
||||
if self.policyOtrEnabled():
|
||||
self.lastSend = time()
|
||||
|
||||
if isinstance(msg, proto.OTRMessage):
|
||||
# we want to send a protocol message (probably internal)
|
||||
# so we don't need further protocol encryption
|
||||
# also we can't add TLVs to arbitrary protocol messages
|
||||
if tlvs:
|
||||
raise TypeError('can\'t add tlvs to protocol message')
|
||||
else:
|
||||
# we got plaintext to send. encrypt it
|
||||
msg = self.processOutgoingMessage(msg, flags, tlvs)
|
||||
|
||||
if isinstance(msg, proto.OTRMessage) \
|
||||
and not isinstance(msg, proto.Query):
|
||||
# if it's a query message, it must not get fragmented
|
||||
return self.sendFragmented(bytes(msg), policy=sendPolicy, appdata=appdata)
|
||||
else:
|
||||
msg = bytes(msg)
|
||||
return msg
|
||||
|
||||
def processOutgoingMessage(self, msg, flags, tlvs=[]):
|
||||
isQuery = self.parseExplicitQuery(msg) is not None
|
||||
if isQuery:
|
||||
return self.user.getDefaultQueryMessage(self.getPolicy)
|
||||
|
||||
if self.state == STATE_PLAINTEXT:
|
||||
if self.getPolicy('REQUIRE_ENCRYPTION'):
|
||||
if not isQuery:
|
||||
self.lastMessage = msg
|
||||
self.lastSend = time()
|
||||
self.mayRetransmit = 2
|
||||
# TODO notify
|
||||
msg = self.user.getDefaultQueryMessage(self.getPolicy)
|
||||
return msg
|
||||
if self.getPolicy('SEND_TAG') and \
|
||||
self.tagOffer != OFFER_REJECTED and \
|
||||
self.shouldTagMessage(msg):
|
||||
self.tagOffer = OFFER_SENT
|
||||
versions = set()
|
||||
if self.getPolicy('ALLOW_V1'):
|
||||
versions.add(1)
|
||||
if self.getPolicy('ALLOW_V2'):
|
||||
versions.add(2)
|
||||
return proto.TaggedPlaintext(msg, versions)
|
||||
return msg
|
||||
if self.state == STATE_ENCRYPTED:
|
||||
msg = self.crypto.createDataMessage(msg, flags, tlvs)
|
||||
self.lastSend = time()
|
||||
return msg
|
||||
if self.state == STATE_FINISHED:
|
||||
raise NotEncryptedError(EXC_FINISHED)
|
||||
|
||||
def disconnect(self, appdata=None):
|
||||
if self.state != STATE_FINISHED:
|
||||
self.sendInternal(b'', tlvs=[proto.DisconnectTLV()], appdata=appdata)
|
||||
self.setState(STATE_PLAINTEXT)
|
||||
self.crypto.finished()
|
||||
else:
|
||||
self.setState(STATE_PLAINTEXT)
|
||||
|
||||
def setState(self, newstate):
|
||||
self.state = newstate
|
||||
|
||||
def _wentEncrypted(self):
|
||||
self.setState(STATE_ENCRYPTED)
|
||||
|
||||
def sendFragmented(self, msg, policy=FRAGMENT_SEND_ALL, appdata=None):
|
||||
mms = self.maxMessageSize(appdata)
|
||||
msgLen = len(msg)
|
||||
if mms != 0 and msgLen > mms:
|
||||
fms = mms - 19
|
||||
fragments = [ msg[i:i+fms] for i in range(0, msgLen, fms) ]
|
||||
|
||||
fc = len(fragments)
|
||||
|
||||
if fc > 65535:
|
||||
raise OverflowError('too many fragments')
|
||||
|
||||
for fi in range(len(fragments)):
|
||||
ctr = unicode(fi+1) + ',' + unicode(fc) + ','
|
||||
fragments[fi] = b'?OTR,' + ctr.encode('ascii') \
|
||||
+ fragments[fi] + b','
|
||||
|
||||
if policy == FRAGMENT_SEND_ALL:
|
||||
for f in fragments:
|
||||
self.inject(f, appdata=appdata)
|
||||
return None
|
||||
elif policy == FRAGMENT_SEND_ALL_BUT_FIRST:
|
||||
for f in fragments[1:]:
|
||||
self.inject(f, appdata=appdata)
|
||||
return fragments[0]
|
||||
elif policy == FRAGMENT_SEND_ALL_BUT_LAST:
|
||||
for f in fragments[:-1]:
|
||||
self.inject(f, appdata=appdata)
|
||||
return fragments[-1]
|
||||
|
||||
else:
|
||||
if policy == FRAGMENT_SEND_ALL:
|
||||
self.inject(msg, appdata=appdata)
|
||||
return None
|
||||
else:
|
||||
return msg
|
||||
|
||||
def processTLVs(self, tlvs, appdata=None):
|
||||
for tlv in tlvs:
|
||||
if isinstance(tlv, proto.DisconnectTLV):
|
||||
logger.info('got disconnect tlv, forcing finished state')
|
||||
self.setState(STATE_FINISHED)
|
||||
self.crypto.finished()
|
||||
# TODO cleanup
|
||||
continue
|
||||
if isinstance(tlv, proto.SMPTLV):
|
||||
self.crypto.smpHandle(tlv, appdata=appdata)
|
||||
continue
|
||||
logger.info('got unhandled tlv: {0!r}'.format(tlv))
|
||||
|
||||
def smpAbort(self, appdata=None):
|
||||
if self.state != STATE_ENCRYPTED:
|
||||
raise NotEncryptedError
|
||||
self.crypto.smpAbort(appdata=appdata)
|
||||
|
||||
def smpIsValid(self):
|
||||
return self.crypto.smp and self.crypto.smp.prog != crypt.SMPPROG_CHEATED
|
||||
|
||||
def smpIsSuccess(self):
|
||||
return self.crypto.smp.prog == crypt.SMPPROG_SUCCEEDED \
|
||||
if self.crypto.smp else None
|
||||
|
||||
def smpGotSecret(self, secret, question=None, appdata=None):
|
||||
if self.state != STATE_ENCRYPTED:
|
||||
raise NotEncryptedError
|
||||
self.crypto.smpSecret(secret, question=question, appdata=appdata)
|
||||
|
||||
def smpInit(self, secret, question=None, appdata=None):
|
||||
if self.state != STATE_ENCRYPTED:
|
||||
raise NotEncryptedError
|
||||
self.crypto.smp = None
|
||||
self.crypto.smpSecret(secret, question=question, appdata=appdata)
|
||||
|
||||
def handleQuery(self, message, appdata=None):
|
||||
if 2 in message.versions and self.getPolicy('ALLOW_V2'):
|
||||
self.authStartV2(appdata=appdata)
|
||||
elif 1 in message.versions and self.getPolicy('ALLOW_V1'):
|
||||
self.authStartV1(appdata=appdata)
|
||||
|
||||
def authStartV1(self, appdata=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def authStartV2(self, appdata=None):
|
||||
self.crypto.startAKE(appdata=appdata)
|
||||
|
||||
def parseExplicitQuery(self, message):
|
||||
otrTagPos = message.find(proto.OTRTAG)
|
||||
|
||||
if otrTagPos == -1:
|
||||
return None
|
||||
|
||||
indexBase = otrTagPos + len(proto.OTRTAG)
|
||||
|
||||
if len(message) <= indexBase:
|
||||
return None
|
||||
|
||||
compare = message[indexBase]
|
||||
|
||||
hasq = compare == b'?'[0]
|
||||
hasv = compare == b'v'[0]
|
||||
|
||||
if not hasq and not hasv:
|
||||
return None
|
||||
|
||||
hasv |= len(message) > indexBase+1 and message[indexBase+1] == b'v'[0]
|
||||
if hasv:
|
||||
end = message.find(b'?', indexBase+1)
|
||||
else:
|
||||
end = indexBase+1
|
||||
return message[indexBase:end]
|
||||
|
||||
def parse(self, message, nofragment=False):
|
||||
otrTagPos = message.find(proto.OTRTAG)
|
||||
if otrTagPos == -1:
|
||||
if proto.MESSAGE_TAG_BASE in message:
|
||||
return proto.TaggedPlaintext.parse(message)
|
||||
else:
|
||||
return message
|
||||
|
||||
indexBase = otrTagPos + len(proto.OTRTAG)
|
||||
|
||||
if len(message) <= indexBase:
|
||||
return message
|
||||
|
||||
compare = message[indexBase]
|
||||
|
||||
if nofragment is False and compare == b','[0]:
|
||||
message = self.fragmentAccumulate(message[indexBase:])
|
||||
if message is None:
|
||||
return None
|
||||
else:
|
||||
return self.parse(message, nofragment=True)
|
||||
else:
|
||||
self.discardFragment()
|
||||
|
||||
queryPayload = self.parseExplicitQuery(message)
|
||||
if queryPayload is not None:
|
||||
return proto.Query.parse(queryPayload)
|
||||
|
||||
if compare == b':'[0] and len(message) > indexBase + 4:
|
||||
try:
|
||||
infoTag = base64.b64decode(message[indexBase+1:indexBase+5])
|
||||
classInfo = struct.unpack(b'!HB', infoTag)
|
||||
|
||||
cls = proto.messageClasses.get(classInfo, None)
|
||||
if cls is None:
|
||||
return message
|
||||
|
||||
logger.debug('{user} got msg {typ!r}' \
|
||||
.format(user=self.user.name, typ=cls))
|
||||
return cls.parsePayload(message[indexBase+5:])
|
||||
except (TypeError, struct.error):
|
||||
logger.exception('could not parse OTR message %s', message)
|
||||
return message
|
||||
|
||||
if message[indexBase:indexBase+7] == b' Error:':
|
||||
return proto.Error(message[indexBase+7:])
|
||||
|
||||
return message
|
||||
|
||||
def maxMessageSize(self, appdata=None):
|
||||
"""Return the max message size for this context."""
|
||||
return self.user.maxMessageSize
|
||||
|
||||
def getExtraKey(self, extraKeyAppId=None, extraKeyAppData=None, appdata=None):
|
||||
""" retrieves the generated extra symmetric key.
|
||||
|
||||
if extraKeyAppId is set, notifies the chat partner about intended
|
||||
usage (additional application specific information can be supplied in
|
||||
extraKeyAppData).
|
||||
|
||||
returns the 256 bit symmetric key """
|
||||
|
||||
if self.state != STATE_ENCRYPTED:
|
||||
raise NotEncryptedError
|
||||
if extraKeyAppId is not None:
|
||||
tlvs = [proto.ExtraKeyTLV(extraKeyAppId, extraKeyAppData)]
|
||||
self.sendInternal(b'', tlvs=tlvs, appdata=appdata)
|
||||
return self.crypto.extraKey
|
||||
|
||||
def shouldTagMessage(self, msg):
|
||||
"""Hook to decide whether to tag a message based on its contents."""
|
||||
return True
|
||||
|
||||
class Account(object):
|
||||
contextclass = Context
|
||||
def __init__(self, name, protocol, maxMessageSize, privkey=None):
|
||||
self.name = name
|
||||
self.privkey = privkey
|
||||
self.policy = {}
|
||||
self.protocol = protocol
|
||||
self.ctxs = {}
|
||||
self.trusts = {}
|
||||
self.maxMessageSize = maxMessageSize
|
||||
self.defaultQuery = '?OTRv{versions}?\nI would like to start ' \
|
||||
'an Off-the-Record private conversation. However, you ' \
|
||||
'do not have a plugin to support that.\nSee '\
|
||||
'https://otr.cypherpunks.ca/ for more information.'
|
||||
|
||||
def __repr__(self):
|
||||
return '<{cls}(name={name!r})>'.format(cls=self.__class__.__name__,
|
||||
name=self.name)
|
||||
|
||||
def getPrivkey(self, autogen=True):
|
||||
if self.privkey is None:
|
||||
self.privkey = self.loadPrivkey()
|
||||
if self.privkey is None:
|
||||
if autogen is True:
|
||||
self.privkey = compatcrypto.generateDefaultKey()
|
||||
self.savePrivkey()
|
||||
else:
|
||||
raise LookupError
|
||||
return self.privkey
|
||||
|
||||
def loadPrivkey(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def savePrivkey(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def saveTrusts(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def getContext(self, uid, newCtxCb=None):
|
||||
if uid not in self.ctxs:
|
||||
self.ctxs[uid] = self.contextclass(self, uid)
|
||||
if callable(newCtxCb):
|
||||
newCtxCb(self.ctxs[uid])
|
||||
return self.ctxs[uid]
|
||||
|
||||
def getDefaultQueryMessage(self, policy):
|
||||
v = '2' if policy('ALLOW_V2') else ''
|
||||
msg = self.defaultQuery.format(versions=v)
|
||||
return msg.encode('ascii')
|
||||
|
||||
def setTrust(self, key, fingerprint, trustLevel):
|
||||
if key not in self.trusts:
|
||||
self.trusts[key] = {}
|
||||
self.trusts[key][fingerprint] = trustLevel
|
||||
self.saveTrusts()
|
||||
|
||||
def getTrust(self, key, fingerprint, default=None):
|
||||
if key not in self.trusts:
|
||||
return default
|
||||
return self.trusts[key].get(fingerprint, default)
|
||||
|
||||
def removeFingerprint(self, key, fingerprint):
|
||||
if key in self.trusts and fingerprint in self.trusts[key]:
|
||||
del self.trusts[key][fingerprint]
|
||||
|
||||
class NotEncryptedError(RuntimeError):
|
||||
pass
|
||||
class UnencryptedMessage(RuntimeError):
|
||||
pass
|
||||
class ErrorReceived(RuntimeError):
|
||||
pass
|
||||
class NotOTRMessage(RuntimeError):
|
||||
pass
|
801
potr/crypt.py
Normal file
801
potr/crypt.py
Normal file
|
@ -0,0 +1,801 @@
|
|||
# Copyright 2011-2012 Kjell Braden <afflux@pentabarf.de>
|
||||
#
|
||||
# This file is part of the python-potr library.
|
||||
#
|
||||
# python-potr is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License as published by
|
||||
# the Free Software Foundation; either version 3 of the License, or
|
||||
# any later version.
|
||||
#
|
||||
# python-potr is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
# some python3 compatibilty
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import logging
|
||||
import struct
|
||||
|
||||
|
||||
from potr.compatcrypto import SHA256, SHA1, SHA1HMAC, SHA256HMAC, \
|
||||
Counter, AESCTR, PK, getrandbits, randrange
|
||||
from potr.utils import bytes_to_long, long_to_bytes, pack_mpi, read_mpi
|
||||
from potr import proto
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
STATE_NONE = 0
|
||||
STATE_AWAITING_DHKEY = 1
|
||||
STATE_AWAITING_REVEALSIG = 2
|
||||
STATE_AWAITING_SIG = 4
|
||||
STATE_V1_SETUP = 5
|
||||
|
||||
|
||||
DH_MODULUS = 2410312426921032588552076022197566074856950548502459942654116941958108831682612228890093858261341614673227141477904012196503648957050582631942730706805009223062734745341073406696246014589361659774041027169249453200378729434170325843778659198143763193776859869524088940195577346119843545301547043747207749969763750084308926339295559968882457872412993810129130294592999947926365264059284647209730384947211681434464714438488520940127459844288859336526896320919633919
|
||||
DH_MODULUS_2 = DH_MODULUS-2
|
||||
DH_GENERATOR = 2
|
||||
DH_BITS = 1536
|
||||
DH_MAX = 2**DH_BITS
|
||||
SM_ORDER = (DH_MODULUS - 1) // 2
|
||||
|
||||
def check_group(n):
|
||||
return 2 <= n <= DH_MODULUS_2
|
||||
def check_exp(n):
|
||||
return 1 <= n < SM_ORDER
|
||||
|
||||
def SHA256HMAC160(key, data):
|
||||
return SHA256HMAC(key, data)[:20]
|
||||
|
||||
class DH(object):
|
||||
@classmethod
|
||||
def set_params(cls, prime, gen):
|
||||
cls.prime = prime
|
||||
cls.gen = gen
|
||||
|
||||
def __init__(self):
|
||||
self.priv = randrange(2, 2**320)
|
||||
self.pub = pow(self.gen, self.priv, self.prime)
|
||||
|
||||
DH.set_params(DH_MODULUS, DH_GENERATOR)
|
||||
|
||||
class DHSession(object):
|
||||
def __init__(self, sendenc, sendmac, rcvenc, rcvmac):
|
||||
self.sendenc = sendenc
|
||||
self.sendmac = sendmac
|
||||
self.rcvenc = rcvenc
|
||||
self.rcvmac = rcvmac
|
||||
self.sendctr = Counter(0)
|
||||
self.rcvctr = Counter(0)
|
||||
self.sendmacused = False
|
||||
self.rcvmacused = False
|
||||
|
||||
def __repr__(self):
|
||||
return '<{cls}(send={s!r},rcv={r!r})>' \
|
||||
.format(cls=self.__class__.__name__,
|
||||
s=self.sendmac, r=self.rcvmac)
|
||||
|
||||
@classmethod
|
||||
def create(cls, dh, y):
|
||||
s = pow(y, dh.priv, DH_MODULUS)
|
||||
sb = pack_mpi(s)
|
||||
|
||||
if dh.pub > y:
|
||||
sendbyte = b'\1'
|
||||
rcvbyte = b'\2'
|
||||
else:
|
||||
sendbyte = b'\2'
|
||||
rcvbyte = b'\1'
|
||||
|
||||
sendenc = SHA1(sendbyte + sb)[:16]
|
||||
sendmac = SHA1(sendenc)
|
||||
rcvenc = SHA1(rcvbyte + sb)[:16]
|
||||
rcvmac = SHA1(rcvenc)
|
||||
return cls(sendenc, sendmac, rcvenc, rcvmac)
|
||||
|
||||
class CryptEngine(object):
|
||||
def __init__(self, ctx):
|
||||
self.ctx = ctx
|
||||
self.ake = None
|
||||
|
||||
self.sessionId = None
|
||||
self.sessionIdHalf = False
|
||||
self.theirKeyid = 0
|
||||
self.theirY = None
|
||||
self.theirOldY = None
|
||||
|
||||
self.ourOldDHKey = None
|
||||
self.ourDHKey = None
|
||||
self.ourKeyid = 0
|
||||
|
||||
self.sessionkeys = {0:{0:None, 1:None}, 1:{0:None, 1:None}}
|
||||
self.theirPubkey = None
|
||||
self.savedMacKeys = []
|
||||
|
||||
self.smp = None
|
||||
self.extraKey = None
|
||||
|
||||
def revealMacs(self, ours=True):
|
||||
if ours:
|
||||
dhs = self.sessionkeys[1].values()
|
||||
else:
|
||||
dhs = ( v[1] for v in self.sessionkeys.values() )
|
||||
for v in dhs:
|
||||
if v is not None:
|
||||
if v.rcvmacused:
|
||||
self.savedMacKeys.append(v.rcvmac)
|
||||
if v.sendmacused:
|
||||
self.savedMacKeys.append(v.sendmac)
|
||||
|
||||
def rotateDHKeys(self):
|
||||
self.revealMacs(ours=True)
|
||||
self.ourOldDHKey = self.ourDHKey
|
||||
self.sessionkeys[1] = self.sessionkeys[0].copy()
|
||||
self.ourDHKey = DH()
|
||||
self.ourKeyid += 1
|
||||
|
||||
self.sessionkeys[0][0] = None if self.theirY is None else \
|
||||
DHSession.create(self.ourDHKey, self.theirY)
|
||||
self.sessionkeys[0][1] = None if self.theirOldY is None else \
|
||||
DHSession.create(self.ourDHKey, self.theirOldY)
|
||||
|
||||
logger.debug('{0}: Refreshing ourkey to {1} {2}'.format(
|
||||
self.ctx.user.name, self.ourKeyid, self.sessionkeys))
|
||||
|
||||
def rotateYKeys(self, new_y):
|
||||
self.theirOldY = self.theirY
|
||||
self.revealMacs(ours=False)
|
||||
self.sessionkeys[0][1] = self.sessionkeys[0][0]
|
||||
self.sessionkeys[1][1] = self.sessionkeys[1][0]
|
||||
self.theirY = new_y
|
||||
self.theirKeyid += 1
|
||||
|
||||
self.sessionkeys[0][0] = DHSession.create(self.ourDHKey, self.theirY)
|
||||
self.sessionkeys[1][0] = DHSession.create(self.ourOldDHKey, self.theirY)
|
||||
|
||||
logger.debug('{0}: Refreshing theirkey to {1} {2}'.format(
|
||||
self.ctx.user.name, self.theirKeyid, self.sessionkeys))
|
||||
|
||||
def handleDataMessage(self, msg):
|
||||
if self.saneKeyIds(msg) is False:
|
||||
raise InvalidParameterError
|
||||
|
||||
sesskey = self.sessionkeys[self.ourKeyid - msg.rkeyid] \
|
||||
[self.theirKeyid - msg.skeyid]
|
||||
|
||||
logger.debug('sesskeys: {0!r}, our={1}, r={2}, their={3}, s={4}' \
|
||||
.format(self.sessionkeys, self.ourKeyid, msg.rkeyid,
|
||||
self.theirKeyid, msg.skeyid))
|
||||
|
||||
if msg.mac != SHA1HMAC(sesskey.rcvmac, msg.getMacedData()):
|
||||
logger.error('HMACs don\'t match')
|
||||
raise InvalidParameterError
|
||||
sesskey.rcvmacused = True
|
||||
|
||||
newCtrPrefix = bytes_to_long(msg.ctr)
|
||||
if newCtrPrefix <= sesskey.rcvctr.prefix:
|
||||
logger.error('CTR must increase (old %r, new %r)',
|
||||
sesskey.rcvctr.prefix, newCtrPrefix)
|
||||
raise InvalidParameterError
|
||||
|
||||
sesskey.rcvctr.prefix = newCtrPrefix
|
||||
|
||||
logger.debug('handle: enc={0!r} mac={1!r} ctr={2!r}' \
|
||||
.format(sesskey.rcvenc, sesskey.rcvmac, sesskey.rcvctr))
|
||||
|
||||
plaintextData = AESCTR(sesskey.rcvenc, sesskey.rcvctr) \
|
||||
.decrypt(msg.encmsg)
|
||||
|
||||
if b'\0' in plaintextData:
|
||||
plaintext, tlvData = plaintextData.split(b'\0', 1)
|
||||
tlvs = proto.TLV.parse(tlvData)
|
||||
else:
|
||||
plaintext = plaintextData
|
||||
tlvs = []
|
||||
|
||||
if msg.rkeyid == self.ourKeyid:
|
||||
self.rotateDHKeys()
|
||||
if msg.skeyid == self.theirKeyid:
|
||||
self.rotateYKeys(bytes_to_long(msg.dhy))
|
||||
|
||||
return plaintext, tlvs
|
||||
|
||||
def smpSecret(self, secret, question=None, appdata=None):
|
||||
if self.smp is None:
|
||||
logger.debug('Creating SMPHandler')
|
||||
self.smp = SMPHandler(self)
|
||||
|
||||
self.smp.gotSecret(secret, question=question, appdata=appdata)
|
||||
|
||||
def smpHandle(self, tlv, appdata=None):
|
||||
if self.smp is None:
|
||||
logger.debug('Creating SMPHandler')
|
||||
self.smp = SMPHandler(self)
|
||||
self.smp.handle(tlv, appdata=appdata)
|
||||
|
||||
def smpAbort(self, appdata=None):
|
||||
if self.smp is None:
|
||||
logger.debug('Creating SMPHandler')
|
||||
self.smp = SMPHandler(self)
|
||||
self.smp.abort(appdata=appdata)
|
||||
|
||||
def createDataMessage(self, message, flags=0, tlvs=None):
|
||||
# check MSGSTATE
|
||||
if self.theirKeyid == 0:
|
||||
raise InvalidParameterError
|
||||
|
||||
if tlvs is None:
|
||||
tlvs = []
|
||||
|
||||
sess = self.sessionkeys[1][0]
|
||||
sess.sendctr.inc()
|
||||
|
||||
logger.debug('create: enc={0!r} mac={1!r} ctr={2!r}' \
|
||||
.format(sess.sendenc, sess.sendmac, sess.sendctr))
|
||||
|
||||
# plaintext + TLVS
|
||||
plainBuf = message + b'\0' + b''.join([ bytes(t) for t in tlvs])
|
||||
encmsg = AESCTR(sess.sendenc, sess.sendctr).encrypt(plainBuf)
|
||||
|
||||
msg = proto.DataMessage(flags, self.ourKeyid-1, self.theirKeyid,
|
||||
long_to_bytes(self.ourDHKey.pub), sess.sendctr.byteprefix(),
|
||||
encmsg, b'', b''.join(self.savedMacKeys))
|
||||
|
||||
self.savedMacKeys = []
|
||||
|
||||
msg.mac = SHA1HMAC(sess.sendmac, msg.getMacedData())
|
||||
return msg
|
||||
|
||||
def saneKeyIds(self, msg):
|
||||
anyzero = self.theirKeyid == 0 or msg.skeyid == 0 or msg.rkeyid == 0
|
||||
if anyzero or (msg.skeyid != self.theirKeyid and \
|
||||
msg.skeyid != self.theirKeyid - 1) or \
|
||||
(msg.rkeyid != self.ourKeyid and msg.rkeyid != self.ourKeyid - 1):
|
||||
return False
|
||||
if self.theirOldY is None and msg.skeyid == self.theirKeyid - 1:
|
||||
return False
|
||||
return True
|
||||
|
||||
def startAKE(self, appdata=None):
|
||||
self.ake = AuthKeyExchange(self.ctx.user.getPrivkey(), self.goEncrypted)
|
||||
outMsg = self.ake.startAKE()
|
||||
self.ctx.sendInternal(outMsg, appdata=appdata)
|
||||
|
||||
def handleAKE(self, inMsg, appdata=None):
|
||||
outMsg = None
|
||||
|
||||
if not self.ctx.getPolicy('ALLOW_V2'):
|
||||
return
|
||||
|
||||
if isinstance(inMsg, proto.DHCommit):
|
||||
if self.ake is None or self.ake.state != STATE_AWAITING_REVEALSIG:
|
||||
self.ake = AuthKeyExchange(self.ctx.user.getPrivkey(),
|
||||
self.goEncrypted)
|
||||
outMsg = self.ake.handleDHCommit(inMsg)
|
||||
|
||||
elif isinstance(inMsg, proto.DHKey):
|
||||
if self.ake is None:
|
||||
return # ignore
|
||||
outMsg = self.ake.handleDHKey(inMsg)
|
||||
|
||||
elif isinstance(inMsg, proto.RevealSig):
|
||||
if self.ake is None:
|
||||
return # ignore
|
||||
outMsg = self.ake.handleRevealSig(inMsg)
|
||||
|
||||
elif isinstance(inMsg, proto.Signature):
|
||||
if self.ake is None:
|
||||
return # ignore
|
||||
self.ake.handleSignature(inMsg)
|
||||
|
||||
if outMsg is not None:
|
||||
self.ctx.sendInternal(outMsg, appdata=appdata)
|
||||
|
||||
def goEncrypted(self, ake):
|
||||
if ake.dh.pub == ake.gy:
|
||||
logger.warning('We are receiving our own messages')
|
||||
raise InvalidParameterError
|
||||
|
||||
# TODO handle new fingerprint
|
||||
self.theirPubkey = ake.theirPubkey
|
||||
|
||||
self.sessionId = ake.sessionId
|
||||
self.sessionIdHalf = ake.sessionIdHalf
|
||||
self.theirKeyid = ake.theirKeyid
|
||||
self.ourKeyid = ake.ourKeyid
|
||||
self.theirY = ake.gy
|
||||
self.theirOldY = None
|
||||
self.extraKey = ake.extraKey
|
||||
|
||||
if self.ourKeyid != ake.ourKeyid + 1 or self.ourOldDHKey != ake.dh.pub:
|
||||
self.ourDHKey = ake.dh
|
||||
self.sessionkeys[0][0] = DHSession.create(self.ourDHKey, self.theirY)
|
||||
self.rotateDHKeys()
|
||||
|
||||
# we don't need the AKE anymore, free the reference
|
||||
self.ake = None
|
||||
|
||||
self.ctx._wentEncrypted()
|
||||
logger.info('went encrypted with {0}'.format(self.theirPubkey))
|
||||
|
||||
def finished(self):
|
||||
self.smp = None
|
||||
|
||||
class AuthKeyExchange(object):
|
||||
def __init__(self, privkey, onSuccess):
|
||||
self.privkey = privkey
|
||||
self.state = STATE_NONE
|
||||
self.r = None
|
||||
self.encgx = None
|
||||
self.hashgx = None
|
||||
self.ourKeyid = 1
|
||||
self.theirPubkey = None
|
||||
self.theirKeyid = 1
|
||||
self.enc_c = None
|
||||
self.enc_cp = None
|
||||
self.mac_m1 = None
|
||||
self.mac_m1p = None
|
||||
self.mac_m2 = None
|
||||
self.mac_m2p = None
|
||||
self.sessionId = None
|
||||
self.sessionIdHalf = False
|
||||
self.dh = DH()
|
||||
self.onSuccess = onSuccess
|
||||
self.gy = None
|
||||
self.extraKey = None
|
||||
self.lastmsg = None
|
||||
|
||||
def startAKE(self):
|
||||
self.r = long_to_bytes(getrandbits(128), 16)
|
||||
|
||||
gxmpi = pack_mpi(self.dh.pub)
|
||||
|
||||
self.hashgx = SHA256(gxmpi)
|
||||
self.encgx = AESCTR(self.r).encrypt(gxmpi)
|
||||
|
||||
self.state = STATE_AWAITING_DHKEY
|
||||
|
||||
return proto.DHCommit(self.encgx, self.hashgx)
|
||||
|
||||
def handleDHCommit(self, msg):
|
||||
self.encgx = msg.encgx
|
||||
self.hashgx = msg.hashgx
|
||||
|
||||
self.state = STATE_AWAITING_REVEALSIG
|
||||
return proto.DHKey(long_to_bytes(self.dh.pub))
|
||||
|
||||
def handleDHKey(self, msg):
|
||||
if self.state == STATE_AWAITING_DHKEY:
|
||||
self.gy = bytes_to_long(msg.gy)
|
||||
|
||||
# check 2 <= g**y <= p-2
|
||||
if not check_group(self.gy):
|
||||
logger.error('Invalid g**y received: %r', self.gy)
|
||||
return
|
||||
|
||||
self.createAuthKeys()
|
||||
|
||||
aesxb = self.calculatePubkeyAuth(self.enc_c, self.mac_m1)
|
||||
|
||||
self.state = STATE_AWAITING_SIG
|
||||
|
||||
self.lastmsg = proto.RevealSig(self.r, aesxb, b'')
|
||||
self.lastmsg.mac = SHA256HMAC160(self.mac_m2,
|
||||
self.lastmsg.getMacedData())
|
||||
return self.lastmsg
|
||||
|
||||
elif self.state == STATE_AWAITING_SIG:
|
||||
logger.info('received DHKey while not awaiting DHKEY')
|
||||
if msg.gy == self.gy:
|
||||
logger.info('resending revealsig')
|
||||
return self.lastmsg
|
||||
else:
|
||||
logger.info('bad state for DHKey')
|
||||
|
||||
def handleRevealSig(self, msg):
|
||||
if self.state != STATE_AWAITING_REVEALSIG:
|
||||
logger.error('bad state for RevealSig')
|
||||
raise InvalidParameterError
|
||||
|
||||
self.r = msg.rkey
|
||||
gxmpi = AESCTR(self.r).decrypt(self.encgx)
|
||||
if SHA256(gxmpi) != self.hashgx:
|
||||
logger.error('Hashes don\'t match')
|
||||
logger.info('r=%r, hashgx=%r, computed hash=%r, gxmpi=%r',
|
||||
self.r, self.hashgx, SHA256(gxmpi), gxmpi)
|
||||
raise InvalidParameterError
|
||||
|
||||
self.gy = read_mpi(gxmpi)[0]
|
||||
self.createAuthKeys()
|
||||
|
||||
if msg.mac != SHA256HMAC160(self.mac_m2, msg.getMacedData()):
|
||||
logger.error('HMACs don\'t match')
|
||||
logger.info('mac=%r, mac_m2=%r, data=%r', msg.mac, self.mac_m2,
|
||||
msg.getMacedData())
|
||||
raise InvalidParameterError
|
||||
|
||||
self.checkPubkeyAuth(self.enc_c, self.mac_m1, msg.encsig)
|
||||
|
||||
aesxb = self.calculatePubkeyAuth(self.enc_cp, self.mac_m1p)
|
||||
self.sessionIdHalf = True
|
||||
|
||||
self.onSuccess(self)
|
||||
|
||||
self.ourKeyid = 0
|
||||
self.state = STATE_NONE
|
||||
|
||||
cmpmac = struct.pack(b'!I', len(aesxb)) + aesxb
|
||||
|
||||
return proto.Signature(aesxb, SHA256HMAC160(self.mac_m2p, cmpmac))
|
||||
|
||||
def handleSignature(self, msg):
|
||||
if self.state != STATE_AWAITING_SIG:
|
||||
logger.error('bad state (%d) for Signature', self.state)
|
||||
raise InvalidParameterError
|
||||
|
||||
if msg.mac != SHA256HMAC160(self.mac_m2p, msg.getMacedData()):
|
||||
logger.error('HMACs don\'t match')
|
||||
raise InvalidParameterError
|
||||
|
||||
self.checkPubkeyAuth(self.enc_cp, self.mac_m1p, msg.encsig)
|
||||
|
||||
self.sessionIdHalf = False
|
||||
|
||||
self.onSuccess(self)
|
||||
|
||||
self.ourKeyid = 0
|
||||
self.state = STATE_NONE
|
||||
|
||||
def createAuthKeys(self):
|
||||
s = pow(self.gy, self.dh.priv, DH_MODULUS)
|
||||
sbyte = pack_mpi(s)
|
||||
self.sessionId = SHA256(b'\x00' + sbyte)[:8]
|
||||
enc = SHA256(b'\x01' + sbyte)
|
||||
self.enc_c = enc[:16]
|
||||
self.enc_cp = enc[16:]
|
||||
self.mac_m1 = SHA256(b'\x02' + sbyte)
|
||||
self.mac_m2 = SHA256(b'\x03' + sbyte)
|
||||
self.mac_m1p = SHA256(b'\x04' + sbyte)
|
||||
self.mac_m2p = SHA256(b'\x05' + sbyte)
|
||||
self.extraKey = SHA256(b'\xff' + sbyte)
|
||||
|
||||
def calculatePubkeyAuth(self, key, mackey):
|
||||
pubkey = self.privkey.serializePublicKey()
|
||||
buf = pack_mpi(self.dh.pub)
|
||||
buf += pack_mpi(self.gy)
|
||||
buf += pubkey
|
||||
buf += struct.pack(b'!I', self.ourKeyid)
|
||||
MB = self.privkey.sign(SHA256HMAC(mackey, buf))
|
||||
|
||||
buf = pubkey
|
||||
buf += struct.pack(b'!I', self.ourKeyid)
|
||||
buf += MB
|
||||
return AESCTR(key).encrypt(buf)
|
||||
|
||||
def checkPubkeyAuth(self, key, mackey, encsig):
|
||||
auth = AESCTR(key).decrypt(encsig)
|
||||
self.theirPubkey, auth = PK.parsePublicKey(auth)
|
||||
|
||||
receivedKeyid, auth = proto.unpack(b'!I', auth)
|
||||
if receivedKeyid == 0:
|
||||
raise InvalidParameterError
|
||||
|
||||
authbuf = pack_mpi(self.gy)
|
||||
authbuf += pack_mpi(self.dh.pub)
|
||||
authbuf += self.theirPubkey.serializePublicKey()
|
||||
authbuf += struct.pack(b'!I', receivedKeyid)
|
||||
|
||||
if self.theirPubkey.verify(SHA256HMAC(mackey, authbuf), auth) is False:
|
||||
raise InvalidParameterError
|
||||
self.theirKeyid = receivedKeyid
|
||||
|
||||
SMPPROG_OK = 0
|
||||
SMPPROG_CHEATED = -2
|
||||
SMPPROG_FAILED = -1
|
||||
SMPPROG_SUCCEEDED = 1
|
||||
|
||||
class SMPHandler:
|
||||
def __init__(self, crypto):
|
||||
self.crypto = crypto
|
||||
self.state = 1
|
||||
self.g1 = DH_GENERATOR
|
||||
self.g2 = None
|
||||
self.g3 = None
|
||||
self.g3o = None
|
||||
self.x2 = None
|
||||
self.x3 = None
|
||||
self.prog = SMPPROG_OK
|
||||
self.pab = None
|
||||
self.qab = None
|
||||
self.questionReceived = False
|
||||
self.secret = None
|
||||
self.p = None
|
||||
self.q = None
|
||||
|
||||
def abort(self, appdata=None):
|
||||
self.state = 1
|
||||
self.sendTLV(proto.SMPABORTTLV(), appdata=appdata)
|
||||
|
||||
def sendTLV(self, tlv, appdata=None):
|
||||
self.crypto.ctx.sendInternal(b'', tlvs=[tlv], appdata=appdata)
|
||||
|
||||
def handle(self, tlv, appdata=None):
|
||||
logger.debug('handling TLV {0.__class__.__name__}'.format(tlv))
|
||||
self.prog = SMPPROG_CHEATED
|
||||
if isinstance(tlv, proto.SMPABORTTLV):
|
||||
self.state = 1
|
||||
return
|
||||
is1qTlv = isinstance(tlv, proto.SMP1QTLV)
|
||||
if isinstance(tlv, proto.SMP1TLV) or is1qTlv:
|
||||
if self.state != 1:
|
||||
self.abort(appdata=appdata)
|
||||
return
|
||||
|
||||
msg = tlv.mpis
|
||||
|
||||
if not check_group(msg[0]) or not check_group(msg[3]) \
|
||||
or not check_exp(msg[2]) or not check_exp(msg[5]) \
|
||||
or not check_known_log(msg[1], msg[2], self.g1, msg[0], 1) \
|
||||
or not check_known_log(msg[4], msg[5], self.g1, msg[3], 2):
|
||||
logger.error('invalid SMP1TLV received')
|
||||
self.abort(appdata=appdata)
|
||||
return
|
||||
|
||||
self.questionReceived = is1qTlv
|
||||
|
||||
self.g3o = msg[3]
|
||||
|
||||
self.x2 = randrange(2, DH_MAX)
|
||||
self.x3 = randrange(2, DH_MAX)
|
||||
|
||||
self.g2 = pow(msg[0], self.x2, DH_MODULUS)
|
||||
self.g3 = pow(msg[3], self.x3, DH_MODULUS)
|
||||
|
||||
self.prog = SMPPROG_OK
|
||||
self.state = 0
|
||||
return
|
||||
if isinstance(tlv, proto.SMP2TLV):
|
||||
if self.state != 2:
|
||||
self.abort(appdata=appdata)
|
||||
return
|
||||
|
||||
msg = tlv.mpis
|
||||
mp = msg[6]
|
||||
mq = msg[7]
|
||||
|
||||
if not check_group(msg[0]) or not check_group(msg[3]) \
|
||||
or not check_group(msg[6]) or not check_group(msg[7]) \
|
||||
or not check_exp(msg[2]) or not check_exp(msg[5]) \
|
||||
or not check_exp(msg[9]) or not check_exp(msg[10]) \
|
||||
or not check_known_log(msg[1], msg[2], self.g1, msg[0], 3) \
|
||||
or not check_known_log(msg[4], msg[5], self.g1, msg[3], 4):
|
||||
logger.error('invalid SMP2TLV received')
|
||||
self.abort(appdata=appdata)
|
||||
return
|
||||
|
||||
self.g3o = msg[3]
|
||||
self.g2 = pow(msg[0], self.x2, DH_MODULUS)
|
||||
self.g3 = pow(msg[3], self.x3, DH_MODULUS)
|
||||
|
||||
if not self.check_equal_coords(msg[6:11], 5):
|
||||
logger.error('invalid SMP2TLV received')
|
||||
self.abort(appdata=appdata)
|
||||
return
|
||||
|
||||
r = randrange(2, DH_MAX)
|
||||
self.p = pow(self.g3, r, DH_MODULUS)
|
||||
msg = [self.p]
|
||||
qa1 = pow(self.g1, r, DH_MODULUS)
|
||||
qa2 = pow(self.g2, self.secret, DH_MODULUS)
|
||||
self.q = qa1*qa2 % DH_MODULUS
|
||||
msg.append(self.q)
|
||||
msg += self.proof_equal_coords(r, 6)
|
||||
|
||||
inv = invMod(mp)
|
||||
self.pab = self.p * inv % DH_MODULUS
|
||||
inv = invMod(mq)
|
||||
self.qab = self.q * inv % DH_MODULUS
|
||||
|
||||
msg.append(pow(self.qab, self.x3, DH_MODULUS))
|
||||
msg += self.proof_equal_logs(7)
|
||||
|
||||
self.state = 4
|
||||
self.prog = SMPPROG_OK
|
||||
self.sendTLV(proto.SMP3TLV(msg), appdata=appdata)
|
||||
return
|
||||
if isinstance(tlv, proto.SMP3TLV):
|
||||
if self.state != 3:
|
||||
self.abort(appdata=appdata)
|
||||
return
|
||||
|
||||
msg = tlv.mpis
|
||||
|
||||
if not check_group(msg[0]) or not check_group(msg[1]) \
|
||||
or not check_group(msg[5]) or not check_exp(msg[3]) \
|
||||
or not check_exp(msg[4]) or not check_exp(msg[7]) \
|
||||
or not self.check_equal_coords(msg[:5], 6):
|
||||
logger.error('invalid SMP3TLV received')
|
||||
self.abort(appdata=appdata)
|
||||
return
|
||||
|
||||
inv = invMod(self.p)
|
||||
self.pab = msg[0] * inv % DH_MODULUS
|
||||
inv = invMod(self.q)
|
||||
self.qab = msg[1] * inv % DH_MODULUS
|
||||
|
||||
if not self.check_equal_logs(msg[5:8], 7):
|
||||
logger.error('invalid SMP3TLV received')
|
||||
self.abort(appdata=appdata)
|
||||
return
|
||||
|
||||
md = msg[5]
|
||||
msg = [pow(self.qab, self.x3, DH_MODULUS)]
|
||||
msg += self.proof_equal_logs(8)
|
||||
|
||||
rab = pow(md, self.x3, DH_MODULUS)
|
||||
self.prog = SMPPROG_SUCCEEDED if self.pab == rab else SMPPROG_FAILED
|
||||
|
||||
if self.prog != SMPPROG_SUCCEEDED:
|
||||
logger.error('secrets don\'t match')
|
||||
self.abort(appdata=appdata)
|
||||
self.crypto.ctx.setCurrentTrust('')
|
||||
return
|
||||
|
||||
logger.info('secrets matched')
|
||||
if not self.questionReceived:
|
||||
self.crypto.ctx.setCurrentTrust('smp')
|
||||
self.state = 1
|
||||
self.sendTLV(proto.SMP4TLV(msg), appdata=appdata)
|
||||
return
|
||||
if isinstance(tlv, proto.SMP4TLV):
|
||||
if self.state != 4:
|
||||
self.abort(appdata=appdata)
|
||||
return
|
||||
|
||||
msg = tlv.mpis
|
||||
|
||||
if not check_group(msg[0]) or not check_exp(msg[2]) \
|
||||
or not self.check_equal_logs(msg[:3], 8):
|
||||
logger.error('invalid SMP4TLV received')
|
||||
self.abort(appdata=appdata)
|
||||
return
|
||||
|
||||
rab = pow(msg[0], self.x3, DH_MODULUS)
|
||||
|
||||
self.prog = SMPPROG_SUCCEEDED if self.pab == rab else SMPPROG_FAILED
|
||||
|
||||
if self.prog != SMPPROG_SUCCEEDED:
|
||||
logger.error('secrets don\'t match')
|
||||
self.abort(appdata=appdata)
|
||||
self.crypto.ctx.setCurrentTrust('')
|
||||
return
|
||||
|
||||
logger.info('secrets matched')
|
||||
self.crypto.ctx.setCurrentTrust('smp')
|
||||
self.state = 1
|
||||
return
|
||||
|
||||
def gotSecret(self, secret, question=None, appdata=None):
|
||||
ourFP = self.crypto.ctx.user.getPrivkey().fingerprint()
|
||||
if self.state == 1:
|
||||
# first secret -> SMP1TLV
|
||||
combSecret = SHA256(b'\1' + ourFP +
|
||||
self.crypto.theirPubkey.fingerprint() +
|
||||
self.crypto.sessionId + secret)
|
||||
|
||||
self.secret = bytes_to_long(combSecret)
|
||||
|
||||
self.x2 = randrange(2, DH_MAX)
|
||||
self.x3 = randrange(2, DH_MAX)
|
||||
|
||||
msg = [pow(self.g1, self.x2, DH_MODULUS)]
|
||||
msg += proof_known_log(self.g1, self.x2, 1)
|
||||
msg.append(pow(self.g1, self.x3, DH_MODULUS))
|
||||
msg += proof_known_log(self.g1, self.x3, 2)
|
||||
|
||||
self.prog = SMPPROG_OK
|
||||
self.state = 2
|
||||
if question is None:
|
||||
self.sendTLV(proto.SMP1TLV(msg), appdata=appdata)
|
||||
else:
|
||||
self.sendTLV(proto.SMP1QTLV(question, msg), appdata=appdata)
|
||||
if self.state == 0:
|
||||
# response secret -> SMP2TLV
|
||||
combSecret = SHA256(b'\1' + self.crypto.theirPubkey.fingerprint() +
|
||||
ourFP + self.crypto.sessionId + secret)
|
||||
|
||||
self.secret = bytes_to_long(combSecret)
|
||||
|
||||
msg = [pow(self.g1, self.x2, DH_MODULUS)]
|
||||
msg += proof_known_log(self.g1, self.x2, 3)
|
||||
msg.append(pow(self.g1, self.x3, DH_MODULUS))
|
||||
msg += proof_known_log(self.g1, self.x3, 4)
|
||||
|
||||
r = randrange(2, DH_MAX)
|
||||
|
||||
self.p = pow(self.g3, r, DH_MODULUS)
|
||||
msg.append(self.p)
|
||||
|
||||
qb1 = pow(self.g1, r, DH_MODULUS)
|
||||
qb2 = pow(self.g2, self.secret, DH_MODULUS)
|
||||
self.q = qb1 * qb2 % DH_MODULUS
|
||||
msg.append(self.q)
|
||||
|
||||
msg += self.proof_equal_coords(r, 5)
|
||||
|
||||
self.state = 3
|
||||
self.sendTLV(proto.SMP2TLV(msg), appdata=appdata)
|
||||
|
||||
def proof_equal_coords(self, r, v):
|
||||
r1 = randrange(2, DH_MAX)
|
||||
r2 = randrange(2, DH_MAX)
|
||||
temp2 = pow(self.g1, r1, DH_MODULUS) \
|
||||
* pow(self.g2, r2, DH_MODULUS) % DH_MODULUS
|
||||
temp1 = pow(self.g3, r1, DH_MODULUS)
|
||||
|
||||
cb = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2))
|
||||
c = bytes_to_long(cb)
|
||||
|
||||
temp1 = r * c % SM_ORDER
|
||||
d1 = (r1-temp1) % SM_ORDER
|
||||
|
||||
temp1 = self.secret * c % SM_ORDER
|
||||
d2 = (r2 - temp1) % SM_ORDER
|
||||
return c, d1, d2
|
||||
|
||||
def check_equal_coords(self, coords, v):
|
||||
(p, q, c, d1, d2) = coords
|
||||
temp1 = pow(self.g3, d1, DH_MODULUS) * pow(p, c, DH_MODULUS) \
|
||||
% DH_MODULUS
|
||||
|
||||
temp2 = pow(self.g1, d1, DH_MODULUS) \
|
||||
* pow(self.g2, d2, DH_MODULUS) \
|
||||
* pow(q, c, DH_MODULUS) % DH_MODULUS
|
||||
|
||||
cprime = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2))
|
||||
|
||||
return long_to_bytes(c, 32) == cprime
|
||||
|
||||
def proof_equal_logs(self, v):
|
||||
r = randrange(2, DH_MAX)
|
||||
temp1 = pow(self.g1, r, DH_MODULUS)
|
||||
temp2 = pow(self.qab, r, DH_MODULUS)
|
||||
|
||||
cb = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2))
|
||||
c = bytes_to_long(cb)
|
||||
temp1 = self.x3 * c % SM_ORDER
|
||||
d = (r - temp1) % SM_ORDER
|
||||
return c, d
|
||||
|
||||
def check_equal_logs(self, logs, v):
|
||||
(r, c, d) = logs
|
||||
temp1 = pow(self.g1, d, DH_MODULUS) \
|
||||
* pow(self.g3o, c, DH_MODULUS) % DH_MODULUS
|
||||
|
||||
temp2 = pow(self.qab, d, DH_MODULUS) \
|
||||
* pow(r, c, DH_MODULUS) % DH_MODULUS
|
||||
|
||||
cprime = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2))
|
||||
return long_to_bytes(c, 32) == cprime
|
||||
|
||||
def proof_known_log(g, x, v):
|
||||
r = randrange(2, DH_MAX)
|
||||
c = bytes_to_long(SHA256(struct.pack(b'B', v) + pack_mpi(pow(g, r, DH_MODULUS))))
|
||||
temp = x * c % SM_ORDER
|
||||
return c, (r-temp) % SM_ORDER
|
||||
|
||||
def check_known_log(c, d, g, x, v):
|
||||
gd = pow(g, d, DH_MODULUS)
|
||||
xc = pow(x, c, DH_MODULUS)
|
||||
gdxc = gd * xc % DH_MODULUS
|
||||
return SHA256(struct.pack(b'B', v) + pack_mpi(gdxc)) == long_to_bytes(c, 32)
|
||||
|
||||
def invMod(n):
|
||||
return pow(n, DH_MODULUS_2, DH_MODULUS)
|
||||
|
||||
class InvalidParameterError(RuntimeError):
|
||||
pass
|
471
potr/proto.py
Normal file
471
potr/proto.py
Normal file
|
@ -0,0 +1,471 @@
|
|||
# Copyright 2011-2012 Kjell Braden <afflux@pentabarf.de>
|
||||
#
|
||||
# This file is part of the python-potr library.
|
||||
#
|
||||
# python-potr is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License as published by
|
||||
# the Free Software Foundation; either version 3 of the License, or
|
||||
# any later version.
|
||||
#
|
||||
# python-potr is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
# some python3 compatibilty
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import base64
|
||||
import struct
|
||||
from potr.utils import pack_mpi, read_mpi, pack_data, read_data, unpack
|
||||
|
||||
OTRTAG = b'?OTR'
|
||||
MESSAGE_TAG_BASE = b' \t \t\t\t\t \t \t \t '
|
||||
MESSAGE_TAGS = {
|
||||
1:b' \t \t \t ',
|
||||
2:b' \t\t \t ',
|
||||
3:b' \t\t \t\t',
|
||||
}
|
||||
|
||||
MSGTYPE_NOTOTR = 0
|
||||
MSGTYPE_TAGGEDPLAINTEXT = 1
|
||||
MSGTYPE_QUERY = 2
|
||||
MSGTYPE_DH_COMMIT = 3
|
||||
MSGTYPE_DH_KEY = 4
|
||||
MSGTYPE_REVEALSIG = 5
|
||||
MSGTYPE_SIGNATURE = 6
|
||||
MSGTYPE_V1_KEYEXCH = 7
|
||||
MSGTYPE_DATA = 8
|
||||
MSGTYPE_ERROR = 9
|
||||
MSGTYPE_UNKNOWN = -1
|
||||
|
||||
MSGFLAGS_IGNORE_UNREADABLE = 1
|
||||
|
||||
tlvClasses = {}
|
||||
messageClasses = {}
|
||||
|
||||
hasByteStr = bytes == str
|
||||
def bytesAndStrings(cls):
|
||||
if hasByteStr:
|
||||
cls.__str__ = lambda self: self.__bytes__()
|
||||
else:
|
||||
cls.__str__ = lambda self: str(self.__bytes__(), 'utf-8', 'replace')
|
||||
return cls
|
||||
|
||||
def registermessage(cls):
|
||||
if not hasattr(cls, 'parsePayload'):
|
||||
raise TypeError('registered message types need parsePayload()')
|
||||
messageClasses[cls.version, cls.msgtype] = cls
|
||||
return cls
|
||||
|
||||
def registertlv(cls):
|
||||
if not hasattr(cls, 'parsePayload'):
|
||||
raise TypeError('registered tlv types need parsePayload()')
|
||||
if cls.typ is None:
|
||||
raise TypeError('registered tlv type needs type ID')
|
||||
tlvClasses[cls.typ] = cls
|
||||
return cls
|
||||
|
||||
|
||||
def getslots(cls, base):
|
||||
''' helper to collect all the message slots from ancestors '''
|
||||
clss = [cls]
|
||||
|
||||
for cls in clss:
|
||||
if cls == base:
|
||||
continue
|
||||
|
||||
clss.extend(cls.__bases__)
|
||||
|
||||
for slot in cls.__slots__:
|
||||
yield slot
|
||||
|
||||
@bytesAndStrings
|
||||
class OTRMessage(object):
|
||||
__slots__ = ['payload']
|
||||
version = 0x0002
|
||||
msgtype = 0
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
|
||||
for slot in getslots(self.__class__, OTRMessage):
|
||||
if getattr(self, slot) != getattr(other, slot):
|
||||
return False
|
||||
return True
|
||||
|
||||
def __neq__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
class Error(OTRMessage):
|
||||
__slots__ = ['error']
|
||||
def __init__(self, error):
|
||||
super(Error, self).__init__()
|
||||
self.error = error
|
||||
|
||||
def __repr__(self):
|
||||
return '<proto.Error(%r)>' % self.error
|
||||
|
||||
def __bytes__(self):
|
||||
return b'?OTR Error:' + self.error
|
||||
|
||||
class Query(OTRMessage):
|
||||
__slots__ = ['versions']
|
||||
def __init__(self, versions=set()):
|
||||
super(Query, self).__init__()
|
||||
self.versions = versions
|
||||
|
||||
@classmethod
|
||||
def parse(cls, data):
|
||||
if not isinstance(data, bytes):
|
||||
raise TypeError('can only parse bytes')
|
||||
udata = data.decode('ascii', 'replace')
|
||||
|
||||
versions = set()
|
||||
if len(udata) > 0 and udata[0] == '?':
|
||||
udata = udata[1:]
|
||||
versions.add(1)
|
||||
|
||||
if len(udata) > 0 and udata[0] == 'v':
|
||||
versions.update(( int(c) for c in udata if c.isdigit() ))
|
||||
return cls(versions)
|
||||
|
||||
def __repr__(self):
|
||||
return '<proto.Query(versions=%r)>' % (self.versions)
|
||||
|
||||
def __bytes__(self):
|
||||
d = b'?OTR'
|
||||
if 1 in self.versions:
|
||||
d += b'?'
|
||||
d += b'v'
|
||||
|
||||
# in python3 there is only int->unicode conversion
|
||||
# so I convert to unicode and encode it to a byte string
|
||||
versions = [ '%d' % v for v in self.versions if v != 1 ]
|
||||
d += ''.join(versions).encode('ascii')
|
||||
|
||||
d += b'?'
|
||||
return d
|
||||
|
||||
class TaggedPlaintext(Query):
|
||||
__slots__ = ['msg']
|
||||
def __init__(self, msg, versions):
|
||||
super(TaggedPlaintext, self).__init__(versions)
|
||||
self.msg = msg
|
||||
|
||||
def __bytes__(self):
|
||||
data = self.msg + MESSAGE_TAG_BASE
|
||||
for v in self.versions:
|
||||
data += MESSAGE_TAGS[v]
|
||||
return data
|
||||
|
||||
def __repr__(self):
|
||||
return '<proto.TaggedPlaintext(versions={versions!r},msg={msg!r})>' \
|
||||
.format(versions=self.versions, msg=self.msg)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, data):
|
||||
tagPos = data.find(MESSAGE_TAG_BASE)
|
||||
if tagPos < 0:
|
||||
raise TypeError(
|
||||
'this is not a tagged plaintext ({0!r:.20})'.format(data))
|
||||
|
||||
tags = [ data[i:i+8] for i in range(tagPos, len(data), 8) ]
|
||||
versions = set([ version for version, tag in MESSAGE_TAGS.items() if tag
|
||||
in tags ])
|
||||
|
||||
return TaggedPlaintext(data[:tagPos], versions)
|
||||
|
||||
class GenericOTRMessage(OTRMessage):
|
||||
__slots__ = ['data']
|
||||
fields = []
|
||||
|
||||
def __init__(self, *args):
|
||||
super(GenericOTRMessage, self).__init__()
|
||||
if len(args) != len(self.fields):
|
||||
raise TypeError('%s needs %d arguments, got %d' %
|
||||
(self.__class__.__name__, len(self.fields), len(args)))
|
||||
|
||||
super(GenericOTRMessage, self).__setattr__('data',
|
||||
dict(zip((f[0] for f in self.fields), args)))
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.data:
|
||||
return self.data[attr]
|
||||
raise AttributeError(
|
||||
"'{t!r}' object has no attribute '{attr!r}'".format(attr=attr,
|
||||
t=self.__class__.__name__))
|
||||
|
||||
def __setattr__(self, attr, val):
|
||||
if attr in self.__slots__:
|
||||
super(GenericOTRMessage, self).__setattr__(attr, val)
|
||||
else:
|
||||
self.__getattr__(attr) # existence check
|
||||
self.data[attr] = val
|
||||
|
||||
def __bytes__(self):
|
||||
data = struct.pack(b'!HB', self.version, self.msgtype) \
|
||||
+ self.getPayload()
|
||||
return b'?OTR:' + base64.b64encode(data) + b'.'
|
||||
|
||||
def __repr__(self):
|
||||
name = self.__class__.__name__
|
||||
data = ''
|
||||
for k, _ in self.fields:
|
||||
data += '%s=%r,' % (k, self.data[k])
|
||||
return '<proto.%s(%s)>' % (name, data)
|
||||
|
||||
@classmethod
|
||||
def parsePayload(cls, data):
|
||||
data = base64.b64decode(data)
|
||||
args = []
|
||||
for _, ftype in cls.fields:
|
||||
if ftype == 'data':
|
||||
value, data = read_data(data)
|
||||
elif isinstance(ftype, bytes):
|
||||
value, data = unpack(ftype, data)
|
||||
elif isinstance(ftype, int):
|
||||
value, data = data[:ftype], data[ftype:]
|
||||
args.append(value)
|
||||
return cls(*args)
|
||||
|
||||
def getPayload(self, *ffilter):
|
||||
payload = b''
|
||||
for k, ftype in self.fields:
|
||||
if k in ffilter:
|
||||
continue
|
||||
|
||||
if ftype == 'data':
|
||||
payload += pack_data(self.data[k])
|
||||
elif isinstance(ftype, bytes):
|
||||
payload += struct.pack(ftype, self.data[k])
|
||||
else:
|
||||
payload += self.data[k]
|
||||
return payload
|
||||
|
||||
class AKEMessage(GenericOTRMessage):
|
||||
__slots__ = []
|
||||
|
||||
@registermessage
|
||||
class DHCommit(AKEMessage):
|
||||
__slots__ = []
|
||||
msgtype = 0x02
|
||||
fields = [('encgx', 'data'), ('hashgx', 'data'), ]
|
||||
|
||||
@registermessage
|
||||
class DHKey(AKEMessage):
|
||||
__slots__ = []
|
||||
msgtype = 0x0a
|
||||
fields = [('gy', 'data'), ]
|
||||
|
||||
@registermessage
|
||||
class RevealSig(AKEMessage):
|
||||
__slots__ = []
|
||||
msgtype = 0x11
|
||||
fields = [('rkey', 'data'), ('encsig', 'data'), ('mac', 20),]
|
||||
|
||||
def getMacedData(self):
|
||||
p = self.encsig
|
||||
return struct.pack(b'!I', len(p)) + p
|
||||
|
||||
@registermessage
|
||||
class Signature(AKEMessage):
|
||||
__slots__ = []
|
||||
msgtype = 0x12
|
||||
fields = [('encsig', 'data'), ('mac', 20)]
|
||||
|
||||
def getMacedData(self):
|
||||
p = self.encsig
|
||||
return struct.pack(b'!I', len(p)) + p
|
||||
|
||||
@registermessage
|
||||
class DataMessage(GenericOTRMessage):
|
||||
__slots__ = []
|
||||
msgtype = 0x03
|
||||
fields = [('flags', b'!B'), ('skeyid', b'!I'), ('rkeyid', b'!I'),
|
||||
('dhy', 'data'), ('ctr', 8), ('encmsg', 'data'), ('mac', 20),
|
||||
('oldmacs', 'data'), ]
|
||||
|
||||
def getMacedData(self):
|
||||
return struct.pack(b'!HB', self.version, self.msgtype) + \
|
||||
self.getPayload('mac', 'oldmacs')
|
||||
|
||||
@bytesAndStrings
|
||||
class TLV(object):
|
||||
__slots__ = []
|
||||
typ = None
|
||||
|
||||
def getPayload(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
val = self.getPayload()
|
||||
return '<{cls}(typ={t},len={l},val={v!r})>'.format(t=self.typ,
|
||||
l=len(val), v=val, cls=self.__class__.__name__)
|
||||
|
||||
def __bytes__(self):
|
||||
val = self.getPayload()
|
||||
return struct.pack(b'!HH', self.typ, len(val)) + val
|
||||
|
||||
@classmethod
|
||||
def parse(cls, data):
|
||||
if not data:
|
||||
return []
|
||||
typ, length, data = unpack(b'!HH', data)
|
||||
if typ in tlvClasses:
|
||||
return [tlvClasses[typ].parsePayload(data[:length])] \
|
||||
+ cls.parse(data[length:])
|
||||
else:
|
||||
raise UnknownTLV(data)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
|
||||
for slot in getslots(self.__class__, TLV):
|
||||
if getattr(self, slot) != getattr(other, slot):
|
||||
return False
|
||||
return True
|
||||
|
||||
def __neq__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
@registertlv
|
||||
class PaddingTLV(TLV):
|
||||
typ = 0
|
||||
|
||||
__slots__ = ['padding']
|
||||
|
||||
def __init__(self, padding):
|
||||
super(PaddingTLV, self).__init__()
|
||||
self.padding = padding
|
||||
|
||||
def getPayload(self):
|
||||
return self.padding
|
||||
|
||||
@classmethod
|
||||
def parsePayload(cls, data):
|
||||
return cls(data)
|
||||
|
||||
@registertlv
|
||||
class DisconnectTLV(TLV):
|
||||
typ = 1
|
||||
def __init__(self):
|
||||
super(DisconnectTLV, self).__init__()
|
||||
|
||||
def getPayload(self):
|
||||
return b''
|
||||
|
||||
@classmethod
|
||||
def parsePayload(cls, data):
|
||||
if len(data) > 0:
|
||||
raise TypeError('DisconnectTLV must not contain data. got {0!r}'
|
||||
.format(data))
|
||||
return cls()
|
||||
|
||||
class SMPTLV(TLV):
|
||||
__slots__ = ['mpis']
|
||||
dlen = None
|
||||
|
||||
def __init__(self, mpis=None):
|
||||
super(SMPTLV, self).__init__()
|
||||
if mpis is None:
|
||||
mpis = []
|
||||
if self.dlen is None:
|
||||
raise TypeError('no amount of mpis specified in dlen')
|
||||
if len(mpis) != self.dlen:
|
||||
raise TypeError('expected {0} mpis, got {1}'
|
||||
.format(self.dlen, len(mpis)))
|
||||
self.mpis = mpis
|
||||
|
||||
def getPayload(self):
|
||||
d = struct.pack(b'!I', len(self.mpis))
|
||||
for n in self.mpis:
|
||||
d += pack_mpi(n)
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def parsePayload(cls, data):
|
||||
mpis = []
|
||||
if cls.dlen > 0:
|
||||
count, data = unpack(b'!I', data)
|
||||
for _ in range(count):
|
||||
n, data = read_mpi(data)
|
||||
mpis.append(n)
|
||||
if len(data) > 0:
|
||||
raise TypeError('too much data for {0} mpis'.format(cls.dlen))
|
||||
return cls(mpis)
|
||||
|
||||
@registertlv
|
||||
class SMP1TLV(SMPTLV):
|
||||
typ = 2
|
||||
dlen = 6
|
||||
|
||||
@registertlv
|
||||
class SMP1QTLV(SMPTLV):
|
||||
typ = 7
|
||||
dlen = 6
|
||||
__slots__ = ['msg']
|
||||
|
||||
def __init__(self, msg, mpis):
|
||||
self.msg = msg
|
||||
super(SMP1QTLV, self).__init__(mpis)
|
||||
|
||||
def getPayload(self):
|
||||
return self.msg + b'\0' + super(SMP1QTLV, self).getPayload()
|
||||
|
||||
@classmethod
|
||||
def parsePayload(cls, data):
|
||||
msg, data = data.split(b'\0', 1)
|
||||
mpis = SMP1TLV.parsePayload(data).mpis
|
||||
return cls(msg, mpis)
|
||||
|
||||
@registertlv
|
||||
class SMP2TLV(SMPTLV):
|
||||
typ = 3
|
||||
dlen = 11
|
||||
|
||||
@registertlv
|
||||
class SMP3TLV(SMPTLV):
|
||||
typ = 4
|
||||
dlen = 8
|
||||
|
||||
@registertlv
|
||||
class SMP4TLV(SMPTLV):
|
||||
typ = 5
|
||||
dlen = 3
|
||||
|
||||
@registertlv
|
||||
class SMPABORTTLV(SMPTLV):
|
||||
typ = 6
|
||||
dlen = 0
|
||||
|
||||
def getPayload(self):
|
||||
return b''
|
||||
|
||||
@registertlv
|
||||
class ExtraKeyTLV(TLV):
|
||||
typ = 8
|
||||
|
||||
__slots__ = ['appid', 'appdata']
|
||||
|
||||
def __init__(self, appid, appdata):
|
||||
super(ExtraKeyTLV, self).__init__()
|
||||
self.appid = appid
|
||||
self.appdata = appdata
|
||||
if appdata is None:
|
||||
self.appdata = b''
|
||||
|
||||
def getPayload(self):
|
||||
return self.appid + self.appdata
|
||||
|
||||
@classmethod
|
||||
def parsePayload(cls, data):
|
||||
return cls(data[:4], data[4:])
|
||||
|
||||
class UnknownTLV(RuntimeError):
|
||||
pass
|
66
potr/utils.py
Normal file
66
potr/utils.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
# Copyright 2012 Kjell Braden <afflux@pentabarf.de>
|
||||
#
|
||||
# This file is part of the python-potr library.
|
||||
#
|
||||
# python-potr is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License as published by
|
||||
# the Free Software Foundation; either version 3 of the License, or
|
||||
# any later version.
|
||||
#
|
||||
# python-potr is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
# some python3 compatibilty
|
||||
from __future__ import unicode_literals
|
||||
|
||||
|
||||
import struct
|
||||
|
||||
def pack_mpi(n):
|
||||
return pack_data(long_to_bytes(n))
|
||||
def read_mpi(data):
|
||||
n, data = read_data(data)
|
||||
return bytes_to_long(n), data
|
||||
def pack_data(data):
|
||||
return struct.pack(b'!I', len(data)) + data
|
||||
def read_data(data):
|
||||
datalen, data = unpack(b'!I', data)
|
||||
return data[:datalen], data[datalen:]
|
||||
def unpack(fmt, buf):
|
||||
s = struct.Struct(fmt)
|
||||
return s.unpack(buf[:s.size]) + (buf[s.size:],)
|
||||
|
||||
|
||||
def bytes_to_long(b):
|
||||
l = len(b)
|
||||
s = 0
|
||||
for i in range(l):
|
||||
s += byte_to_long(b[i:i+1]) << 8*(l-i-1)
|
||||
return s
|
||||
|
||||
def long_to_bytes(l, n=0):
|
||||
b = b''
|
||||
while l != 0 or n > 0:
|
||||
b = long_to_byte(l & 0xff) + b
|
||||
l >>= 8
|
||||
n -= 1
|
||||
return b
|
||||
|
||||
def byte_to_long(b):
|
||||
return struct.unpack(b'B', b)[0]
|
||||
def long_to_byte(l):
|
||||
return struct.pack(b'B', l)
|
||||
|
||||
def human_hash(fp):
|
||||
fp = fp.upper()
|
||||
fplen = len(fp)
|
||||
wordsize = fplen//5
|
||||
buf = ''
|
||||
for w in range(0, fplen, wordsize):
|
||||
buf += '{0} '.format(fp[w:w+wordsize])
|
||||
return buf.rstrip()
|
Loading…
Reference in a new issue