From 3b740c04b62c4ec9b034aa8eb9810fcbe72b7b94 Mon Sep 17 00:00:00 2001 From: lollerfirst Date: Sat, 26 Oct 2024 18:20:34 +0200 Subject: [PATCH 1/6] sort proofs --- cashu/wallet/wallet.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cashu/wallet/wallet.py b/cashu/wallet/wallet.py index d8915ed5..47ad443f 100644 --- a/cashu/wallet/wallet.py +++ b/cashu/wallet/wallet.py @@ -667,6 +667,9 @@ async def split( # make sure we're operating on an independent copy of proofs proofs = copy.copy(proofs) + # sort proof + proofs.sort(key=lambda p: p.amount) + # potentially add witnesses to unlock provided proofs (if they indicate one) proofs = await self.add_witnesses_to_proofs(proofs) From e27ff0d9cca777c0261a962b8779b461881a9918 Mon Sep 17 00:00:00 2001 From: lollerfirst Date: Mon, 28 Oct 2024 13:53:29 +0100 Subject: [PATCH 2/6] outputs-ordering --- cashu/wallet/wallet.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/cashu/wallet/wallet.py b/cashu/wallet/wallet.py index 47ad443f..e5aae91f 100644 --- a/cashu/wallet/wallet.py +++ b/cashu/wallet/wallet.py @@ -667,9 +667,6 @@ async def split( # make sure we're operating on an independent copy of proofs proofs = copy.copy(proofs) - # sort proof - proofs.sort(key=lambda p: p.amount) - # potentially add witnesses to unlock provided proofs (if they indicate one) proofs = await self.add_witnesses_to_proofs(proofs) @@ -705,12 +702,31 @@ async def split( # potentially add witnesses to outputs based on what requirement the proofs indicate outputs = await self.add_witnesses_to_outputs(proofs, outputs) + logger.debug(f"unsorted outputs: {[o.amount for o in outputs]}") + + # sort outputs, remember original order + order_and_sorted_outputs = sorted(enumerate(outputs), key=lambda p: p[1].amount) + order = [x[0] for x in order_and_sorted_outputs] + outputs = [x[1] for x in order_and_sorted_outputs] + + logger.debug(f"{order = }") + logger.debug(f"sorted outputs: {[o.amount for o in outputs]}") + # Call swap API promises = await super().split(proofs, outputs) + logger.debug(f"sorted promises: {[p.amount for p in promises]}") + + # unsort promises + unsorted_promises = [None] * len(promises) + for i, j in enumerate(order): + unsorted_promises[j] = promises[i] + + logger.debug(f"unsorted promises: {[p.amount for p in unsorted_promises]}") + # Construct proofs from returned promises (i.e., unblind the signatures) new_proofs = await self._construct_proofs( - promises, secrets, rs, derivation_paths + unsorted_promises, secrets, rs, derivation_paths ) await self.invalidate(proofs) From 3988f96d8e12b984522bf2ea6ca84dd0c88374ba Mon Sep 17 00:00:00 2001 From: lollerfirst Date: Mon, 28 Oct 2024 14:03:27 +0100 Subject: [PATCH 3/6] mypy fix --- cashu/wallet/wallet.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cashu/wallet/wallet.py b/cashu/wallet/wallet.py index e5aae91f..da86a3e4 100644 --- a/cashu/wallet/wallet.py +++ b/cashu/wallet/wallet.py @@ -722,11 +722,9 @@ async def split( for i, j in enumerate(order): unsorted_promises[j] = promises[i] - logger.debug(f"unsorted promises: {[p.amount for p in unsorted_promises]}") - # Construct proofs from returned promises (i.e., unblind the signatures) new_proofs = await self._construct_proofs( - unsorted_promises, secrets, rs, derivation_paths + unsorted_promises, secrets, rs, derivation_paths # type: ignore[arg-type] ) await self.invalidate(proofs) From 58da7a077eb9cbe531acf4bcdab90c36f70ab945 Mon Sep 17 00:00:00 2001 From: callebtc <93376500+callebtc@users.noreply.github.com> Date: Tue, 5 Nov 2024 14:00:24 +0100 Subject: [PATCH 4/6] clean up --- cashu/wallet/wallet.py | 37 +++++++++++++++++-------------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/cashu/wallet/wallet.py b/cashu/wallet/wallet.py index da86a3e4..830a2461 100644 --- a/cashu/wallet/wallet.py +++ b/cashu/wallet/wallet.py @@ -464,7 +464,7 @@ def split_wallet_state(self, amount: int) -> List[int]: # sort by increasing amount amounts_we_want.sort() - logger.debug( + logger.trace( f"Amounts we have: {[(a, amounts_we_have.count(a)) for a in set(amounts_we_have)]}" ) amounts: list[int] = [] @@ -477,7 +477,7 @@ def split_wallet_state(self, amount: int) -> List[int]: if remaining_amount > 0: amounts += amount_split(remaining_amount) - logger.debug(f"Amounts we want: {amounts}") + logger.trace(f"Amounts we want: {amounts}") if sum(amounts) != amount: raise Exception(f"Amounts do not sum to {amount}.") @@ -671,7 +671,7 @@ async def split( proofs = await self.add_witnesses_to_proofs(proofs) input_fees = self.get_fees_for_proofs(proofs) - logger.debug(f"Input fees: {input_fees}") + logger.trace(f"Input fees: {input_fees}") # create a suitable amounts to keep and send. keep_outputs, send_outputs = self.determine_output_amounts( proofs, @@ -702,29 +702,26 @@ async def split( # potentially add witnesses to outputs based on what requirement the proofs indicate outputs = await self.add_witnesses_to_outputs(proofs, outputs) - logger.debug(f"unsorted outputs: {[o.amount for o in outputs]}") - - # sort outputs, remember original order - order_and_sorted_outputs = sorted(enumerate(outputs), key=lambda p: p[1].amount) - order = [x[0] for x in order_and_sorted_outputs] - outputs = [x[1] for x in order_and_sorted_outputs] - - logger.debug(f"{order = }") - logger.debug(f"sorted outputs: {[o.amount for o in outputs]}") + # sort outputs by amount, remember original order + sorted_outputs_with_indices = sorted( + enumerate(outputs), key=lambda p: p[1].amount + ) + original_indices, sorted_outputs = zip(*sorted_outputs_with_indices) # Call swap API - promises = await super().split(proofs, outputs) + sorted_promises = await super().split(proofs, sorted_outputs) - logger.debug(f"sorted promises: {[p.amount for p in promises]}") - - # unsort promises - unsorted_promises = [None] * len(promises) - for i, j in enumerate(order): - unsorted_promises[j] = promises[i] + # sort promises back to original order + promises = [ + promise + for _, promise in sorted( + zip(original_indices, sorted_promises), key=lambda x: x[0] + ) + ] # Construct proofs from returned promises (i.e., unblind the signatures) new_proofs = await self._construct_proofs( - unsorted_promises, secrets, rs, derivation_paths # type: ignore[arg-type] + promises, secrets, rs, derivation_paths ) await self.invalidate(proofs) From 67fd0575992ee4886696e981fa4ca46ae5c83e83 Mon Sep 17 00:00:00 2001 From: callebtc <93376500+callebtc@users.noreply.github.com> Date: Tue, 5 Nov 2024 14:48:42 +0100 Subject: [PATCH 5/6] test if output amounts are sorted --- tests/test_wallet_requests.py | 113 ++++++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 tests/test_wallet_requests.py diff --git a/tests/test_wallet_requests.py b/tests/test_wallet_requests.py new file mode 100644 index 00000000..08a3c7ee --- /dev/null +++ b/tests/test_wallet_requests.py @@ -0,0 +1,113 @@ +import json +from typing import List, Union + +import pytest +import pytest_asyncio +import respx +from httpx import Request, Response + +from cashu.core.base import BlindedSignature, Proof +from cashu.core.crypto.b_dhke import hash_to_curve +from cashu.core.errors import CashuError +from cashu.wallet.wallet import Wallet +from cashu.wallet.wallet import Wallet as Wallet1 +from cashu.wallet.wallet import Wallet as Wallet2 +from tests.conftest import SERVER_ENDPOINT +from tests.helpers import pay_if_regtest + + +async def assert_err(f, msg: Union[str, CashuError]): + """Compute f() and expect an error message 'msg'.""" + try: + await f + except Exception as exc: + error_message: str = str(exc.args[0]) + if isinstance(msg, CashuError): + if msg.detail not in error_message: + raise Exception( + f"CashuError. Expected error: {msg.detail}, got: {error_message}" + ) + return + if msg not in error_message: + raise Exception(f"Expected error: {msg}, got: {error_message}") + return + raise Exception(f"Expected error: {msg}, got no error") + + +async def assert_err_multiple(f, msgs: List[str]): + """Compute f() and expect an error message 'msg'.""" + try: + await f + except Exception as exc: + for msg in msgs: + if msg in str(exc.args[0]): + return + raise Exception(f"Expected error: {msgs}, got: {exc.args[0]}") + raise Exception(f"Expected error: {msgs}, got no error") + + +def assert_amt(proofs: List[Proof], expected: int): + """Assert amounts the proofs contain.""" + assert sum([p.amount for p in proofs]) == expected + + +async def reset_wallet_db(wallet: Wallet): + await wallet.db.execute("DELETE FROM proofs") + await wallet.db.execute("DELETE FROM proofs_used") + await wallet.db.execute("DELETE FROM keysets") + await wallet.load_mint() + + +@pytest_asyncio.fixture(scope="function") +async def wallet1(mint): + wallet1 = await Wallet1.with_db( + url=SERVER_ENDPOINT, + db="test_data/wallet1", + name="wallet1", + ) + await wallet1.load_mint() + yield wallet1 + + +@pytest_asyncio.fixture(scope="function") +async def wallet2(): + wallet2 = await Wallet2.with_db( + url=SERVER_ENDPOINT, + db="test_data/wallet2", + name="wallet2", + ) + await wallet2.load_mint() + yield wallet2 + + +@pytest.mark.asyncio +async def test_swap_outputs_are_sorted(wallet1: Wallet): + await wallet1.load_mint() + mint_quote = await wallet1.request_mint(16) + await pay_if_regtest(mint_quote.request) + await wallet1.mint(16, quote_id=mint_quote.quote, split=[16]) + assert wallet1.balance == 16 + + test_url = f"{wallet1.url}/v1/swap" + key = hash_to_curve("test".encode("utf-8")) + mock_blind_signature = BlindedSignature( + id=wallet1.keyset_id, + amount=8, + C_=key.serialize().hex(), + ) + mock_response_data = {"signatures": [mock_blind_signature.dict()]} + with respx.mock() as mock: + route = mock.post(test_url).mock( + return_value=Response(200, json=mock_response_data) + ) + await wallet1.select_to_send(wallet1.proofs, 5) + + assert route.called + assert route.call_count == 1 + request: Request = route.calls[0].request + assert request.method == "POST" + assert request.url == test_url + request_data = json.loads(request.content.decode("utf-8")) + output_amounts = [o["amount"] for o in request_data["outputs"]] + # assert that output amounts are sorted + assert output_amounts == sorted(output_amounts) From a586a2d9ae8edaaaff9e1fe1db5d63f44ccbf07a Mon Sep 17 00:00:00 2001 From: callebtc <93376500+callebtc@users.noreply.github.com> Date: Tue, 5 Nov 2024 14:49:42 +0100 Subject: [PATCH 6/6] clean up test --- tests/test_wallet_requests.py | 58 +---------------------------------- 1 file changed, 1 insertion(+), 57 deletions(-) diff --git a/tests/test_wallet_requests.py b/tests/test_wallet_requests.py index 08a3c7ee..3f2d6ea8 100644 --- a/tests/test_wallet_requests.py +++ b/tests/test_wallet_requests.py @@ -1,63 +1,18 @@ import json -from typing import List, Union import pytest import pytest_asyncio import respx from httpx import Request, Response -from cashu.core.base import BlindedSignature, Proof +from cashu.core.base import BlindedSignature from cashu.core.crypto.b_dhke import hash_to_curve -from cashu.core.errors import CashuError from cashu.wallet.wallet import Wallet from cashu.wallet.wallet import Wallet as Wallet1 -from cashu.wallet.wallet import Wallet as Wallet2 from tests.conftest import SERVER_ENDPOINT from tests.helpers import pay_if_regtest -async def assert_err(f, msg: Union[str, CashuError]): - """Compute f() and expect an error message 'msg'.""" - try: - await f - except Exception as exc: - error_message: str = str(exc.args[0]) - if isinstance(msg, CashuError): - if msg.detail not in error_message: - raise Exception( - f"CashuError. Expected error: {msg.detail}, got: {error_message}" - ) - return - if msg not in error_message: - raise Exception(f"Expected error: {msg}, got: {error_message}") - return - raise Exception(f"Expected error: {msg}, got no error") - - -async def assert_err_multiple(f, msgs: List[str]): - """Compute f() and expect an error message 'msg'.""" - try: - await f - except Exception as exc: - for msg in msgs: - if msg in str(exc.args[0]): - return - raise Exception(f"Expected error: {msgs}, got: {exc.args[0]}") - raise Exception(f"Expected error: {msgs}, got no error") - - -def assert_amt(proofs: List[Proof], expected: int): - """Assert amounts the proofs contain.""" - assert sum([p.amount for p in proofs]) == expected - - -async def reset_wallet_db(wallet: Wallet): - await wallet.db.execute("DELETE FROM proofs") - await wallet.db.execute("DELETE FROM proofs_used") - await wallet.db.execute("DELETE FROM keysets") - await wallet.load_mint() - - @pytest_asyncio.fixture(scope="function") async def wallet1(mint): wallet1 = await Wallet1.with_db( @@ -69,17 +24,6 @@ async def wallet1(mint): yield wallet1 -@pytest_asyncio.fixture(scope="function") -async def wallet2(): - wallet2 = await Wallet2.with_db( - url=SERVER_ENDPOINT, - db="test_data/wallet2", - name="wallet2", - ) - await wallet2.load_mint() - yield wallet2 - - @pytest.mark.asyncio async def test_swap_outputs_are_sorted(wallet1: Wallet): await wallet1.load_mint()