From 023fc810e912f203141dc83931eeab29579bdc1f Mon Sep 17 00:00:00 2001 From: callebtc <93376500+callebtc@users.noreply.github.com> Date: Sat, 23 Sep 2023 18:22:12 +0200 Subject: [PATCH] refactor proof invalidation --- cashu/wallet/wallet.py | 400 ++++++++++++++++++++++++++--------------- tests/test_wallet.py | 1 + 2 files changed, 256 insertions(+), 145 deletions(-) diff --git a/cashu/wallet/wallet.py b/cashu/wallet/wallet.py index edf9a5df..de09009f 100644 --- a/cashu/wallet/wallet.py +++ b/cashu/wallet/wallet.py @@ -124,121 +124,121 @@ async def _init_s(self): """Dummy function that can be called from outside to use LedgerAPI.s""" return - # ---------- DLEQ PROOFS ---------- - - def verify_proofs_dleq(self, proofs: List[Proof]): - """Verifies DLEQ proofs in proofs.""" - for proof in proofs: - if not proof.dleq: - logger.trace("No DLEQ proof in proof.") - return - logger.trace("Verifying DLEQ proof.") - assert self.keys.public_keys - if not b_dhke.carol_verify_dleq( - secret_msg=proof.secret, - C=PublicKey(bytes.fromhex(proof.C), raw=True), - r=PrivateKey(bytes.fromhex(proof.dleq.r), raw=True), - e=PrivateKey(bytes.fromhex(proof.dleq.e), raw=True), - s=PrivateKey(bytes.fromhex(proof.dleq.s), raw=True), - A=self.keys.public_keys[proof.amount], - ): - raise Exception("DLEQ proof invalid.") - else: - logger.debug("DLEQ proof valid.") - - def _construct_proofs( - self, - promises: List[BlindedSignature], - secrets: List[str], - rs: List[PrivateKey], - derivation_paths: List[str], - ) -> List[Proof]: - """Constructs proofs from promises, secrets, rs and derivation paths. - - This method is called after the user has received blind signatures from - the mint. The results are proofs that can be used as ecash. - - Args: - promises (List[BlindedSignature]): blind signatures from mint - secrets (List[str]): secrets that were previously used to create blind messages (that turned into promises) - rs (List[PrivateKey]): blinding factors that were previously used to create blind messages (that turned into promises) - derivation_paths (List[str]): derivation paths that were used to generate secrets and blinding factors - - Returns: - List[Proof]: list of proofs that can be used as ecash - """ - logger.trace("Constructing proofs.") - proofs: List[Proof] = [] - for promise, secret, r, path in zip(promises, secrets, rs, derivation_paths): - logger.trace(f"Creating proof with keyset {self.keyset_id} = {promise.id}") - assert ( - self.keyset_id == promise.id - ), "our keyset id does not match promise id." - - C_ = PublicKey(bytes.fromhex(promise.C_), raw=True) - C = b_dhke.step3_alice(C_, r, self.public_keys[promise.amount]) - B_, r = b_dhke.step1_alice(secret, r) # recompute B_ for dleq proofs - - proof = Proof( - id=promise.id, - amount=promise.amount, - C=C.serialize().hex(), - secret=secret, - derivation_path=path, - ) - - # if the mint returned a dleq proof, we add it to the proof - if promise.dleq: - proof.dleq = DLEQWallet( - e=promise.dleq.e, s=promise.dleq.s, r=r.serialize() - ) - - proofs.append(proof) - - logger.trace( - f"Created proof: {proof}, r: {r.serialize()} out of promise {promise}" - ) - - # DLEQ verify - self.verify_proofs_dleq(proofs) - - logger.trace(f"Constructed {len(proofs)} proofs.") - return proofs - - @staticmethod - def _construct_outputs( - amounts: List[int], secrets: List[str], rs: List[PrivateKey] = [] - ) -> Tuple[List[BlindedMessage], List[PrivateKey]]: - """Takes a list of amounts and secrets and returns outputs. - Outputs are blinded messages `outputs` and blinding factors `rs` - - Args: - amounts (List[int]): list of amounts - secrets (List[str]): list of secrets - rs (List[PrivateKey], optional): list of blinding factors. If not given, `rs` are generated in step1_alice. Defaults to []. - - Returns: - List[BlindedMessage]: list of blinded messages that can be sent to the mint - List[PrivateKey]: list of blinding factors that can be used to construct proofs after receiving blind signatures from the mint - - Raises: - AssertionError: if len(amounts) != len(secrets) - """ - assert len(amounts) == len( - secrets - ), f"len(amounts)={len(amounts)} not equal to len(secrets)={len(secrets)}" - outputs: List[BlindedMessage] = [] - - rs_ = [None] * len(amounts) if not rs else rs - rs_return: List[PrivateKey] = [] - for secret, amount, r in zip(secrets, amounts, rs_): - B_, r = b_dhke.step1_alice(secret, r or None) - rs_return.append(r) - output = BlindedMessage(amount=amount, B_=B_.serialize().hex()) - outputs.append(output) - logger.trace(f"Constructing output: {output}, r: {r.serialize()}") - - return outputs, rs_return + # # ---------- DLEQ PROOFS ---------- + + # def verify_proofs_dleq(self, proofs: List[Proof]): + # """Verifies DLEQ proofs in proofs.""" + # for proof in proofs: + # if not proof.dleq: + # logger.trace("No DLEQ proof in proof.") + # return + # logger.trace("Verifying DLEQ proof.") + # assert self.keys.public_keys + # if not b_dhke.carol_verify_dleq( + # secret_msg=proof.secret, + # C=PublicKey(bytes.fromhex(proof.C), raw=True), + # r=PrivateKey(bytes.fromhex(proof.dleq.r), raw=True), + # e=PrivateKey(bytes.fromhex(proof.dleq.e), raw=True), + # s=PrivateKey(bytes.fromhex(proof.dleq.s), raw=True), + # A=self.keys.public_keys[proof.amount], + # ): + # raise Exception("DLEQ proof invalid.") + # else: + # logger.debug("DLEQ proof valid.") + + # def _construct_proofs( + # self, + # promises: List[BlindedSignature], + # secrets: List[str], + # rs: List[PrivateKey], + # derivation_paths: List[str], + # ) -> List[Proof]: + # """Constructs proofs from promises, secrets, rs and derivation paths. + + # This method is called after the user has received blind signatures from + # the mint. The results are proofs that can be used as ecash. + + # Args: + # promises (List[BlindedSignature]): blind signatures from mint + # secrets (List[str]): secrets that were previously used to create blind messages (that turned into promises) + # rs (List[PrivateKey]): blinding factors that were previously used to create blind messages (that turned into promises) + # derivation_paths (List[str]): derivation paths that were used to generate secrets and blinding factors + + # Returns: + # List[Proof]: list of proofs that can be used as ecash + # """ + # logger.trace("Constructing proofs.") + # proofs: List[Proof] = [] + # for promise, secret, r, path in zip(promises, secrets, rs, derivation_paths): + # logger.trace(f"Creating proof with keyset {self.keyset_id} = {promise.id}") + # assert ( + # self.keyset_id == promise.id + # ), "our keyset id does not match promise id." + + # C_ = PublicKey(bytes.fromhex(promise.C_), raw=True) + # C = b_dhke.step3_alice(C_, r, self.public_keys[promise.amount]) + # B_, r = b_dhke.step1_alice(secret, r) # recompute B_ for dleq proofs + + # proof = Proof( + # id=promise.id, + # amount=promise.amount, + # C=C.serialize().hex(), + # secret=secret, + # derivation_path=path, + # ) + + # # if the mint returned a dleq proof, we add it to the proof + # if promise.dleq: + # proof.dleq = DLEQWallet( + # e=promise.dleq.e, s=promise.dleq.s, r=r.serialize() + # ) + + # proofs.append(proof) + + # logger.trace( + # f"Created proof: {proof}, r: {r.serialize()} out of promise {promise}" + # ) + + # # DLEQ verify + # self.verify_proofs_dleq(proofs) + + # logger.trace(f"Constructed {len(proofs)} proofs.") + # return proofs + + # @staticmethod + # def _construct_outputs( + # amounts: List[int], secrets: List[str], rs: List[PrivateKey] = [] + # ) -> Tuple[List[BlindedMessage], List[PrivateKey]]: + # """Takes a list of amounts and secrets and returns outputs. + # Outputs are blinded messages `outputs` and blinding factors `rs` + + # Args: + # amounts (List[int]): list of amounts + # secrets (List[str]): list of secrets + # rs (List[PrivateKey], optional): list of blinding factors. If not given, `rs` are generated in step1_alice. Defaults to []. + + # Returns: + # List[BlindedMessage]: list of blinded messages that can be sent to the mint + # List[PrivateKey]: list of blinding factors that can be used to construct proofs after receiving blind signatures from the mint + + # Raises: + # AssertionError: if len(amounts) != len(secrets) + # """ + # assert len(amounts) == len( + # secrets + # ), f"len(amounts)={len(amounts)} not equal to len(secrets)={len(secrets)}" + # outputs: List[BlindedMessage] = [] + + # rs_ = [None] * len(amounts) if not rs else rs + # rs_return: List[PrivateKey] = [] + # for secret, amount, r in zip(secrets, amounts, rs_): + # B_, r = b_dhke.step1_alice(secret, r or None) + # rs_return.append(r) + # output = BlindedMessage(amount=amount, B_=B_.serialize().hex()) + # outputs.append(output) + # logger.trace(f"Constructing output: {output}, r: {r.serialize()}") + + # return outputs, rs_return @staticmethod def raise_on_error(resp: Response) -> None: @@ -761,16 +761,12 @@ async def mint( await bump_secret_derivation( db=self.db, keyset_id=self.keyset_id, by=len(amounts) ) - proofs = self._construct_proofs(promises, secrets, rs, derivation_paths) + proofs = await self._construct_proofs(promises, secrets, rs, derivation_paths) - if proofs == []: - raise Exception("received no proofs.") - await self._store_proofs(proofs) if hash: await update_lightning_invoice( db=self.db, hash=hash, paid=True, time_paid=int(time.time()) ) - self.proofs += proofs return proofs async def redeem( @@ -861,18 +857,11 @@ async def split( promises = await super().split(proofs, outputs) # Construct proofs from returned promises (i.e., unblind the signatures) - new_proofs = self._construct_proofs(promises, secrets, rs, derivation_paths) + new_proofs = await self._construct_proofs( + promises, secrets, rs, derivation_paths + ) - # remove used proofs from wallet and add new ones - used_secrets = [p.secret for p in proofs] - self.proofs = list(filter(lambda p: p.secret not in used_secrets, self.proofs)) - # add new proofs to wallet - self.proofs += new_proofs - # store new proofs in database - await self._store_proofs(new_proofs) - # invalidate used proofs in database - for proof in proofs: - await invalidate_proof(proof, db=self.db) + await self.invalidate(proofs) keep_proofs = new_proofs[: len(frst_outputs)] send_proofs = new_proofs[len(frst_outputs) :] @@ -901,7 +890,6 @@ async def pay_lightning( if status.paid: # the payment was successful - await self.invalidate(proofs) invoice_obj = Invoice( amount=-sum_proofs(proofs), pr=invoice, @@ -916,14 +904,15 @@ async def pay_lightning( # handle change and produce proofs if status.change: - change_proofs = self._construct_proofs( + change_proofs = await self._construct_proofs( status.change, secrets[: len(status.change)], rs[: len(status.change)], derivation_paths[: len(status.change)], ) logger.debug(f"Received change: {sum_proofs(change_proofs)} sat") - await self._store_proofs(change_proofs) + + await self.invalidate(proofs) else: raise Exception("could not pay invoice.") @@ -934,10 +923,137 @@ async def check_proof_state(self, proofs): # ---------- TOKEN MECHANICS ---------- + # ---------- DLEQ PROOFS ---------- + + def verify_proofs_dleq(self, proofs: List[Proof]): + """Verifies DLEQ proofs in proofs.""" + for proof in proofs: + if not proof.dleq: + logger.trace("No DLEQ proof in proof.") + return + logger.trace("Verifying DLEQ proof.") + assert self.keys.public_keys + if not b_dhke.carol_verify_dleq( + secret_msg=proof.secret, + C=PublicKey(bytes.fromhex(proof.C), raw=True), + r=PrivateKey(bytes.fromhex(proof.dleq.r), raw=True), + e=PrivateKey(bytes.fromhex(proof.dleq.e), raw=True), + s=PrivateKey(bytes.fromhex(proof.dleq.s), raw=True), + A=self.keys.public_keys[proof.amount], + ): + raise Exception("DLEQ proof invalid.") + else: + logger.debug("DLEQ proof valid.") + + async def _construct_proofs( + self, + promises: List[BlindedSignature], + secrets: List[str], + rs: List[PrivateKey], + derivation_paths: List[str], + ) -> List[Proof]: + """Constructs proofs from promises, secrets, rs and derivation paths. + + This method is called after the user has received blind signatures from + the mint. The results are proofs that can be used as ecash. + + Args: + promises (List[BlindedSignature]): blind signatures from mint + secrets (List[str]): secrets that were previously used to create blind messages (that turned into promises) + rs (List[PrivateKey]): blinding factors that were previously used to create blind messages (that turned into promises) + derivation_paths (List[str]): derivation paths that were used to generate secrets and blinding factors + + Returns: + List[Proof]: list of proofs that can be used as ecash + """ + logger.trace("Constructing proofs.") + proofs: List[Proof] = [] + for promise, secret, r, path in zip(promises, secrets, rs, derivation_paths): + logger.trace(f"Creating proof with keyset {self.keyset_id} = {promise.id}") + assert ( + self.keyset_id == promise.id + ), "our keyset id does not match promise id." + + C_ = PublicKey(bytes.fromhex(promise.C_), raw=True) + C = b_dhke.step3_alice(C_, r, self.public_keys[promise.amount]) + B_, r = b_dhke.step1_alice(secret, r) # recompute B_ for dleq proofs + + proof = Proof( + id=promise.id, + amount=promise.amount, + C=C.serialize().hex(), + secret=secret, + derivation_path=path, + ) + + # if the mint returned a dleq proof, we add it to the proof + if promise.dleq: + proof.dleq = DLEQWallet( + e=promise.dleq.e, s=promise.dleq.s, r=r.serialize() + ) + + proofs.append(proof) + + logger.trace( + f"Created proof: {proof}, r: {r.serialize()} out of promise {promise}" + ) + + # DLEQ verify + self.verify_proofs_dleq(proofs) + + logger.trace(f"Constructed {len(proofs)} proofs.") + + # add new proofs to wallet + self.proofs += proofs + # store new proofs in database + await self._store_proofs(proofs) + + return proofs + + @staticmethod + def _construct_outputs( + amounts: List[int], secrets: List[str], rs: List[PrivateKey] = [] + ) -> Tuple[List[BlindedMessage], List[PrivateKey]]: + """Takes a list of amounts and secrets and returns outputs. + Outputs are blinded messages `outputs` and blinding factors `rs` + + Args: + amounts (List[int]): list of amounts + secrets (List[str]): list of secrets + rs (List[PrivateKey], optional): list of blinding factors. If not given, `rs` are generated in step1_alice. Defaults to []. + + Returns: + List[BlindedMessage]: list of blinded messages that can be sent to the mint + List[PrivateKey]: list of blinding factors that can be used to construct proofs after receiving blind signatures from the mint + + Raises: + AssertionError: if len(amounts) != len(secrets) + """ + assert len(amounts) == len( + secrets + ), f"len(amounts)={len(amounts)} not equal to len(secrets)={len(secrets)}" + outputs: List[BlindedMessage] = [] + + rs_ = [None] * len(amounts) if not rs else rs + rs_return: List[PrivateKey] = [] + for secret, amount, r in zip(secrets, amounts, rs_): + B_, r = b_dhke.step1_alice(secret, r or None) + rs_return.append(r) + output = BlindedMessage(amount=amount, B_=B_.serialize().hex()) + outputs.append(output) + logger.trace(f"Constructing output: {output}, r: {r.serialize()}") + + return outputs, rs_return + async def _store_proofs(self, proofs): - async with self.db.connect() as conn: - for proof in proofs: - await store_proof(proof, db=self.db, conn=conn) + try: + async with self.db.connect() as conn: + for proof in proofs: + await store_proof(proof, db=self.db, conn=conn) + except Exception as e: + logger.error(f"Could not store proofs in database: {e}") + logger.error(proofs) + raise e @staticmethod def _get_proofs_per_keyset(proofs: List[Proof]): @@ -1172,7 +1288,7 @@ async def invalidate( invalidated_proofs = proofs if invalidated_proofs: - logger.debug( + logger.trace( f"Invalidating {len(invalidated_proofs)} proofs worth" f" {sum_proofs(invalidated_proofs)} sat." ) @@ -1379,14 +1495,8 @@ async def restore_promises( secrets = [secrets[i] for i in matching_indices] rs = [rs[i] for i in matching_indices] # now we can construct the proofs with the secrets and rs - proofs = self._construct_proofs( + proofs = await self._construct_proofs( restored_promises, secrets, rs, derivation_paths ) logger.debug(f"Restored {len(restored_promises)} promises") - await self._store_proofs(proofs) - - # append proofs to proofs in memory so the balance updates - for proof in proofs: - if proof.secret not in [p.secret for p in self.proofs]: - self.proofs.append(proof) return proofs diff --git a/tests/test_wallet.py b/tests/test_wallet.py index f4ec78b6..b4e5cb07 100644 --- a/tests/test_wallet.py +++ b/tests/test_wallet.py @@ -172,6 +172,7 @@ async def test_mint_amounts_wrong_order(wallet1: Wallet): @pytest.mark.asyncio async def test_split(wallet1: Wallet): await wallet1.mint(64) + assert wallet1.balance == 64 p1, p2 = await wallet1.split(wallet1.proofs, 20) assert wallet1.balance == 64 assert sum_proofs(p1) == 44