From 404953856e33d23d9205ff1aab1497156f9f24d7 Mon Sep 17 00:00:00 2001 From: callebtc <93376500+callebtc@users.noreply.github.com> Date: Fri, 24 May 2024 19:56:48 +0200 Subject: [PATCH] wallet works with units --- cashu/core/base.py | 3 +++ cashu/wallet/cli/cli.py | 15 ++++++++++----- cashu/wallet/cli/cli_helpers.py | 9 +++++---- cashu/wallet/helpers.py | 27 +++++++++++++++------------ cashu/wallet/proofs.py | 8 ++++++-- cashu/wallet/wallet.py | 12 ++++-------- 6 files changed, 43 insertions(+), 31 deletions(-) diff --git a/cashu/core/base.py b/cashu/core/base.py index 0c528235..a386887d 100644 --- a/cashu/core/base.py +++ b/cashu/core/base.py @@ -669,11 +669,14 @@ class TokenV3(BaseModel): token: List[TokenV3Token] = [] memo: Optional[str] = None + unit: Optional[str] = None def to_dict(self, include_dleq=False): return_dict = dict(token=[t.to_dict(include_dleq) for t in self.token]) if self.memo: return_dict.update(dict(memo=self.memo)) # type: ignore + if self.unit: + return_dict.update(dict(unit=self.unit)) # type: ignore return return_dict def get_proofs(self): diff --git a/cashu/wallet/cli/cli.py b/cashu/wallet/cli/cli.py index 7e38c959..173992fa 100644 --- a/cashu/wallet/cli/cli.py +++ b/cashu/wallet/cli/cli.py @@ -138,7 +138,8 @@ async def cli(ctx: Context, host: str, walletname: str, unit: str, tests: bool): ctx.ensure_object(dict) ctx.obj["HOST"] = host or settings.mint_url - ctx.obj["UNIT"] = unit + ctx.obj["UNIT"] = unit or settings.wallet_unit + unit = ctx.obj["UNIT"] ctx.obj["WALLET_NAME"] = walletname settings.wallet_name = walletname @@ -147,16 +148,18 @@ async def cli(ctx: Context, host: str, walletname: str, unit: str, tests: bool): # otherwise it will create a mnemonic and store it in the database if ctx.invoked_subcommand == "restore": wallet = await Wallet.with_db( - ctx.obj["HOST"], db_path, name=walletname, skip_db_read=True + ctx.obj["HOST"], db_path, name=walletname, skip_db_read=True, unit=unit ) else: # # we need to run the migrations before we load the wallet for the first time # # otherwise the wallet will not be able to generate a new private key and store it wallet = await Wallet.with_db( - ctx.obj["HOST"], db_path, name=walletname, skip_db_read=True + ctx.obj["HOST"], db_path, name=walletname, skip_db_read=True, unit=unit ) # now with the migrations done, we can load the wallet and generate a new mnemonic if needed - wallet = await Wallet.with_db(ctx.obj["HOST"], db_path, name=walletname) + wallet = await Wallet.with_db( + ctx.obj["HOST"], db_path, name=walletname, unit=unit + ) assert wallet, "Wallet not found." ctx.obj["WALLET"] = wallet @@ -514,7 +517,9 @@ async def receive_cli( # ask the user if they want to trust the new mints for mint_url in set([t.mint for t in tokenObj.token if t.mint]): mint_wallet = Wallet( - mint_url, os.path.join(settings.cashu_dir, wallet.name) + mint_url, + os.path.join(settings.cashu_dir, wallet.name), + unit=tokenObj.unit or wallet.unit.name, ) await verify_mint(mint_wallet, mint_url) receive_wallet = await receive(wallet, tokenObj) diff --git a/cashu/wallet/cli/cli_helpers.py b/cashu/wallet/cli/cli_helpers.py index be6ee102..c0c02df9 100644 --- a/cashu/wallet/cli/cli_helpers.py +++ b/cashu/wallet/cli/cli_helpers.py @@ -24,11 +24,11 @@ async def get_unit_wallet(ctx: Context, force_select: bool = False): force_select (bool, optional): Force the user to select a unit. Defaults to False. """ wallet: Wallet = ctx.obj["WALLET"] - await wallet.load_proofs(reload=True) + await wallet.load_proofs(reload=False) # show balances per unit unit_balances = wallet.balance_per_unit() - if ctx.obj["UNIT"] in [u.name for u in unit_balances] and not force_select: - wallet.unit = Unit[ctx.obj["UNIT"]] + if wallet.unit in [unit_balances.keys()] and not force_select: + return wallet elif len(unit_balances) > 1 and not ctx.obj["UNIT"]: print(f"You have balances in {len(unit_balances)} units:") print("") @@ -68,7 +68,7 @@ async def get_mint_wallet(ctx: Context, force_select: bool = False): """ # we load a dummy wallet so we can check the balance per mint wallet: Wallet = ctx.obj["WALLET"] - await wallet.load_proofs(reload=True) + await wallet.load_proofs(reload=False) mint_balances = await wallet.balance_per_minturl() if ctx.obj["HOST"] not in mint_balances and not force_select: @@ -102,6 +102,7 @@ async def get_mint_wallet(ctx: Context, force_select: bool = False): mint_url, os.path.join(settings.cashu_dir, ctx.obj["WALLET_NAME"]), name=wallet.name, + unit=wallet.unit.name, ) await mint_wallet.load_proofs(reload=True) diff --git a/cashu/wallet/helpers.py b/cashu/wallet/helpers.py index 6e370cdc..5c8ac075 100644 --- a/cashu/wallet/helpers.py +++ b/cashu/wallet/helpers.py @@ -40,23 +40,26 @@ async def redeem_TokenV3_multimint(wallet: Wallet, token: TokenV3) -> Wallet: Helper function to iterate thruogh a token with multiple mints and redeem them from these mints one keyset at a time. """ + if not token.unit: + # load unit from wallet keyset db + keysets = await get_keysets(id=token.token[0].proofs[0].id, db=wallet.db) + if keysets: + token.unit = keysets[0].unit.name + for t in token.token: assert t.mint, Exception( "redeem_TokenV3_multimint: multimint redeem without URL" ) mint_wallet = await Wallet.with_db( - t.mint, os.path.join(settings.cashu_dir, wallet.name) + t.mint, + os.path.join(settings.cashu_dir, wallet.name), + unit=token.unit or wallet.unit.name, ) keyset_ids = mint_wallet._get_proofs_keysets(t.proofs) logger.trace(f"Keysets in tokens: {' '.join(set(keyset_ids))}") - # loop over all keysets - for keyset_id in set(keyset_ids): - await mint_wallet.load_mint(keyset_id) - mint_wallet.unit = mint_wallet.keysets[keyset_id].unit - # redeem proofs of this keyset - redeem_proofs = [p for p in t.proofs if p.id == keyset_id] - _, _ = await mint_wallet.redeem(redeem_proofs) - print(f"Received {mint_wallet.unit.str(sum_proofs(redeem_proofs))}") + await mint_wallet.load_mint() + _, _ = await mint_wallet.redeem(t.proofs) + print(f"Received {mint_wallet.unit.str(sum_proofs(t.proofs))}") # return the last mint_wallet return mint_wallet @@ -137,19 +140,19 @@ async def receive( ) else: # this is very legacy code, virtually any token should have mint information - # no mint information present, we extract the proofs and use wallet's default mint - # first we load the mint URL from the DB + # no mint information present, we extract the proofs find the mint and unit from the db keyset_in_token = proofs[0].id assert keyset_in_token # we get the keyset from the db mint_keysets = await get_keysets(id=keyset_in_token, db=wallet.db) assert mint_keysets, Exception(f"we don't know this keyset: {keyset_in_token}") - mint_keyset = mint_keysets[0] + mint_keyset = [k for k in mint_keysets if k.id == keyset_in_token][0] assert mint_keyset.mint_url, Exception("we don't know this mint's URL") # now we have the URL mint_wallet = await Wallet.with_db( mint_keyset.mint_url, os.path.join(settings.cashu_dir, wallet.name), + unit=mint_keyset.unit.name or wallet.unit.name, ) await mint_wallet.load_mint(keyset_in_token) _, _ = await mint_wallet.redeem(proofs) diff --git a/cashu/wallet/proofs.py b/cashu/wallet/proofs.py index 08ef7c81..d3bd4cf4 100644 --- a/cashu/wallet/proofs.py +++ b/cashu/wallet/proofs.py @@ -70,7 +70,7 @@ def _get_proofs_keysets(self, proofs: List[Proof]) -> List[str]: Args: proofs (List[Proof]): List of proofs to get the keyset id's of """ - keysets: List[str] = [proof.id for proof in proofs if proof.id] + keysets: List[str] = [proof.id for proof in proofs] return keysets async def _get_keyset_urls(self, keysets: List[str]) -> Dict[str, List[str]]: @@ -92,7 +92,9 @@ async def _get_keyset_urls(self, keysets: List[str]) -> Dict[str, List[str]]: ) return mint_urls - async def _make_token(self, proofs: List[Proof], include_mints=True) -> TokenV3: + async def _make_token( + self, proofs: List[Proof], include_mints=True, include_unit=True + ) -> TokenV3: """ Takes list of proofs and produces a TokenV3 by looking up the mint URLs by the keyset id from the database. @@ -105,6 +107,8 @@ async def _make_token(self, proofs: List[Proof], include_mints=True) -> TokenV3: TokenV3: TokenV3 object """ token = TokenV3() + if include_unit: + token.unit = self.unit.name if include_mints: # we create a map from mint url to keyset id and then group diff --git a/cashu/wallet/wallet.py b/cashu/wallet/wallet.py index 26f24bd4..ce37cc49 100644 --- a/cashu/wallet/wallet.py +++ b/cashu/wallet/wallet.py @@ -692,12 +692,7 @@ class Wallet( bip32: BIP32 # private_key: Optional[PrivateKey] = None - def __init__( - self, - url: str, - db: str, - name: str = "no_name", - ): + def __init__(self, url: str, db: str, name: str = "no_name", unit: str = "sat"): """A Cashu wallet. Args: @@ -708,7 +703,7 @@ def __init__( self.db = Database("wallet", db) self.proofs: List[Proof] = [] self.name = name - self.unit = Unit[settings.wallet_unit] + self.unit = Unit[unit] super().__init__(url=url, db=self.db) logger.debug("Wallet initialized") @@ -723,6 +718,7 @@ async def with_db( db: str, name: str = "no_name", skip_db_read: bool = False, + unit: str = "sat", ): """Initializes a wallet with a database and initializes the private key. @@ -738,7 +734,7 @@ async def with_db( Wallet: Initialized wallet. """ logger.trace(f"Initializing wallet with database: {db}") - self = cls(url=url, db=db, name=name) + self = cls(url=url, db=db, name=name, unit=unit) await self._migrate_database() if not skip_db_read: logger.trace("Mint init: loading private key and keysets from db.")