diff --git a/cashu/wallet/wallet.py b/cashu/wallet/wallet.py index c761e507..f73e0079 100644 --- a/cashu/wallet/wallet.py +++ b/cashu/wallet/wallet.py @@ -99,9 +99,12 @@ async def wrapper(self, *args, **kwargs): class LedgerAPI(object): - keys: WalletKeyset # holds current keys of mint + keysets: Dict[str, WalletKeyset] # holds keysets of mint + mint_keyset_ids: List[str] # holds ids of keysets of mint + # keys: WalletKeyset # holds current keys of mint keyset_id: str # holds id of current keyset - public_keys: Dict[int, PublicKey] # holds public keys of + # public_keys: Dict[int, PublicKey] # holds public keys of current keyset + mint_info: GetInfoResponse # holds info about mint tor: TorProxy s: requests.Session @@ -111,6 +114,7 @@ def __init__(self, url: str, db: Database): self.url = url self.s = requests.Session() self.db = db + self.keysets = {} @async_set_requests async def _init_s(self): @@ -137,7 +141,7 @@ def raise_on_error(resp: Response) -> None: # raise for status if no error resp.raise_for_status() - async def _load_mint_keys(self, keyset_id: str = "") -> WalletKeyset: + async def _load_mint_keys(self, keyset_id: Optional[str] = None) -> None: """Loads keys from mint and stores them in the database. Args: @@ -163,6 +167,16 @@ async def _load_mint_keys(self, keyset_id: str = "") -> WalletKeyset: assert keyset.id assert len(keyset.public_keys) > 0, "did not receive keys from mint." + if keyset_id != keyset.id: + # NOTE: Because of the upcoming change of how to calculate keyset ids + # with version 0.14.0, we overwrite the calculated keyset id with the + # requested one. This is a temporary fix and should be removed once all + # ecash is transitioned to 0.14.0. + logger.debug( + "Keyset ID mismatch. This can happen due to a version upgrade." + ) + keyset.id = keyset_id or keyset.id + # check if current keyset is in db keyset_local: Optional[WalletKeyset] = await get_keyset(keyset.id, db=self.db) # if not, store it @@ -170,13 +184,12 @@ async def _load_mint_keys(self, keyset_id: str = "") -> WalletKeyset: logger.debug(f"Storing new mint keyset: {keyset.id}") await store_keyset(keyset=keyset, db=self.db) - self.keys = keyset - assert self.keys.public_keys - self.public_keys = self.keys.public_keys - assert self.keys.id - self.keyset_id = self.keys.id - logger.debug(f"Current mint keyset: {self.keys.id}") - return self.keys + # set current keyset id + self.keyset_id = keyset.id + logger.debug(f"Current mint keyset: {self.keyset_id}") + + # add keyset to keysets dict + self.keysets[keyset.id] = keyset async def _load_mint_keysets(self) -> List[str]: """Loads the keyset IDs of the mint. @@ -191,11 +204,13 @@ async def _load_mint_keysets(self) -> List[str]: try: mint_keysets = await self._get_keyset_ids(self.url) except Exception: - assert self.keys.id, "could not get keysets from mint, and do not have keys" + assert self.keysets[ + self.keyset_id + ].id, "could not get keysets from mint, and do not have keys" pass - self.keysets = mint_keysets or [self.keys.id] - logger.debug(f"Mint keysets: {self.keysets}") - return self.keysets + self.mint_keyset_ids = mint_keysets or [self.keysets[self.keyset_id].id] + logger.debug(f"Mint keysets: {self.mint_keyset_ids}") + return self.mint_keyset_ids async def _load_mint_info(self) -> GetInfoResponse: """Loads the mint info from the mint.""" @@ -207,7 +222,7 @@ async def _load_mint(self, keyset_id: str = "") -> None: """ Loads the public keys of the mint. Either gets the keys for the specified `keyset_id` or gets the keys of the active keyset from the mint. - Gets the active keyset ids of the mint and stores in `self.keysets`. + Gets the active keyset ids of the mint and stores in `self.mint_keyset_ids`. """ await self._load_mint_keys(keyset_id) await self._load_mint_keysets() @@ -218,7 +233,9 @@ async def _load_mint(self, keyset_id: str = "") -> None: pass if keyset_id: - assert keyset_id in self.keysets, f"keyset {keyset_id} not active on mint" + assert ( + keyset_id in self.mint_keyset_ids + ), f"keyset {keyset_id} not active on mint" async def _check_used_secrets(self, secrets): """Checks if any of the secrets have already been used""" @@ -568,7 +585,7 @@ async def _migrate_database(self): async def load_mint(self, keyset_id: str = ""): """Load a mint's keys with a given keyset_id if specified or else loads the active keyset of the mint into self.keys. - Also loads all keyset ids into self.keysets. + Also loads all keyset ids into self.mint_keyset_ids. Args: keyset_id (str, optional): _description_. Defaults to "". @@ -818,14 +835,14 @@ def verify_proofs_dleq(self, proofs: List[Proof]): logger.trace("No DLEQ proof in proof.") return logger.trace("Verifying DLEQ proof.") - assert self.keys.public_keys + assert proof.id if not b_dhke.carol_verify_dleq( secret_msg=proof.secret, C=PublicKey(bytes.fromhex(proof.C), raw=True), r=PrivateKey(bytes.fromhex(proof.dleq.r), raw=True), e=PrivateKey(bytes.fromhex(proof.dleq.e), raw=True), s=PrivateKey(bytes.fromhex(proof.dleq.s), raw=True), - A=self.keys.public_keys[proof.amount], + A=self.keysets[proof.id].public_keys[proof.amount], ): raise Exception("DLEQ proof invalid.") else: @@ -855,13 +872,15 @@ async def _construct_proofs( logger.trace("Constructing proofs.") proofs: List[Proof] = [] for promise, secret, r, path in zip(promises, secrets, rs, derivation_paths): - logger.trace(f"Creating proof with keyset {self.keyset_id} = {promise.id}") - assert ( - self.keyset_id == promise.id - ), "our keyset id does not match promise id." + if promise.id not in self.keysets: + # we don't have the keyset for this promise, so we load it + await self._load_mint_keys(promise.id) + assert promise.id in self.keysets, "Could not load keyset." C_ = PublicKey(bytes.fromhex(promise.C_), raw=True) - C = b_dhke.step3_alice(C_, r, self.public_keys[promise.amount]) + C = b_dhke.step3_alice( + C_, r, self.keysets[promise.id].public_keys[promise.amount] + ) B_, r = b_dhke.step1_alice(secret, r) # recompute B_ for dleq proofs proof = Proof( @@ -1104,7 +1123,7 @@ async def _select_proofs_to_send( Rules: 1) Proofs that are not marked as reserved - 2) Proofs that have a keyset id that is in self.keysets (all active keysets of mint) + 2) Proofs that have a keyset id that is in self.mint_keyset_ids (all active keysets of mint) 3) Include all proofs that have an older keyset than the current keyset of the mint (to get rid of old epochs). 4) If the target amount is not reached, add proofs of the current keyset until it is. """ @@ -1114,19 +1133,23 @@ async def _select_proofs_to_send( proofs = [p for p in proofs if not p.reserved] # select proofs that are in the active keysets of the mint - proofs = [p for p in proofs if p.id in self.keysets or not p.id] + proofs = [p for p in proofs if p.id in self.mint_keyset_ids or not p.id] # check that enough spendable proofs exist if sum_proofs(proofs) < amount_to_send: raise Exception("balance too low.") # add all proofs that have an older keyset than the current keyset of the mint - proofs_old_epochs = [p for p in proofs if p.id != self.keys.id] + proofs_old_epochs = [ + p for p in proofs if p.id != self.keysets[self.keyset_id].id + ] send_proofs += proofs_old_epochs # coinselect based on amount only from the current keyset # start with the proofs with the largest amount and add them until the target amount is reached - proofs_current_epoch = [p for p in proofs if p.id == self.keys.id] + proofs_current_epoch = [ + p for p in proofs if p.id == self.keysets[self.keyset_id].id + ] sorted_proofs_of_current_keyset = sorted( proofs_current_epoch, key=lambda p: p.amount ) diff --git a/tests/test_wallet.py b/tests/test_wallet.py index b8519332..575e0bdf 100644 --- a/tests/test_wallet.py +++ b/tests/test_wallet.py @@ -89,8 +89,8 @@ async def wallet3(mint): @pytest.mark.asyncio async def test_get_keys(wallet1: Wallet): - assert wallet1.keys.public_keys - assert len(wallet1.keys.public_keys) == settings.max_order + assert wallet1.keysets[wallet1.keyset_id].public_keys + assert len(wallet1.keysets[wallet1.keyset_id].public_keys) == settings.max_order keyset = await wallet1._get_keys(wallet1.url) assert keyset.id is not None assert keyset.id == "1cCNIAZ2X/w1" @@ -100,8 +100,8 @@ async def test_get_keys(wallet1: Wallet): @pytest.mark.asyncio async def test_get_keyset(wallet1: Wallet): - assert wallet1.keys.public_keys - assert len(wallet1.keys.public_keys) == settings.max_order + assert wallet1.keysets[wallet1.keyset_id].public_keys + assert len(wallet1.keysets[wallet1.keyset_id].public_keys) == settings.max_order # let's get the keys first so we can get a keyset ID that we use later keys1 = await wallet1._get_keys(wallet1.url) # gets the keys of a specific keyset