Skip to content

Commit

Permalink
store dleq proofs in wallet db
Browse files Browse the repository at this point in the history
  • Loading branch information
callebtc committed Sep 23, 2023
1 parent 023fc81 commit 35f9a2b
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 148 deletions.
7 changes: 7 additions & 0 deletions cashu/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ class Proof(BaseModel):
time_reserved: Union[None, str] = ""
derivation_path: Union[None, str] = "" # derivation path of the proof

@classmethod
def from_dict(cls, proof_dict: dict):
if proof_dict.get("dleq"):
proof_dict["dleq"] = DLEQWallet(**json.loads(proof_dict["dleq"]))
c = cls(**proof_dict)
return c

def to_dict(self, include_dleq=False):
# dictionary without the fields that don't need to be send to Carol
if not include_dleq:
Expand Down
48 changes: 25 additions & 23 deletions cashu/wallet/crud.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import time
from typing import Any, List, Optional, Tuple

Expand All @@ -9,12 +10,12 @@ async def store_proof(
proof: Proof,
db: Database,
conn: Optional[Connection] = None,
):
) -> None:
await (conn or db).execute(
"""
INSERT INTO proofs
(id, amount, C, secret, time_created, derivation_path)
VALUES (?, ?, ?, ?, ?, ?)
(id, amount, C, secret, time_created, derivation_path, dleq)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(
proof.id,
Expand All @@ -23,36 +24,37 @@ async def store_proof(
str(proof.secret),
int(time.time()),
proof.derivation_path,
json.dumps(proof.dleq.dict()) if proof.dleq else "",
),
)


async def get_proofs(
db: Database,
conn: Optional[Connection] = None,
):
) -> List[Proof]:
rows = await (conn or db).fetchall("""
SELECT * from proofs
""")
return [Proof(**dict(r)) for r in rows]
return [Proof.from_dict(dict(r)) for r in rows]


async def get_reserved_proofs(
db: Database,
conn: Optional[Connection] = None,
):
) -> List[Proof]:
rows = await (conn or db).fetchall("""
SELECT * from proofs
WHERE reserved
""")
return [Proof(**r) for r in rows]
return [Proof.from_dict(dict(r)) for r in rows]


async def invalidate_proof(
proof: Proof,
db: Database,
conn: Optional[Connection] = None,
):
) -> None:
await (conn or db).execute(
"""
DELETE FROM proofs
Expand Down Expand Up @@ -84,7 +86,7 @@ async def update_proof_reserved(
send_id: str = "",
db: Optional[Database] = None,
conn: Optional[Connection] = None,
):
) -> None:
clauses = []
values: List[Any] = []
clauses.append("reserved = ?")
Expand All @@ -109,7 +111,7 @@ async def secret_used(
secret: str,
db: Database,
conn: Optional[Connection] = None,
):
) -> bool:
rows = await (conn or db).fetchone(
"""
SELECT * from proofs
Expand All @@ -124,7 +126,7 @@ async def store_p2sh(
p2sh: P2SHScript,
db: Database,
conn: Optional[Connection] = None,
):
) -> None:
await (conn or db).execute(
"""
INSERT INTO p2sh
Expand All @@ -144,7 +146,7 @@ async def get_unused_locks(
address: str = "",
db: Optional[Database] = None,
conn: Optional[Connection] = None,
):
) -> List[P2SHScript]:
clause: List[str] = []
args: List[str] = []

Expand Down Expand Up @@ -173,7 +175,7 @@ async def update_p2sh_used(
used: bool,
db: Optional[Database] = None,
conn: Optional[Connection] = None,
):
) -> None:
clauses = []
values = []
clauses.append("used = ?")
Expand All @@ -190,7 +192,7 @@ async def store_keyset(
mint_url: str = "",
db: Optional[Database] = None,
conn: Optional[Connection] = None,
):
) -> None:
await (conn or db).execute( # type: ignore
"""
INSERT INTO keysets
Expand Down Expand Up @@ -243,7 +245,7 @@ async def store_lightning_invoice(
db: Database,
invoice: Invoice,
conn: Optional[Connection] = None,
):
) -> None:
await (conn or db).execute(
"""
INSERT INTO invoices
Expand All @@ -266,7 +268,7 @@ async def get_lightning_invoice(
db: Database,
hash: str = "",
conn: Optional[Connection] = None,
):
) -> Invoice:
clauses = []
values: List[Any] = []
if hash:
Expand All @@ -291,7 +293,7 @@ async def get_lightning_invoices(
db: Database,
paid: Optional[bool] = None,
conn: Optional[Connection] = None,
):
) -> List[Invoice]:
clauses: List[Any] = []
values: List[Any] = []

Expand Down Expand Up @@ -319,7 +321,7 @@ async def update_lightning_invoice(
paid: bool,
time_paid: Optional[int] = None,
conn: Optional[Connection] = None,
):
) -> None:
clauses = []
values: List[Any] = []
clauses.append("paid = ?")
Expand All @@ -344,7 +346,7 @@ async def bump_secret_derivation(
by: int = 1,
skip: bool = False,
conn: Optional[Connection] = None,
):
) -> int:
rows = await (conn or db).fetchone(
"SELECT counter from keysets WHERE id = ?", (keyset_id,)
)
Expand Down Expand Up @@ -374,7 +376,7 @@ async def set_secret_derivation(
keyset_id: str,
counter: int,
conn: Optional[Connection] = None,
):
) -> None:
await (conn or db).execute(
"UPDATE keysets SET counter = ? WHERE id = ?",
(
Expand All @@ -388,7 +390,7 @@ async def set_nostr_last_check_timestamp(
db: Database,
timestamp: int,
conn: Optional[Connection] = None,
):
) -> None:
await (conn or db).execute(
"UPDATE nostr SET last = ? WHERE type = ?",
(timestamp, "dm"),
Expand All @@ -398,7 +400,7 @@ async def set_nostr_last_check_timestamp(
async def get_nostr_last_check_timestamp(
db: Database,
conn: Optional[Connection] = None,
):
) -> Optional[int]:
row = await (conn or db).fetchone(
"""
SELECT last from nostr WHERE type = ?
Expand Down Expand Up @@ -432,7 +434,7 @@ async def store_seed_and_mnemonic(
seed: str,
mnemonic: str,
conn: Optional[Connection] = None,
):
) -> None:
await (conn or db).execute(
"""
INSERT INTO seed
Expand Down
7 changes: 7 additions & 0 deletions cashu/wallet/migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,10 @@ async def m009_privatekey_and_determinstic_key_derivation(db: Database):
);
""")
# await db.execute("INSERT INTO secret_derivation (counter) VALUES (0)")


async def m010_add_proofs_dleq(db: Database):
"""
Columns to store DLEQ proofs for proofs.
"""
await db.execute("ALTER TABLE proofs ADD COLUMN dleq TEXT")
126 changes: 1 addition & 125 deletions cashu/wallet/wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,135 +111,11 @@ def __init__(self, url: str, db: Database):
self.s = requests.Session()
self.db = db

# async def generate_n_secrets(
# self, n: int = 1, skip_bump: bool = False
# ) -> Tuple[List[str], List[PrivateKey], List[str]]:
# return await self.generate_n_secrets(n, skip_bump)

# async def _generate_secret(self, skip_bump: bool = False) -> str:
# return await self._generate_secret(skip_bump)

@async_set_requests
async def _init_s(self):
"""Dummy function that can be called from outside to use LedgerAPI.s"""
return

# # ---------- DLEQ PROOFS ----------

# def verify_proofs_dleq(self, proofs: List[Proof]):
# """Verifies DLEQ proofs in proofs."""
# for proof in proofs:
# if not proof.dleq:
# logger.trace("No DLEQ proof in proof.")
# return
# logger.trace("Verifying DLEQ proof.")
# assert self.keys.public_keys
# if not b_dhke.carol_verify_dleq(
# secret_msg=proof.secret,
# C=PublicKey(bytes.fromhex(proof.C), raw=True),
# r=PrivateKey(bytes.fromhex(proof.dleq.r), raw=True),
# e=PrivateKey(bytes.fromhex(proof.dleq.e), raw=True),
# s=PrivateKey(bytes.fromhex(proof.dleq.s), raw=True),
# A=self.keys.public_keys[proof.amount],
# ):
# raise Exception("DLEQ proof invalid.")
# else:
# logger.debug("DLEQ proof valid.")

# def _construct_proofs(
# self,
# promises: List[BlindedSignature],
# secrets: List[str],
# rs: List[PrivateKey],
# derivation_paths: List[str],
# ) -> List[Proof]:
# """Constructs proofs from promises, secrets, rs and derivation paths.

# This method is called after the user has received blind signatures from
# the mint. The results are proofs that can be used as ecash.

# Args:
# promises (List[BlindedSignature]): blind signatures from mint
# secrets (List[str]): secrets that were previously used to create blind messages (that turned into promises)
# rs (List[PrivateKey]): blinding factors that were previously used to create blind messages (that turned into promises)
# derivation_paths (List[str]): derivation paths that were used to generate secrets and blinding factors

# Returns:
# List[Proof]: list of proofs that can be used as ecash
# """
# logger.trace("Constructing proofs.")
# proofs: List[Proof] = []
# for promise, secret, r, path in zip(promises, secrets, rs, derivation_paths):
# logger.trace(f"Creating proof with keyset {self.keyset_id} = {promise.id}")
# assert (
# self.keyset_id == promise.id
# ), "our keyset id does not match promise id."

# C_ = PublicKey(bytes.fromhex(promise.C_), raw=True)
# C = b_dhke.step3_alice(C_, r, self.public_keys[promise.amount])
# B_, r = b_dhke.step1_alice(secret, r) # recompute B_ for dleq proofs

# proof = Proof(
# id=promise.id,
# amount=promise.amount,
# C=C.serialize().hex(),
# secret=secret,
# derivation_path=path,
# )

# # if the mint returned a dleq proof, we add it to the proof
# if promise.dleq:
# proof.dleq = DLEQWallet(
# e=promise.dleq.e, s=promise.dleq.s, r=r.serialize()
# )

# proofs.append(proof)

# logger.trace(
# f"Created proof: {proof}, r: {r.serialize()} out of promise {promise}"
# )

# # DLEQ verify
# self.verify_proofs_dleq(proofs)

# logger.trace(f"Constructed {len(proofs)} proofs.")
# return proofs

# @staticmethod
# def _construct_outputs(
# amounts: List[int], secrets: List[str], rs: List[PrivateKey] = []
# ) -> Tuple[List[BlindedMessage], List[PrivateKey]]:
# """Takes a list of amounts and secrets and returns outputs.
# Outputs are blinded messages `outputs` and blinding factors `rs`

# Args:
# amounts (List[int]): list of amounts
# secrets (List[str]): list of secrets
# rs (List[PrivateKey], optional): list of blinding factors. If not given, `rs` are generated in step1_alice. Defaults to [].

# Returns:
# List[BlindedMessage]: list of blinded messages that can be sent to the mint
# List[PrivateKey]: list of blinding factors that can be used to construct proofs after receiving blind signatures from the mint

# Raises:
# AssertionError: if len(amounts) != len(secrets)
# """
# assert len(amounts) == len(
# secrets
# ), f"len(amounts)={len(amounts)} not equal to len(secrets)={len(secrets)}"
# outputs: List[BlindedMessage] = []

# rs_ = [None] * len(amounts) if not rs else rs
# rs_return: List[PrivateKey] = []
# for secret, amount, r in zip(secrets, amounts, rs_):
# B_, r = b_dhke.step1_alice(secret, r or None)
# rs_return.append(r)
# output = BlindedMessage(amount=amount, B_=B_.serialize().hex())
# outputs.append(output)
# logger.trace(f"Constructing output: {output}, r: {r.serialize()}")

# return outputs, rs_return

@staticmethod
def raise_on_error(resp: Response) -> None:
"""Raises an exception if the response from the mint contains an error.
Expand Down Expand Up @@ -1047,7 +923,7 @@ def _construct_outputs(

async def _store_proofs(self, proofs):
try:
async with self.db.connect() as conn:
async with self.db.connect() as conn: # type: ignore
for proof in proofs:
await store_proof(proof, db=self.db, conn=conn)
except Exception as e:
Expand Down

0 comments on commit 35f9a2b

Please sign in to comment.