Skip to content

Commit

Permalink
fix cli
Browse files Browse the repository at this point in the history
  • Loading branch information
callebtc committed Jul 10, 2024
1 parent 55c80a7 commit 55839a5
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 20 deletions.
14 changes: 12 additions & 2 deletions cashu/wallet/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,16 @@ async def swap(ctx: Context):
@coro
async def balance(ctx: Context, verbose):
wallet: Wallet = ctx.obj["WALLET"]
if verbose:
wallet = await wallet.with_db(
url=wallet.url,
db=wallet.db.db_location,
name=wallet.name,
skip_db_read=False,
unit=wallet.unit.name,
load_all_keysets=True,
)

unit_balances = wallet.balance_per_unit()
await wallet.load_proofs(reload=True)

Expand Down Expand Up @@ -597,13 +607,13 @@ async def receive_cli(
# verify that we trust the mint in this tokens
# ask the user if they want to trust the new mint
mint_url = token_obj.mint
mint_wallet = Wallet(
mint_wallet = await Wallet.with_db(
mint_url,
os.path.join(settings.cashu_dir, wallet.name),
unit=token_obj.unit,
)
await verify_mint(mint_wallet, mint_url)
receive_wallet = await receive(wallet, token_obj)
receive_wallet = await receive(mint_wallet, token_obj)
ctx.obj["WALLET"] = receive_wallet
elif nostr:
await receive_nostr(wallet)
Expand Down
36 changes: 22 additions & 14 deletions cashu/wallet/cli/cli_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@ async def get_unit_wallet(ctx: Context, force_select: bool = False):
await wallet.load_proofs(reload=False)
# show balances per unit
unit_balances = wallet.balance_per_unit()
if wallet.unit in [unit_balances.keys()] and not force_select:
return wallet
elif len(unit_balances) > 1 and not ctx.obj["UNIT"]:

logger.debug(f"Wallet URL: {wallet.url}")
logger.debug(f"Wallet unit: {wallet.unit}")
logger.debug(f"mint_balances: {unit_balances}")
logger.debug(f"ctx.obj['UNIT']: {ctx.obj['UNIT']}")

if len(unit_balances) > 1 and not ctx.obj["UNIT"]:
print(f"You have balances in {len(unit_balances)} units:")
print("")
for i, (k, v) in enumerate(unit_balances.items()):
Expand Down Expand Up @@ -68,14 +72,15 @@ async def get_mint_wallet(ctx: Context, force_select: bool = False):
"""
# we load a dummy wallet so we can check the balance per mint
wallet: Wallet = ctx.obj["WALLET"]
await wallet.load_proofs(reload=False)
mint_balances = await wallet.balance_per_minturl()

if ctx.obj["HOST"] not in mint_balances and not force_select:
mint_url = wallet.url
elif len(mint_balances) > 1:
await wallet.load_proofs(reload=True, all_keysets=True)
mint_balances = await wallet.balance_per_minturl(unit=wallet.unit)
logger.debug(f"Wallet URL: {wallet.url}")
logger.debug(f"Wallet unit: {wallet.unit}")
logger.debug(f"mint_balances: {mint_balances}")
logger.debug(f"ctx.obj['HOST']: {ctx.obj['HOST']}")
if len(mint_balances) > 1:
# if we have balances on more than one mint, we ask the user to select one
await print_mint_balances(wallet, show_mints=True)
await print_mint_balances(wallet, show_mints=True, mint_balances=mint_balances)

url_max = max(mint_balances, key=lambda v: mint_balances[v]["available"])
nr_max = list(mint_balances).index(url_max) + 1
Expand All @@ -92,10 +97,10 @@ async def get_mint_wallet(ctx: Context, force_select: bool = False):
mint_url = list(mint_balances.keys())[int(mint_nr_str) - 1]
else:
raise Exception("invalid input.")
elif ctx.obj["HOST"] and ctx.obj["HOST"] not in mint_balances.keys():
mint_url = ctx.obj["HOST"]
elif len(mint_balances) == 1:
mint_url = list(mint_balances.keys())[0]
else:
mint_url = wallet.url

# load this mint_url into a wallet
mint_wallet = await Wallet.with_db(
Expand All @@ -109,12 +114,15 @@ async def get_mint_wallet(ctx: Context, force_select: bool = False):
return mint_wallet


async def print_mint_balances(wallet: Wallet, show_mints: bool = False):
async def print_mint_balances(
wallet: Wallet, show_mints: bool = False, mint_balances=None
):
"""
Helper function that prints the balances for each mint URL that we have tokens from.
"""
# get balances per mint
mint_balances = await wallet.balance_per_minturl(unit=wallet.unit)
mint_balances = mint_balances or await wallet.balance_per_minturl(unit=wallet.unit)
logger.trace(mint_balances)
# if we have a balance on a non-default mint, we show its URL
keysets = [k for k, v in wallet.balance_per_keyset().items()]
for k in keysets:
Expand Down
11 changes: 7 additions & 4 deletions cashu/wallet/wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ async def load_mint_keysets(self):
logger.trace("Loading mint keysets.")
mint_keysets_resp = await self._get_keysets()
mint_keysets_dict = {k.id: k for k in mint_keysets_resp}

# load all keysets of thisd mint from the db
keysets_in_db = await get_keysets(mint_url=self.url, db=self.db)

Expand Down Expand Up @@ -285,7 +284,7 @@ async def load_mint(self, keyset_id: str = "") -> None:
logger.debug(f"Could not load mint info: {e}")
pass

async def load_proofs(self, reload: bool = False) -> None:
async def load_proofs(self, reload: bool = False, all_keysets=False) -> None:
"""Load all proofs of the selected mint and unit (i.e. self.keysets) into memory."""

if self.proofs and not reload:
Expand All @@ -295,9 +294,13 @@ async def load_proofs(self, reload: bool = False) -> None:
self.proofs = []
await self.load_keysets_from_db()
async with self.db.connect() as conn:
for keyset_id in self.keysets:
proofs = await get_proofs(db=self.db, id=keyset_id, conn=conn)
if all_keysets:
proofs = await get_proofs(db=self.db, conn=conn)
self.proofs.extend(proofs)
else:
for keyset_id in self.keysets:
proofs = await get_proofs(db=self.db, id=keyset_id, conn=conn)
self.proofs.extend(proofs)
logger.trace(
f"Proofs loaded for keysets: {' '.join([k.id + f' ({k.unit})' for k in self.keysets.values()])}"
)
Expand Down

0 comments on commit 55839a5

Please sign in to comment.