From 67188c578d0065502c4a5a9855d2342a265feb8b Mon Sep 17 00:00:00 2001 From: callebtc <93376500+callebtc@users.noreply.github.com> Date: Mon, 26 Feb 2024 00:30:32 +0100 Subject: [PATCH] Revert "Wallet: deprecate old hash to curve (#457)" This reverts commit b06d93c5ff3c5ac29937a5eb9c4a2033e446f12f. --- cashu/core/crypto/b_dhke.py | 74 ++++++++++---------- cashu/core/settings.py | 2 +- cashu/wallet/wallet.py | 26 +++---- tests/test_crypto.py | 133 ++++++++++++++++-------------------- 4 files changed, 105 insertions(+), 130 deletions(-) diff --git a/cashu/core/crypto/b_dhke.py b/cashu/core/crypto/b_dhke.py index d098c236..78b3510f 100644 --- a/cashu/core/crypto/b_dhke.py +++ b/cashu/core/crypto/b_dhke.py @@ -55,10 +55,26 @@ from secp256k1 import PrivateKey, PublicKey + +def hash_to_curve(message: bytes) -> PublicKey: + """Generates a point from the message hash and checks if the point lies on the curve. + If it does not, iteratively tries to compute a new point from the hash.""" + point = None + msg_to_hash = message + while point is None: + _hash = hashlib.sha256(msg_to_hash).digest() + try: + # will error if point does not lie on curve + point = PublicKey(b"\x02" + _hash, raw=True) + except Exception: + msg_to_hash = _hash + return point + + DOMAIN_SEPARATOR = b"Secp256k1_HashToCurve_Cashu_" -def hash_to_curve(message: bytes) -> PublicKey: +def hash_to_curve_domain_separated(message: bytes) -> PublicKey: """Generates a secp256k1 point from a message. The point is generated by hashing the message with a domain separator and then @@ -94,6 +110,15 @@ def step1_alice( return B_, r +def step1_alice_domain_separated( + secret_msg: str, blinding_factor: Optional[PrivateKey] = None +) -> tuple[PublicKey, PrivateKey]: + Y: PublicKey = hash_to_curve_domain_separated(secret_msg.encode("utf-8")) + r = blinding_factor or PrivateKey() + B_: PublicKey = Y + r.pubkey # type: ignore + return B_, r + + def step2_bob(B_: PublicKey, a: PrivateKey) -> Tuple[PublicKey, PrivateKey, PrivateKey]: C_: PublicKey = B_.mult(a) # type: ignore # produce dleq proof @@ -111,11 +136,17 @@ def verify(a: PrivateKey, C: PublicKey, secret_msg: str) -> bool: valid = C == Y.mult(a) # type: ignore # BEGIN: BACKWARDS COMPATIBILITY < 0.15.1 if not valid: - valid = verify_deprecated(a, C, secret_msg) + valid = verify_domain_separated(a, C, secret_msg) # END: BACKWARDS COMPATIBILITY < 0.15.1 return valid +def verify_domain_separated(a: PrivateKey, C: PublicKey, secret_msg: str) -> bool: + Y: PublicKey = hash_to_curve_domain_separated(secret_msg.encode("utf-8")) + valid = C == Y.mult(a) # type: ignore + return valid + + def hash_e(*publickeys: PublicKey) -> bytes: e_ = "" for p in publickeys: @@ -171,45 +202,12 @@ def carol_verify_dleq( valid = alice_verify_dleq(B_, C_, e, s, A) # BEGIN: BACKWARDS COMPATIBILITY < 0.15.1 if not valid: - return carol_verify_dleq_deprecated(secret_msg, r, C, e, s, A) + return carol_verify_dleq_domain_separated(secret_msg, r, C, e, s, A) # END: BACKWARDS COMPATIBILITY < 0.15.1 return valid -# -------- Deprecated hash_to_curve before 0.15.0 -------- - - -def hash_to_curve_deprecated(message: bytes) -> PublicKey: - """Generates a point from the message hash and checks if the point lies on the curve. - If it does not, iteratively tries to compute a new point from the hash.""" - point = None - msg_to_hash = message - while point is None: - _hash = hashlib.sha256(msg_to_hash).digest() - try: - # will error if point does not lie on curve - point = PublicKey(b"\x02" + _hash, raw=True) - except Exception: - msg_to_hash = _hash - return point - - -def step1_alice_deprecated( - secret_msg: str, blinding_factor: Optional[PrivateKey] = None -) -> tuple[PublicKey, PrivateKey]: - Y: PublicKey = hash_to_curve_deprecated(secret_msg.encode("utf-8")) - r = blinding_factor or PrivateKey() - B_: PublicKey = Y + r.pubkey # type: ignore - return B_, r - - -def verify_deprecated(a: PrivateKey, C: PublicKey, secret_msg: str) -> bool: - Y: PublicKey = hash_to_curve_deprecated(secret_msg.encode("utf-8")) - valid = C == Y.mult(a) # type: ignore - return valid - - -def carol_verify_dleq_deprecated( +def carol_verify_dleq_domain_separated( secret_msg: str, r: PrivateKey, C: PublicKey, @@ -217,7 +215,7 @@ def carol_verify_dleq_deprecated( s: PrivateKey, A: PublicKey, ) -> bool: - Y: PublicKey = hash_to_curve_deprecated(secret_msg.encode("utf-8")) + Y: PublicKey = hash_to_curve_domain_separated(secret_msg.encode("utf-8")) C_: PublicKey = C + A.mult(r) # type: ignore B_: PublicKey = Y + r.pubkey # type: ignore valid = alice_verify_dleq(B_, C_, e, s, A) diff --git a/cashu/core/settings.py b/cashu/core/settings.py index d66d6ff7..77bf6cb6 100644 --- a/cashu/core/settings.py +++ b/cashu/core/settings.py @@ -123,7 +123,7 @@ class WalletSettings(CashuSettings): mint_port: int = Field(default=3338) wallet_name: str = Field(default="wallet") wallet_unit: str = Field(default="sat") - wallet_use_deprecated_h2c: bool = Field(default=False) + wallet_domain_separation: bool = Field(default=False) api_port: int = Field(default=4448) api_host: str = Field(default="127.0.0.1") diff --git a/cashu/wallet/wallet.py b/cashu/wallet/wallet.py index 9dd07b82..aef3dc13 100644 --- a/cashu/wallet/wallet.py +++ b/cashu/wallet/wallet.py @@ -998,11 +998,9 @@ async def pay_lightning( # NUT-08, the mint will imprint these outputs with a value depending on the # amount of fees we overpaid. n_change_outputs = calculate_number_of_blank_outputs(fee_reserve_sat) - ( - change_secrets, - change_rs, - change_derivation_paths, - ) = await self.generate_n_secrets(n_change_outputs) + change_secrets, change_rs, change_derivation_paths = ( + await self.generate_n_secrets(n_change_outputs) + ) change_outputs, change_rs = self._construct_outputs( n_change_outputs * [1], change_secrets, change_rs ) @@ -1128,15 +1126,14 @@ async def _construct_proofs( C = b_dhke.step3_alice( C_, r, self.keysets[promise.id].public_keys[promise.amount] ) - - if not settings.wallet_use_deprecated_h2c: - B_, r = b_dhke.step1_alice(secret, r) # recompute B_ for dleq proofs # BEGIN: BACKWARDS COMPATIBILITY < 0.15.1 + if not settings.wallet_domain_separation: + B_, r = b_dhke.step1_alice(secret, r) # recompute B_ for dleq proofs + # END: BACKWARDS COMPATIBILITY < 0.15.1 else: - B_, r = b_dhke.step1_alice_deprecated( + B_, r = b_dhke.step1_alice_domain_separated( secret, r ) # recompute B_ for dleq proofs - # END: BACKWARDS COMPATIBILITY < 0.15.1 proof = Proof( id=promise.id, @@ -1199,13 +1196,12 @@ def _construct_outputs( rs_ = [None] * len(amounts) if not rs else rs rs_return: List[PrivateKey] = [] for secret, amount, r in zip(secrets, amounts, rs_): - if not settings.wallet_use_deprecated_h2c: - B_, r = b_dhke.step1_alice(secret, r or None) # BEGIN: BACKWARDS COMPATIBILITY < 0.15.1 - else: - B_, r = b_dhke.step1_alice_deprecated(secret, r or None) + if not settings.wallet_domain_separation: + B_, r = b_dhke.step1_alice(secret, r or None) # END: BACKWARDS COMPATIBILITY < 0.15.1 - + else: + B_, r = b_dhke.step1_alice_domain_separated(secret, r or None) rs_return.append(r) output = BlindedMessage( amount=amount, B_=B_.serialize().hex(), id=self.keyset_id diff --git a/tests/test_crypto.py b/tests/test_crypto.py index c031521f..279145a6 100644 --- a/tests/test_crypto.py +++ b/tests/test_crypto.py @@ -1,11 +1,12 @@ from cashu.core.crypto.b_dhke import ( alice_verify_dleq, carol_verify_dleq, + carol_verify_dleq_domain_separated, hash_e, hash_to_curve, - hash_to_curve_deprecated, + hash_to_curve_domain_separated, step1_alice, - step1_alice_deprecated, + step1_alice_domain_separated, step2_bob, step2_bob_dleq, step3_alice, @@ -21,19 +22,30 @@ def test_hash_to_curve(): ) assert ( result.serialize().hex() - == "024cce997d3b518f739663b757deaec95bcd9473c30a14ac2fd04023a739d1a725" + == "0266687aadf862bd776c8fc18b8e9f8e20089714856ee233b3902a591d0d5f2925" + ) + + result = hash_to_curve( + bytes.fromhex( + "0000000000000000000000000000000000000000000000000000000000000001" + ) + ) + assert ( + result.serialize().hex() + == "02ec4916dd28fc4c10d78e287ca5d9cc51ee1ae73cbfde08c6b37324cbfaac8bc5" ) def test_hash_to_curve_iteration(): + """This input causes multiple rounds of the hash_to_curve algorithm.""" result = hash_to_curve( bytes.fromhex( - "0000000000000000000000000000000000000000000000000000000000000001" + "0000000000000000000000000000000000000000000000000000000000000002" ) ) assert ( result.serialize().hex() - == "022e7158e11c9506f1aa4248bf531298daa7febd6194f003edcd9b93ade6253acf" + == "02076c988b353fcbb748178ecb286bc9d0b4acf474d4ba31ba62334e46c97c416a" ) @@ -50,7 +62,7 @@ def test_step1(): assert ( B_.serialize().hex() - == "025cc16fe33b953e2ace39653efb3e7a7049711ae1d8a2f7a9108753f1cdea742b" + == "02a9acc1e48c25eeeb9289b5031cc57da9fe72f3fe2861d264bdc074209b107ba2" ) assert blinding_factor.private_key == bytes.fromhex( "0000000000000000000000000000000000000000000000000000000000000001" @@ -76,7 +88,7 @@ def test_step2(): C_, e, s = step2_bob(B_, a) assert ( C_.serialize().hex() - == "025cc16fe33b953e2ace39653efb3e7a7049711ae1d8a2f7a9108753f1cdea742b" + == "02a9acc1e48c25eeeb9289b5031cc57da9fe72f3fe2861d264bdc074209b107ba2" ) @@ -85,7 +97,7 @@ def test_step3(): # C_ from test_step2 C_ = PublicKey( bytes.fromhex( - "025cc16fe33b953e2ace39653efb3e7a7049711ae1d8a2f7a9108753f1cdea742b" + "02a9acc1e48c25eeeb9289b5031cc57da9fe72f3fe2861d264bdc074209b107ba2" ), raw=True, ) @@ -106,7 +118,7 @@ def test_step3(): assert ( C.serialize().hex() - == "0271bf0d702dbad86cbe0af3ab2bfba70a0338f22728e412d88a830ed0580b9de4" + == "03c724d7e6a5443b39ac8acf11f40420adc4f99a02e7cc1b57703d9391f6d129cd" ) @@ -164,11 +176,11 @@ def test_dleq_step2_bob_dleq(): e, s = step2_bob_dleq(B_, a, p_bytes) assert ( e.serialize() - == "a608ae30a54c6d878c706240ee35d4289b68cfe99454bbfa6578b503bce2dbe1" + == "9818e061ee51d5c8edc3342369a554998ff7b4381c8652d724cdf46429be73d9" ) assert ( s.serialize() - == "a608ae30a54c6d878c706240ee35d4289b68cfe99454bbfa6578b503bce2dbe2" + == "9818e061ee51d5c8edc3342369a554998ff7b4381c8652d724cdf46429be73da" ) # differs from e only in least significant byte because `a = 0x1` # change `a` @@ -181,11 +193,11 @@ def test_dleq_step2_bob_dleq(): e, s = step2_bob_dleq(B_, a, p_bytes) assert ( e.serialize() - == "076cbdda4f368053c33056c438df014d1875eb3c8b28120bece74b6d0e6381bb" + == "df1984d5c22f7e17afe33b8669f02f530f286ae3b00a1978edaf900f4721f65e" ) assert ( s.serialize() - == "b6d41ac1e12415862bf8cace95e5355e9262eab8a11d201dadd3b6e41584ea6e" + == "828404170c86f240c50ae0f5fc17bb6b82612d46b355e046d7cd84b0a3c934a0" ) @@ -294,47 +306,36 @@ def test_dleq_carol_verify_from_bob(): assert carol_verify_dleq(secret_msg=secret_msg, C=C, r=r, e=e, s=s, A=A) -# TESTS FOR DEPRECATED HASH TO CURVE +# TESTS FOR DOMAIN SEPARATED HASH TO CURVE -def test_hash_to_curve_deprecated(): - result = hash_to_curve_deprecated( +def test_hash_to_curve_domain_separated(): + result = hash_to_curve_domain_separated( bytes.fromhex( "0000000000000000000000000000000000000000000000000000000000000000" ) ) assert ( result.serialize().hex() - == "0266687aadf862bd776c8fc18b8e9f8e20089714856ee233b3902a591d0d5f2925" - ) - - result = hash_to_curve_deprecated( - bytes.fromhex( - "0000000000000000000000000000000000000000000000000000000000000001" - ) - ) - assert ( - result.serialize().hex() - == "02ec4916dd28fc4c10d78e287ca5d9cc51ee1ae73cbfde08c6b37324cbfaac8bc5" + == "024cce997d3b518f739663b757deaec95bcd9473c30a14ac2fd04023a739d1a725" ) -def test_hash_to_curve_iteration_deprecated(): - """This input causes multiple rounds of the hash_to_curve algorithm.""" - result = hash_to_curve_deprecated( +def test_hash_to_curve_domain_separated_iterative(): + result = hash_to_curve_domain_separated( bytes.fromhex( - "0000000000000000000000000000000000000000000000000000000000000002" + "0000000000000000000000000000000000000000000000000000000000000001" ) ) assert ( result.serialize().hex() - == "02076c988b353fcbb748178ecb286bc9d0b4acf474d4ba31ba62334e46c97c416a" + == "022e7158e11c9506f1aa4248bf531298daa7febd6194f003edcd9b93ade6253acf" ) -def test_step1_deprecated(): +def test_step1_domain_separated(): secret_msg = "test_message" - B_, blinding_factor = step1_alice_deprecated( + B_, blinding_factor = step1_alice_domain_separated( secret_msg, blinding_factor=PrivateKey( privkey=bytes.fromhex( @@ -345,15 +346,15 @@ def test_step1_deprecated(): assert ( B_.serialize().hex() - == "02a9acc1e48c25eeeb9289b5031cc57da9fe72f3fe2861d264bdc074209b107ba2" + == "025cc16fe33b953e2ace39653efb3e7a7049711ae1d8a2f7a9108753f1cdea742b" ) assert blinding_factor.private_key == bytes.fromhex( "0000000000000000000000000000000000000000000000000000000000000001" ) -def test_step2_deprecated(): - B_, _ = step1_alice_deprecated( +def test_step2_domain_separated(): + B_, _ = step1_alice_domain_separated( "test_message", blinding_factor=PrivateKey( privkey=bytes.fromhex( @@ -371,16 +372,16 @@ def test_step2_deprecated(): C_, e, s = step2_bob(B_, a) assert ( C_.serialize().hex() - == "02a9acc1e48c25eeeb9289b5031cc57da9fe72f3fe2861d264bdc074209b107ba2" + == "025cc16fe33b953e2ace39653efb3e7a7049711ae1d8a2f7a9108753f1cdea742b" ) -def test_step3_deprecated(): +def test_step3_domain_separated(): # C = C_ - A.mult(r) - # C_ from test_step2_deprecated + # C_ from test_step2 C_ = PublicKey( bytes.fromhex( - "02a9acc1e48c25eeeb9289b5031cc57da9fe72f3fe2861d264bdc074209b107ba2" + "025cc16fe33b953e2ace39653efb3e7a7049711ae1d8a2f7a9108753f1cdea742b" ), raw=True, ) @@ -401,52 +402,32 @@ def test_step3_deprecated(): assert ( C.serialize().hex() - == "03c724d7e6a5443b39ac8acf11f40420adc4f99a02e7cc1b57703d9391f6d129cd" + == "0271bf0d702dbad86cbe0af3ab2bfba70a0338f22728e412d88a830ed0580b9de4" ) -def test_dleq_step2_bob_dleq_deprecated(): - B_, _ = step1_alice_deprecated( - "test_message", - blinding_factor=PrivateKey( - privkey=bytes.fromhex( - "0000000000000000000000000000000000000000000000000000000000000001" - ), - raw=True, - ), - ) +def test_dleq_carol_verify_from_bob_domain_separated(): a = PrivateKey( privkey=bytes.fromhex( "0000000000000000000000000000000000000000000000000000000000000001" ), raw=True, ) - p_bytes = bytes.fromhex( - "0000000000000000000000000000000000000000000000000000000000000001" - ) # 32 bytes - e, s = step2_bob_dleq(B_, a, p_bytes) - assert ( - e.serialize() - == "9818e061ee51d5c8edc3342369a554998ff7b4381c8652d724cdf46429be73d9" - ) - assert ( - s.serialize() - == "9818e061ee51d5c8edc3342369a554998ff7b4381c8652d724cdf46429be73da" - ) # differs from e only in least significant byte because `a = 0x1` - - # change `a` - a = PrivateKey( + A = a.pubkey + assert A + secret_msg = "test_message" + r = PrivateKey( privkey=bytes.fromhex( - "0000000000000000000000000000000000000000000000000000000000001111" + "0000000000000000000000000000000000000000000000000000000000000001" ), raw=True, ) - e, s = step2_bob_dleq(B_, a, p_bytes) - assert ( - e.serialize() - == "df1984d5c22f7e17afe33b8669f02f530f286ae3b00a1978edaf900f4721f65e" - ) - assert ( - s.serialize() - == "828404170c86f240c50ae0f5fc17bb6b82612d46b355e046d7cd84b0a3c934a0" + B_, _ = step1_alice_domain_separated(secret_msg, r) + C_, e, s = step2_bob(B_, a) + assert alice_verify_dleq(B_, C_, e, s, A) + C = step3_alice(C_, r, A) + + # carol does not know B_ and C_, but she receives C and r from Alice + assert carol_verify_dleq_domain_separated( + secret_msg=secret_msg, C=C, r=r, e=e, s=s, A=A )