From b06d93c5ff3c5ac29937a5eb9c4a2033e446f12f Mon Sep 17 00:00:00 2001 From: callebtc <93376500+callebtc@users.noreply.github.com> Date: Mon, 26 Feb 2024 00:24:58 +0100 Subject: [PATCH] Wallet: deprecate old hash to curve (#457) * wallet: deprecate old hash to curve * fix order --- cashu/core/crypto/b_dhke.py | 74 ++++++++++---------- cashu/core/settings.py | 2 +- cashu/wallet/wallet.py | 26 ++++--- tests/test_crypto.py | 133 ++++++++++++++++++++---------------- 4 files changed, 130 insertions(+), 105 deletions(-) diff --git a/cashu/core/crypto/b_dhke.py b/cashu/core/crypto/b_dhke.py index 78b3510f..d098c236 100644 --- a/cashu/core/crypto/b_dhke.py +++ b/cashu/core/crypto/b_dhke.py @@ -55,26 +55,10 @@ from secp256k1 import PrivateKey, PublicKey - -def hash_to_curve(message: bytes) -> PublicKey: - """Generates a point from the message hash and checks if the point lies on the curve. - If it does not, iteratively tries to compute a new point from the hash.""" - point = None - msg_to_hash = message - while point is None: - _hash = hashlib.sha256(msg_to_hash).digest() - try: - # will error if point does not lie on curve - point = PublicKey(b"\x02" + _hash, raw=True) - except Exception: - msg_to_hash = _hash - return point - - DOMAIN_SEPARATOR = b"Secp256k1_HashToCurve_Cashu_" -def hash_to_curve_domain_separated(message: bytes) -> PublicKey: +def hash_to_curve(message: bytes) -> PublicKey: """Generates a secp256k1 point from a message. The point is generated by hashing the message with a domain separator and then @@ -110,15 +94,6 @@ def step1_alice( return B_, r -def step1_alice_domain_separated( - secret_msg: str, blinding_factor: Optional[PrivateKey] = None -) -> tuple[PublicKey, PrivateKey]: - Y: PublicKey = hash_to_curve_domain_separated(secret_msg.encode("utf-8")) - r = blinding_factor or PrivateKey() - B_: PublicKey = Y + r.pubkey # type: ignore - return B_, r - - def step2_bob(B_: PublicKey, a: PrivateKey) -> Tuple[PublicKey, PrivateKey, PrivateKey]: C_: PublicKey = B_.mult(a) # type: ignore # produce dleq proof @@ -136,17 +111,11 @@ def verify(a: PrivateKey, C: PublicKey, secret_msg: str) -> bool: valid = C == Y.mult(a) # type: ignore # BEGIN: BACKWARDS COMPATIBILITY < 0.15.1 if not valid: - valid = verify_domain_separated(a, C, secret_msg) + valid = verify_deprecated(a, C, secret_msg) # END: BACKWARDS COMPATIBILITY < 0.15.1 return valid -def verify_domain_separated(a: PrivateKey, C: PublicKey, secret_msg: str) -> bool: - Y: PublicKey = hash_to_curve_domain_separated(secret_msg.encode("utf-8")) - valid = C == Y.mult(a) # type: ignore - return valid - - def hash_e(*publickeys: PublicKey) -> bytes: e_ = "" for p in publickeys: @@ -202,12 +171,45 @@ def carol_verify_dleq( valid = alice_verify_dleq(B_, C_, e, s, A) # BEGIN: BACKWARDS COMPATIBILITY < 0.15.1 if not valid: - return carol_verify_dleq_domain_separated(secret_msg, r, C, e, s, A) + return carol_verify_dleq_deprecated(secret_msg, r, C, e, s, A) # END: BACKWARDS COMPATIBILITY < 0.15.1 return valid -def carol_verify_dleq_domain_separated( +# -------- Deprecated hash_to_curve before 0.15.0 -------- + + +def hash_to_curve_deprecated(message: bytes) -> PublicKey: + """Generates a point from the message hash and checks if the point lies on the curve. + If it does not, iteratively tries to compute a new point from the hash.""" + point = None + msg_to_hash = message + while point is None: + _hash = hashlib.sha256(msg_to_hash).digest() + try: + # will error if point does not lie on curve + point = PublicKey(b"\x02" + _hash, raw=True) + except Exception: + msg_to_hash = _hash + return point + + +def step1_alice_deprecated( + secret_msg: str, blinding_factor: Optional[PrivateKey] = None +) -> tuple[PublicKey, PrivateKey]: + Y: PublicKey = hash_to_curve_deprecated(secret_msg.encode("utf-8")) + r = blinding_factor or PrivateKey() + B_: PublicKey = Y + r.pubkey # type: ignore + return B_, r + + +def verify_deprecated(a: PrivateKey, C: PublicKey, secret_msg: str) -> bool: + Y: PublicKey = hash_to_curve_deprecated(secret_msg.encode("utf-8")) + valid = C == Y.mult(a) # type: ignore + return valid + + +def carol_verify_dleq_deprecated( secret_msg: str, r: PrivateKey, C: PublicKey, @@ -215,7 +217,7 @@ def carol_verify_dleq_domain_separated( s: PrivateKey, A: PublicKey, ) -> bool: - Y: PublicKey = hash_to_curve_domain_separated(secret_msg.encode("utf-8")) + Y: PublicKey = hash_to_curve_deprecated(secret_msg.encode("utf-8")) C_: PublicKey = C + A.mult(r) # type: ignore B_: PublicKey = Y + r.pubkey # type: ignore valid = alice_verify_dleq(B_, C_, e, s, A) diff --git a/cashu/core/settings.py b/cashu/core/settings.py index 77bf6cb6..d66d6ff7 100644 --- a/cashu/core/settings.py +++ b/cashu/core/settings.py @@ -123,7 +123,7 @@ class WalletSettings(CashuSettings): mint_port: int = Field(default=3338) wallet_name: str = Field(default="wallet") wallet_unit: str = Field(default="sat") - wallet_domain_separation: bool = Field(default=False) + wallet_use_deprecated_h2c: bool = Field(default=False) api_port: int = Field(default=4448) api_host: str = Field(default="127.0.0.1") diff --git a/cashu/wallet/wallet.py b/cashu/wallet/wallet.py index aef3dc13..9dd07b82 100644 --- a/cashu/wallet/wallet.py +++ b/cashu/wallet/wallet.py @@ -998,9 +998,11 @@ async def pay_lightning( # NUT-08, the mint will imprint these outputs with a value depending on the # amount of fees we overpaid. n_change_outputs = calculate_number_of_blank_outputs(fee_reserve_sat) - change_secrets, change_rs, change_derivation_paths = ( - await self.generate_n_secrets(n_change_outputs) - ) + ( + change_secrets, + change_rs, + change_derivation_paths, + ) = await self.generate_n_secrets(n_change_outputs) change_outputs, change_rs = self._construct_outputs( n_change_outputs * [1], change_secrets, change_rs ) @@ -1126,14 +1128,15 @@ async def _construct_proofs( C = b_dhke.step3_alice( C_, r, self.keysets[promise.id].public_keys[promise.amount] ) - # BEGIN: BACKWARDS COMPATIBILITY < 0.15.1 - if not settings.wallet_domain_separation: + + if not settings.wallet_use_deprecated_h2c: B_, r = b_dhke.step1_alice(secret, r) # recompute B_ for dleq proofs - # END: BACKWARDS COMPATIBILITY < 0.15.1 + # BEGIN: BACKWARDS COMPATIBILITY < 0.15.1 else: - B_, r = b_dhke.step1_alice_domain_separated( + B_, r = b_dhke.step1_alice_deprecated( secret, r ) # recompute B_ for dleq proofs + # END: BACKWARDS COMPATIBILITY < 0.15.1 proof = Proof( id=promise.id, @@ -1196,12 +1199,13 @@ def _construct_outputs( rs_ = [None] * len(amounts) if not rs else rs rs_return: List[PrivateKey] = [] for secret, amount, r in zip(secrets, amounts, rs_): - # BEGIN: BACKWARDS COMPATIBILITY < 0.15.1 - if not settings.wallet_domain_separation: + if not settings.wallet_use_deprecated_h2c: B_, r = b_dhke.step1_alice(secret, r or None) - # END: BACKWARDS COMPATIBILITY < 0.15.1 + # BEGIN: BACKWARDS COMPATIBILITY < 0.15.1 else: - B_, r = b_dhke.step1_alice_domain_separated(secret, r or None) + B_, r = b_dhke.step1_alice_deprecated(secret, r or None) + # END: BACKWARDS COMPATIBILITY < 0.15.1 + rs_return.append(r) output = BlindedMessage( amount=amount, B_=B_.serialize().hex(), id=self.keyset_id diff --git a/tests/test_crypto.py b/tests/test_crypto.py index 279145a6..c031521f 100644 --- a/tests/test_crypto.py +++ b/tests/test_crypto.py @@ -1,12 +1,11 @@ from cashu.core.crypto.b_dhke import ( alice_verify_dleq, carol_verify_dleq, - carol_verify_dleq_domain_separated, hash_e, hash_to_curve, - hash_to_curve_domain_separated, + hash_to_curve_deprecated, step1_alice, - step1_alice_domain_separated, + step1_alice_deprecated, step2_bob, step2_bob_dleq, step3_alice, @@ -22,30 +21,19 @@ def test_hash_to_curve(): ) assert ( result.serialize().hex() - == "0266687aadf862bd776c8fc18b8e9f8e20089714856ee233b3902a591d0d5f2925" - ) - - result = hash_to_curve( - bytes.fromhex( - "0000000000000000000000000000000000000000000000000000000000000001" - ) - ) - assert ( - result.serialize().hex() - == "02ec4916dd28fc4c10d78e287ca5d9cc51ee1ae73cbfde08c6b37324cbfaac8bc5" + == "024cce997d3b518f739663b757deaec95bcd9473c30a14ac2fd04023a739d1a725" ) def test_hash_to_curve_iteration(): - """This input causes multiple rounds of the hash_to_curve algorithm.""" result = hash_to_curve( bytes.fromhex( - "0000000000000000000000000000000000000000000000000000000000000002" + "0000000000000000000000000000000000000000000000000000000000000001" ) ) assert ( result.serialize().hex() - == "02076c988b353fcbb748178ecb286bc9d0b4acf474d4ba31ba62334e46c97c416a" + == "022e7158e11c9506f1aa4248bf531298daa7febd6194f003edcd9b93ade6253acf" ) @@ -62,7 +50,7 @@ def test_step1(): assert ( B_.serialize().hex() - == "02a9acc1e48c25eeeb9289b5031cc57da9fe72f3fe2861d264bdc074209b107ba2" + == "025cc16fe33b953e2ace39653efb3e7a7049711ae1d8a2f7a9108753f1cdea742b" ) assert blinding_factor.private_key == bytes.fromhex( "0000000000000000000000000000000000000000000000000000000000000001" @@ -88,7 +76,7 @@ def test_step2(): C_, e, s = step2_bob(B_, a) assert ( C_.serialize().hex() - == "02a9acc1e48c25eeeb9289b5031cc57da9fe72f3fe2861d264bdc074209b107ba2" + == "025cc16fe33b953e2ace39653efb3e7a7049711ae1d8a2f7a9108753f1cdea742b" ) @@ -97,7 +85,7 @@ def test_step3(): # C_ from test_step2 C_ = PublicKey( bytes.fromhex( - "02a9acc1e48c25eeeb9289b5031cc57da9fe72f3fe2861d264bdc074209b107ba2" + "025cc16fe33b953e2ace39653efb3e7a7049711ae1d8a2f7a9108753f1cdea742b" ), raw=True, ) @@ -118,7 +106,7 @@ def test_step3(): assert ( C.serialize().hex() - == "03c724d7e6a5443b39ac8acf11f40420adc4f99a02e7cc1b57703d9391f6d129cd" + == "0271bf0d702dbad86cbe0af3ab2bfba70a0338f22728e412d88a830ed0580b9de4" ) @@ -176,11 +164,11 @@ def test_dleq_step2_bob_dleq(): e, s = step2_bob_dleq(B_, a, p_bytes) assert ( e.serialize() - == "9818e061ee51d5c8edc3342369a554998ff7b4381c8652d724cdf46429be73d9" + == "a608ae30a54c6d878c706240ee35d4289b68cfe99454bbfa6578b503bce2dbe1" ) assert ( s.serialize() - == "9818e061ee51d5c8edc3342369a554998ff7b4381c8652d724cdf46429be73da" + == "a608ae30a54c6d878c706240ee35d4289b68cfe99454bbfa6578b503bce2dbe2" ) # differs from e only in least significant byte because `a = 0x1` # change `a` @@ -193,11 +181,11 @@ def test_dleq_step2_bob_dleq(): e, s = step2_bob_dleq(B_, a, p_bytes) assert ( e.serialize() - == "df1984d5c22f7e17afe33b8669f02f530f286ae3b00a1978edaf900f4721f65e" + == "076cbdda4f368053c33056c438df014d1875eb3c8b28120bece74b6d0e6381bb" ) assert ( s.serialize() - == "828404170c86f240c50ae0f5fc17bb6b82612d46b355e046d7cd84b0a3c934a0" + == "b6d41ac1e12415862bf8cace95e5355e9262eab8a11d201dadd3b6e41584ea6e" ) @@ -306,36 +294,47 @@ def test_dleq_carol_verify_from_bob(): assert carol_verify_dleq(secret_msg=secret_msg, C=C, r=r, e=e, s=s, A=A) -# TESTS FOR DOMAIN SEPARATED HASH TO CURVE +# TESTS FOR DEPRECATED HASH TO CURVE -def test_hash_to_curve_domain_separated(): - result = hash_to_curve_domain_separated( +def test_hash_to_curve_deprecated(): + result = hash_to_curve_deprecated( bytes.fromhex( "0000000000000000000000000000000000000000000000000000000000000000" ) ) assert ( result.serialize().hex() - == "024cce997d3b518f739663b757deaec95bcd9473c30a14ac2fd04023a739d1a725" + == "0266687aadf862bd776c8fc18b8e9f8e20089714856ee233b3902a591d0d5f2925" ) - -def test_hash_to_curve_domain_separated_iterative(): - result = hash_to_curve_domain_separated( + result = hash_to_curve_deprecated( bytes.fromhex( "0000000000000000000000000000000000000000000000000000000000000001" ) ) assert ( result.serialize().hex() - == "022e7158e11c9506f1aa4248bf531298daa7febd6194f003edcd9b93ade6253acf" + == "02ec4916dd28fc4c10d78e287ca5d9cc51ee1ae73cbfde08c6b37324cbfaac8bc5" + ) + + +def test_hash_to_curve_iteration_deprecated(): + """This input causes multiple rounds of the hash_to_curve algorithm.""" + result = hash_to_curve_deprecated( + bytes.fromhex( + "0000000000000000000000000000000000000000000000000000000000000002" + ) + ) + assert ( + result.serialize().hex() + == "02076c988b353fcbb748178ecb286bc9d0b4acf474d4ba31ba62334e46c97c416a" ) -def test_step1_domain_separated(): +def test_step1_deprecated(): secret_msg = "test_message" - B_, blinding_factor = step1_alice_domain_separated( + B_, blinding_factor = step1_alice_deprecated( secret_msg, blinding_factor=PrivateKey( privkey=bytes.fromhex( @@ -346,15 +345,15 @@ def test_step1_domain_separated(): assert ( B_.serialize().hex() - == "025cc16fe33b953e2ace39653efb3e7a7049711ae1d8a2f7a9108753f1cdea742b" + == "02a9acc1e48c25eeeb9289b5031cc57da9fe72f3fe2861d264bdc074209b107ba2" ) assert blinding_factor.private_key == bytes.fromhex( "0000000000000000000000000000000000000000000000000000000000000001" ) -def test_step2_domain_separated(): - B_, _ = step1_alice_domain_separated( +def test_step2_deprecated(): + B_, _ = step1_alice_deprecated( "test_message", blinding_factor=PrivateKey( privkey=bytes.fromhex( @@ -372,16 +371,16 @@ def test_step2_domain_separated(): C_, e, s = step2_bob(B_, a) assert ( C_.serialize().hex() - == "025cc16fe33b953e2ace39653efb3e7a7049711ae1d8a2f7a9108753f1cdea742b" + == "02a9acc1e48c25eeeb9289b5031cc57da9fe72f3fe2861d264bdc074209b107ba2" ) -def test_step3_domain_separated(): +def test_step3_deprecated(): # C = C_ - A.mult(r) - # C_ from test_step2 + # C_ from test_step2_deprecated C_ = PublicKey( bytes.fromhex( - "025cc16fe33b953e2ace39653efb3e7a7049711ae1d8a2f7a9108753f1cdea742b" + "02a9acc1e48c25eeeb9289b5031cc57da9fe72f3fe2861d264bdc074209b107ba2" ), raw=True, ) @@ -402,32 +401,52 @@ def test_step3_domain_separated(): assert ( C.serialize().hex() - == "0271bf0d702dbad86cbe0af3ab2bfba70a0338f22728e412d88a830ed0580b9de4" + == "03c724d7e6a5443b39ac8acf11f40420adc4f99a02e7cc1b57703d9391f6d129cd" ) -def test_dleq_carol_verify_from_bob_domain_separated(): +def test_dleq_step2_bob_dleq_deprecated(): + B_, _ = step1_alice_deprecated( + "test_message", + blinding_factor=PrivateKey( + privkey=bytes.fromhex( + "0000000000000000000000000000000000000000000000000000000000000001" + ), + raw=True, + ), + ) a = PrivateKey( privkey=bytes.fromhex( "0000000000000000000000000000000000000000000000000000000000000001" ), raw=True, ) - A = a.pubkey - assert A - secret_msg = "test_message" - r = PrivateKey( + p_bytes = bytes.fromhex( + "0000000000000000000000000000000000000000000000000000000000000001" + ) # 32 bytes + e, s = step2_bob_dleq(B_, a, p_bytes) + assert ( + e.serialize() + == "9818e061ee51d5c8edc3342369a554998ff7b4381c8652d724cdf46429be73d9" + ) + assert ( + s.serialize() + == "9818e061ee51d5c8edc3342369a554998ff7b4381c8652d724cdf46429be73da" + ) # differs from e only in least significant byte because `a = 0x1` + + # change `a` + a = PrivateKey( privkey=bytes.fromhex( - "0000000000000000000000000000000000000000000000000000000000000001" + "0000000000000000000000000000000000000000000000000000000000001111" ), raw=True, ) - B_, _ = step1_alice_domain_separated(secret_msg, r) - C_, e, s = step2_bob(B_, a) - assert alice_verify_dleq(B_, C_, e, s, A) - C = step3_alice(C_, r, A) - - # carol does not know B_ and C_, but she receives C and r from Alice - assert carol_verify_dleq_domain_separated( - secret_msg=secret_msg, C=C, r=r, e=e, s=s, A=A + e, s = step2_bob_dleq(B_, a, p_bytes) + assert ( + e.serialize() + == "df1984d5c22f7e17afe33b8669f02f530f286ae3b00a1978edaf900f4721f65e" + ) + assert ( + s.serialize() + == "828404170c86f240c50ae0f5fc17bb6b82612d46b355e046d7cd84b0a3c934a0" )