Skip to content

Commit

Permalink
Wallet: add flag --force-swap to send command
Browse files Browse the repository at this point in the history
  • Loading branch information
callebtc committed Jul 11, 2024
1 parent 77697c5 commit 10ed146
Show file tree
Hide file tree
Showing 18 changed files with 123 additions and 107 deletions.
2 changes: 1 addition & 1 deletion cashu/wallet/api/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ async def swap(
if outgoing_wallet.available_balance < total_amount:
raise Exception("balance too low")

_, send_proofs = await outgoing_wallet.split_to_send(
_, send_proofs = await outgoing_wallet.swap_to_send(
outgoing_wallet.proofs, total_amount, set_reserved=True
)
await outgoing_wallet.melt(
Expand Down
10 changes: 10 additions & 0 deletions cashu/wallet/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,14 @@ async def balance(ctx: Context, verbose):
help="Include fees for receiving token.",
type=bool,
)
@click.option(
"--force-swap",
"-s",
default=False,
is_flag=True,
help="Force swap token.",
type=bool,
)
@click.pass_context
@coro
async def send_command(
Expand All @@ -562,6 +570,7 @@ async def send_command(
yes: bool,
offline: bool,
include_fees: bool,
force_swap: bool,
):
wallet: Wallet = ctx.obj["WALLET"]
amount = int(amount * 100) if wallet.unit in [Unit.usd, Unit.eur] else int(amount)
Expand All @@ -575,6 +584,7 @@ async def send_command(
include_dleq=dleq,
include_fees=include_fees,
memo=memo,
force_swap=force_swap,
)
else:
await send_nostr(wallet, amount=amount, pubkey=nostr, verbose=verbose, yes=yes)
Expand Down
6 changes: 3 additions & 3 deletions cashu/wallet/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ async def send(
include_dleq: bool = False,
include_fees: bool = False,
memo: Optional[str] = None,
force_swap: bool = False,
):
"""
Prints token to send to stdout.
Expand Down Expand Up @@ -144,13 +145,12 @@ async def send(
await wallet.load_proofs()

await wallet.load_mint()
if secret_lock:
_, send_proofs = await wallet.split_to_send(
if secret_lock or force_swap:
_, send_proofs = await wallet.swap_to_send(
wallet.proofs,
amount,
set_reserved=False, # we set reserved later
secret_lock=secret_lock,
include_fees=include_fees,
)
else:
send_proofs, fees = await wallet.select_to_send(
Expand Down
2 changes: 1 addition & 1 deletion cashu/wallet/lightning/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def pay_invoice(self, pr: str) -> PaymentResponse:
if self.available_balance < total_amount:
print("Error: Balance too low.")
return PaymentResponse(ok=False)
_, send_proofs = await self.split_to_send(self.proofs, total_amount)
_, send_proofs = await self.swap_to_send(self.proofs, total_amount)
try:
resp = await self.melt(send_proofs, pr, quote.fee_reserve, quote.quote)
if resp.change:
Expand Down
2 changes: 1 addition & 1 deletion cashu/wallet/nostr.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async def send_nostr(
pubkey = await nip5_to_pubkey(wallet, pubkey)
await wallet.load_mint()
await wallet.load_proofs()
_, send_proofs = await wallet.split_to_send(
_, send_proofs = await wallet.swap_to_send(
wallet.proofs, amount, set_reserved=True, include_fees=False
)
token = await wallet.serialize_proofs(send_proofs, include_dleq=include_dleq)
Expand Down
1 change: 1 addition & 0 deletions cashu/wallet/proofs.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ async def serialize_proofs(
try:
_ = [bytes.fromhex(p.id) for p in proofs]
except ValueError:
logger.debug("Proof with base64 keyset, using legacy token serialization")
legacy = True

if legacy:
Expand Down
105 changes: 55 additions & 50 deletions cashu/wallet/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,55 +36,63 @@ def get_fees_for_proofs(self, proofs: List[Proof]) -> int:
def get_fees_for_proofs_ppk(self, proofs: List[Proof]) -> int:
return sum([self.keysets[p.id].input_fee_ppk for p in proofs])

async def _select_proofs_to_send_(
self, proofs: List[Proof], amount_to_send: int, tolerance: int = 0
) -> List[Proof]:
send_proofs: List[Proof] = []
NO_SELECTION: List[Proof] = []

logger.trace(f"proofs: {[p.amount for p in proofs]}")
# sort proofs by amount (descending)
sorted_proofs = sorted(proofs, key=lambda p: p.amount, reverse=True)
# only consider proofs smaller than the amount we want to send (+ tolerance) for coin selection
fee_for_single_proof = self.get_fees_for_proofs([sorted_proofs[0]])
sorted_proofs = [
p
for p in sorted_proofs
if p.amount <= amount_to_send + tolerance + fee_for_single_proof
]
if not sorted_proofs:
logger.info(
f"no small-enough proofs to send. Have: {[p.amount for p in proofs]}"
)
return NO_SELECTION

target_amount = amount_to_send

# compose the target amount from the remaining_proofs
logger.debug(f"sorted_proofs: {[p.amount for p in sorted_proofs]}")
for p in sorted_proofs:
# logger.debug(f"send_proofs: {[p.amount for p in send_proofs]}")
# logger.debug(f"target_amount: {target_amount}")
# logger.debug(f"p.amount: {p.amount}")
if sum_proofs(send_proofs) + p.amount <= target_amount + tolerance:
send_proofs.append(p)
target_amount = amount_to_send + self.get_fees_for_proofs(send_proofs)

if sum_proofs(send_proofs) < amount_to_send:
logger.info("could not select proofs to reach target amount (too little).")
return NO_SELECTION

fees = self.get_fees_for_proofs(send_proofs)
logger.debug(f"Selected sum of proofs: {sum_proofs(send_proofs)}, fees: {fees}")
return send_proofs
# async def _select_proofs_to_send_legacy(
# self, proofs: List[Proof], amount_to_send: int, tolerance: int = 0
# ) -> List[Proof]:
# send_proofs: List[Proof] = []
# NO_SELECTION: List[Proof] = []

# logger.trace(f"proofs: {[p.amount for p in proofs]}")
# # sort proofs by amount (descending)
# sorted_proofs = sorted(proofs, key=lambda p: p.amount, reverse=True)
# # only consider proofs smaller than the amount we want to send (+ tolerance) for coin selection
# fee_for_single_proof = self.get_fees_for_proofs([sorted_proofs[0]])
# sorted_proofs = [
# p
# for p in sorted_proofs
# if p.amount <= amount_to_send + tolerance + fee_for_single_proof
# ]
# if not sorted_proofs:
# logger.info(
# f"no small-enough proofs to send. Have: {[p.amount for p in proofs]}"
# )
# return NO_SELECTION

# target_amount = amount_to_send

# # compose the target amount from the remaining_proofs
# logger.debug(f"sorted_proofs: {[p.amount for p in sorted_proofs]}")
# for p in sorted_proofs:
# if sum_proofs(send_proofs) + p.amount <= target_amount + tolerance:
# send_proofs.append(p)
# target_amount = amount_to_send + self.get_fees_for_proofs(send_proofs)

# if sum_proofs(send_proofs) < amount_to_send:
# logger.info("could not select proofs to reach target amount (too little).")
# return NO_SELECTION

# fees = self.get_fees_for_proofs(send_proofs)
# logger.debug(f"Selected sum of proofs: {sum_proofs(send_proofs)}, fees: {fees}")
# return send_proofs

async def _select_proofs_to_send(
self,
proofs: List[Proof],
amount_to_send: Union[int, float],
*,
include_fees: bool = True,
include_fees: bool = False,
) -> List[Proof]:
"""Select proofs to send based on the amount to send and the proofs available. Implements a simple coin selection algorithm.
Can be used for selecting proofs to send an offline transaction.
Args:
proofs (List[Proof]): List of proofs to select from
amount_to_send (Union[int, float]): Amount to select proofs for
include_fees (bool, optional): Whether to include fees necessary to redeem the tokens in the selection. Defaults to False.
Returns:
List[Proof]: _description_
"""
# check that enough spendable proofs exist
if sum_proofs(proofs) < amount_to_send:
return []
Expand Down Expand Up @@ -147,9 +155,8 @@ async def _select_proofs_to_split(
Rules:
1) Proofs that are not marked as reserved
2) Proofs that have a different keyset than the activated keyset_id of the mint
3) Include all proofs that have an older keyset than the current keyset of the mint (to get rid of old epochs).
4) If the target amount is not reached, add proofs of the current keyset until it is.
2) Include all proofs from inactive keysets (old epochs) to get rid of them
3) If the target amount is not reached, add proofs of the current keyset until it is.
Args:
proofs (List[Proof]): List of proofs to select from
Expand All @@ -171,11 +178,9 @@ async def _select_proofs_to_split(
if sum_proofs(proofs) < amount_to_send:
raise Exception("balance too low.")

# add all proofs that have an older keyset than the current keyset of the mint
proofs_old_epochs = [
p for p in proofs if p.id != self.keysets[self.keyset_id].id
]
send_proofs += proofs_old_epochs
# add all proofs from inactive keysets
proofs_inactive_keysets = [p for p in proofs if not self.keysets[p.id].active]
send_proofs += proofs_inactive_keysets

# coinselect based on amount only from the current keyset
# start with the proofs with the largest amount and add them until the target amount is reached
Expand Down
4 changes: 2 additions & 2 deletions cashu/wallet/wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,7 +1049,7 @@ async def select_to_send(
if not offline:
logger.debug("Offline coin selection unsuccessful. Splitting proofs.")
# we set the proofs as reserved later
_, send_proofs = await self.split_to_send(
_, send_proofs = await self.swap_to_send(
proofs, amount, set_reserved=False
)
else:
Expand All @@ -1061,7 +1061,7 @@ async def select_to_send(
await self.set_reserved(send_proofs, reserved=True)
return send_proofs, fees

async def split_to_send(
async def swap_to_send(
self,
proofs: List[Proof],
amount: int,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_mint_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ async def test_melt_external(ledger: Ledger, wallet: Wallet):
assert quote.amount == 62
assert quote.fee_reserve == 2

keep, send = await wallet.split_to_send(wallet.proofs, 64)
keep, send = await wallet.swap_to_send(wallet.proofs, 64)
inputs_payload = [p.to_dict() for p in send]

# outputs for change
Expand Down
6 changes: 3 additions & 3 deletions tests/test_mint_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ async def test_startup_regtest_pending_quote_pending(wallet: Wallet, ledger: Led
# wallet pays the invoice
quote = await wallet.melt_quote(invoice_payment_request)
total_amount = quote.amount + quote.fee_reserve
_, send_proofs = await wallet.split_to_send(wallet.proofs, total_amount)
_, send_proofs = await wallet.swap_to_send(wallet.proofs, total_amount)
asyncio.create_task(
wallet.melt(
proofs=send_proofs,
Expand Down Expand Up @@ -297,7 +297,7 @@ async def test_startup_regtest_pending_quote_success(wallet: Wallet, ledger: Led
# wallet pays the invoice
quote = await wallet.melt_quote(invoice_payment_request)
total_amount = quote.amount + quote.fee_reserve
_, send_proofs = await wallet.split_to_send(wallet.proofs, total_amount)
_, send_proofs = await wallet.swap_to_send(wallet.proofs, total_amount)
asyncio.create_task(
wallet.melt(
proofs=send_proofs,
Expand Down Expand Up @@ -347,7 +347,7 @@ async def test_startup_regtest_pending_quote_failure(wallet: Wallet, ledger: Led
# wallet pays the invoice
quote = await wallet.melt_quote(invoice_payment_request)
total_amount = quote.amount + quote.fee_reserve
_, send_proofs = await wallet.split_to_send(wallet.proofs, total_amount)
_, send_proofs = await wallet.swap_to_send(wallet.proofs, total_amount)
asyncio.create_task(
wallet.melt(
proofs=send_proofs,
Expand Down
12 changes: 6 additions & 6 deletions tests/test_mint_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def test_melt_internal(wallet1: Wallet, ledger: Ledger):
assert not melt_quote_pre_payment.paid, "melt quote should not be paid"
assert melt_quote_pre_payment.state == MeltQuoteState.unpaid

keep_proofs, send_proofs = await wallet1.split_to_send(wallet1.proofs, 64)
keep_proofs, send_proofs = await wallet1.swap_to_send(wallet1.proofs, 64)
await ledger.melt(proofs=send_proofs, quote=melt_quote.quote)

melt_quote_post_payment = await ledger.get_melt_quote(melt_quote.quote)
Expand All @@ -85,7 +85,7 @@ async def test_melt_external(wallet1: Wallet, ledger: Ledger):
assert mint_quote.state == MeltQuoteState.unpaid.value

total_amount = mint_quote.amount + mint_quote.fee_reserve
keep_proofs, send_proofs = await wallet1.split_to_send(wallet1.proofs, total_amount)
keep_proofs, send_proofs = await wallet1.swap_to_send(wallet1.proofs, total_amount)
melt_quote = await ledger.melt_quote(
PostMeltQuoteRequest(request=invoice_payment_request, unit="sat")
)
Expand Down Expand Up @@ -169,7 +169,7 @@ async def test_split(wallet1: Wallet, ledger: Ledger):
await pay_if_regtest(invoice.bolt11)
await wallet1.mint(64, id=invoice.id)

keep_proofs, send_proofs = await wallet1.split_to_send(wallet1.proofs, 10)
keep_proofs, send_proofs = await wallet1.swap_to_send(wallet1.proofs, 10)
secrets, rs, derivation_paths = await wallet1.generate_n_secrets(len(send_proofs))
outputs, rs = wallet1._construct_outputs(
[p.amount for p in send_proofs], secrets, rs
Expand All @@ -185,7 +185,7 @@ async def test_split_with_no_outputs(wallet1: Wallet, ledger: Ledger):
invoice = await wallet1.request_mint(64)
await pay_if_regtest(invoice.bolt11)
await wallet1.mint(64, id=invoice.id)
_, send_proofs = await wallet1.split_to_send(wallet1.proofs, 10, set_reserved=False)
_, send_proofs = await wallet1.swap_to_send(wallet1.proofs, 10, set_reserved=False)
await assert_err(
ledger.split(proofs=send_proofs, outputs=[]),
"no outputs provided",
Expand All @@ -198,7 +198,7 @@ async def test_split_with_input_less_than_outputs(wallet1: Wallet, ledger: Ledge
await pay_if_regtest(invoice.bolt11)
await wallet1.mint(64, id=invoice.id)

keep_proofs, send_proofs = await wallet1.split_to_send(
keep_proofs, send_proofs = await wallet1.swap_to_send(
wallet1.proofs, 10, set_reserved=False
)

Expand Down Expand Up @@ -396,7 +396,7 @@ async def test_check_proof_state(wallet1: Wallet, ledger: Ledger):
await pay_if_regtest(invoice.bolt11)
await wallet1.mint(64, id=invoice.id)

keep_proofs, send_proofs = await wallet1.split_to_send(wallet1.proofs, 10)
keep_proofs, send_proofs = await wallet1.swap_to_send(wallet1.proofs, 10)

proof_states = await ledger.db_read.get_proofs_states(Ys=[p.Y for p in send_proofs])
assert all([p.state.value == "UNSPENT" for p in proof_states])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_mint_regtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async def test_regtest_pending_quote(wallet: Wallet, ledger: Ledger):
# wallet pays the invoice
quote = await wallet.melt_quote(invoice_payment_request)
total_amount = quote.amount + quote.fee_reserve
_, send_proofs = await wallet.split_to_send(wallet.proofs, total_amount)
_, send_proofs = await wallet.swap_to_send(wallet.proofs, total_amount)
asyncio.create_task(ledger.melt(proofs=send_proofs, quote=quote.quote))
# asyncio.create_task(
# wallet.melt(
Expand Down
12 changes: 6 additions & 6 deletions tests/test_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,14 +241,14 @@ async def test_split(wallet1: Wallet):


@pytest.mark.asyncio
async def test_split_to_send(wallet1: Wallet):
async def test_swap_to_send(wallet1: Wallet):
invoice = await wallet1.request_mint(64)
await pay_if_regtest(invoice.bolt11)
await wallet1.mint(64, id=invoice.id)
assert wallet1.balance == 64

# this will select 32 sats and them (nothing to keep)
keep_proofs, send_proofs = await wallet1.split_to_send(
keep_proofs, send_proofs = await wallet1.swap_to_send(
wallet1.proofs, 32, set_reserved=True
)
assert_amt(send_proofs, 32)
Expand Down Expand Up @@ -307,7 +307,7 @@ async def test_melt(wallet1: Wallet):
assert total_amount == 64
assert quote.fee_reserve == 0

_, send_proofs = await wallet1.split_to_send(wallet1.proofs, total_amount)
_, send_proofs = await wallet1.swap_to_send(wallet1.proofs, total_amount)

melt_response = await wallet1.melt(
proofs=send_proofs,
Expand Down Expand Up @@ -343,12 +343,12 @@ async def test_melt(wallet1: Wallet):


@pytest.mark.asyncio
async def test_split_to_send_more_than_balance(wallet1: Wallet):
async def test_swap_to_send_more_than_balance(wallet1: Wallet):
invoice = await wallet1.request_mint(64)
await pay_if_regtest(invoice.bolt11)
await wallet1.mint(64, id=invoice.id)
await assert_err(
wallet1.split_to_send(wallet1.proofs, 128, set_reserved=True),
wallet1.swap_to_send(wallet1.proofs, 128, set_reserved=True),
"balance too low.",
)
assert wallet1.balance == 64
Expand Down Expand Up @@ -405,7 +405,7 @@ async def test_send_and_redeem(wallet1: Wallet, wallet2: Wallet):
invoice = await wallet1.request_mint(64)
await pay_if_regtest(invoice.bolt11)
await wallet1.mint(64, id=invoice.id)
_, spendable_proofs = await wallet1.split_to_send(
_, spendable_proofs = await wallet1.swap_to_send(
wallet1.proofs, 32, set_reserved=True
)
await wallet2.redeem(spendable_proofs)
Expand Down
Loading

0 comments on commit 10ed146

Please sign in to comment.