gajim-otrplugin/potr/proto.py

472 lines
12 KiB
Python
Raw Normal View History

# Copyright 2011-2012 Kjell Braden <afflux@pentabarf.de>
#
# This file is part of the python-potr library.
#
# python-potr is free software; you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# any later version.
#
# python-potr is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>.
# some python3 compatibilty
from __future__ import unicode_literals
import base64
import struct
from potr.utils import pack_mpi, read_mpi, pack_data, read_data, unpack
OTRTAG = b'?OTR'
MESSAGE_TAG_BASE = b' \t \t\t\t\t \t \t \t '
MESSAGE_TAGS = {
1:b' \t \t \t ',
2:b' \t\t \t ',
3:b' \t\t \t\t',
}
MSGTYPE_NOTOTR = 0
MSGTYPE_TAGGEDPLAINTEXT = 1
MSGTYPE_QUERY = 2
MSGTYPE_DH_COMMIT = 3
MSGTYPE_DH_KEY = 4
MSGTYPE_REVEALSIG = 5
MSGTYPE_SIGNATURE = 6
MSGTYPE_V1_KEYEXCH = 7
MSGTYPE_DATA = 8
MSGTYPE_ERROR = 9
MSGTYPE_UNKNOWN = -1
MSGFLAGS_IGNORE_UNREADABLE = 1
tlvClasses = {}
messageClasses = {}
hasByteStr = bytes == str
def bytesAndStrings(cls):
if hasByteStr:
cls.__str__ = lambda self: self.__bytes__()
else:
cls.__str__ = lambda self: str(self.__bytes__(), 'utf-8', 'replace')
return cls
def registermessage(cls):
if not hasattr(cls, 'parsePayload'):
raise TypeError('registered message types need parsePayload()')
messageClasses[cls.version, cls.msgtype] = cls
return cls
def registertlv(cls):
if not hasattr(cls, 'parsePayload'):
raise TypeError('registered tlv types need parsePayload()')
if cls.typ is None:
raise TypeError('registered tlv type needs type ID')
tlvClasses[cls.typ] = cls
return cls
def getslots(cls, base):
''' helper to collect all the message slots from ancestors '''
clss = [cls]
for cls in clss:
if cls == base:
continue
clss.extend(cls.__bases__)
for slot in cls.__slots__:
yield slot
@bytesAndStrings
class OTRMessage(object):
__slots__ = ['payload']
version = 0x0002
msgtype = 0
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
for slot in getslots(self.__class__, OTRMessage):
if getattr(self, slot) != getattr(other, slot):
return False
return True
def __neq__(self, other):
return not self.__eq__(other)
class Error(OTRMessage):
__slots__ = ['error']
def __init__(self, error):
super(Error, self).__init__()
self.error = error
def __repr__(self):
return '<proto.Error(%r)>' % self.error
def __bytes__(self):
return b'?OTR Error:' + self.error
class Query(OTRMessage):
__slots__ = ['versions']
def __init__(self, versions=set()):
super(Query, self).__init__()
self.versions = versions
@classmethod
def parse(cls, data):
if not isinstance(data, bytes):
raise TypeError('can only parse bytes')
udata = data.decode('ascii', 'replace')
versions = set()
if len(udata) > 0 and udata[0] == '?':
udata = udata[1:]
versions.add(1)
if len(udata) > 0 and udata[0] == 'v':
versions.update(( int(c) for c in udata if c.isdigit() ))
return cls(versions)
def __repr__(self):
return '<proto.Query(versions=%r)>' % (self.versions)
def __bytes__(self):
d = b'?OTR'
if 1 in self.versions:
d += b'?'
d += b'v'
# in python3 there is only int->unicode conversion
# so I convert to unicode and encode it to a byte string
versions = [ '%d' % v for v in self.versions if v != 1 ]
d += ''.join(versions).encode('ascii')
d += b'?'
return d
class TaggedPlaintext(Query):
__slots__ = ['msg']
def __init__(self, msg, versions):
super(TaggedPlaintext, self).__init__(versions)
self.msg = msg
def __bytes__(self):
data = self.msg + MESSAGE_TAG_BASE
for v in self.versions:
data += MESSAGE_TAGS[v]
return data
def __repr__(self):
return '<proto.TaggedPlaintext(versions={versions!r},msg={msg!r})>' \
.format(versions=self.versions, msg=self.msg)
@classmethod
def parse(cls, data):
tagPos = data.find(MESSAGE_TAG_BASE)
if tagPos < 0:
raise TypeError(
'this is not a tagged plaintext ({0!r:.20})'.format(data))
tags = [ data[i:i+8] for i in range(tagPos, len(data), 8) ]
versions = set([ version for version, tag in MESSAGE_TAGS.items() if tag
in tags ])
return TaggedPlaintext(data[:tagPos], versions)
class GenericOTRMessage(OTRMessage):
__slots__ = ['data']
fields = []
def __init__(self, *args):
super(GenericOTRMessage, self).__init__()
if len(args) != len(self.fields):
raise TypeError('%s needs %d arguments, got %d' %
(self.__class__.__name__, len(self.fields), len(args)))
super(GenericOTRMessage, self).__setattr__('data',
dict(zip((f[0] for f in self.fields), args)))
def __getattr__(self, attr):
if attr in self.data:
return self.data[attr]
raise AttributeError(
"'{t!r}' object has no attribute '{attr!r}'".format(attr=attr,
t=self.__class__.__name__))
def __setattr__(self, attr, val):
if attr in self.__slots__:
super(GenericOTRMessage, self).__setattr__(attr, val)
else:
self.__getattr__(attr) # existence check
self.data[attr] = val
def __bytes__(self):
data = struct.pack(b'!HB', self.version, self.msgtype) \
+ self.getPayload()
return b'?OTR:' + base64.b64encode(data) + b'.'
def __repr__(self):
name = self.__class__.__name__
data = ''
for k, _ in self.fields:
data += '%s=%r,' % (k, self.data[k])
return '<proto.%s(%s)>' % (name, data)
@classmethod
def parsePayload(cls, data):
data = base64.b64decode(data)
args = []
for _, ftype in cls.fields:
if ftype == 'data':
value, data = read_data(data)
elif isinstance(ftype, bytes):
value, data = unpack(ftype, data)
elif isinstance(ftype, int):
value, data = data[:ftype], data[ftype:]
args.append(value)
return cls(*args)
def getPayload(self, *ffilter):
payload = b''
for k, ftype in self.fields:
if k in ffilter:
continue
if ftype == 'data':
payload += pack_data(self.data[k])
elif isinstance(ftype, bytes):
payload += struct.pack(ftype, self.data[k])
else:
payload += self.data[k]
return payload
class AKEMessage(GenericOTRMessage):
__slots__ = []
@registermessage
class DHCommit(AKEMessage):
__slots__ = []
msgtype = 0x02
fields = [('encgx', 'data'), ('hashgx', 'data'), ]
@registermessage
class DHKey(AKEMessage):
__slots__ = []
msgtype = 0x0a
fields = [('gy', 'data'), ]
@registermessage
class RevealSig(AKEMessage):
__slots__ = []
msgtype = 0x11
fields = [('rkey', 'data'), ('encsig', 'data'), ('mac', 20),]
def getMacedData(self):
p = self.encsig
return struct.pack(b'!I', len(p)) + p
@registermessage
class Signature(AKEMessage):
__slots__ = []
msgtype = 0x12
fields = [('encsig', 'data'), ('mac', 20)]
def getMacedData(self):
p = self.encsig
return struct.pack(b'!I', len(p)) + p
@registermessage
class DataMessage(GenericOTRMessage):
__slots__ = []
msgtype = 0x03
fields = [('flags', b'!B'), ('skeyid', b'!I'), ('rkeyid', b'!I'),
('dhy', 'data'), ('ctr', 8), ('encmsg', 'data'), ('mac', 20),
('oldmacs', 'data'), ]
def getMacedData(self):
return struct.pack(b'!HB', self.version, self.msgtype) + \
self.getPayload('mac', 'oldmacs')
@bytesAndStrings
class TLV(object):
__slots__ = []
typ = None
def getPayload(self):
raise NotImplementedError
def __repr__(self):
val = self.getPayload()
return '<{cls}(typ={t},len={l},val={v!r})>'.format(t=self.typ,
l=len(val), v=val, cls=self.__class__.__name__)
def __bytes__(self):
val = self.getPayload()
return struct.pack(b'!HH', self.typ, len(val)) + val
@classmethod
def parse(cls, data):
if not data:
return []
typ, length, data = unpack(b'!HH', data)
if typ in tlvClasses:
return [tlvClasses[typ].parsePayload(data[:length])] \
+ cls.parse(data[length:])
else:
raise UnknownTLV(data)
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
for slot in getslots(self.__class__, TLV):
if getattr(self, slot) != getattr(other, slot):
return False
return True
def __neq__(self, other):
return not self.__eq__(other)
@registertlv
class PaddingTLV(TLV):
typ = 0
__slots__ = ['padding']
def __init__(self, padding):
super(PaddingTLV, self).__init__()
self.padding = padding
def getPayload(self):
return self.padding
@classmethod
def parsePayload(cls, data):
return cls(data)
@registertlv
class DisconnectTLV(TLV):
typ = 1
def __init__(self):
super(DisconnectTLV, self).__init__()
def getPayload(self):
return b''
@classmethod
def parsePayload(cls, data):
if len(data) > 0:
raise TypeError('DisconnectTLV must not contain data. got {0!r}'
.format(data))
return cls()
class SMPTLV(TLV):
__slots__ = ['mpis']
dlen = None
def __init__(self, mpis=None):
super(SMPTLV, self).__init__()
if mpis is None:
mpis = []
if self.dlen is None:
raise TypeError('no amount of mpis specified in dlen')
if len(mpis) != self.dlen:
raise TypeError('expected {0} mpis, got {1}'
.format(self.dlen, len(mpis)))
self.mpis = mpis
def getPayload(self):
d = struct.pack(b'!I', len(self.mpis))
for n in self.mpis:
d += pack_mpi(n)
return d
@classmethod
def parsePayload(cls, data):
mpis = []
if cls.dlen > 0:
count, data = unpack(b'!I', data)
for _ in range(count):
n, data = read_mpi(data)
mpis.append(n)
if len(data) > 0:
raise TypeError('too much data for {0} mpis'.format(cls.dlen))
return cls(mpis)
@registertlv
class SMP1TLV(SMPTLV):
typ = 2
dlen = 6
@registertlv
class SMP1QTLV(SMPTLV):
typ = 7
dlen = 6
__slots__ = ['msg']
def __init__(self, msg, mpis):
self.msg = msg
super(SMP1QTLV, self).__init__(mpis)
def getPayload(self):
return self.msg + b'\0' + super(SMP1QTLV, self).getPayload()
@classmethod
def parsePayload(cls, data):
msg, data = data.split(b'\0', 1)
mpis = SMP1TLV.parsePayload(data).mpis
return cls(msg, mpis)
@registertlv
class SMP2TLV(SMPTLV):
typ = 3
dlen = 11
@registertlv
class SMP3TLV(SMPTLV):
typ = 4
dlen = 8
@registertlv
class SMP4TLV(SMPTLV):
typ = 5
dlen = 3
@registertlv
class SMPABORTTLV(SMPTLV):
typ = 6
dlen = 0
def getPayload(self):
return b''
@registertlv
class ExtraKeyTLV(TLV):
typ = 8
__slots__ = ['appid', 'appdata']
def __init__(self, appid, appdata):
super(ExtraKeyTLV, self).__init__()
self.appid = appid
self.appdata = appdata
if appdata is None:
self.appdata = b''
def getPayload(self):
return self.appid + self.appdata
@classmethod
def parsePayload(cls, data):
return cls(data[:4], data[4:])
class UnknownTLV(RuntimeError):
pass