Skip to content

Commit

Permalink
wallet works with units
Browse files Browse the repository at this point in the history
  • Loading branch information
callebtc committed May 24, 2024
1 parent 281985d commit 4049538
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 31 deletions.
3 changes: 3 additions & 0 deletions cashu/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,11 +669,14 @@ class TokenV3(BaseModel):

token: List[TokenV3Token] = []
memo: Optional[str] = None
unit: Optional[str] = None

def to_dict(self, include_dleq=False):
return_dict = dict(token=[t.to_dict(include_dleq) for t in self.token])
if self.memo:
return_dict.update(dict(memo=self.memo)) # type: ignore
if self.unit:
return_dict.update(dict(unit=self.unit)) # type: ignore
return return_dict

def get_proofs(self):
Expand Down
15 changes: 10 additions & 5 deletions cashu/wallet/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ async def cli(ctx: Context, host: str, walletname: str, unit: str, tests: bool):

ctx.ensure_object(dict)
ctx.obj["HOST"] = host or settings.mint_url
ctx.obj["UNIT"] = unit
ctx.obj["UNIT"] = unit or settings.wallet_unit
unit = ctx.obj["UNIT"]
ctx.obj["WALLET_NAME"] = walletname
settings.wallet_name = walletname

Expand All @@ -147,16 +148,18 @@ async def cli(ctx: Context, host: str, walletname: str, unit: str, tests: bool):
# otherwise it will create a mnemonic and store it in the database
if ctx.invoked_subcommand == "restore":
wallet = await Wallet.with_db(
ctx.obj["HOST"], db_path, name=walletname, skip_db_read=True
ctx.obj["HOST"], db_path, name=walletname, skip_db_read=True, unit=unit
)
else:
# # we need to run the migrations before we load the wallet for the first time
# # otherwise the wallet will not be able to generate a new private key and store it
wallet = await Wallet.with_db(
ctx.obj["HOST"], db_path, name=walletname, skip_db_read=True
ctx.obj["HOST"], db_path, name=walletname, skip_db_read=True, unit=unit
)
# now with the migrations done, we can load the wallet and generate a new mnemonic if needed
wallet = await Wallet.with_db(ctx.obj["HOST"], db_path, name=walletname)
wallet = await Wallet.with_db(
ctx.obj["HOST"], db_path, name=walletname, unit=unit
)

assert wallet, "Wallet not found."
ctx.obj["WALLET"] = wallet
Expand Down Expand Up @@ -514,7 +517,9 @@ async def receive_cli(
# ask the user if they want to trust the new mints
for mint_url in set([t.mint for t in tokenObj.token if t.mint]):
mint_wallet = Wallet(
mint_url, os.path.join(settings.cashu_dir, wallet.name)
mint_url,
os.path.join(settings.cashu_dir, wallet.name),
unit=tokenObj.unit or wallet.unit.name,
)
await verify_mint(mint_wallet, mint_url)
receive_wallet = await receive(wallet, tokenObj)
Expand Down
9 changes: 5 additions & 4 deletions cashu/wallet/cli/cli_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ async def get_unit_wallet(ctx: Context, force_select: bool = False):
force_select (bool, optional): Force the user to select a unit. Defaults to False.
"""
wallet: Wallet = ctx.obj["WALLET"]
await wallet.load_proofs(reload=True)
await wallet.load_proofs(reload=False)
# show balances per unit
unit_balances = wallet.balance_per_unit()
if ctx.obj["UNIT"] in [u.name for u in unit_balances] and not force_select:
wallet.unit = Unit[ctx.obj["UNIT"]]
if wallet.unit in [unit_balances.keys()] and not force_select:
return wallet
elif len(unit_balances) > 1 and not ctx.obj["UNIT"]:
print(f"You have balances in {len(unit_balances)} units:")
print("")
Expand Down Expand Up @@ -68,7 +68,7 @@ async def get_mint_wallet(ctx: Context, force_select: bool = False):
"""
# we load a dummy wallet so we can check the balance per mint
wallet: Wallet = ctx.obj["WALLET"]
await wallet.load_proofs(reload=True)
await wallet.load_proofs(reload=False)
mint_balances = await wallet.balance_per_minturl()

if ctx.obj["HOST"] not in mint_balances and not force_select:
Expand Down Expand Up @@ -102,6 +102,7 @@ async def get_mint_wallet(ctx: Context, force_select: bool = False):
mint_url,
os.path.join(settings.cashu_dir, ctx.obj["WALLET_NAME"]),
name=wallet.name,
unit=wallet.unit.name,
)
await mint_wallet.load_proofs(reload=True)

Expand Down
27 changes: 15 additions & 12 deletions cashu/wallet/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,26 @@ async def redeem_TokenV3_multimint(wallet: Wallet, token: TokenV3) -> Wallet:
Helper function to iterate thruogh a token with multiple mints and redeem them from
these mints one keyset at a time.
"""
if not token.unit:
# load unit from wallet keyset db
keysets = await get_keysets(id=token.token[0].proofs[0].id, db=wallet.db)
if keysets:
token.unit = keysets[0].unit.name

for t in token.token:
assert t.mint, Exception(
"redeem_TokenV3_multimint: multimint redeem without URL"
)
mint_wallet = await Wallet.with_db(
t.mint, os.path.join(settings.cashu_dir, wallet.name)
t.mint,
os.path.join(settings.cashu_dir, wallet.name),
unit=token.unit or wallet.unit.name,
)
keyset_ids = mint_wallet._get_proofs_keysets(t.proofs)
logger.trace(f"Keysets in tokens: {' '.join(set(keyset_ids))}")
# loop over all keysets
for keyset_id in set(keyset_ids):
await mint_wallet.load_mint(keyset_id)
mint_wallet.unit = mint_wallet.keysets[keyset_id].unit
# redeem proofs of this keyset
redeem_proofs = [p for p in t.proofs if p.id == keyset_id]
_, _ = await mint_wallet.redeem(redeem_proofs)
print(f"Received {mint_wallet.unit.str(sum_proofs(redeem_proofs))}")
await mint_wallet.load_mint()
_, _ = await mint_wallet.redeem(t.proofs)
print(f"Received {mint_wallet.unit.str(sum_proofs(t.proofs))}")

# return the last mint_wallet
return mint_wallet
Expand Down Expand Up @@ -137,19 +140,19 @@ async def receive(
)
else:
# this is very legacy code, virtually any token should have mint information
# no mint information present, we extract the proofs and use wallet's default mint
# first we load the mint URL from the DB
# no mint information present, we extract the proofs find the mint and unit from the db
keyset_in_token = proofs[0].id
assert keyset_in_token
# we get the keyset from the db
mint_keysets = await get_keysets(id=keyset_in_token, db=wallet.db)
assert mint_keysets, Exception(f"we don't know this keyset: {keyset_in_token}")
mint_keyset = mint_keysets[0]
mint_keyset = [k for k in mint_keysets if k.id == keyset_in_token][0]
assert mint_keyset.mint_url, Exception("we don't know this mint's URL")
# now we have the URL
mint_wallet = await Wallet.with_db(
mint_keyset.mint_url,
os.path.join(settings.cashu_dir, wallet.name),
unit=mint_keyset.unit.name or wallet.unit.name,
)
await mint_wallet.load_mint(keyset_in_token)
_, _ = await mint_wallet.redeem(proofs)
Expand Down
8 changes: 6 additions & 2 deletions cashu/wallet/proofs.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _get_proofs_keysets(self, proofs: List[Proof]) -> List[str]:
Args:
proofs (List[Proof]): List of proofs to get the keyset id's of
"""
keysets: List[str] = [proof.id for proof in proofs if proof.id]
keysets: List[str] = [proof.id for proof in proofs]
return keysets

async def _get_keyset_urls(self, keysets: List[str]) -> Dict[str, List[str]]:
Expand All @@ -92,7 +92,9 @@ async def _get_keyset_urls(self, keysets: List[str]) -> Dict[str, List[str]]:
)
return mint_urls

async def _make_token(self, proofs: List[Proof], include_mints=True) -> TokenV3:
async def _make_token(
self, proofs: List[Proof], include_mints=True, include_unit=True
) -> TokenV3:
"""
Takes list of proofs and produces a TokenV3 by looking up
the mint URLs by the keyset id from the database.
Expand All @@ -105,6 +107,8 @@ async def _make_token(self, proofs: List[Proof], include_mints=True) -> TokenV3:
TokenV3: TokenV3 object
"""
token = TokenV3()
if include_unit:
token.unit = self.unit.name

if include_mints:
# we create a map from mint url to keyset id and then group
Expand Down
12 changes: 4 additions & 8 deletions cashu/wallet/wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,12 +692,7 @@ class Wallet(
bip32: BIP32
# private_key: Optional[PrivateKey] = None

def __init__(
self,
url: str,
db: str,
name: str = "no_name",
):
def __init__(self, url: str, db: str, name: str = "no_name", unit: str = "sat"):
"""A Cashu wallet.
Args:
Expand All @@ -708,7 +703,7 @@ def __init__(
self.db = Database("wallet", db)
self.proofs: List[Proof] = []
self.name = name
self.unit = Unit[settings.wallet_unit]
self.unit = Unit[unit]

super().__init__(url=url, db=self.db)
logger.debug("Wallet initialized")
Expand All @@ -723,6 +718,7 @@ async def with_db(
db: str,
name: str = "no_name",
skip_db_read: bool = False,
unit: str = "sat",
):
"""Initializes a wallet with a database and initializes the private key.
Expand All @@ -738,7 +734,7 @@ async def with_db(
Wallet: Initialized wallet.
"""
logger.trace(f"Initializing wallet with database: {db}")
self = cls(url=url, db=db, name=name)
self = cls(url=url, db=db, name=name, unit=unit)
await self._migrate_database()
if not skip_db_read:
logger.trace("Mint init: loading private key and keysets from db.")
Expand Down

0 comments on commit 4049538

Please sign in to comment.