Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make it work with KeepKey #8

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ python:
- "3.4"

install:
- pip install ecdsa ed25519 # test without trezorlib for now
- pip install ecdsa ed25519 semver # test without trezorlib for now
- pip install pylint coverage pep8

script:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
author_email='roman.zeyde@gmail.com',
url='http://github.com/romanz/trezor-agent',
packages=['trezor_agent', 'trezor_agent.trezor'],
install_requires=['ecdsa>=0.13', 'ed25519>=1.4', 'Cython>=0.23.4', 'trezor>=0.6.6'],
install_requires=['ecdsa>=0.13', 'ed25519>=1.4', 'Cython>=0.23.4', 'trezor>=0.6.6', 'keepkey>=0.7.0', 'semver>=2.2'],
platforms=['POSIX'],
classifiers=[
'Development Status :: 4 - Beta',
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ deps=
pep8
coverage
pylint
semver
commands=
pep8 trezor_agent
pylint --reports=no --rcfile .pylintrc trezor_agent
Expand Down
52 changes: 15 additions & 37 deletions trezor_agent/tests/test_trezor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import io

import mock
import pytest

from .. import formats, util
from ..trezor import client
from ..trezor import client, factory

ADDR = [2147483661, 2810943954, 3938368396, 3454558782, 3848009040]
CURVE = 'nist256p1'
Expand All @@ -18,15 +17,7 @@

class ConnectionMock(object):

def __init__(self, version):
self.features = mock.Mock(spec=[])
self.features.device_id = '123456789'
self.features.label = 'mywallet'
self.features.vendor = 'mock'
self.features.major_version = version[0]
self.features.minor_version = version[1]
self.features.patch_version = version[2]
self.features.revision = b'456'
def __init__(self):
self.closed = False

def close(self):
Expand All @@ -49,21 +40,20 @@ def ping(self, msg):
return msg


class FactoryMock(object):
def identity_type(**kwargs):
result = mock.Mock(spec=[])
result.index = 0
result.proto = result.user = result.host = result.port = None
result.path = None
for k, v in kwargs.items():
setattr(result, k, v)
return result

@staticmethod
def client():
return ConnectionMock(version=(1, 3, 4))

@staticmethod
def identity_type(**kwargs):
result = mock.Mock(spec=[])
result.index = 0
result.proto = result.user = result.host = result.port = None
result.path = None
for k, v in kwargs.items():
setattr(result, k, v)
return result
def load_client():
return factory.ClientWrapper(connection=ConnectionMock(),
identity_type=identity_type,
device_name='DEVICE_NAME')


BLOB = (b'\x00\x00\x00 \xce\xe0\xc9\xd5\xceu/\xe8\xc5\xf2\xbfR+x\xa1\xcf\xb0'
Expand All @@ -82,7 +72,7 @@ def identity_type(**kwargs):

def test_ssh_agent():
label = 'localhost:22'
c = client.Client(factory=FactoryMock)
c = client.Client(loader=load_client)
ident = c.get_identity(label=label)
assert ident.host == 'localhost'
assert ident.proto == 'ssh'
Expand Down Expand Up @@ -129,15 +119,3 @@ def test_utils():

url = 'https://user@host:443/path'
assert client.identity_to_string(identity) == url


def test_old_version():

class OldFactoryMock(FactoryMock):

@staticmethod
def client():
return ConnectionMock(version=(1, 2, 3))

with pytest.raises(ValueError):
client.Client(factory=OldFactoryMock)
23 changes: 0 additions & 23 deletions trezor_agent/trezor/_factory.py

This file was deleted.

36 changes: 12 additions & 24 deletions trezor_agent/trezor/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,53 +4,41 @@
import re
import struct

from . import _factory as TrezorFactory
from . import factory
from .. import formats, util

log = logging.getLogger(__name__)


class Client(object):

MIN_VERSION = [1, 3, 4]

def __init__(self, factory=TrezorFactory, curve=formats.CURVE_NIST256):
def __init__(self, loader=factory.load, curve=formats.CURVE_NIST256):
client_wrapper = loader()
self.client = client_wrapper.connection
self.identity_type = client_wrapper.identity_type
self.device_name = client_wrapper.device_name
self.curve = curve
self.factory = factory
self.client = self.factory.client()
f = self.client.features
log.debug('connected to Trezor %s', f.device_id)
log.debug('label : %s', f.label)
log.debug('vendor : %s', f.vendor)
version = [f.major_version, f.minor_version, f.patch_version]
version_str = '.'.join([str(v) for v in version])
log.debug('version : %s', version_str)
log.debug('revision : %s', binascii.hexlify(f.revision))
if version < self.MIN_VERSION:
fmt = 'Please upgrade your TREZOR to v{}+ firmware'
version_str = '.'.join([str(v) for v in self.MIN_VERSION])
raise ValueError(fmt.format(version_str))

def __enter__(self):
msg = 'Hello World!'
assert self.client.ping(msg) == msg
return self

def __exit__(self, *args):
log.info('disconnected from Trezor')
log.info('disconnected from %s', self.device_name)
self.client.clear_session() # forget PIN and shutdown screen
self.client.close()

def get_identity(self, label):
identity = string_to_identity(label, self.factory.identity_type)
identity = string_to_identity(label, self.identity_type)
identity.proto = 'ssh'
return identity

def get_public_key(self, label):
identity = self.get_identity(label=label)
label = identity_to_string(identity) # canonize key label
log.info('getting "%s" public key (%s) from Trezor...',
label, self.curve)
log.info('getting "%s" public key (%s) from %s...',
label, self.curve, self.device_name)
addr = _get_address(identity)
node = self.client.get_public_node(n=addr,
ecdsa_curve_name=self.curve)
Expand All @@ -63,8 +51,8 @@ def sign_ssh_challenge(self, label, blob):
identity = self.get_identity(label=label)
msg = _parse_ssh_blob(blob)

log.info('please confirm user "%s" login to "%s" using Trezor...',
msg['user'], label)
log.info('please confirm user "%s" login to "%s" using %s...',
msg['user'], label, self.device_name)

visual = identity.path # not signed when proto='ssh'
result = self.client.sign_identity(identity=identity,
Expand Down
78 changes: 78 additions & 0 deletions trezor_agent/trezor/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
''' Thin wrapper around trezor/keepkey libraries. '''
import binascii
import collections
import logging

import semver

log = logging.getLogger(__name__)

ClientWrapper = collections.namedtuple(
'ClientWrapper',
['connection', 'identity_type', 'device_name'])


# pylint: disable=too-many-arguments
def _load_client(name, client_type, hid_transport,
passphrase_ack, identity_type, required_version):

def empty_passphrase_handler(_):
return passphrase_ack(passphrase='')

for d in hid_transport.enumerate():
connection = client_type(hid_transport(d))
connection.callback_PassphraseRequest = empty_passphrase_handler
f = connection.features
log.debug('connected to %s %s', name, f.device_id)
log.debug('label : %s', f.label)
log.debug('vendor : %s', f.vendor)
current_version = '{}.{}.{}'.format(f.major_version,
f.minor_version,
f.patch_version)
log.debug('version : %s', current_version)
log.debug('revision : %s', binascii.hexlify(f.revision))
if not semver.match(current_version, required_version):
fmt = 'Please upgrade your {} firmware to {} version (current: {})'
raise ValueError(fmt.format(name,
required_version,
current_version))
yield ClientWrapper(connection=connection,
identity_type=identity_type,
device_name=name)


def _load_trezor():
# pylint: disable=import-error
from trezorlib.client import TrezorClient
from trezorlib.transport_hid import HidTransport
from trezorlib.messages_pb2 import PassphraseAck
from trezorlib.types_pb2 import IdentityType
return _load_client(name='Trezor',
client_type=TrezorClient,
hid_transport=HidTransport,
passphrase_ack=PassphraseAck,
identity_type=IdentityType,
required_version='>=1.3.4')


def _load_keepkey():
# pylint: disable=import-error
from keepkeylib.client import KeepKeyClient
from keepkeylib.transport_hid import HidTransport
from keepkeylib.messages_pb2 import PassphraseAck
from keepkeylib.types_pb2 import IdentityType
return _load_client(name='KeepKey',
client_type=KeepKeyClient,
hid_transport=HidTransport,
passphrase_ack=PassphraseAck,
identity_type=IdentityType,
required_version='>=1.0.4')


def load():
devices = list(_load_trezor()) + list(_load_keepkey())
if len(devices) == 1:
return devices[0]

msg = '{:d} devices found'.format(len(devices))
raise IOError(msg)
9 changes: 8 additions & 1 deletion trezor_agent/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import io
import struct
import socket
import time


def send(conn, data, fmt=None):
Expand All @@ -21,7 +23,12 @@ def recv(conn, size):

res = io.BytesIO()
while size > 0:
buf = _read(size)
try:
buf = _read(size)
except socket.error as ex:
if str(ex) == "[Errno 35] Resource temporarily unavailable":
time.sleep(0)
continue
if not buf:
raise EOFError
size = size - len(buf)
Expand Down