Skip to content

Commit

Permalink
Wallet: deprecate old h2c (#459)
Browse files Browse the repository at this point in the history
* wallet: deprecate old hash to curve

* fix order

* added migration: untested

* recompute Y always
  • Loading branch information
callebtc authored Feb 26, 2024
1 parent 53cd8ff commit 29be002
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 149 deletions.
22 changes: 10 additions & 12 deletions cashu/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,16 @@ class Proof(BaseModel):
time_created: Union[None, str] = ""
time_reserved: Union[None, str] = ""
derivation_path: Union[None, str] = "" # derivation path of the proof
mint_id: Union[None, str] = (
None # holds the id of the mint operation that created this proof
)
melt_id: Union[None, str] = (
None # holds the id of the melt operation that destroyed this proof
)
mint_id: Union[
None, str
] = None # holds the id of the mint operation that created this proof
melt_id: Union[
None, str
] = None # holds the id of the melt operation that destroyed this proof

def __init__(self, **data):
super().__init__(**data)
if not self.Y:
self.Y = hash_to_curve(self.secret.encode("utf-8")).serialize().hex()
self.Y = hash_to_curve(self.secret.encode("utf-8")).serialize().hex()

@classmethod
def from_dict(cls, proof_dict: dict):
Expand Down Expand Up @@ -274,7 +273,6 @@ class MintQuote(BaseModel):

@classmethod
def from_row(cls, row: Row):

try:
# SQLITE: row is timestamp (string)
created_time = int(row["created_time"]) if row["created_time"] else None
Expand Down Expand Up @@ -664,9 +662,9 @@ def __init__(
self.id = id

def serialize(self):
return json.dumps({
amount: key.serialize().hex() for amount, key in self.public_keys.items()
})
return json.dumps(
{amount: key.serialize().hex() for amount, key in self.public_keys.items()}
)

@classmethod
def from_row(cls, row: Row):
Expand Down
74 changes: 38 additions & 36 deletions cashu/core/crypto/b_dhke.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,26 +55,10 @@

from secp256k1 import PrivateKey, PublicKey


def hash_to_curve(message: bytes) -> PublicKey:
"""Generates a point from the message hash and checks if the point lies on the curve.
If it does not, iteratively tries to compute a new point from the hash."""
point = None
msg_to_hash = message
while point is None:
_hash = hashlib.sha256(msg_to_hash).digest()
try:
# will error if point does not lie on curve
point = PublicKey(b"\x02" + _hash, raw=True)
except Exception:
msg_to_hash = _hash
return point


DOMAIN_SEPARATOR = b"Secp256k1_HashToCurve_Cashu_"


def hash_to_curve_domain_separated(message: bytes) -> PublicKey:
def hash_to_curve(message: bytes) -> PublicKey:
"""Generates a secp256k1 point from a message.
The point is generated by hashing the message with a domain separator and then
Expand Down Expand Up @@ -110,15 +94,6 @@ def step1_alice(
return B_, r


def step1_alice_domain_separated(
secret_msg: str, blinding_factor: Optional[PrivateKey] = None
) -> tuple[PublicKey, PrivateKey]:
Y: PublicKey = hash_to_curve_domain_separated(secret_msg.encode("utf-8"))
r = blinding_factor or PrivateKey()
B_: PublicKey = Y + r.pubkey # type: ignore
return B_, r


def step2_bob(B_: PublicKey, a: PrivateKey) -> Tuple[PublicKey, PrivateKey, PrivateKey]:
C_: PublicKey = B_.mult(a) # type: ignore
# produce dleq proof
Expand All @@ -136,17 +111,11 @@ def verify(a: PrivateKey, C: PublicKey, secret_msg: str) -> bool:
valid = C == Y.mult(a) # type: ignore
# BEGIN: BACKWARDS COMPATIBILITY < 0.15.1
if not valid:
valid = verify_domain_separated(a, C, secret_msg)
valid = verify_deprecated(a, C, secret_msg)
# END: BACKWARDS COMPATIBILITY < 0.15.1
return valid


def verify_domain_separated(a: PrivateKey, C: PublicKey, secret_msg: str) -> bool:
Y: PublicKey = hash_to_curve_domain_separated(secret_msg.encode("utf-8"))
valid = C == Y.mult(a) # type: ignore
return valid


def hash_e(*publickeys: PublicKey) -> bytes:
e_ = ""
for p in publickeys:
Expand Down Expand Up @@ -202,20 +171,53 @@ def carol_verify_dleq(
valid = alice_verify_dleq(B_, C_, e, s, A)
# BEGIN: BACKWARDS COMPATIBILITY < 0.15.1
if not valid:
return carol_verify_dleq_domain_separated(secret_msg, r, C, e, s, A)
return carol_verify_dleq_deprecated(secret_msg, r, C, e, s, A)
# END: BACKWARDS COMPATIBILITY < 0.15.1
return valid


def carol_verify_dleq_domain_separated(
# -------- Deprecated hash_to_curve before 0.15.0 --------


def hash_to_curve_deprecated(message: bytes) -> PublicKey:
"""Generates a point from the message hash and checks if the point lies on the curve.
If it does not, iteratively tries to compute a new point from the hash."""
point = None
msg_to_hash = message
while point is None:
_hash = hashlib.sha256(msg_to_hash).digest()
try:
# will error if point does not lie on curve
point = PublicKey(b"\x02" + _hash, raw=True)
except Exception:
msg_to_hash = _hash
return point


def step1_alice_deprecated(
secret_msg: str, blinding_factor: Optional[PrivateKey] = None
) -> tuple[PublicKey, PrivateKey]:
Y: PublicKey = hash_to_curve_deprecated(secret_msg.encode("utf-8"))
r = blinding_factor or PrivateKey()
B_: PublicKey = Y + r.pubkey # type: ignore
return B_, r


def verify_deprecated(a: PrivateKey, C: PublicKey, secret_msg: str) -> bool:
Y: PublicKey = hash_to_curve_deprecated(secret_msg.encode("utf-8"))
valid = C == Y.mult(a) # type: ignore
return valid


def carol_verify_dleq_deprecated(
secret_msg: str,
r: PrivateKey,
C: PublicKey,
e: PrivateKey,
s: PrivateKey,
A: PublicKey,
) -> bool:
Y: PublicKey = hash_to_curve_domain_separated(secret_msg.encode("utf-8"))
Y: PublicKey = hash_to_curve_deprecated(secret_msg.encode("utf-8"))
C_: PublicKey = C + A.mult(r) # type: ignore
B_: PublicKey = Y + r.pubkey # type: ignore
valid = alice_verify_dleq(B_, C_, e, s, A)
Expand Down
2 changes: 1 addition & 1 deletion cashu/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class WalletSettings(CashuSettings):
mint_port: int = Field(default=3338)
wallet_name: str = Field(default="wallet")
wallet_unit: str = Field(default="sat")
wallet_domain_separation: bool = Field(default=False)
wallet_use_deprecated_h2c: bool = Field(default=False)
api_port: int = Field(default=4448)
api_host: str = Field(default="127.0.0.1")

Expand Down
Loading

0 comments on commit 29be002

Please sign in to comment.