Skip to content

Commit

Permalink
Token state check with Y
Browse files Browse the repository at this point in the history
  • Loading branch information
callebtc committed Mar 7, 2024
1 parent ff1e759 commit dae1251
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 88 deletions.
8 changes: 6 additions & 2 deletions cashu/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,10 @@ class PostSplitResponse_Very_Deprecated(BaseModel):


class PostCheckStateRequest(BaseModel):
secrets: List[str] = Field(..., max_items=settings.mint_max_request_length)
Ys: List[str] = Field(..., max_items=settings.mint_max_request_length)
secrets: Optional[List[str]] = Field(
default=None, max_items=settings.mint_max_request_length
) # deprecated since 0.15.1


class SpentState(Enum):
Expand All @@ -499,9 +502,10 @@ def __str__(self):


class ProofState(BaseModel):
secret: str
Y: str
state: SpentState
witness: Optional[str] = None
secret: Optional[str] = None # deprecated since 0.15.1


class PostCheckStateResponse(BaseModel):
Expand Down
74 changes: 48 additions & 26 deletions cashu/mint/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,26 @@ async def get_keyset(
derivation_path: str = "",
seed: str = "",
conn: Optional[Connection] = None,
) -> List[MintKeyset]: ...
) -> List[MintKeyset]:
...

@abstractmethod
async def get_spent_proofs(
self,
*,
db: Database,
conn: Optional[Connection] = None,
) -> List[Proof]: ...
) -> List[Proof]:
...

async def get_proof_used(
self,
*,
Y: str,
db: Database,
conn: Optional[Connection] = None,
) -> Optional[Proof]: ...
) -> Optional[Proof]:
...

@abstractmethod
async def invalidate_proof(
Expand All @@ -59,16 +62,18 @@ async def invalidate_proof(
db: Database,
proof: Proof,
conn: Optional[Connection] = None,
) -> None: ...
) -> None:
...

@abstractmethod
async def get_proofs_pending(
self,
*,
proofs: List[Proof],
Ys: List[str],
db: Database,
conn: Optional[Connection] = None,
) -> List[Proof]: ...
) -> List[Proof]:
...

@abstractmethod
async def set_proof_pending(
Expand All @@ -77,12 +82,14 @@ async def set_proof_pending(
db: Database,
proof: Proof,
conn: Optional[Connection] = None,
) -> None: ...
) -> None:
...

@abstractmethod
async def unset_proof_pending(
self, *, proof: Proof, db: Database, conn: Optional[Connection] = None
) -> None: ...
) -> None:
...

@abstractmethod
async def store_keyset(
Expand All @@ -91,14 +98,16 @@ async def store_keyset(
db: Database,
keyset: MintKeyset,
conn: Optional[Connection] = None,
) -> None: ...
) -> None:
...

@abstractmethod
async def get_balance(
self,
db: Database,
conn: Optional[Connection] = None,
) -> int: ...
) -> int:
...

@abstractmethod
async def store_promise(
Expand All @@ -112,7 +121,8 @@ async def store_promise(
e: str = "",
s: str = "",
conn: Optional[Connection] = None,
) -> None: ...
) -> None:
...

@abstractmethod
async def get_promise(
Expand All @@ -121,7 +131,8 @@ async def get_promise(
db: Database,
B_: str,
conn: Optional[Connection] = None,
) -> Optional[BlindedSignature]: ...
) -> Optional[BlindedSignature]:
...

@abstractmethod
async def store_mint_quote(
Expand All @@ -130,7 +141,8 @@ async def store_mint_quote(
quote: MintQuote,
db: Database,
conn: Optional[Connection] = None,
) -> None: ...
) -> None:
...

@abstractmethod
async def get_mint_quote(
Expand All @@ -139,7 +151,8 @@ async def get_mint_quote(
quote_id: str,
db: Database,
conn: Optional[Connection] = None,
) -> Optional[MintQuote]: ...
) -> Optional[MintQuote]:
...

@abstractmethod
async def get_mint_quote_by_checking_id(
Expand All @@ -148,7 +161,8 @@ async def get_mint_quote_by_checking_id(
checking_id: str,
db: Database,
conn: Optional[Connection] = None,
) -> Optional[MintQuote]: ...
) -> Optional[MintQuote]:
...

@abstractmethod
async def update_mint_quote(
Expand All @@ -157,7 +171,8 @@ async def update_mint_quote(
quote: MintQuote,
db: Database,
conn: Optional[Connection] = None,
) -> None: ...
) -> None:
...

# @abstractmethod
# async def update_mint_quote_paid(
Expand All @@ -176,7 +191,8 @@ async def store_melt_quote(
quote: MeltQuote,
db: Database,
conn: Optional[Connection] = None,
) -> None: ...
) -> None:
...

@abstractmethod
async def get_melt_quote(
Expand All @@ -186,7 +202,8 @@ async def get_melt_quote(
db: Database,
checking_id: Optional[str] = None,
conn: Optional[Connection] = None,
) -> Optional[MeltQuote]: ...
) -> Optional[MeltQuote]:
...

@abstractmethod
async def update_melt_quote(
Expand All @@ -195,7 +212,8 @@ async def update_melt_quote(
quote: MeltQuote,
db: Database,
conn: Optional[Connection] = None,
) -> None: ...
) -> None:
...


class LedgerCrudSqlite(LedgerCrud):
Expand Down Expand Up @@ -256,9 +274,11 @@ async def get_spent_proofs(
db: Database,
conn: Optional[Connection] = None,
) -> List[Proof]:
rows = await (conn or db).fetchall(f"""
rows = await (conn or db).fetchall(
f"""
SELECT * from {table_with_schema(db, 'proofs_used')}
""")
"""
)
return [Proof(**r) for r in rows] if rows else []

async def invalidate_proof(
Expand Down Expand Up @@ -289,16 +309,16 @@ async def invalidate_proof(
async def get_proofs_pending(
self,
*,
proofs: List[Proof],
Ys: List[str],
db: Database,
conn: Optional[Connection] = None,
) -> List[Proof]:
rows = await (conn or db).fetchall(
f"""
SELECT * from {table_with_schema(db, 'proofs_pending')}
WHERE Y IN ({','.join(['?']*len(proofs))})
WHERE Y IN ({','.join(['?']*len(Ys))})
""",
tuple(proof.Y for proof in proofs),
tuple(Ys),
)
return [Proof(**r) for r in rows]

Expand Down Expand Up @@ -549,9 +569,11 @@ async def get_balance(
db: Database,
conn: Optional[Connection] = None,
) -> int:
row = await (conn or db).fetchone(f"""
row = await (conn or db).fetchone(
f"""
SELECT * from {table_with_schema(db, 'balance')}
""")
"""
)
assert row, "Balance not found"
return int(row[0])

Expand Down
32 changes: 14 additions & 18 deletions cashu/mint/ledger.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,7 @@ async def load_used_proofs(self) -> None:
logger.debug(f"Loaded {len(spent_proofs_list)} used proofs")
self.spent_proofs = {p.Y: p for p in spent_proofs_list}

async def check_proofs_state(self, secrets: List[str]) -> List[ProofState]:
async def check_proofs_state(self, Ys: List[str]) -> List[ProofState]:
"""Checks if provided proofs are spend or are pending.
Used by wallets to check if their proofs have been redeemed by a receiver or they are still in-flight in a transaction.
Expand All @@ -895,32 +895,26 @@ async def check_proofs_state(self, secrets: List[str]) -> List[ProofState]:
and which isn't.
Args:
proofs (List[Proof]): List of proofs to check.
Ys (List[str]): List of Y's of proofs to check
Returns:
List[bool]: List of which proof is still spendable (True if still spendable, else False)
List[bool]: List of which proof are pending (True if pending, else False)
"""
states: List[ProofState] = []
proofs_spent_idx_secret = await self._get_proofs_spent_idx_secret(secrets)
proofs_pending_idx_secret = await self._get_proofs_pending_idx_secret(secrets)
for secret in secrets:
if (
secret not in proofs_spent_idx_secret
and secret not in proofs_pending_idx_secret
):
states.append(ProofState(secret=secret, state=SpentState.unspent))
elif (
secret not in proofs_spent_idx_secret
and secret in proofs_pending_idx_secret
):
states.append(ProofState(secret=secret, state=SpentState.pending))
proofs_spent = await self._get_proofs_spent(Ys)
proofs_pending = await self._get_proofs_pending(Ys)
for Y in Ys:
if Y not in proofs_spent and Y not in proofs_pending:
states.append(ProofState(Y=Y, state=SpentState.unspent))
elif Y not in proofs_spent and Y in proofs_pending:
states.append(ProofState(Y=Y, state=SpentState.pending))
else:
states.append(
ProofState(
secret=secret,
Y=Y,
state=SpentState.spent,
witness=proofs_spent_idx_secret[secret].witness,
witness=proofs_spent[Y].witness,
)
)
return states
Expand Down Expand Up @@ -971,7 +965,9 @@ async def _validate_proofs_pending(
"""
assert (
len(
await self.crud.get_proofs_pending(proofs=proofs, db=self.db, conn=conn)
await self.crud.get_proofs_pending(
Ys=[p.Y for p in proofs], db=self.db, conn=conn
)
)
== 0
), TransactionError("proofs are pending.")
21 changes: 20 additions & 1 deletion cashu/mint/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from fastapi import APIRouter
from loguru import logger

from cashu.core.crypto import b_dhke

from ..core.base import (
GetInfoResponse,
KeysetsResponse,
Expand Down Expand Up @@ -338,7 +340,24 @@ async def check_state(
) -> PostCheckStateResponse:
"""Check whether a secret has been spent already or not."""
logger.trace(f"> POST /v1/checkstate: {payload}")
proof_states = await ledger.check_proofs_state(payload.secrets)

# BEGIN BACKWARDS COMPATIBILITY < 0.15.1
# If the request includes "secret", compuate Ys from them and continue request
if payload.secrets:
payload.Ys = [
b_dhke.hash_to_curve(s.encode()).serialize().hex() for s in payload.secrets
]
# END BACKWARDS COMPATIBILITY < 0.15.1

proof_states = await ledger.check_proofs_state(payload.Ys)

# BEGIN BACKWARDS COMPATIBILITY < 0.15.1
# If the request includes "secret", remove add the secret to the response
if payload.secrets:
for i, state in enumerate(proof_states):
state.secret = payload.secrets[i]
# END BACKWARDS COMPATIBILITY < 0.15.1

return PostCheckStateResponse(states=proof_states)


Expand Down
2 changes: 1 addition & 1 deletion cashu/mint/router_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ async def check_spendable_deprecated(
) -> CheckSpendableResponse_deprecated:
"""Check whether a secret has been spent already or not."""
logger.trace(f"> POST /check: {payload}")
proofs_state = await ledger.check_proofs_state([p.secret for p in payload.proofs])
proofs_state = await ledger.check_proofs_state([p.Y for p in payload.proofs])
spendableList: List[bool] = []
pendingList: List[bool] = []
for proof_state in proofs_state:
Expand Down
Loading

0 comments on commit dae1251

Please sign in to comment.