Skip to content

Commit

Permalink
wasm: delegate derive key to hd-keys-curves
Browse files Browse the repository at this point in the history
  • Loading branch information
cairomassimo committed Feb 14, 2024
1 parent 6d99ccf commit 194d7ba
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 185 deletions.
1 change: 1 addition & 0 deletions ng/wasm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ frost-secp256k1 = { git = "https://github.com/LIT-Protocol/frost.git" }
frost-taproot = { git = "https://github.com/LIT-Protocol/frost.git" }
lit-frost = { git = "https://github.com/LIT-Protocol/lit-frost.git" }
tsify = { version = "0.4.5", default-features = false, features = ["js"] }
hd-keys-curves = { git = "https://github.com/LIT-Protocol/hd-keys-curves.git", branch = "fix-wasm32-compile-error-convert-scalars", version = "0.2.0" }

# TODO(cairomassimo): remove once https://github.com/mikelodder7/bls12_381_plus/pull/4 lands on crates.io
[patch.crates-io]
Expand Down
232 changes: 47 additions & 185 deletions ng/wasm/src/ecdsa.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use elliptic_curve::{
group::{cofactor::CofactorGroup, GroupEncoding},
hash2curve::{ExpandMsgXmd, FromOkm, GroupDigest},
point::AffineCoordinates as _,
point::AffineCoordinates,
scalar::IsHigh as _,
sec1::{EncodedPoint, FromEncodedPoint, ModulusSize, ToEncodedPoint},
subtle::ConditionallySelectable as _,
CurveArithmetic, Field as _, Group as _, PrimeCurve, PrimeField, ScalarPrimitive,
CurveArithmetic, PrimeCurve, PrimeField,
};
use hd_keys_curves::{HDDerivable, HDDeriver};
use js_sys::Uint8Array;
use k256::Secp256k1;
use p256::NistP256;
Expand Down Expand Up @@ -39,12 +39,12 @@ impl HdCtx for NistP256 {
const CTX: &'static [u8] = b"LIT_HD_KEY_ID_P256_XMD:SHA-256_SSWU_RO_NUL_";
}

impl<C: PrimeCurve + CurveArithmetic + GroupDigest> Ecdsa<C>
impl<C: PrimeCurve + CurveArithmetic> Ecdsa<C>
where
C::AffinePoint: GroupEncoding + FromEncodedPoint<C>,
C::Scalar: FromOkm,
C::Scalar: HDDeriver,
C::FieldBytesSize: ModulusSize,
C::ProjectivePoint: CofactorGroup + FromEncodedPoint<C> + ToEncodedPoint<C>,
C::ProjectivePoint: CofactorGroup + HDDerivable + FromEncodedPoint<C> + ToEncodedPoint<C>,
C: HdCtx,
{
pub fn combine(
Expand All @@ -53,37 +53,13 @@ where
) -> JsResult<Uint8Array> {
let signature_shares = signature_shares
.into_iter()
.map(|s| {
let s = from_js::<Vec<u8>>(s)?;
let s =
C::Scalar::from_repr(<C::Scalar as PrimeField>::Repr::from_slice(&s).clone());
let s = Option::from(s);
let s = s.ok_or_else(|| JsError::new("cannot parse signature share"))?;

Ok(s)
})
.map(Self::scalar_from_js)
.collect::<JsResult<Vec<_>>>()?;

let big_r = presignature;
let big_r = from_js::<Vec<u8>>(big_r)?;
let big_r = EncodedPoint::<C>::from_bytes(big_r)?;
let big_r = C::AffinePoint::from_encoded_point(&big_r);
let big_r = Option::<C::AffinePoint>::from(big_r)
.ok_or_else(|| JsError::new("cannot parse input public key"))?;

let r = big_r.x();
let v = u8::conditional_select(&0, &1, big_r.y_is_odd());

let big_r = Self::point_from_js(presignature)?;
let s = Self::sum_scalars(signature_shares)?;
let s = s.to_repr();

let mut signature = Vec::new();
signature.extend_from_slice(&r);
signature.extend_from_slice(&s);
signature.push(v);
let signature = into_js(Bytes::new(signature.as_ref()))?;

Ok(signature)
Self::signature_into_js(big_r, s)
}

fn sum_scalars(values: Vec<C::Scalar>) -> JsResult<C::Scalar> {
Expand All @@ -102,176 +78,62 @@ where
let id = from_js::<Vec<u8>>(id)?;
let public_keys = public_keys
.into_iter()
.map(|k| {
let k = from_js::<Vec<u8>>(k)?;
let k = EncodedPoint::<C>::from_bytes(k)?;
let k = C::ProjectivePoint::from_encoded_point(&k);
let k =
Option::from(k).ok_or_else(|| JsError::new("cannot parse input public key"))?;

Ok(k)
})
.map(Self::point_from_js::<C::ProjectivePoint>)
.collect::<JsResult<Vec<_>>>()?;

let k = Self::derive_key_inner(id, public_keys)?;
let deriver = C::Scalar::create(&id, C::CTX);

let k = deriver.hd_derive_public_key(&public_keys);
let k = k.to_encoded_point(false);
let k = Bytes::new(k.as_bytes().as_ref());
let k = into_js(&k)?;

Ok(k)
}

fn derive_key_inner(
id: Vec<u8>,
public_keys: Vec<C::ProjectivePoint>,
) -> JsResult<C::ProjectivePoint> {
let scalar = C::hash_to_scalar::<ExpandMsgXmd<sha2::Sha256>>(&[&id], &[&C::CTX])?;
let mut powers = vec![C::Scalar::ONE; public_keys.len()];
powers[1] = scalar;
for i in 2..powers.len() {
powers[i] = powers[i - 1] * scalar;
}
let k = Self::sum_of_products_pippenger(&public_keys, &powers);
Ok(k)
}
fn scalar_from_js(s: Uint8Array) -> JsResult<C::Scalar> {
let s = from_js::<Vec<u8>>(s)?;
let s = C::Scalar::from_repr(<C::Scalar as PrimeField>::Repr::from_slice(&s).clone());
let s = Option::from(s);
let s = s.ok_or_else(|| JsError::new("cannot deserialize"))?;

fn sum_of_products_pippenger(
points: &[C::ProjectivePoint],
scalars: &[C::Scalar],
) -> C::ProjectivePoint {
const WINDOW: usize = 4;
const NUM_BUCKETS: usize = 1 << WINDOW;
const EDGE: usize = WINDOW - 1;
const MASK: u64 = (NUM_BUCKETS - 1) as u64;
Ok(s)
}

let scalars = Self::convert_scalars(scalars);
let num_components = std::cmp::min(points.len(), scalars.len());
let mut buckets = [C::ProjectivePoint::identity(); NUM_BUCKETS];
let mut res = C::ProjectivePoint::identity();
let mut num_doubles = 0;
let mut bit_sequence_index = 255usize;
fn point_from_js<T: FromEncodedPoint<C>>(q: Uint8Array) -> JsResult<T> {
let q = from_js::<Vec<u8>>(q)?;
let q = EncodedPoint::<C>::from_bytes(q)?;
let q = T::from_encoded_point(&q);
let q = Option::<T>::from(q);
let q = q.ok_or_else(|| JsError::new("cannot deserialize"))?;

loop {
for _ in 0..num_doubles {
res = res.double();
}
Ok(q)
}

let mut max_bucket = 0;
let word_index = bit_sequence_index >> 6;
let bit_index = bit_sequence_index & 63;
fn signature_into_js(big_r: C::AffinePoint, s: C::Scalar) -> JsResult<Uint8Array> {
let r = big_r.x();
let v = u8::conditional_select(&0, &1, big_r.y_is_odd());
let s = s.to_repr();

if bit_index < EDGE {
// we are on the edge of a word; have to look at the previous word, if it exists
if word_index == 0 {
// there is no word before
let smaller_mask = ((1 << (bit_index + 1)) - 1) as u64;
for i in 0..num_components {
let bucket_index: usize = (scalars[i][word_index] & smaller_mask) as usize;
if bucket_index > 0 {
buckets[bucket_index] += points[i];
if bucket_index > max_bucket {
max_bucket = bucket_index;
}
}
}
} else {
// there is a word before
let high_order_mask = ((1 << (bit_index + 1)) - 1) as u64;
let high_order_shift = EDGE - bit_index;
let low_order_mask = ((1 << high_order_shift) - 1) as u64;
let low_order_shift = 64 - high_order_shift;
let prev_word_index = word_index - 1;
for i in 0..num_components {
let mut bucket_index = ((scalars[i][word_index] & high_order_mask)
<< high_order_shift)
as usize;
bucket_index |= ((scalars[i][prev_word_index] >> low_order_shift)
& low_order_mask) as usize;
if bucket_index > 0 {
buckets[bucket_index] += points[i];
if bucket_index > max_bucket {
max_bucket = bucket_index;
}
}
}
}
} else {
let shift = bit_index - EDGE;
for i in 0..num_components {
let bucket_index: usize = ((scalars[i][word_index] >> shift) & MASK) as usize;
if bucket_index > 0 {
buckets[bucket_index] += points[i];
if bucket_index > max_bucket {
max_bucket = bucket_index;
}
}
}
}
res += &buckets[max_bucket];
for i in (1..max_bucket).rev() {
buckets[i] += buckets[i + 1];
res += buckets[i];
buckets[i + 1] = C::ProjectivePoint::identity();
}
buckets[1] = C::ProjectivePoint::identity();
if bit_sequence_index < WINDOW {
break;
}
bit_sequence_index -= WINDOW;
num_doubles = {
if bit_sequence_index < EDGE {
bit_sequence_index + 1
} else {
WINDOW
}
};
}
res
}
let bytes = Self::concat_rsv(r, s, v);
let signature = into_js(Bytes::new(&bytes))?;

#[cfg(target_pointer_width = "32")]
fn convert_scalars(scalars: &[C::Scalar]) -> Vec<[u64; 4]> {
scalars
.iter()
.map(|s| {
let mut out = [0u64; 4];
let primitive: ScalarPrimitive<C> = (*s).into();
let small_limbs = primitive
.as_limbs()
.iter()
.map(|l| l.0 as u64)
.collect::<Vec<_>>();
let mut i = 0;
let mut j = 0;
while i < small_limbs.len() && j < out.len() {
out[j] = small_limbs[i + 1] << 32 | small_limbs[i];
i += 2;
j += 1;
}
out
})
.collect::<Vec<_>>()
Ok(signature)
}

#[cfg(target_pointer_width = "64")]
fn convert_scalars(scalars: &[C::Scalar]) -> Vec<[u64; 4]> {
scalars
.iter()
.map(|s| {
let mut out = [0u64; 4];
let primitive: ScalarPrimitive<C> = (*s).into();
out.copy_from_slice(
primitive
.as_limbs()
.iter()
.map(|l| l.0 as u64)
.collect::<Vec<_>>()
.as_slice(),
);
out
})
.collect::<Vec<_>>()
fn concat_rsv(
r: <C::AffinePoint as AffineCoordinates>::FieldRepr,
s: <C::AffinePoint as AffineCoordinates>::FieldRepr,
v: u8,
) -> Vec<u8>
where
C: HdCtx,
{
let mut bytes = Vec::new();
bytes.extend_from_slice(&r);
bytes.extend_from_slice(&s);
bytes.push(v);
bytes
}
}

Expand Down

0 comments on commit 194d7ba

Please sign in to comment.