Skip to content

Commit

Permalink
Merge pull request iotaledger#36 from iotaledger/release/1.1.2
Browse files Browse the repository at this point in the history
1.1.2
  • Loading branch information
todofixthis authored Mar 25, 2017
2 parents 93f67cf + ffcbf18 commit 59bde1b
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 22 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
name = 'PyOTA',
description = 'IOTA API library for Python',
url = 'https://github.com/iotaledger/iota.lib.py',
version = '1.1.1',
version = '1.1.2',

packages = find_packages('src'),
include_package_data = True,
Expand Down
7 changes: 5 additions & 2 deletions src/iota/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,8 +822,8 @@ def send_transfer(
minWeightMagnitude = min_weight_magnitude,
)

def send_trytes(self, trytes, depth, min_weight_magnitude=18):
# type: (Iterable[TransactionTrytes], int, int) -> dict
def send_trytes(self, trytes, depth, min_weight_magnitude=None):
# type: (Iterable[TransactionTrytes], int, Optional[int]) -> dict
"""
Attaches transaction trytes to the Tangle, then broadcasts and
stores them.
Expand Down Expand Up @@ -851,6 +851,9 @@ def send_trytes(self, trytes, depth, min_weight_magnitude=18):
References:
- https://github.com/iotaledger/wiki/blob/master/api-proposal.md#sendtrytes
"""
if min_weight_magnitude is None:
min_weight_magnitude = self.default_min_weight_magnitude

return extended.SendTrytesCommand(self.adapter)(
trytes = trytes,
depth = depth,
Expand Down
67 changes: 48 additions & 19 deletions src/iota/crypto/addresses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

import hashlib
from abc import ABCMeta, abstractmethod as abstract_method
from typing import Dict, Generator, Iterable, List, MutableSequence, Optional
from contextlib import contextmanager as context_manager
from threading import Lock
from typing import Dict, Generator, Iterable, List, MutableSequence, \
Optional, Tuple

from iota import Address, TRITS_PER_TRYTE, TrytesCompatible
from iota.crypto import Curl
Expand All @@ -23,6 +26,19 @@ class BaseAddressCache(with_metaclass(ABCMeta)):
"""
Base functionality for classes that cache generated addresses.
"""
LockType = Lock
"""
The type of locking mechanism used by :py:meth:`acquire_lock`.
Defaults to ``threading.Lock``, but you can change it if you want to
use a different mechanism (e.g., multithreading or distributed).
"""

def __init__(self):
super(BaseAddressCache, self).__init__()

self._lock = self.LockType()

@abstract_method
def get(self, seed, index):
# type: (Seed, int) -> Optional[Address]
Expand All @@ -34,6 +50,18 @@ def get(self, seed, index):
'Not implemented in {cls}.'.format(cls=type(self).__name__),
)

@context_manager
def acquire_lock(self):
"""
Acquires a lock on the cache instance, to prevent invalid cache
misses when multiple threads access the cache concurrently.
Note: Acquire lock before checking the cache, and do not release it
until after the cache hit/miss is resolved.
"""
with self._lock:
yield

@abstract_method
def set(self, seed, index, address):
# type: (Seed, int, Address) -> None
Expand All @@ -45,6 +73,17 @@ def set(self, seed, index, address):
'Not implemented in {cls}.'.format(cls=type(self).__name__),
)

@staticmethod
def _gen_cache_key(seed, index):
# type: (Seed, int) -> binary_type
"""
Generates an obfuscated cache key so that we're not storing seeds
in cleartext.
"""
h = hashlib.new('sha256')
h.update(binary_type(seed) + b':' + binary_type(index))
return h.digest()


class MemoryAddressCache(BaseAddressCache):
"""
Expand All @@ -63,17 +102,6 @@ def set(self, seed, index, address):
# type: (Seed, int, Address) -> None
self.cache[self._gen_cache_key(seed, index)] = address

@staticmethod
def _gen_cache_key(seed, index):
# type: (Seed, int) -> binary_type
"""
Generates an obfuscated cache key so that we're not storing seeds
in cleartext.
"""
h = hashlib.new('sha256')
h.update(binary_type(seed) + b':' + binary_type(index))
return h.digest()


class AddressGenerator(Iterable[Address]):
"""
Expand Down Expand Up @@ -213,18 +241,19 @@ def create_iterator(self, start=0, step=1):

while True:
if self.cache:
address = self.cache.get(self.seed, key_iterator.current)
with self.cache.acquire_lock():
address = self.cache.get(self.seed, key_iterator.current)

if not address:
address = self._generate_address(key_iterator)
self.cache.set(self.seed, address.key_index, address)
if not address:
address = self._generate_address(key_iterator)
self.cache.set(self.seed, address.key_index, address)
else:
address = self._generate_address(key_iterator)

yield address

@staticmethod
def address_from_digest(digest_trits, key_index):
def address_from_digest_trits(digest_trits, key_index):
# type: (List[int], int) -> Address
"""
Generates an address from a private key digest.
Expand All @@ -247,13 +276,13 @@ def _generate_address(self, key_iterator):
Used in the event of a cache miss.
"""
return self.address_from_digest(*self._get_digest_params(key_iterator))
return self.address_from_digest_trits(*self._get_digest_params(key_iterator))

@staticmethod
def _get_digest_params(key_iterator):
# type: (KeyIterator) -> Tuple[List[int], int]
"""
Extracts parameters for :py:meth:`address_from_digest`.
Extracts parameters for :py:meth:`address_from_digest_trits`.
Split into a separate method so that it can be mocked during unit
tests.
Expand Down
72 changes: 72 additions & 0 deletions test/crypto/addresses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from __future__ import absolute_import, division, print_function, \
unicode_literals

from threading import Thread
from time import sleep
from typing import List, Tuple
from unittest import TestCase

Expand Down Expand Up @@ -44,6 +46,26 @@ def setUp(self):
b'CFANWBQFGMFKITZBJDSYLGXYUIQVCMXFWSWFRNHRV'
)

# noinspection SpellCheckingInspection
def test_address_from_digest(self):
"""
Generating an address from a private key digest.
"""
digest =\
Hash(
b'ABQXVJNER9MPMXMBPNMFBMDGTXRWSYHNZKGAGUOI'
b'JKOJGZVGHCUXXGFZEMMGDSGWDCKJXO9ILLFAKGGZE'
)

self.assertEqual(
AddressGenerator.address_from_digest_trits(digest.as_trits(), 0),

Address(
b'QLOEDSBXXOLLUJYLEGKEPYDRIJJTPIMEPKMFHUVJ'
b'MPMLYYCLPQPANEVDSERQWPVNHCAXYRLAYMBHJLWWR'
),
)

def test_get_addresses_single(self):
"""
Generating a single address.
Expand Down Expand Up @@ -329,3 +351,53 @@ def test_cache_miss_seed(self):
generator2 = AddressGenerator(Seed.random())
generator2.get_addresses(42)
self.assertEqual(mock_generate_address.call_count, 2)

def test_thread_safety(self):
"""
Address cache is thread-safe, eliminating invalid cache misses when
multiple threads attempt to access the cache concurrently.
"""
AddressGenerator.cache = MemoryAddressCache()

seed = Seed.random()

generated = []

def get_address():
generator = AddressGenerator(seed)
generated.extend(generator.get_addresses(0))

# noinspection PyUnusedLocal
def mock_generate_address(address_generator, key_iterator):
# type: (AddressGenerator, KeyIterator) -> Address
# Insert a teensy delay, to make it more likely that multiple
# threads hit the cache concurrently.
sleep(0.01)

# Note that in this test, the address generator always returns a
# new instance.
return Address(self.addy, key_index=key_iterator.current)

with patch(
'iota.crypto.addresses.AddressGenerator._generate_address',
mock_generate_address,
):
threads = [Thread(target=get_address) for _ in range(100)]

for t in threads:
t.start()

for t in threads:
t.join()

# Quick sanity check.
self.assertEqual(len(generated), len(threads))

# If the cache is operating in a thread-safe manner, then it will
# always return the exact same instance, given the same seed and
# key index.
expected = generated[0]
for actual in generated[1:]:
# Compare `id` values instead of using ``self.assertIs`` because
# the failure message is a bit easier to understand.
self.assertEqual(id(actual), id(expected))

0 comments on commit 59bde1b

Please sign in to comment.