Skip to content

Commit

Permalink
fix fee return and melt quote max allowed amount check during creatio…
Browse files Browse the repository at this point in the history
…n of melt quote
  • Loading branch information
callebtc committed May 30, 2024
1 parent e360502 commit 333b0da
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 47 deletions.
83 changes: 47 additions & 36 deletions cashu/mint/ledger.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,8 @@ async def _invalidate_proofs(

async def _generate_change_promises(
self,
input_amount: int,
output_amount: int,
output_fee_paid: int,
fee_provided: int,
fee_paid: int,
outputs: Optional[List[BlindedMessage]],
keyset: Optional[MintKeyset] = None,
) -> List[BlindedSignature]:
Expand All @@ -329,34 +328,35 @@ async def _generate_change_promises(
List[BlindedSignature]: Signatures on the outputs.
"""
# we make sure that the fee is positive
user_fee_paid = input_amount - output_amount
overpaid_fee = user_fee_paid - output_fee_paid
overpaid_fee = fee_provided - fee_paid

if overpaid_fee == 0 or outputs is None:
return []

logger.debug(
f"Lightning fee was: {output_fee_paid}. User paid: {user_fee_paid}. "
f"Lightning fee was: {fee_paid}. User provided: {fee_provided}. "
f"Returning difference: {overpaid_fee}."
)

if overpaid_fee > 0 and outputs is not None:
return_amounts = amount_split(overpaid_fee)

# We return at most as many outputs as were provided or as many as are
# required to pay back the overpaid fee.
n_return_outputs = min(len(outputs), len(return_amounts))

# we only need as many outputs as we have change to return
outputs = outputs[:n_return_outputs]
# we sort the return_amounts in descending order so we only
# take the largest values in the next step
return_amounts_sorted = sorted(return_amounts, reverse=True)
# we need to imprint these amounts into the blanket outputs
for i in range(len(outputs)):
outputs[i].amount = return_amounts_sorted[i] # type: ignore
if not self._verify_no_duplicate_outputs(outputs):
raise TransactionError("duplicate promises.")
return_promises = await self._generate_promises(outputs, keyset)
return return_promises
else:
return []
return_amounts = amount_split(overpaid_fee)

# We return at most as many outputs as were provided or as many as are
# required to pay back the overpaid fee.
n_return_outputs = min(len(outputs), len(return_amounts))

# we only need as many outputs as we have change to return
outputs = outputs[:n_return_outputs]

# we sort the return_amounts in descending order so we only
# take the largest values in the next step
return_amounts_sorted = sorted(return_amounts, reverse=True)
# we need to imprint these amounts into the blanket outputs
for i in range(len(outputs)):
outputs[i].amount = return_amounts_sorted[i] # type: ignore
if not self._verify_no_duplicate_outputs(outputs):
raise TransactionError("duplicate promises.")
return_promises = await self._generate_promises(outputs, keyset)
return return_promises

# ------- TRANSACTIONS -------

Expand Down Expand Up @@ -593,6 +593,15 @@ async def melt_quote(
if not payment_quote.fee.unit == unit:
raise TransactionError("payment quote fee units do not match")

# verify that the amount of the proofs is not larger than the maximum allowed
if (
settings.mint_max_peg_out
and payment_quote.amount.to(unit).amount > settings.mint_max_peg_out
):
raise NotAllowedError(
f"Maximum melt amount is {settings.mint_max_peg_out} sat."
)

# We assume that the request is a bolt11 invoice, this works since we
# support only the bol11 method for now.
invoice_obj = bolt11.decode(melt_quote.request)
Expand Down Expand Up @@ -782,15 +791,18 @@ async def melt(

# verify that the amount of the input proofs is equal to the amount of the quote
total_provided = sum_proofs(proofs)
total_needed = (
melt_quote.amount
+ melt_quote.fee_reserve
+ self.get_fees_for_proofs(proofs)
)
if not total_provided >= total_needed:
input_fees = self.get_fees_for_proofs(proofs)
total_needed = melt_quote.amount + melt_quote.fee_reserve + input_fees
# we need the fees specifically for lightning to return the overpaid fees
fee_reserve_provided = total_provided - melt_quote.amount - input_fees
if total_provided < total_needed:
raise TransactionError(
f"not enough inputs provided for melt. Provided: {total_provided}, needed: {total_needed}"
)
if fee_reserve_provided < melt_quote.fee_reserve:
raise TransactionError(
f"not enough fee reserve provided for melt. Provided fee reserve: {fee_reserve_provided}, needed: {melt_quote.fee_reserve}"
)

# verify that the amount of the proofs is not larger than the maximum allowed
if settings.mint_max_peg_out and total_provided > settings.mint_max_peg_out:
Expand Down Expand Up @@ -840,9 +852,8 @@ async def melt(
return_promises: List[BlindedSignature] = []
if outputs:
return_promises = await self._generate_change_promises(
input_amount=total_provided,
output_amount=melt_quote.amount,
output_fee_paid=melt_quote.fee_paid,
fee_provided=fee_reserve_provided,
fee_paid=melt_quote.fee_paid,
outputs=outputs,
keyset=self.keysets[outputs[0].id],
)
Expand Down
20 changes: 9 additions & 11 deletions tests/test_mint.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ async def test_generate_promises(ledger: Ledger):
async def test_generate_change_promises(ledger: Ledger):
# Example slightly adapted from NUT-08 because we want to ensure the dynamic change
# token amount works: `n_blank_outputs != n_returned_promises != 4`.
invoice_amount = 100_000
# invoice_amount = 100_000
fee_reserve = 2_000
total_provided = invoice_amount + fee_reserve
# total_provided = invoice_amount + fee_reserve
actual_fee = 100

expected_returned_promises = 7 # Amounts = [4, 8, 32, 64, 256, 512, 1024]
Expand All @@ -150,7 +150,7 @@ async def test_generate_change_promises(ledger: Ledger):
]

promises = await ledger._generate_change_promises(
total_provided, invoice_amount, actual_fee, outputs
fee_provided=fee_reserve, fee_paid=actual_fee, outputs=outputs
)

assert len(promises) == expected_returned_promises
Expand All @@ -161,9 +161,9 @@ async def test_generate_change_promises(ledger: Ledger):
async def test_generate_change_promises_legacy_wallet(ledger: Ledger):
# Check if mint handles a legacy wallet implementation (always sends 4 blank
# outputs) as well.
invoice_amount = 100_000
# invoice_amount = 100_000
fee_reserve = 2_000
total_provided = invoice_amount + fee_reserve
# total_provided = invoice_amount + fee_reserve
actual_fee = 100

expected_returned_promises = 4 # Amounts = [64, 256, 512, 1024]
Expand All @@ -180,24 +180,22 @@ async def test_generate_change_promises_legacy_wallet(ledger: Ledger):
for b, _ in blinded_msgs
]

promises = await ledger._generate_change_promises(
total_provided, invoice_amount, actual_fee, outputs
)
promises = await ledger._generate_change_promises(fee_reserve, actual_fee, outputs)

assert len(promises) == expected_returned_promises
assert sum([promise.amount for promise in promises]) == expected_returned_fees


@pytest.mark.asyncio
async def test_generate_change_promises_returns_empty_if_no_outputs(ledger: Ledger):
invoice_amount = 100_000
# invoice_amount = 100_000
fee_reserve = 1_000
total_provided = invoice_amount + fee_reserve
# total_provided = invoice_amount + fee_reserve
actual_fee_msat = 100_000
outputs = None

promises = await ledger._generate_change_promises(
total_provided, invoice_amount, actual_fee_msat, outputs
fee_reserve, actual_fee_msat, outputs
)
assert len(promises) == 0

Expand Down

0 comments on commit 333b0da

Please sign in to comment.