Skip to content

Commit

Permalink
Refactor conditions and fix HTLC multisig (#643)
Browse files Browse the repository at this point in the history
* refactor conditions and fix htlc multisig

* restore db/write.py

* safer check for P2PK secrets for SIG_ALL

* comment cleanup
  • Loading branch information
callebtc authored Oct 22, 2024
1 parent d12a8d1 commit 09d007e
Show file tree
Hide file tree
Showing 9 changed files with 349 additions and 177 deletions.
24 changes: 8 additions & 16 deletions cashu/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,7 @@ def pending(self) -> bool:

class HTLCWitness(BaseModel):
preimage: Optional[str] = None
signature: Optional[str] = None

@classmethod
def from_witness(cls, witness: str):
return cls(**json.loads(witness))


class P2SHWitness(BaseModel):
"""
Unlocks P2SH spending condition of a Proof
"""

script: str
signature: str
address: Union[str, None] = None
signatures: Optional[List[str]] = None

@classmethod
def from_witness(cls, witness: str):
Expand Down Expand Up @@ -206,10 +192,15 @@ def p2pksigs(self) -> List[str]:
return P2PKWitness.from_witness(self.witness).signatures

@property
def htlcpreimage(self) -> Union[str, None]:
def htlcpreimage(self) -> str | None:
assert self.witness, "Witness is missing for htlc preimage"
return HTLCWitness.from_witness(self.witness).preimage

@property
def htlcsigs(self) -> List[str] | None:
assert self.witness, "Witness is missing for htlc signatures"
return HTLCWitness.from_witness(self.witness).signatures


class Proofs(BaseModel):
# NOTE: not used in Pydantic validation
Expand Down Expand Up @@ -647,6 +638,7 @@ def deserialize(serialized: str) -> Dict[int, PublicKey]:
int(amount): PublicKey(bytes.fromhex(hex_key), raw=True)
for amount, hex_key in dict(json.loads(serialized)).items()
}

return cls(
id=row["id"],
unit=row["unit"],
Expand Down
18 changes: 18 additions & 0 deletions cashu/core/htlc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
from enum import Enum
from typing import Union

from .secret import Secret, SecretKind


class SigFlags(Enum):
# require signatures only on the inputs (default signature flag)
SIG_INPUTS = "SIG_INPUTS"
# require signatures on inputs and outputs
SIG_ALL = "SIG_ALL"


class HTLCSecret(Secret):
@classmethod
def from_secret(cls, secret: Secret):
Expand All @@ -15,3 +23,13 @@ def from_secret(cls, secret: Secret):
def locktime(self) -> Union[None, int]:
locktime = self.tags.get_tag("locktime")
return int(locktime) if locktime else None

@property
def sigflag(self) -> Union[None, SigFlags]:
sigflag = self.tags.get_tag("sigflag")
return SigFlags(sigflag) if sigflag else None

@property
def n_sigs(self) -> Union[None, int]:
n_sigs = self.tags.get_tag("n_sigs")
return int(n_sigs) if n_sigs else None
10 changes: 4 additions & 6 deletions cashu/core/p2pk.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,16 @@ def n_sigs(self) -> Union[None, int]:
return int(n_sigs) if n_sigs else None


def sign_p2pk_sign(message: bytes, private_key: PrivateKey) -> bytes:
# ecdsa version
# signature = private_key.ecdsa_serialize(private_key.ecdsa_sign(message))
def schnorr_sign(message: bytes, private_key: PrivateKey) -> bytes:
signature = private_key.schnorr_sign(
hashlib.sha256(message).digest(), None, raw=True
)
return signature


def verify_p2pk_signature(message: bytes, pubkey: PublicKey, signature: bytes) -> bool:
# ecdsa version
# return pubkey.ecdsa_verify(message, pubkey.ecdsa_deserialize(signature))
def verify_schnorr_signature(
message: bytes, pubkey: PublicKey, signature: bytes
) -> bool:
return pubkey.schnorr_verify(
hashlib.sha256(message).digest(), signature, None, raw=True
)
Expand Down
155 changes: 72 additions & 83 deletions cashu/mint/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from loguru import logger

from ..core.base import BlindedMessage, HTLCWitness, Proof
from ..core.base import BlindedMessage, Proof
from ..core.crypto.secp import PublicKey
from ..core.errors import (
TransactionError,
Expand All @@ -13,7 +13,7 @@
from ..core.p2pk import (
P2PKSecret,
SigFlags,
verify_p2pk_signature,
verify_schnorr_signature,
)
from ..core.secret import Secret, SecretKind

Expand Down Expand Up @@ -50,62 +50,9 @@ def _verify_p2pk_spending_conditions(self, proof: Proof, secret: Secret) -> bool
if not pubkeys:
return True

assert len(set(pubkeys)) == len(pubkeys), "pubkeys must be unique."
logger.trace(f"pubkeys: {pubkeys}")

# verify that signatures are present
if not proof.p2pksigs:
# no signature present although secret indicates one
logger.error(f"no p2pk signatures in proof: {proof.p2pksigs}")
raise TransactionError("no p2pk signatures in proof.")

# we make sure that there are no duplicate signatures
if len(set(proof.p2pksigs)) != len(proof.p2pksigs):
raise TransactionError("p2pk signatures must be unique.")

# we parse the secret as a P2PK commitment
# assert len(proof.secret.split(":")) == 5, "p2pk secret format invalid."

# INPUTS: check signatures proof.p2pksigs against pubkey
# we expect the signature to be on the pubkey (=message) itself
n_sigs_required = p2pk_secret.n_sigs or 1
assert n_sigs_required > 0, "n_sigs must be positive."

# check if enough signatures are present
assert (
len(proof.p2pksigs) >= n_sigs_required
), f"not enough signatures provided: {len(proof.p2pksigs)} < {n_sigs_required}."

n_valid_sigs_per_output = 0
# loop over all signatures in output
for input_sig in proof.p2pksigs:
for pubkey in pubkeys:
logger.trace(f"verifying signature {input_sig} by pubkey {pubkey}.")
logger.trace(f"Message: {p2pk_secret.serialize().encode('utf-8')}")
if verify_p2pk_signature(
message=proof.secret.encode("utf-8"),
pubkey=PublicKey(bytes.fromhex(pubkey), raw=True),
signature=bytes.fromhex(input_sig),
):
n_valid_sigs_per_output += 1
logger.trace(
f"p2pk signature on input is valid: {input_sig} on {pubkey}."
)

# check if we have enough valid signatures
assert n_valid_sigs_per_output, "no valid signature provided for input."
assert n_valid_sigs_per_output >= n_sigs_required, (
f"signature threshold not met. {n_valid_sigs_per_output} <"
f" {n_sigs_required}."
)

logger.trace(
f"{n_valid_sigs_per_output} of {n_sigs_required} valid signatures found."
return self._verify_secret_signatures(
proof, pubkeys, proof.p2pksigs, p2pk_secret.n_sigs
)
logger.trace(proof.p2pksigs)
logger.trace("p2pk signature on inputs is valid.")

return True

def _verify_htlc_spending_conditions(self, proof: Proof, secret: Secret) -> bool:
"""
Expand Down Expand Up @@ -149,18 +96,9 @@ def _verify_htlc_spending_conditions(self, proof: Proof, secret: Secret) -> bool
if htlc_secret.locktime and htlc_secret.locktime < time.time():
refund_pubkeys = htlc_secret.tags.get_tag_all("refund")
if refund_pubkeys:
assert proof.witness, TransactionError("no HTLC refund signature.")
signature = HTLCWitness.from_witness(proof.witness).signature
assert signature, TransactionError("no HTLC refund signature provided")
for pubkey in refund_pubkeys:
if verify_p2pk_signature(
message=proof.secret.encode("utf-8"),
pubkey=PublicKey(bytes.fromhex(pubkey), raw=True),
signature=bytes.fromhex(signature),
):
# a signature matches
return True
raise TransactionError("HTLC refund signatures did not match.")
return self._verify_secret_signatures(
proof, refund_pubkeys, proof.p2pksigs, htlc_secret.n_sigs
)
# no pubkeys given in secret, anyone can spend
return True

Expand All @@ -173,23 +111,74 @@ def _verify_htlc_spending_conditions(self, proof: Proof, secret: Secret) -> bool
).digest() == bytes.fromhex(htlc_secret.data):
raise TransactionError("HTLC preimage does not match.")

# then we check whether a signature is required
# then we check whether signatures are required
hashlock_pubkeys = htlc_secret.tags.get_tag_all("pubkeys")
if hashlock_pubkeys:
assert proof.witness, TransactionError("no HTLC hash lock signature.")
signature = HTLCWitness.from_witness(proof.witness).signature
assert signature, TransactionError("HTLC no hash lock signatures provided.")
for pubkey in hashlock_pubkeys:
if verify_p2pk_signature(
if not hashlock_pubkeys:
# no pubkeys given in secret, anyone can spend
return True

return self._verify_secret_signatures(
proof, hashlock_pubkeys, proof.htlcsigs or [], htlc_secret.n_sigs
)

def _verify_secret_signatures(
self,
proof: Proof,
pubkeys: List[str],
signatures: List[str],
n_sigs_required: int | None = 1,
) -> bool:
assert len(set(pubkeys)) == len(pubkeys), "pubkeys must be unique."
logger.trace(f"pubkeys: {pubkeys}")

# verify that signatures are present
if not signatures:
# no signature present although secret indicates one
logger.error(f"no signatures in proof: {proof}")
raise TransactionError("no signatures in proof.")

# we make sure that there are no duplicate signatures
if len(set(signatures)) != len(signatures):
raise TransactionError("signatures must be unique.")

# INPUTS: check signatures against pubkey
# we expect the signature to be on the pubkey (=message) itself
n_sigs_required = n_sigs_required or 1
assert n_sigs_required > 0, "n_sigs must be positive."

# check if enough signatures are present
assert (
len(signatures) >= n_sigs_required
), f"not enough signatures provided: {len(signatures)} < {n_sigs_required}."

n_valid_sigs_per_output = 0
# loop over all signatures in input
for input_sig in signatures:
for pubkey in pubkeys:
logger.trace(f"verifying signature {input_sig} by pubkey {pubkey}.")
logger.trace(f"Message: {proof.secret}")
if verify_schnorr_signature(
message=proof.secret.encode("utf-8"),
pubkey=PublicKey(bytes.fromhex(pubkey), raw=True),
signature=bytes.fromhex(signature),
signature=bytes.fromhex(input_sig),
):
# a signature matches
return True
# none of the pubkeys had a match
raise TransactionError("HTLC hash lock signatures did not match.")
# no pubkeys were included, anyone can spend
n_valid_sigs_per_output += 1
logger.trace(
f"signature on input is valid: {input_sig} on {pubkey}."
)

# check if we have enough valid signatures
assert n_valid_sigs_per_output, "no valid signature provided for input."
assert n_valid_sigs_per_output >= n_sigs_required, (
f"signature threshold not met. {n_valid_sigs_per_output} <"
f" {n_sigs_required}."
)

logger.trace(
f"{n_valid_sigs_per_output} of {n_sigs_required} valid signatures found."
)
logger.trace("p2pk signature on inputs is valid.")

return True

def _verify_input_spending_conditions(self, proof: Proof) -> bool:
Expand Down Expand Up @@ -304,7 +293,7 @@ def _verify_output_p2pk_spending_conditions(
# loop over all signatures in output
for sig in p2pksigs:
for pubkey in pubkeys:
if verify_p2pk_signature(
if verify_schnorr_signature(
message=bytes.fromhex(output.B_),
pubkey=PublicKey(bytes.fromhex(pubkey), raw=True),
signature=bytes.fromhex(sig),
Expand Down
24 changes: 14 additions & 10 deletions cashu/wallet/htlc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import hashlib
from datetime import datetime, timedelta
from typing import List, Optional
from typing import List

from ..core.base import HTLCWitness, Proof
from ..core.db import Database
Expand All @@ -17,27 +17,31 @@ class WalletHTLC(SupportsDb):
async def create_htlc_lock(
self,
*,
preimage: Optional[str] = None,
preimage_hash: Optional[str] = None,
hashlock_pubkey: Optional[str] = None,
locktime_seconds: Optional[int] = None,
locktime_pubkey: Optional[str] = None,
preimage: str | None = None,
preimage_hash: str | None = None,
hashlock_pubkeys: List[str] | None = None,
hashlock_n_sigs: int | None = None,
locktime_seconds: int | None = None,
locktime_pubkeys: List[str] | None = None,
) -> HTLCSecret:
tags = Tags()
if locktime_seconds:
tags["locktime"] = str(
int((datetime.now() + timedelta(seconds=locktime_seconds)).timestamp())
)
if locktime_pubkey:
tags["refund"] = locktime_pubkey
if locktime_pubkeys:
tags["refund"] = locktime_pubkeys

if not preimage_hash and preimage:
preimage_hash = hashlib.sha256(bytes.fromhex(preimage)).hexdigest()

assert preimage_hash, "preimage_hash or preimage must be provided"

if hashlock_pubkey:
tags["pubkeys"] = hashlock_pubkey
if hashlock_pubkeys:
tags["pubkeys"] = hashlock_pubkeys

if hashlock_n_sigs:
tags["n_sigs"] = str(hashlock_n_sigs)

return HTLCSecret(
kind=SecretKind.HTLC.value,
Expand Down
Loading

0 comments on commit 09d007e

Please sign in to comment.