Skip to content

Commit

Permalink
Fix race condition (#586)
Browse files Browse the repository at this point in the history
* `_set_proofs_pending` performs DB related "proofs are spendable" check inside the lock.

* move _verify_spent_proofs_and_set_pending to write.py

* edit logging

---------

Co-authored-by: callebtc <[email protected]>
  • Loading branch information
lollerfirst and callebtc authored Jul 17, 2024
1 parent 71580a5 commit efdfecc
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 38 deletions.
17 changes: 17 additions & 0 deletions cashu/mint/db/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from ...core.base import Proof, ProofSpentState, ProofState
from ...core.db import Connection, Database
from ...core.errors import TokenAlreadySpentError
from ..crud import LedgerCrud


Expand Down Expand Up @@ -74,3 +75,19 @@ async def get_proofs_states(
)
)
return states

async def _verify_proofs_spendable(
self, proofs: List[Proof], conn: Optional[Connection] = None
):
"""Checks the database to see if any of the proofs are already spent.
Args:
proofs (List[Proof]): Proofs to verify
conn (Optional[Connection]): Database connection to use. Defaults to None.
Raises:
TokenAlreadySpentError: If any of the proofs are already spent
"""
async with self.db.get_connection(conn) as conn:
if not len(await self._get_proofs_spent([p.Y for p in proofs], conn)) == 0:
raise TokenAlreadySpentError()
33 changes: 19 additions & 14 deletions cashu/mint/db/write.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import random
from typing import List, Optional, Union

from loguru import logger
Expand All @@ -18,54 +17,60 @@
)
from ..crud import LedgerCrud
from ..events.events import LedgerEventManager
from .read import DbReadHelper


class DbWriteHelper:
db: Database
crud: LedgerCrud
events: LedgerEventManager
db_read: DbReadHelper

def __init__(
self, db: Database, crud: LedgerCrud, events: LedgerEventManager
self,
db: Database,
crud: LedgerCrud,
events: LedgerEventManager,
db_read: DbReadHelper,
) -> None:
self.db = db
self.crud = crud
self.events = events
self.db_read = db_read

async def _set_proofs_pending(
async def _verify_spent_proofs_and_set_pending(
self, proofs: List[Proof], quote_id: Optional[str] = None
) -> None:
"""If none of the proofs is in the pending table (_validate_proofs_pending), adds proofs to
the list of pending proofs or removes them. Used as a mutex for proofs.
"""
Method to check if proofs are already spent. If they are not spent, we check if they are pending.
If they are not pending, we set them as pending.
Args:
proofs (List[Proof]): Proofs to add to pending table.
quote_id (Optional[str]): Melt quote ID. If it is not set, we assume the pending tokens to be from a swap.
Raises:
Exception: At least one proof already in pending table.
TransactionError: If any one of the proofs is already spent or pending.
"""
# first we check whether these proofs are pending already
random_id = random.randint(0, 1000000)
try:
logger.debug("trying to set proofs pending")
logger.trace(f"get_connection: random_id: {random_id}")
logger.trace("_verify_spent_proofs_and_set_pending acquiring lock")
async with self.db.get_connection(
lock_table="proofs_pending",
lock_timeout=1,
) as conn:
logger.trace(f"get_connection: got connection {random_id}")
logger.trace("checking whether proofs are already spent")
await self.db_read._verify_proofs_spendable(proofs, conn)
logger.trace("checking whether proofs are already pending")
await self._validate_proofs_pending(proofs, conn)
for p in proofs:
logger.trace(f"crud: setting proof {p.Y} as PENDING")
await self.crud.set_proof_pending(
proof=p, db=self.db, quote_id=quote_id, conn=conn
)
logger.trace(f"crud: set proof {p.Y} as PENDING")
logger.trace("_verify_spent_proofs_and_set_pending released lock")
except Exception as e:
logger.error(f"Failed to set proofs pending: {e}")
raise TransactionError(f"Failed to set proofs pending: {str(e)}")
logger.trace("_set_proofs_pending released lock")
raise e
for p in proofs:
await self.events.submit(ProofState(Y=p.Y, state=ProofSpentState.pending))

Expand Down
8 changes: 5 additions & 3 deletions cashu/mint/ledger.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(
self.backends = backends
self.pubkey = derive_pubkey(self.seed)
self.db_read = DbReadHelper(self.db, self.crud)
self.db_write = DbWriteHelper(self.db, self.crud, self.events)
self.db_write = DbWriteHelper(self.db, self.crud, self.events, self.db_read)

# ------- STARTUP -------

Expand Down Expand Up @@ -905,7 +905,9 @@ async def melt(
await self.verify_inputs_and_outputs(proofs=proofs)

# set proofs to pending to avoid race conditions
await self.db_write._set_proofs_pending(proofs, quote_id=melt_quote.quote)
await self.db_write._verify_spent_proofs_and_set_pending(
proofs, quote_id=melt_quote.quote
)
try:
# settle the transaction internally if there is a mint quote with the same payment request
melt_quote = await self.melt_mint_settle_internally(melt_quote, proofs)
Expand Down Expand Up @@ -985,7 +987,7 @@ async def swap(
logger.trace("swap called")
# verify spending inputs, outputs, and spending conditions
await self.verify_inputs_and_outputs(proofs=proofs, outputs=outputs)
await self.db_write._set_proofs_pending(proofs)
await self.db_write._verify_spent_proofs_and_set_pending(proofs)
try:
async with self.db.get_connection(lock_table="proofs_pending") as conn:
await self._invalidate_proofs(proofs=proofs, conn=conn)
Expand Down
7 changes: 0 additions & 7 deletions cashu/mint/verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
NoSecretInProofsError,
NotAllowedError,
SecretTooLongError,
TokenAlreadySpentError,
TransactionError,
TransactionUnitError,
)
Expand Down Expand Up @@ -67,12 +66,6 @@ async def verify_inputs_and_outputs(
# Verify inputs
if not proofs:
raise TransactionError("no proofs provided.")
# Verify proofs are spendable
if (
not len(await self.db_read._get_proofs_spent([p.Y for p in proofs], conn))
== 0
):
raise TokenAlreadySpentError()
# Verify amounts of inputs
if not all([self._verify_amount(p.amount) for p in proofs]):
raise TransactionError("invalid amount.")
Expand Down
26 changes: 14 additions & 12 deletions tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ async def get_connection():


@pytest.mark.asyncio
async def test_db_set_proofs_pending_race_condition(wallet: Wallet, ledger: Ledger):
async def test_db_verify_spent_proofs_and_set_pending_race_condition(
wallet: Wallet, ledger: Ledger
):
# fill wallet
invoice = await wallet.request_mint(64)
await pay_if_regtest(invoice.bolt11)
Expand All @@ -193,8 +195,8 @@ async def test_db_set_proofs_pending_race_condition(wallet: Wallet, ledger: Ledg

await assert_err_multiple(
asyncio.gather(
ledger.db_write._set_proofs_pending(wallet.proofs),
ledger.db_write._set_proofs_pending(wallet.proofs),
ledger.db_write._verify_spent_proofs_and_set_pending(wallet.proofs),
ledger.db_write._verify_spent_proofs_and_set_pending(wallet.proofs),
),
[
"failed to acquire database lock",
Expand All @@ -204,7 +206,7 @@ async def test_db_set_proofs_pending_race_condition(wallet: Wallet, ledger: Ledg


@pytest.mark.asyncio
async def test_db_set_proofs_pending_delayed_no_race_condition(
async def test_db_verify_spent_proofs_and_set_pending_delayed_no_race_condition(
wallet: Wallet, ledger: Ledger
):
# fill wallet
Expand All @@ -213,21 +215,21 @@ async def test_db_set_proofs_pending_delayed_no_race_condition(
await wallet.mint(64, id=invoice.id)
assert wallet.balance == 64

async def delayed_set_proofs_pending():
async def delayed_verify_spent_proofs_and_set_pending():
await asyncio.sleep(0.1)
await ledger.db_write._set_proofs_pending(wallet.proofs)
await ledger.db_write._verify_spent_proofs_and_set_pending(wallet.proofs)

await assert_err(
asyncio.gather(
ledger.db_write._set_proofs_pending(wallet.proofs),
delayed_set_proofs_pending(),
ledger.db_write._verify_spent_proofs_and_set_pending(wallet.proofs),
delayed_verify_spent_proofs_and_set_pending(),
),
"proofs are pending",
)


@pytest.mark.asyncio
async def test_db_set_proofs_pending_no_race_condition_different_proofs(
async def test_db_verify_spent_proofs_and_set_pending_no_race_condition_different_proofs(
wallet: Wallet, ledger: Ledger
):
# fill wallet
Expand All @@ -238,8 +240,8 @@ async def test_db_set_proofs_pending_no_race_condition_different_proofs(
assert len(wallet.proofs) == 2

asyncio.gather(
ledger.db_write._set_proofs_pending(wallet.proofs[:1]),
ledger.db_write._set_proofs_pending(wallet.proofs[1:]),
ledger.db_write._verify_spent_proofs_and_set_pending(wallet.proofs[:1]),
ledger.db_write._verify_spent_proofs_and_set_pending(wallet.proofs[1:]),
)


Expand Down Expand Up @@ -300,6 +302,6 @@ async def test_db_lock_table(wallet: Wallet, ledger: Ledger):
async with ledger.db.connect(lock_table="proofs_pending", lock_timeout=0.1) as conn:
assert isinstance(conn, Connection)
await assert_err(
ledger.db_write._set_proofs_pending(wallet.proofs),
ledger.db_write._verify_spent_proofs_and_set_pending(wallet.proofs),
"failed to acquire database lock",
)
2 changes: 1 addition & 1 deletion tests/test_mint_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def test_mint_proofs_pending(wallet1: Wallet, ledger: Ledger):
[s.state == ProofSpentState.unspent for s in proofs_states_before_split.states]
)

await ledger.db_write._set_proofs_pending(proofs)
await ledger.db_write._verify_spent_proofs_and_set_pending(proofs)

proof_states = await wallet1.check_proof_state(proofs)
assert all([s.state == ProofSpentState.pending for s in proof_states.states])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ async def test_double_spend(wallet1: Wallet):
await wallet1.split(wallet1.proofs, 20)
await assert_err(
wallet1.split(doublespend, 20),
"Mint Error: Token already spent.",
"Token already spent.",
)
assert wallet1.balance == 64
assert wallet1.available_balance == 64
Expand Down

0 comments on commit efdfecc

Please sign in to comment.