diff --git a/cashu/core/base.py b/cashu/core/base.py index c2694ec9..072d2a8e 100644 --- a/cashu/core/base.py +++ b/cashu/core/base.py @@ -535,6 +535,12 @@ class CheckFeesResponse_deprecated(BaseModel): # ------- API: RESTORE ------- +class PostRestoreRequest(BaseModel): + outputs: List[BlindedMessage] = Field( + ..., max_items=settings.mint_max_request_length + ) + + class PostRestoreResponse(BaseModel): outputs: List[BlindedMessage] = [] signatures: List[BlindedSignature] = [] diff --git a/cashu/mint/ledger.py b/cashu/mint/ledger.py index 3bb0cb8a..cff10b4d 100644 --- a/cashu/mint/ledger.py +++ b/cashu/mint/ledger.py @@ -325,8 +325,11 @@ async def mint_quote(self, quote_request: PostMintQuoteRequest) -> MintQuote: ) if settings.mint_peg_out_only: raise NotAllowedError("Mint does not allow minting new tokens.") - unit = Unit[quote_request.unit] - method = Method.bolt11 + + unit, method = self._verify_and_get_unit_method( + quote_request.unit, Method.bolt11.name + ) + if settings.mint_max_balance: balance = await self.get_balance() if balance + quote_request.amount > settings.mint_max_balance: @@ -387,10 +390,10 @@ async def get_mint_quote(self, quote_id: str) -> MintQuote: MintQuote: Mint quote object. """ quote = await self.crud.get_mint_quote(quote_id=quote_id, db=self.db) - assert quote, "quote not found" - assert quote.method == Method.bolt11.name, "only bolt11 supported" - unit = Unit[quote.unit] - method = Method[quote.method] + if not quote: + raise Exception("quote not found") + + unit, method = self._verify_and_get_unit_method(quote.unit, quote.method) if not quote.paid: assert quote.checking_id, "quote has no checking id" @@ -471,8 +474,10 @@ async def melt_quote( Returns: PostMeltQuoteResponse: Melt quote response. """ - unit = Unit[melt_quote.unit] - method = Method.bolt11 + unit, method = self._verify_and_get_unit_method( + melt_quote.unit, Method.bolt11.name + ) + # NOTE: we normalize the request to lowercase to avoid case sensitivity # This works with Lightning but might not work with other methods request = melt_quote.request.lower() @@ -557,10 +562,12 @@ async def get_melt_quote( MeltQuote: Melt quote object. """ melt_quote = await self.crud.get_melt_quote(quote_id=quote_id, db=self.db) - assert melt_quote, "quote not found" - assert melt_quote.method == Method.bolt11.name, "only bolt11 supported" - unit = Unit[melt_quote.unit] - method = Method[melt_quote.method] + if not melt_quote: + raise Exception("quote not found") + + unit, method = self._verify_and_get_unit_method( + melt_quote.unit, melt_quote.method + ) # we only check the state with the backend if there is no associated internal # mint quote for this melt quote @@ -664,8 +671,11 @@ async def melt( """ # get melt quote and check if it was already paid melt_quote = await self.get_melt_quote(quote_id=quote) - method = Method[melt_quote.method] - unit = Unit[melt_quote.unit] + + unit, method = self._verify_and_get_unit_method( + melt_quote.unit, melt_quote.method + ) + assert not melt_quote.paid, "melt quote already paid" # make sure that the outputs (for fee return) are in the same unit as the quote diff --git a/cashu/mint/protocols.py b/cashu/mint/protocols.py index 47bf618e..04d24c0c 100644 --- a/cashu/mint/protocols.py +++ b/cashu/mint/protocols.py @@ -1,6 +1,6 @@ -from typing import Dict, Protocol +from typing import Dict, Mapping, Protocol -from ..core.base import MintKeyset, Unit +from ..core.base import Method, MintKeyset, Unit from ..core.db import Database from ..lightning.base import LightningBackend from ..mint.crud import LedgerCrud @@ -11,8 +11,8 @@ class SupportsKeysets(Protocol): keysets: Dict[str, MintKeyset] -class SupportLightning(Protocol): - lightning: Dict[Unit, LightningBackend] +class SupportsBackends(Protocol): + backends: Mapping[Method, Mapping[Unit, LightningBackend]] = {} class SupportsDb(Protocol): diff --git a/cashu/mint/router.py b/cashu/mint/router.py index 50c41c7f..19e6f46b 100644 --- a/cashu/mint/router.py +++ b/cashu/mint/router.py @@ -20,6 +20,7 @@ PostMintQuoteResponse, PostMintRequest, PostMintResponse, + PostRestoreRequest, PostRestoreResponse, PostSplitRequest, PostSplitResponse, @@ -358,14 +359,14 @@ async def check_state( @router.post( "/v1/restore", name="Restore", - summary="Restores a blinded signature from a secret", + summary="Restores blind signature for a set of outputs.", response_model=PostRestoreResponse, response_description=( "Two lists with the first being the list of the provided outputs that " "have an associated blinded signature which is given in the second list." ), ) -async def restore(payload: PostMintRequest) -> PostRestoreResponse: +async def restore(payload: PostRestoreRequest) -> PostRestoreResponse: assert payload.outputs, Exception("no outputs provided.") outputs, signatures = await ledger.restore(payload.outputs) return PostRestoreResponse(outputs=outputs, signatures=signatures) diff --git a/cashu/mint/router_deprecated.py b/cashu/mint/router_deprecated.py index f71a8a34..e1901d7b 100644 --- a/cashu/mint/router_deprecated.py +++ b/cashu/mint/router_deprecated.py @@ -19,6 +19,7 @@ PostMintQuoteRequest, PostMintRequest_deprecated, PostMintResponse_deprecated, + PostRestoreRequest, PostRestoreResponse, PostSplitRequest_Deprecated, PostSplitResponse_Deprecated, @@ -357,7 +358,7 @@ async def check_spendable_deprecated( ), deprecated=True, ) -async def restore(payload: PostMintRequest_deprecated) -> PostRestoreResponse: +async def restore(payload: PostRestoreRequest) -> PostRestoreResponse: assert payload.outputs, Exception("no outputs provided.") outputs, promises = await ledger.restore(payload.outputs) return PostRestoreResponse(outputs=outputs, signatures=promises) diff --git a/cashu/mint/verification.py b/cashu/mint/verification.py index 754c8b1c..de38dca3 100644 --- a/cashu/mint/verification.py +++ b/cashu/mint/verification.py @@ -1,12 +1,14 @@ -from typing import Dict, List, Literal, Optional, Union +from typing import Dict, List, Literal, Optional, Tuple, Union from loguru import logger from ..core.base import ( BlindedMessage, BlindedSignature, + Method, MintKeyset, Proof, + Unit, ) from ..core.crypto import b_dhke from ..core.crypto.secp import PublicKey @@ -19,12 +21,15 @@ TransactionError, ) from ..core.settings import settings +from ..lightning.base import LightningBackend from ..mint.crud import LedgerCrud from .conditions import LedgerSpendingConditions -from .protocols import SupportsDb, SupportsKeysets +from .protocols import SupportsBackends, SupportsDb, SupportsKeysets -class LedgerVerification(LedgerSpendingConditions, SupportsKeysets, SupportsDb): +class LedgerVerification( + LedgerSpendingConditions, SupportsKeysets, SupportsDb, SupportsBackends +): """Verification functions for the ledger.""" keyset: MintKeyset @@ -32,6 +37,7 @@ class LedgerVerification(LedgerSpendingConditions, SupportsKeysets, SupportsDb): spent_proofs: Dict[str, Proof] crud: LedgerCrud db: Database + lightning: Dict[Unit, LightningBackend] async def verify_inputs_and_outputs( self, *, proofs: List[Proof], outputs: Optional[List[BlindedMessage]] = None @@ -240,6 +246,22 @@ def _verify_equation_balanced( """ sum_inputs = sum(self._verify_amount(p.amount) for p in proofs) sum_outputs = sum(self._verify_amount(p.amount) for p in outs) - assert sum_outputs - sum_inputs == 0, TransactionError( - "inputs do not have same amount as outputs." - ) + if not sum_outputs - sum_inputs == 0: + raise TransactionError("inputs do not have same amount as outputs.") + + def _verify_and_get_unit_method( + self, unit_str: str, method_str: str + ) -> Tuple[Unit, Method]: + """Verify that the unit is supported by the ledger.""" + method = Method[method_str] + unit = Unit[unit_str] + + if not any([unit == k.unit for k in self.keysets.values()]): + raise NotAllowedError(f"unit '{unit.name}' not supported in any keyset.") + + if not self.backends.get(method) or unit not in self.backends[method]: + raise NotAllowedError( + f"no support for method '{method.name}' with unit '{unit.name}'." + ) + + return unit, method diff --git a/tests/test_mint_api.py b/tests/test_mint_api.py index fab370a6..3a7acbd3 100644 --- a/tests/test_mint_api.py +++ b/tests/test_mint_api.py @@ -8,7 +8,7 @@ MintMeltMethodSetting, PostCheckStateRequest, PostCheckStateResponse, - PostMintRequest, + PostRestoreRequest, PostRestoreResponse, SpentState, ) @@ -430,7 +430,7 @@ async def test_api_restore(ledger: Ledger, wallet: Wallet): ) outputs, rs = wallet._construct_outputs([64], secrets, rs) - payload = PostMintRequest(outputs=outputs, quote="placeholder") + payload = PostRestoreRequest(outputs=outputs) response = httpx.post( f"{BASE_URL}/v1/restore", json=payload.dict(), diff --git a/tests/test_mint_api_deprecated.py b/tests/test_mint_api_deprecated.py index 478b5ba4..fc40589c 100644 --- a/tests/test_mint_api_deprecated.py +++ b/tests/test_mint_api_deprecated.py @@ -5,7 +5,7 @@ from cashu.core.base import ( CheckSpendableRequest_deprecated, CheckSpendableResponse_deprecated, - PostMintRequest, + PostRestoreRequest, PostRestoreResponse, Proof, ) @@ -340,7 +340,7 @@ async def test_api_restore(ledger: Ledger, wallet: Wallet): ) outputs, rs = wallet._construct_outputs([64], secrets, rs) - payload = PostMintRequest(outputs=outputs, quote="placeholder") + payload = PostRestoreRequest(outputs=outputs) response = httpx.post( f"{BASE_URL}/restore", json=payload.dict(),