diff --git a/src/jmclient/configure.py b/src/jmclient/configure.py index ac0b29f6b..14482ff8d 100644 --- a/src/jmclient/configure.py +++ b/src/jmclient/configure.py @@ -272,6 +272,11 @@ def jm_single() -> AttributeDict: # scripts can use the command line flag `-g` instead. gaplimit = 6 +# Disable the caching of addresses and scripts when +# syncing the wallet. You DO NOT need to set this to 'true', +# unless there is an issue of file corruption or a code bug. +wallet_caching_disabled = false + # The fee estimate is based on a projection of how many sats/kilo-vbyte # are needed to get in one of the next N blocks. N is set here as # the value of 'tx_fees'. This cost estimate is high if you set diff --git a/src/jmclient/wallet.py b/src/jmclient/wallet.py index 691dafb15..dc2937a48 100644 --- a/src/jmclient/wallet.py +++ b/src/jmclient/wallet.py @@ -388,7 +388,7 @@ class BaseWallet(object): ADDRESS_TYPE_INTERNAL = 1 def __init__(self, storage, gap_limit=6, merge_algorithm_name=None, - mixdepth=None): + mixdepth=None, load_cache=True): # to be defined by inheriting classes assert self.TYPE is not None assert self._ENGINE is not None @@ -410,7 +410,7 @@ def __init__(self, storage, gap_limit=6, merge_algorithm_name=None, # {address: path}, should always hold mappings for all "known" keys self._addr_map = {} - self._load_storage() + self._load_storage(load_cache=load_cache) assert self._utxos is not None assert self._cache is not None @@ -440,7 +440,7 @@ def max_mix_depth(self): def gaplimit(self): return self.gap_limit - def _load_storage(self): + def _load_storage(self, load_cache: bool = True) -> None: """ load data from storage """ @@ -450,7 +450,10 @@ def _load_storage(self): self.network = self._storage.data[b'network'].decode('ascii') self._utxos = UTXOManager(self._storage, self.merge_algorithm) self._addr_labels = AddressLabelsManager(self._storage) - self._cache = self._storage.data.setdefault(b'cache', {}) + if load_cache: + self._cache = self._storage.data.setdefault(b'cache', {}) + else: + self._cache = {} def get_storage_location(self): """ Return the location of the @@ -1893,8 +1896,8 @@ def __init__(self, storage, **kwargs): # path is (_IMPORTED_ROOT_PATH, mixdepth, key_index) super().__init__(storage, **kwargs) - def _load_storage(self): - super()._load_storage() + def _load_storage(self, load_cache: bool = True) -> None: + super()._load_storage(load_cache=load_cache) self._imported = collections.defaultdict(list) for md, keys in self._storage.data[self._IMPORTED_STORAGE_KEY].items(): md = int(md) @@ -2070,8 +2073,8 @@ class BIP39WalletMixin(object): _BIP39_EXTENSION_KEY = b'seed_extension' MNEMONIC_LANG = 'english' - def _load_storage(self): - super()._load_storage() + def _load_storage(self, load_cache: bool = True) -> None: + super()._load_storage(load_cache=load_cache) self._entropy_extension = self._storage.data.get(self._BIP39_EXTENSION_KEY) @classmethod @@ -2177,8 +2180,8 @@ def initialize(cls, storage, network, max_mixdepth=2, timestamp=None, if write: storage.save() - def _load_storage(self): - super()._load_storage() + def _load_storage(self, load_cache: bool = True) -> None: + super()._load_storage(load_cache=load_cache) self._entropy = self._storage.data[self._STORAGE_ENTROPY_KEY] self._index_cache = collections.defaultdict( diff --git a/src/jmclient/wallet_utils.py b/src/jmclient/wallet_utils.py index 786637fcf..43210e1a8 100644 --- a/src/jmclient/wallet_utils.py +++ b/src/jmclient/wallet_utils.py @@ -1547,8 +1547,11 @@ def open_wallet(path, ask_for_password=True, password=None, read_only=False, else: storage = Storage(path, password, read_only=read_only) + load_cache = True + if jm_single().config.get("POLICY", "wallet_caching_disabled") == "true": + load_cache = False wallet_cls = get_wallet_cls_from_storage(storage) - wallet = wallet_cls(storage, **kwargs) + wallet = wallet_cls(storage, load_cache=load_cache, **kwargs) wallet_sanity_check(wallet) return wallet diff --git a/test/jmclient/test_wallet.py b/test/jmclient/test_wallet.py index 45b23fa8e..5dffa71b8 100644 --- a/test/jmclient/test_wallet.py +++ b/test/jmclient/test_wallet.py @@ -12,7 +12,7 @@ SegwitLegacyWallet,BIP32Wallet, BIP49Wallet, LegacyWallet,\ VolatileStorage, get_network, cryptoengine, WalletError,\ SegwitWallet, WalletService, SegwitWalletFidelityBonds,\ - create_wallet, open_test_wallet_maybe, \ + create_wallet, open_test_wallet_maybe, open_wallet, \ FidelityBondMixin, FidelityBondWatchonlyWallet,\ wallet_gettimelockaddress, UnknownAddressForLabel from test_blockchaininterface import sync_test_wallet @@ -23,7 +23,7 @@ testdir = os.path.dirname(os.path.realpath(__file__)) test_create_wallet_filename = "testwallet_for_create_wallet_test" - +test_cache_cleared_filename = "testwallet_for_cache_clear_test" log = get_log() @@ -764,6 +764,50 @@ def test_wallet_id(setup_wallet): assert wallet1.get_wallet_id() == wallet2.get_wallet_id() +def test_cache_cleared(setup_wallet): + # test plan: + # 1. create a new wallet and sync from scratch + # 2. read its cache as an object + # 3. close the wallet, reopen it, sync it. + # 4. corrupt its cache and save. + # 5. Re open the wallet with recoversync + # and check that the corrupted data is not present. + if os.path.exists(test_cache_cleared_filename): + os.remove(test_cache_cleared_filename) + wallet = create_wallet(test_cache_cleared_filename, + b"hunter2", 2, SegwitWallet) + # note: we use the WalletService as an encapsulation + # of the wallet here because we want to be able to sync, + # but we do not actually start the service and go into + # the monitoring loop. + wallet_service = WalletService(wallet) + # default fast sync, no coins, so no loop + wallet_service.sync_wallet() + wallet_service.update_blockheight() + # to get the cache to save, we need to + # use an address: + addr = wallet_service.get_new_addr(0,0) + jm_single().bc_interface.grab_coins(addr, 1.0) + wallet_service.transaction_monitor() + path_to_corrupt = list(wallet._cache.keys())[0] + # we'll just corrupt the first address and script: + entry_to_corrupt = wallet._cache[path_to_corrupt][b"84'"][b"1'"][b"0'"][b'0'][b'0'] + entry_to_corrupt[b'A'] = "notanaddress" + entry_to_corrupt[b'S'] = "notascript" + wallet_service.wallet.save() + wallet_service.wallet.close() + jm_single().config.set("POLICY", "wallet_caching_disabled", "true") + wallet2 = open_wallet(test_cache_cleared_filename, + ask_for_password=False, + password=b"hunter2") + jm_single().config.set("POLICY", "wallet_caching_disabled", "false") + wallet_service2 = WalletService(wallet2) + while not wallet_service2.synced: + wallet_service2.sync_wallet(fast=False) + wallet_service.transaction_monitor() + # we ignored the corrupt cache? + assert wallet_service2.get_balance_at_mixdepth(0) == 10 ** 8 + def test_addr_script_conversion(setup_wallet): wallet = get_populated_wallet(num=1) @@ -1016,4 +1060,6 @@ def setup_wallet(request): def teardown(): if os.path.exists(test_create_wallet_filename): os.remove(test_create_wallet_filename) + if os.path.exists(test_cache_cleared_filename): + os.remove(test_cache_cleared_filename) request.addfinalizer(teardown)