From 194d7ba7be4f6cd84b018a30065dcc4d66e68684 Mon Sep 17 00:00:00 2001 From: Massimo Cairo Date: Wed, 14 Feb 2024 17:46:32 +0200 Subject: [PATCH] wasm: delegate derive key to `hd-keys-curves` --- ng/wasm/Cargo.toml | 1 + ng/wasm/src/ecdsa.rs | 232 +++++++++---------------------------------- 2 files changed, 48 insertions(+), 185 deletions(-) diff --git a/ng/wasm/Cargo.toml b/ng/wasm/Cargo.toml index 9583397a13..aa760d40bc 100644 --- a/ng/wasm/Cargo.toml +++ b/ng/wasm/Cargo.toml @@ -67,6 +67,7 @@ frost-secp256k1 = { git = "https://github.com/LIT-Protocol/frost.git" } frost-taproot = { git = "https://github.com/LIT-Protocol/frost.git" } lit-frost = { git = "https://github.com/LIT-Protocol/lit-frost.git" } tsify = { version = "0.4.5", default-features = false, features = ["js"] } +hd-keys-curves = { git = "https://github.com/LIT-Protocol/hd-keys-curves.git", branch = "fix-wasm32-compile-error-convert-scalars", version = "0.2.0" } # TODO(cairomassimo): remove once https://github.com/mikelodder7/bls12_381_plus/pull/4 lands on crates.io [patch.crates-io] diff --git a/ng/wasm/src/ecdsa.rs b/ng/wasm/src/ecdsa.rs index da31ea37e8..715ea592ce 100644 --- a/ng/wasm/src/ecdsa.rs +++ b/ng/wasm/src/ecdsa.rs @@ -1,12 +1,12 @@ use elliptic_curve::{ group::{cofactor::CofactorGroup, GroupEncoding}, - hash2curve::{ExpandMsgXmd, FromOkm, GroupDigest}, - point::AffineCoordinates as _, + point::AffineCoordinates, scalar::IsHigh as _, sec1::{EncodedPoint, FromEncodedPoint, ModulusSize, ToEncodedPoint}, subtle::ConditionallySelectable as _, - CurveArithmetic, Field as _, Group as _, PrimeCurve, PrimeField, ScalarPrimitive, + CurveArithmetic, PrimeCurve, PrimeField, }; +use hd_keys_curves::{HDDerivable, HDDeriver}; use js_sys::Uint8Array; use k256::Secp256k1; use p256::NistP256; @@ -39,12 +39,12 @@ impl HdCtx for NistP256 { const CTX: &'static [u8] = b"LIT_HD_KEY_ID_P256_XMD:SHA-256_SSWU_RO_NUL_"; } -impl Ecdsa +impl Ecdsa where C::AffinePoint: GroupEncoding + FromEncodedPoint, - C::Scalar: FromOkm, + C::Scalar: HDDeriver, C::FieldBytesSize: ModulusSize, - C::ProjectivePoint: CofactorGroup + FromEncodedPoint + ToEncodedPoint, + C::ProjectivePoint: CofactorGroup + HDDerivable + FromEncodedPoint + ToEncodedPoint, C: HdCtx, { pub fn combine( @@ -53,37 +53,13 @@ where ) -> JsResult { let signature_shares = signature_shares .into_iter() - .map(|s| { - let s = from_js::>(s)?; - let s = - C::Scalar::from_repr(::Repr::from_slice(&s).clone()); - let s = Option::from(s); - let s = s.ok_or_else(|| JsError::new("cannot parse signature share"))?; - - Ok(s) - }) + .map(Self::scalar_from_js) .collect::>>()?; - let big_r = presignature; - let big_r = from_js::>(big_r)?; - let big_r = EncodedPoint::::from_bytes(big_r)?; - let big_r = C::AffinePoint::from_encoded_point(&big_r); - let big_r = Option::::from(big_r) - .ok_or_else(|| JsError::new("cannot parse input public key"))?; - - let r = big_r.x(); - let v = u8::conditional_select(&0, &1, big_r.y_is_odd()); - + let big_r = Self::point_from_js(presignature)?; let s = Self::sum_scalars(signature_shares)?; - let s = s.to_repr(); - - let mut signature = Vec::new(); - signature.extend_from_slice(&r); - signature.extend_from_slice(&s); - signature.push(v); - let signature = into_js(Bytes::new(signature.as_ref()))?; - Ok(signature) + Self::signature_into_js(big_r, s) } fn sum_scalars(values: Vec) -> JsResult { @@ -102,19 +78,12 @@ where let id = from_js::>(id)?; let public_keys = public_keys .into_iter() - .map(|k| { - let k = from_js::>(k)?; - let k = EncodedPoint::::from_bytes(k)?; - let k = C::ProjectivePoint::from_encoded_point(&k); - let k = - Option::from(k).ok_or_else(|| JsError::new("cannot parse input public key"))?; - - Ok(k) - }) + .map(Self::point_from_js::) .collect::>>()?; - let k = Self::derive_key_inner(id, public_keys)?; + let deriver = C::Scalar::create(&id, C::CTX); + let k = deriver.hd_derive_public_key(&public_keys); let k = k.to_encoded_point(false); let k = Bytes::new(k.as_bytes().as_ref()); let k = into_js(&k)?; @@ -122,156 +91,49 @@ where Ok(k) } - fn derive_key_inner( - id: Vec, - public_keys: Vec, - ) -> JsResult { - let scalar = C::hash_to_scalar::>(&[&id], &[&C::CTX])?; - let mut powers = vec![C::Scalar::ONE; public_keys.len()]; - powers[1] = scalar; - for i in 2..powers.len() { - powers[i] = powers[i - 1] * scalar; - } - let k = Self::sum_of_products_pippenger(&public_keys, &powers); - Ok(k) - } + fn scalar_from_js(s: Uint8Array) -> JsResult { + let s = from_js::>(s)?; + let s = C::Scalar::from_repr(::Repr::from_slice(&s).clone()); + let s = Option::from(s); + let s = s.ok_or_else(|| JsError::new("cannot deserialize"))?; - fn sum_of_products_pippenger( - points: &[C::ProjectivePoint], - scalars: &[C::Scalar], - ) -> C::ProjectivePoint { - const WINDOW: usize = 4; - const NUM_BUCKETS: usize = 1 << WINDOW; - const EDGE: usize = WINDOW - 1; - const MASK: u64 = (NUM_BUCKETS - 1) as u64; + Ok(s) + } - let scalars = Self::convert_scalars(scalars); - let num_components = std::cmp::min(points.len(), scalars.len()); - let mut buckets = [C::ProjectivePoint::identity(); NUM_BUCKETS]; - let mut res = C::ProjectivePoint::identity(); - let mut num_doubles = 0; - let mut bit_sequence_index = 255usize; + fn point_from_js>(q: Uint8Array) -> JsResult { + let q = from_js::>(q)?; + let q = EncodedPoint::::from_bytes(q)?; + let q = T::from_encoded_point(&q); + let q = Option::::from(q); + let q = q.ok_or_else(|| JsError::new("cannot deserialize"))?; - loop { - for _ in 0..num_doubles { - res = res.double(); - } + Ok(q) + } - let mut max_bucket = 0; - let word_index = bit_sequence_index >> 6; - let bit_index = bit_sequence_index & 63; + fn signature_into_js(big_r: C::AffinePoint, s: C::Scalar) -> JsResult { + let r = big_r.x(); + let v = u8::conditional_select(&0, &1, big_r.y_is_odd()); + let s = s.to_repr(); - if bit_index < EDGE { - // we are on the edge of a word; have to look at the previous word, if it exists - if word_index == 0 { - // there is no word before - let smaller_mask = ((1 << (bit_index + 1)) - 1) as u64; - for i in 0..num_components { - let bucket_index: usize = (scalars[i][word_index] & smaller_mask) as usize; - if bucket_index > 0 { - buckets[bucket_index] += points[i]; - if bucket_index > max_bucket { - max_bucket = bucket_index; - } - } - } - } else { - // there is a word before - let high_order_mask = ((1 << (bit_index + 1)) - 1) as u64; - let high_order_shift = EDGE - bit_index; - let low_order_mask = ((1 << high_order_shift) - 1) as u64; - let low_order_shift = 64 - high_order_shift; - let prev_word_index = word_index - 1; - for i in 0..num_components { - let mut bucket_index = ((scalars[i][word_index] & high_order_mask) - << high_order_shift) - as usize; - bucket_index |= ((scalars[i][prev_word_index] >> low_order_shift) - & low_order_mask) as usize; - if bucket_index > 0 { - buckets[bucket_index] += points[i]; - if bucket_index > max_bucket { - max_bucket = bucket_index; - } - } - } - } - } else { - let shift = bit_index - EDGE; - for i in 0..num_components { - let bucket_index: usize = ((scalars[i][word_index] >> shift) & MASK) as usize; - if bucket_index > 0 { - buckets[bucket_index] += points[i]; - if bucket_index > max_bucket { - max_bucket = bucket_index; - } - } - } - } - res += &buckets[max_bucket]; - for i in (1..max_bucket).rev() { - buckets[i] += buckets[i + 1]; - res += buckets[i]; - buckets[i + 1] = C::ProjectivePoint::identity(); - } - buckets[1] = C::ProjectivePoint::identity(); - if bit_sequence_index < WINDOW { - break; - } - bit_sequence_index -= WINDOW; - num_doubles = { - if bit_sequence_index < EDGE { - bit_sequence_index + 1 - } else { - WINDOW - } - }; - } - res - } + let bytes = Self::concat_rsv(r, s, v); + let signature = into_js(Bytes::new(&bytes))?; - #[cfg(target_pointer_width = "32")] - fn convert_scalars(scalars: &[C::Scalar]) -> Vec<[u64; 4]> { - scalars - .iter() - .map(|s| { - let mut out = [0u64; 4]; - let primitive: ScalarPrimitive = (*s).into(); - let small_limbs = primitive - .as_limbs() - .iter() - .map(|l| l.0 as u64) - .collect::>(); - let mut i = 0; - let mut j = 0; - while i < small_limbs.len() && j < out.len() { - out[j] = small_limbs[i + 1] << 32 | small_limbs[i]; - i += 2; - j += 1; - } - out - }) - .collect::>() + Ok(signature) } - #[cfg(target_pointer_width = "64")] - fn convert_scalars(scalars: &[C::Scalar]) -> Vec<[u64; 4]> { - scalars - .iter() - .map(|s| { - let mut out = [0u64; 4]; - let primitive: ScalarPrimitive = (*s).into(); - out.copy_from_slice( - primitive - .as_limbs() - .iter() - .map(|l| l.0 as u64) - .collect::>() - .as_slice(), - ); - out - }) - .collect::>() + fn concat_rsv( + r: ::FieldRepr, + s: ::FieldRepr, + v: u8, + ) -> Vec + where + C: HdCtx, + { + let mut bytes = Vec::new(); + bytes.extend_from_slice(&r); + bytes.extend_from_slice(&s); + bytes.push(v); + bytes } }