Skip to content

Commit

Permalink
accept new keyset id calculation, cache keysets, remove duplicate rep…
Browse files Browse the repository at this point in the history
…resentations
  • Loading branch information
callebtc committed Oct 7, 2023
1 parent 1149533 commit 910a65a
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 32 deletions.
79 changes: 51 additions & 28 deletions cashu/wallet/wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,12 @@ async def wrapper(self, *args, **kwargs):


class LedgerAPI(object):
keys: WalletKeyset # holds current keys of mint
keysets: Dict[str, WalletKeyset] # holds keysets of mint
mint_keyset_ids: List[str] # holds ids of keysets of mint
# keys: WalletKeyset # holds current keys of mint
keyset_id: str # holds id of current keyset
public_keys: Dict[int, PublicKey] # holds public keys of
# public_keys: Dict[int, PublicKey] # holds public keys of current keyset

mint_info: GetInfoResponse # holds info about mint
tor: TorProxy
s: requests.Session
Expand All @@ -111,6 +114,7 @@ def __init__(self, url: str, db: Database):
self.url = url
self.s = requests.Session()
self.db = db
self.keysets = {}

@async_set_requests
async def _init_s(self):
Expand All @@ -137,7 +141,7 @@ def raise_on_error(resp: Response) -> None:
# raise for status if no error
resp.raise_for_status()

async def _load_mint_keys(self, keyset_id: str = "") -> WalletKeyset:
async def _load_mint_keys(self, keyset_id: Optional[str] = None) -> None:
"""Loads keys from mint and stores them in the database.
Args:
Expand All @@ -163,20 +167,29 @@ async def _load_mint_keys(self, keyset_id: str = "") -> WalletKeyset:
assert keyset.id
assert len(keyset.public_keys) > 0, "did not receive keys from mint."

if keyset_id != keyset.id:
# NOTE: Because of the upcoming change of how to calculate keyset ids
# with version 0.14.0, we overwrite the calculated keyset id with the
# requested one. This is a temporary fix and should be removed once all
# ecash is transitioned to 0.14.0.
logger.debug(
"Keyset ID mismatch. This can happen due to a version upgrade."
)
keyset.id = keyset_id or keyset.id

# check if current keyset is in db
keyset_local: Optional[WalletKeyset] = await get_keyset(keyset.id, db=self.db)
# if not, store it
if keyset_local is None:
logger.debug(f"Storing new mint keyset: {keyset.id}")
await store_keyset(keyset=keyset, db=self.db)

self.keys = keyset
assert self.keys.public_keys
self.public_keys = self.keys.public_keys
assert self.keys.id
self.keyset_id = self.keys.id
logger.debug(f"Current mint keyset: {self.keys.id}")
return self.keys
# set current keyset id
self.keyset_id = keyset.id
logger.debug(f"Current mint keyset: {self.keyset_id}")

# add keyset to keysets dict
self.keysets[keyset.id] = keyset

async def _load_mint_keysets(self) -> List[str]:
"""Loads the keyset IDs of the mint.
Expand All @@ -191,11 +204,13 @@ async def _load_mint_keysets(self) -> List[str]:
try:
mint_keysets = await self._get_keyset_ids(self.url)
except Exception:
assert self.keys.id, "could not get keysets from mint, and do not have keys"
assert self.keysets[
self.keyset_id
].id, "could not get keysets from mint, and do not have keys"
pass
self.keysets = mint_keysets or [self.keys.id]
logger.debug(f"Mint keysets: {self.keysets}")
return self.keysets
self.mint_keyset_ids = mint_keysets or [self.keysets[self.keyset_id].id]
logger.debug(f"Mint keysets: {self.mint_keyset_ids}")
return self.mint_keyset_ids

async def _load_mint_info(self) -> GetInfoResponse:
"""Loads the mint info from the mint."""
Expand All @@ -207,7 +222,7 @@ async def _load_mint(self, keyset_id: str = "") -> None:
"""
Loads the public keys of the mint. Either gets the keys for the specified
`keyset_id` or gets the keys of the active keyset from the mint.
Gets the active keyset ids of the mint and stores in `self.keysets`.
Gets the active keyset ids of the mint and stores in `self.mint_keyset_ids`.
"""
await self._load_mint_keys(keyset_id)
await self._load_mint_keysets()
Expand All @@ -218,7 +233,9 @@ async def _load_mint(self, keyset_id: str = "") -> None:
pass

if keyset_id:
assert keyset_id in self.keysets, f"keyset {keyset_id} not active on mint"
assert (
keyset_id in self.mint_keyset_ids
), f"keyset {keyset_id} not active on mint"

async def _check_used_secrets(self, secrets):
"""Checks if any of the secrets have already been used"""
Expand Down Expand Up @@ -568,7 +585,7 @@ async def _migrate_database(self):
async def load_mint(self, keyset_id: str = ""):
"""Load a mint's keys with a given keyset_id if specified or else
loads the active keyset of the mint into self.keys.
Also loads all keyset ids into self.keysets.
Also loads all keyset ids into self.mint_keyset_ids.
Args:
keyset_id (str, optional): _description_. Defaults to "".
Expand Down Expand Up @@ -818,14 +835,14 @@ def verify_proofs_dleq(self, proofs: List[Proof]):
logger.trace("No DLEQ proof in proof.")
return
logger.trace("Verifying DLEQ proof.")
assert self.keys.public_keys
assert proof.id
if not b_dhke.carol_verify_dleq(
secret_msg=proof.secret,
C=PublicKey(bytes.fromhex(proof.C), raw=True),
r=PrivateKey(bytes.fromhex(proof.dleq.r), raw=True),
e=PrivateKey(bytes.fromhex(proof.dleq.e), raw=True),
s=PrivateKey(bytes.fromhex(proof.dleq.s), raw=True),
A=self.keys.public_keys[proof.amount],
A=self.keysets[proof.id].public_keys[proof.amount],
):
raise Exception("DLEQ proof invalid.")
else:
Expand Down Expand Up @@ -855,13 +872,15 @@ async def _construct_proofs(
logger.trace("Constructing proofs.")
proofs: List[Proof] = []
for promise, secret, r, path in zip(promises, secrets, rs, derivation_paths):
logger.trace(f"Creating proof with keyset {self.keyset_id} = {promise.id}")
assert (
self.keyset_id == promise.id
), "our keyset id does not match promise id."
if promise.id not in self.keysets:
# we don't have the keyset for this promise, so we load it
await self._load_mint_keys(promise.id)
assert promise.id in self.keysets, "Could not load keyset."

C_ = PublicKey(bytes.fromhex(promise.C_), raw=True)
C = b_dhke.step3_alice(C_, r, self.public_keys[promise.amount])
C = b_dhke.step3_alice(
C_, r, self.keysets[promise.id].public_keys[promise.amount]
)
B_, r = b_dhke.step1_alice(secret, r) # recompute B_ for dleq proofs

proof = Proof(
Expand Down Expand Up @@ -1104,7 +1123,7 @@ async def _select_proofs_to_send(
Rules:
1) Proofs that are not marked as reserved
2) Proofs that have a keyset id that is in self.keysets (all active keysets of mint)
2) Proofs that have a keyset id that is in self.mint_keyset_ids (all active keysets of mint)
3) Include all proofs that have an older keyset than the current keyset of the mint (to get rid of old epochs).
4) If the target amount is not reached, add proofs of the current keyset until it is.
"""
Expand All @@ -1114,19 +1133,23 @@ async def _select_proofs_to_send(
proofs = [p for p in proofs if not p.reserved]

# select proofs that are in the active keysets of the mint
proofs = [p for p in proofs if p.id in self.keysets or not p.id]
proofs = [p for p in proofs if p.id in self.mint_keyset_ids or not p.id]

# check that enough spendable proofs exist
if sum_proofs(proofs) < amount_to_send:
raise Exception("balance too low.")

# add all proofs that have an older keyset than the current keyset of the mint
proofs_old_epochs = [p for p in proofs if p.id != self.keys.id]
proofs_old_epochs = [
p for p in proofs if p.id != self.keysets[self.keyset_id].id
]
send_proofs += proofs_old_epochs

# coinselect based on amount only from the current keyset
# start with the proofs with the largest amount and add them until the target amount is reached
proofs_current_epoch = [p for p in proofs if p.id == self.keys.id]
proofs_current_epoch = [
p for p in proofs if p.id == self.keysets[self.keyset_id].id
]
sorted_proofs_of_current_keyset = sorted(
proofs_current_epoch, key=lambda p: p.amount
)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ async def wallet3(mint):

@pytest.mark.asyncio
async def test_get_keys(wallet1: Wallet):
assert wallet1.keys.public_keys
assert len(wallet1.keys.public_keys) == settings.max_order
assert wallet1.keysets[wallet1.keyset_id].public_keys
assert len(wallet1.keysets[wallet1.keyset_id].public_keys) == settings.max_order
keyset = await wallet1._get_keys(wallet1.url)
assert keyset.id is not None
assert keyset.id == "1cCNIAZ2X/w1"
Expand All @@ -100,8 +100,8 @@ async def test_get_keys(wallet1: Wallet):

@pytest.mark.asyncio
async def test_get_keyset(wallet1: Wallet):
assert wallet1.keys.public_keys
assert len(wallet1.keys.public_keys) == settings.max_order
assert wallet1.keysets[wallet1.keyset_id].public_keys
assert len(wallet1.keysets[wallet1.keyset_id].public_keys) == settings.max_order
# let's get the keys first so we can get a keyset ID that we use later
keys1 = await wallet1._get_keys(wallet1.url)
# gets the keys of a specific keyset
Expand Down

0 comments on commit 910a65a

Please sign in to comment.