Skip to content

Commit

Permalink
Add get quote API to wallet + check proof states in batches (#637)
Browse files Browse the repository at this point in the history
* add get quote api to wallet

* wrong string

* test before pushing

* fix tests for deprecated api only

* sigh
  • Loading branch information
callebtc authored Oct 8, 2024
1 parent cd39e18 commit 4490cc6
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 22 deletions.
19 changes: 10 additions & 9 deletions cashu/wallet/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,12 @@ def mint_invoice_callback(msg: JSONRPCNotficationParams):
while time.time() < check_until and not paid:
await asyncio.sleep(5)
try:
await wallet.mint(amount, split=optional_split, id=invoice.id)
paid = True
mint_quote_resp = await wallet.get_mint_quote(invoice.id)
if mint_quote_resp.state == MintQuoteState.paid.value:
await wallet.mint(amount, split=optional_split, id=invoice.id)
paid = True
else:
print(".", end="", flush=True)
except Exception as e:
# TODO: user error codes!
if "not paid" in str(e):
Expand Down Expand Up @@ -710,12 +714,7 @@ async def burn(ctx: Context, token: str, all: bool, force: bool, delete: str):
if delete:
await wallet.invalidate(proofs)
else:
# invalidate proofs in batches
for _proofs in [
proofs[i : i + settings.proofs_batch_size]
for i in range(0, len(proofs), settings.proofs_batch_size)
]:
await wallet.invalidate(_proofs, check_spendable=True)
await wallet.invalidate(proofs, check_spendable=True)
await print_balance(ctx)


Expand Down Expand Up @@ -1024,7 +1023,9 @@ async def info(ctx: Context, mint: bool, mnemonic: bool):
if mint_info.get("time"):
print(f" - Server time: {mint_info['time']}")
if mint_info.get("nuts"):
nuts_str = ', '.join([f"NUT-{k}" for k in mint_info['nuts'].keys()])
nuts_str = ", ".join(
[f"NUT-{k}" for k in mint_info["nuts"].keys()]
)
print(f" - Supported NUTS: {nuts_str}")
print("")
except Exception as e:
Expand Down
38 changes: 37 additions & 1 deletion cashu/wallet/v1_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ async def _get_keys(self) -> List[WalletKeyset]:
keys_dict: dict = resp.json()
assert len(keys_dict), Exception("did not receive any keys")
keys = KeysResponse.parse_obj(keys_dict)
keysets_str = ' '.join([f"{k.id} ({k.unit})" for k in keys.keysets])
keysets_str = " ".join([f"{k.id} ({k.unit})" for k in keys.keysets])
logger.debug(f"Received {len(keys.keysets)} keysets from mint: {keysets_str}.")
ret = [
WalletKeyset(
Expand Down Expand Up @@ -312,6 +312,24 @@ async def mint_quote(
return_dict = resp.json()
return PostMintQuoteResponse.parse_obj(return_dict)

@async_set_httpx_client
@async_ensure_mint_loaded
async def get_mint_quote(self, quote: str) -> PostMintQuoteResponse:
"""Returns an existing mint quote from the server.
Args:
quote (str): Quote ID
Returns:
PostMintQuoteResponse: Mint Quote Response
"""
resp = await self.httpx.get(
join(self.url, f"/v1/mint/quote/bolt11/{quote}"),
)
self.raise_on_error_request(resp)
return_dict = resp.json()
return PostMintQuoteResponse.parse_obj(return_dict)

@async_set_httpx_client
@async_ensure_mint_loaded
async def mint(
Expand Down Expand Up @@ -400,6 +418,24 @@ async def melt_quote(
return_dict = resp.json()
return PostMeltQuoteResponse.parse_obj(return_dict)

@async_set_httpx_client
@async_ensure_mint_loaded
async def get_melt_quote(self, quote: str) -> PostMeltQuoteResponse:
"""Returns an existing melt quote from the server.
Args:
quote (str): Quote ID
Returns:
PostMeltQuoteResponse: Melt Quote Response
"""
resp = await self.httpx.get(
join(self.url, f"/v1/melt/quote/bolt11/{quote}"),
)
self.raise_on_error_request(resp)
return_dict = resp.json()
return PostMeltQuoteResponse.parse_obj(return_dict)

@async_set_httpx_client
@async_ensure_mint_loaded
async def melt(
Expand Down
27 changes: 17 additions & 10 deletions cashu/wallet/wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ async def with_db(
self.keysets = {k.id: k for k in keysets_active_unit}
else:
self.keysets = {k.id: k for k in keysets_list}
keysets_str = ' '.join([f"{i} {k.unit}" for i, k in self.keysets.items()])
keysets_str = " ".join([f"{i} {k.unit}" for i, k in self.keysets.items()])
logger.debug(f"Loaded keysets: {keysets_str}")
return self

Expand Down Expand Up @@ -351,10 +351,9 @@ async def load_proofs(self, reload: bool = False, all_keysets=False) -> None:
for keyset_id in self.keysets:
proofs = await get_proofs(db=self.db, id=keyset_id, conn=conn)
self.proofs.extend(proofs)
keysets_str = ' '.join([f"{k.id} ({k.unit})" for k in self.keysets.values()])
keysets_str = " ".join([f"{k.id} ({k.unit})" for k in self.keysets.values()])
logger.trace(f"Proofs loaded for keysets: {keysets_str}")


async def load_keysets_from_db(
self, url: Union[str, None] = "", unit: Union[str, None] = ""
):
Expand Down Expand Up @@ -1020,10 +1019,15 @@ async def invalidate(
"""
invalidated_proofs: List[Proof] = []
if check_spendable:
proof_states = await self.check_proof_state(proofs)
for i, state in enumerate(proof_states.states):
if state.spent:
invalidated_proofs.append(proofs[i])
# checks proofs in batches
for _proofs in [
proofs[i : i + settings.proofs_batch_size]
for i in range(0, len(proofs), settings.proofs_batch_size)
]:
proof_states = await self.check_proof_state(proofs)
for i, state in enumerate(proof_states.states):
if state.spent:
invalidated_proofs.append(proofs[i])
else:
invalidated_proofs = proofs

Expand All @@ -1033,9 +1037,12 @@ async def invalidate(
f" {self.unit.str(sum_proofs(invalidated_proofs))}."
)

async with self.db.connect() as conn:
for p in invalidated_proofs:
await invalidate_proof(p, db=self.db, conn=conn)
for p in invalidated_proofs:
try:
# mark proof as spent
await invalidate_proof(p, db=self.db)
except Exception as e:
logger.error(f"DB error while invalidating proof: {e}")

invalidate_secrets = [p.secret for p in invalidated_proofs]
self.proofs = list(
Expand Down
26 changes: 25 additions & 1 deletion tests/test_mint_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from cashu.core.base import MeltQuoteState
from cashu.core.helpers import sum_proofs
from cashu.core.models import PostMeltQuoteRequest, PostMintQuoteRequest
from cashu.core.settings import settings
from cashu.mint.ledger import Ledger
from cashu.wallet.wallet import Wallet
from cashu.wallet.wallet import Wallet as Wallet1
Expand Down Expand Up @@ -55,6 +56,13 @@ async def test_melt_internal(wallet1: Wallet, ledger: Ledger):
assert melt_quote.amount == 64
assert melt_quote.fee_reserve == 0

if not settings.debug_mint_only_deprecated:
melt_quote_response_pre_payment = await wallet1.get_melt_quote(melt_quote.quote)
assert (
not melt_quote_response_pre_payment.state == MeltQuoteState.paid.value
), "melt quote should not be paid"
assert melt_quote_response_pre_payment.amount == 64

melt_quote_pre_payment = await ledger.get_melt_quote(melt_quote.quote)
assert not melt_quote_pre_payment.paid, "melt quote should not be paid"
assert melt_quote_pre_payment.unpaid
Expand Down Expand Up @@ -89,6 +97,13 @@ async def test_melt_external(wallet1: Wallet, ledger: Ledger):
PostMeltQuoteRequest(request=invoice_payment_request, unit="sat")
)

if not settings.debug_mint_only_deprecated:
melt_quote_response_pre_payment = await wallet1.get_melt_quote(melt_quote.quote)
assert (
melt_quote_response_pre_payment.state == MeltQuoteState.unpaid.value
), "melt quote should not be paid"
assert melt_quote_response_pre_payment.amount == melt_quote.amount

melt_quote_pre_payment = await ledger.get_melt_quote(melt_quote.quote)
assert not melt_quote_pre_payment.paid, "melt quote should not be paid"
assert melt_quote_pre_payment.unpaid
Expand All @@ -109,7 +124,12 @@ async def test_mint_internal(wallet1: Wallet, ledger: Ledger):
mint_quote = await ledger.get_mint_quote(invoice.id)

assert mint_quote.paid, "mint quote should be paid"
assert mint_quote.paid

if not settings.debug_mint_only_deprecated:
mint_quote_resp = await wallet1.get_mint_quote(invoice.id)
assert (
mint_quote_resp.state == MeltQuoteState.paid.value
), "mint quote should be paid"

output_amounts = [128]
secrets, rs, derivation_paths = await wallet1.generate_n_secrets(
Expand Down Expand Up @@ -139,6 +159,10 @@ async def test_mint_external(wallet1: Wallet, ledger: Ledger):
assert not mint_quote.paid, "mint quote already paid"
assert mint_quote.unpaid

if not settings.debug_mint_only_deprecated:
mint_quote_resp = await wallet1.get_mint_quote(quote.quote)
assert not mint_quote_resp.paid, "mint quote should not be paid"

await assert_err(
wallet1.mint(128, id=quote.quote),
"quote not paid",
Expand Down
11 changes: 10 additions & 1 deletion tests/test_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import pytest_asyncio

from cashu.core.base import Proof
from cashu.core.base import MintQuoteState, Proof
from cashu.core.errors import CashuError, KeysetNotFoundError
from cashu.core.helpers import sum_proofs
from cashu.core.settings import settings
Expand Down Expand Up @@ -168,6 +168,11 @@ async def test_request_mint(wallet1: Wallet):
async def test_mint(wallet1: Wallet):
invoice = await wallet1.request_mint(64)
await pay_if_regtest(invoice.bolt11)
if not settings.debug_mint_only_deprecated:
quote_resp = await wallet1.get_mint_quote(invoice.id)
assert quote_resp.request == invoice.bolt11
assert quote_resp.state == MintQuoteState.paid.value

expected_proof_amounts = wallet1.split_wallet_state(64)
await wallet1.mint(64, id=invoice.id)
assert wallet1.balance == 64
Expand Down Expand Up @@ -307,6 +312,10 @@ async def test_melt(wallet1: Wallet):
assert total_amount == 64
assert quote.fee_reserve == 0

if not settings.debug_mint_only_deprecated:
quote_resp = await wallet1.get_melt_quote(quote.quote)
assert quote_resp.amount == quote.amount

_, send_proofs = await wallet1.swap_to_send(wallet1.proofs, total_amount)

melt_response = await wallet1.melt(
Expand Down
3 changes: 3 additions & 0 deletions tests/test_wallet_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def test_balance(cli_prefix):

@pytest.mark.skipif(is_regtest, reason="only works with FakeWallet")
def test_invoice(mint, cli_prefix):
if settings.debug_mint_only_deprecated:
pytest.skip("only works with v1 API")

runner = CliRunner()
result = runner.invoke(
cli,
Expand Down

0 comments on commit 4490cc6

Please sign in to comment.