From ed0d25dec76200d117b74552b5047f76021686f6 Mon Sep 17 00:00:00 2001 From: lollerfirst <43107113+lollerfirst@users.noreply.github.com> Date: Tue, 5 Nov 2024 15:00:37 +0100 Subject: [PATCH] [FIX] Wallet sort outputs before swapping (#648) * sort proofs * outputs-ordering * mypy fix * clean up * test if output amounts are sorted * clean up test --------- Co-authored-by: callebtc <93376500+callebtc@users.noreply.github.com> --- cashu/wallet/wallet.py | 22 +++++++++++--- tests/test_wallet_requests.py | 57 +++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 4 deletions(-) create mode 100644 tests/test_wallet_requests.py diff --git a/cashu/wallet/wallet.py b/cashu/wallet/wallet.py index 8d61e9e7..f0a9d76c 100644 --- a/cashu/wallet/wallet.py +++ b/cashu/wallet/wallet.py @@ -456,7 +456,7 @@ def split_wallet_state(self, amount: int) -> List[int]: # sort by increasing amount amounts_we_want.sort() - logger.debug( + logger.trace( f"Amounts we have: {[(a, amounts_we_have.count(a)) for a in set(amounts_we_have)]}" ) amounts: list[int] = [] @@ -470,7 +470,7 @@ def split_wallet_state(self, amount: int) -> List[int]: amounts += amount_split(remaining_amount) amounts.sort() - logger.debug(f"Amounts we want: {amounts}") + logger.trace(f"Amounts we want: {amounts}") if sum(amounts) != amount: raise Exception(f"Amounts do not sum to {amount}.") @@ -643,7 +643,7 @@ async def split( proofs = self.add_witnesses_to_proofs(proofs) input_fees = self.get_fees_for_proofs(proofs) - logger.debug(f"Input fees: {input_fees}") + logger.trace(f"Input fees: {input_fees}") # create a suitable amounts to keep and send. keep_outputs, send_outputs = self.determine_output_amounts( proofs, @@ -674,8 +674,22 @@ async def split( # potentially add witnesses to outputs based on what requirement the proofs indicate outputs = self.add_witnesses_to_outputs(proofs, outputs) + # sort outputs by amount, remember original order + sorted_outputs_with_indices = sorted( + enumerate(outputs), key=lambda p: p[1].amount + ) + original_indices, sorted_outputs = zip(*sorted_outputs_with_indices) + # Call swap API - promises = await super().split(proofs, outputs) + sorted_promises = await super().split(proofs, sorted_outputs) + + # sort promises back to original order + promises = [ + promise + for _, promise in sorted( + zip(original_indices, sorted_promises), key=lambda x: x[0] + ) + ] # Construct proofs from returned promises (i.e., unblind the signatures) new_proofs = await self._construct_proofs( diff --git a/tests/test_wallet_requests.py b/tests/test_wallet_requests.py new file mode 100644 index 00000000..3f2d6ea8 --- /dev/null +++ b/tests/test_wallet_requests.py @@ -0,0 +1,57 @@ +import json + +import pytest +import pytest_asyncio +import respx +from httpx import Request, Response + +from cashu.core.base import BlindedSignature +from cashu.core.crypto.b_dhke import hash_to_curve +from cashu.wallet.wallet import Wallet +from cashu.wallet.wallet import Wallet as Wallet1 +from tests.conftest import SERVER_ENDPOINT +from tests.helpers import pay_if_regtest + + +@pytest_asyncio.fixture(scope="function") +async def wallet1(mint): + wallet1 = await Wallet1.with_db( + url=SERVER_ENDPOINT, + db="test_data/wallet1", + name="wallet1", + ) + await wallet1.load_mint() + yield wallet1 + + +@pytest.mark.asyncio +async def test_swap_outputs_are_sorted(wallet1: Wallet): + await wallet1.load_mint() + mint_quote = await wallet1.request_mint(16) + await pay_if_regtest(mint_quote.request) + await wallet1.mint(16, quote_id=mint_quote.quote, split=[16]) + assert wallet1.balance == 16 + + test_url = f"{wallet1.url}/v1/swap" + key = hash_to_curve("test".encode("utf-8")) + mock_blind_signature = BlindedSignature( + id=wallet1.keyset_id, + amount=8, + C_=key.serialize().hex(), + ) + mock_response_data = {"signatures": [mock_blind_signature.dict()]} + with respx.mock() as mock: + route = mock.post(test_url).mock( + return_value=Response(200, json=mock_response_data) + ) + await wallet1.select_to_send(wallet1.proofs, 5) + + assert route.called + assert route.call_count == 1 + request: Request = route.calls[0].request + assert request.method == "POST" + assert request.url == test_url + request_data = json.loads(request.content.decode("utf-8")) + output_amounts = [o["amount"] for o in request_data["outputs"]] + # assert that output amounts are sorted + assert output_amounts == sorted(output_amounts)