diff --git a/setup.py b/setup.py index e19a0d2..f6c0ec4 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ name = 'PyOTA', description = 'IOTA API library for Python', url = 'https://github.com/iotaledger/iota.lib.py', - version = '1.1.2', + version = '1.1.3', packages = find_packages('src'), include_package_data = True, diff --git a/src/iota/crypto/addresses.py b/src/iota/crypto/addresses.py index 44db482..25ba900 100644 --- a/src/iota/crypto/addresses.py +++ b/src/iota/crypto/addresses.py @@ -6,15 +6,16 @@ from abc import ABCMeta, abstractmethod as abstract_method from contextlib import contextmanager as context_manager from threading import Lock + +from six import binary_type, with_metaclass from typing import Dict, Generator, Iterable, List, MutableSequence, \ - Optional, Tuple + Optional from iota import Address, TRITS_PER_TRYTE, TrytesCompatible from iota.crypto import Curl from iota.crypto.signing import KeyGenerator, KeyIterator -from iota.crypto.types import PrivateKey, Seed +from iota.crypto.types import Digest, PrivateKey, Seed from iota.exceptions import with_context -from six import binary_type, with_metaclass __all__ = [ 'AddressGenerator', @@ -253,19 +254,19 @@ def create_iterator(self, start=0, step=1): yield address @staticmethod - def address_from_digest_trits(digest_trits, key_index): - # type: (List[int], int) -> Address + def address_from_digest(digest): + # type: (Digest) -> Address """ Generates an address from a private key digest. """ address_trits = [0] * (Address.LEN * TRITS_PER_TRYTE) # type: MutableSequence[int] sponge = Curl() - sponge.absorb(digest_trits) + sponge.absorb(digest.as_trits()) sponge.squeeze(address_trits) address = Address.from_trits(address_trits) - address.key_index = key_index + address.key_index = digest.key_index return address @@ -276,16 +277,16 @@ def _generate_address(self, key_iterator): Used in the event of a cache miss. """ - return self.address_from_digest_trits(*self._get_digest_params(key_iterator)) + return self.address_from_digest(self._get_digest(key_iterator)) @staticmethod - def _get_digest_params(key_iterator): - # type: (KeyIterator) -> Tuple[List[int], int] + def _get_digest(key_iterator): + # type: (KeyIterator) -> Digest """ - Extracts parameters for :py:meth:`address_from_digest_trits`. + Extracts parameters for :py:meth:`address_from_digest`. Split into a separate method so that it can be mocked during unit tests. """ private_key = next(key_iterator) # type: PrivateKey - return private_key.get_digest_trits(), private_key.key_index + return private_key.get_digest() diff --git a/src/iota/crypto/types.py b/src/iota/crypto/types.py index 93a44d8..d7185df 100644 --- a/src/iota/crypto/types.py +++ b/src/iota/crypto/types.py @@ -4,20 +4,33 @@ from math import ceil from os import urandom -from typing import Callable, List, MutableSequence, Optional, Tuple from six import binary_type +from typing import Callable, MutableSequence, Optional, Tuple from iota import Hash, TryteString, TrytesCompatible from iota.crypto import Curl, FRAGMENT_LENGTH, HASH_LENGTH from iota.exceptions import with_context __all__ = [ + 'Digest', 'PrivateKey', 'Seed', ] +class Digest(TryteString): + """ + A private key digest. Basically the same thing as a regular + `TryteString`, except that it has a key index associated with it. + """ + def __init__(self, trytes, key_index): + # type: (TrytesCompatible, int) -> None + super(Digest, self).__init__(trytes) + + self.key_index = key_index + + class Seed(TryteString): """ A TryteString that acts as a seed for crypto functions. @@ -73,8 +86,8 @@ def __init__(self, trytes, key_index=None): self.key_index = key_index - def get_digest_trits(self): - # type: () -> List[int] + def get_digest(self): + # type: () -> Digest """ Generates the digest used to do the actual signing. @@ -119,4 +132,4 @@ def get_digest_trits(self): digest[fragment_start:fragment_end] = hash_trits - return digest + return Digest(TryteString.from_trits(digest), self.key_index) diff --git a/test/crypto/addresses_test.py b/test/crypto/addresses_test.py index 24d3679..8965f60 100644 --- a/test/crypto/addresses_test.py +++ b/test/crypto/addresses_test.py @@ -4,15 +4,14 @@ from threading import Thread from time import sleep -from typing import List, Tuple from unittest import TestCase from mock import Mock, patch -from iota import Address, Hash +from iota import Address from iota.crypto.addresses import AddressGenerator, MemoryAddressCache from iota.crypto.signing import KeyIterator -from iota.crypto.types import Seed +from iota.crypto.types import Digest, Seed class AddressGeneratorTestCase(TestCase): @@ -21,7 +20,7 @@ def setUp(self): super(AddressGeneratorTestCase, self).setUp() # Addresses that correspond to the digests defined in - # :py:meth:`_mock_get_digest_params`. + # :py:meth:`_mock_get_digest`. self.addy0 =\ Address( b'VOPYUSDRHYGGOHLAYDWCLLOFWBLK99PYYKENW9IQ' @@ -52,13 +51,16 @@ def test_address_from_digest(self): Generating an address from a private key digest. """ digest =\ - Hash( - b'ABQXVJNER9MPMXMBPNMFBMDGTXRWSYHNZKGAGUOI' - b'JKOJGZVGHCUXXGFZEMMGDSGWDCKJXO9ILLFAKGGZE' + Digest( + trytes = + b'ABQXVJNER9MPMXMBPNMFBMDGTXRWSYHNZKGAGUOI' + b'JKOJGZVGHCUXXGFZEMMGDSGWDCKJXO9ILLFAKGGZE', + + key_index = 0, ) self.assertEqual( - AddressGenerator.address_from_digest_trits(digest.as_trits(), 0), + AddressGenerator.address_from_digest(digest), Address( b'QLOEDSBXXOLLUJYLEGKEPYDRIJJTPIMEPKMFHUVJ' @@ -75,13 +77,13 @@ def test_get_addresses_single(self): ag = AddressGenerator(seed=b'') # noinspection PyUnresolvedReferences - with patch.object(ag, '_get_digest_params', self._mock_get_digest_params): + with patch.object(ag, '_get_digest', self._mock_get_digest): addresses = ag.get_addresses(start=0) self.assertListEqual(addresses, [self.addy0]) # noinspection PyUnresolvedReferences - with patch.object(ag, '_get_digest_params', self._mock_get_digest_params): + with patch.object(ag, '_get_digest', self._mock_get_digest): # You can provide any positive integer as the ``start`` value. addresses = ag.get_addresses(start=2) @@ -96,7 +98,7 @@ def test_get_addresses_multiple(self): ag = AddressGenerator(seed=b'') # noinspection PyUnresolvedReferences - with patch.object(ag, '_get_digest_params', self._mock_get_digest_params): + with patch.object(ag, '_get_digest', self._mock_get_digest): addresses = ag.get_addresses(start=1, count=2) self.assertListEqual(addresses, [self.addy1, self.addy2]) @@ -145,7 +147,7 @@ def test_get_addresses_step_negative(self): ag = AddressGenerator(seed=b'') # noinspection PyUnresolvedReferences - with patch.object(ag, '_get_digest_params', self._mock_get_digest_params): + with patch.object(ag, '_get_digest', self._mock_get_digest): addresses = ag.get_addresses(start=1, count=2, step=-1) self.assertListEqual( @@ -165,7 +167,7 @@ def test_generator(self): ag = AddressGenerator(seed=b'') # noinspection PyUnresolvedReferences - with patch.object(ag, '_get_digest_params', self._mock_get_digest_params): + with patch.object(ag, '_get_digest', self._mock_get_digest): generator = ag.create_iterator() self.assertEqual(next(generator), self.addy0) @@ -181,15 +183,15 @@ def test_generator_with_offset(self): ag = AddressGenerator(seed=b'') # noinspection PyUnresolvedReferences - with patch.object(ag, '_get_digest_params', self._mock_get_digest_params): + with patch.object(ag, '_get_digest', self._mock_get_digest): generator = ag.create_iterator(start=1, step=2) self.assertEqual(next(generator), self.addy1) self.assertEqual(next(generator), self.addy3) @staticmethod - def _mock_get_digest_params(key_iterator): - # type: (KeyIterator) -> Tuple[List[int], int] + def _mock_get_digest(key_iterator): + # type: (KeyIterator) -> Digest """ Mocks the behavior of :py:class:`KeyGenerator`, to speed up unit tests. @@ -220,7 +222,7 @@ def _mock_get_digest_params(key_iterator): # This should still behave like the real thing, so that we can # verify that :py:class`AddressGenerator` is invoking the key # generator correctly. - return Hash(digests[key_index]).as_trits(), key_index + return Digest(digests[key_index], key_index) class MemoryAddressCacheTestCase(TestCase): diff --git a/test/crypto/types_test.py b/test/crypto/types_test.py index b52ece0..ae851a9 100644 --- a/test/crypto/types_test.py +++ b/test/crypto/types_test.py @@ -10,9 +10,9 @@ # noinspection SpellCheckingInspection class PrivateKeyTestCase(TestCase): - def test_get_digest_trits_single_fragment(self): + def test_get_digest_single_fragment(self): """ - Generating digest trits from a PrivateKey 1 fragment long. + Generating digest from a PrivateKey 1 fragment long. """ key =\ PrivateKey( @@ -53,7 +53,7 @@ def test_get_digest_trits_single_fragment(self): ) self.assertEqual( - TryteString.from_trits(key.get_digest_trits()), + key.get_digest(), TryteString( b'ABQXVJNER9MPMXMBPNMFBMDGTXRWSYHNZKGAGUOI' @@ -61,9 +61,9 @@ def test_get_digest_trits_single_fragment(self): ), ) - def test_get_digest_trits_multiple_fragments(self): + def test_get_digest_multiple_fragments(self): """ - Generating digest trits from a PrivateKey longer than 1 fragment. + Generating digest from a PrivateKey longer than 1 fragment. """ key =\ PrivateKey( @@ -137,7 +137,7 @@ def test_get_digest_trits_multiple_fragments(self): ) self.assertEqual( - TryteString.from_trits(key.get_digest_trits()), + key.get_digest(), # Note that the digest is 2 hashes long, because the key # is 2 fragments long.