From a66be009b0597b0487d5f8bfc8902c0a4a57b7bc Mon Sep 17 00:00:00 2001 From: Nikita Masych Date: Wed, 11 Dec 2024 13:43:55 +0200 Subject: [PATCH] fix: post-review fixes --- .../gadgets/curves/sw_projective/extended.rs | 8 +- .../tower_extension/algebraic_torus.rs | 37 +++---- .../src/gadgets/tower_extension/fq12.rs | 49 +++++----- .../boojum/src/gadgets/tower_extension/fq2.rs | 12 +-- .../boojum/src/gadgets/tower_extension/fq6.rs | 97 ++++++++++--------- crates/boojum/src/gadgets/u256/mod.rs | 6 +- 6 files changed, 99 insertions(+), 110 deletions(-) diff --git a/crates/boojum/src/gadgets/curves/sw_projective/extended.rs b/crates/boojum/src/gadgets/curves/sw_projective/extended.rs index d4eca83..d16ddbb 100644 --- a/crates/boojum/src/gadgets/curves/sw_projective/extended.rs +++ b/crates/boojum/src/gadgets/curves/sw_projective/extended.rs @@ -548,13 +548,9 @@ where let one_nn = NN::allocated_constant(cs, T::one(), ¶ms); let mut safe_z = NN::conditionally_select(cs, is_point_at_infty, &one_nn, &self.z); let mut safe_z_squared = safe_z.square(cs); - safe_z_squared.normalize(cs); let mut safe_z_cubed = safe_z.mul(cs, &mut safe_z_squared); - safe_z_cubed.normalize(cs); - let mut x_for_safe_z = self.x.div_unchecked(cs, &mut safe_z_squared); - x_for_safe_z.normalize(cs); - let mut y_for_safe_z = self.y.div_unchecked(cs, &mut safe_z_cubed); - y_for_safe_z.normalize(cs); + let x_for_safe_z = self.x.div_unchecked(cs, &mut safe_z_squared); + let y_for_safe_z = self.y.div_unchecked(cs, &mut safe_z_cubed); let (default_x, default_y) = default.into_xy_unchecked(); diff --git a/crates/boojum/src/gadgets/tower_extension/algebraic_torus.rs b/crates/boojum/src/gadgets/tower_extension/algebraic_torus.rs index 906000e..e34d4ce 100644 --- a/crates/boojum/src/gadgets/tower_extension/algebraic_torus.rs +++ b/crates/boojum/src/gadgets/tower_extension/algebraic_torus.rs @@ -257,21 +257,11 @@ where let mut sum = self.encoding.clone().add(cs, &mut other.encoding); let lhs = encoding_new.clone().mul(cs, &mut sum); - // rhs = {(g + g') == 0} ? zero : (g * g' + \gamma) + // rhs = g * g' + \gamma let mut gamma = Fq6::gamma(cs, params); let mut rhs = self.encoding.clone().mul(cs, &mut other.encoding); let rhs = rhs.add(cs, &mut gamma); - let zero = Fq6::zero(cs, params); - let is_zero_sum = sum.is_zero(cs); - - let rhs = , P::Ex6>>::conditionally_select( - cs, - is_zero_sum, - &zero, - &rhs, - ); - // Enforce equality Fq6::enforce_equal(cs, &lhs, &rhs); @@ -320,20 +310,19 @@ where CS: ConstraintSystem, { let mut result = Self::one(cs, self.get_params()); - let mut found_one = false; - - for bit in BitIterator::new(exponent) { - let apply_squaring = Boolean::allocated_constant(cs, found_one); - let result_squared = result.square(cs); - result = Self::conditionally_select(cs, apply_squaring, &result_squared, &result); - if !found_one { - found_one = bit; - } + let mut base = self.clone(); + + for i in BitIterator::new(exponent) { + let mut squared = result.square(cs); + let mut squared_and_multiplied = squared.mul(cs, &mut base); + let shall_multiply = Boolean::allocated_constant(cs, i); - let result_multiplied = result.mul(cs, self); - let apply_multiplication = Boolean::allocated_constant(cs, bit); - result = - Self::conditionally_select(cs, apply_multiplication, &result_multiplied, &result); + result = Self::conditionally_select( + cs, + shall_multiply, + &mut squared_and_multiplied, + &mut squared, + ); result.normalize(cs); } diff --git a/crates/boojum/src/gadgets/tower_extension/fq12.rs b/crates/boojum/src/gadgets/tower_extension/fq12.rs index 998b4f8..7513cc4 100644 --- a/crates/boojum/src/gadgets/tower_extension/fq12.rs +++ b/crates/boojum/src/gadgets/tower_extension/fq12.rs @@ -80,23 +80,21 @@ where where CS: ConstraintSystem, { - let mut result = Self::one(cs, self.c0.c0.get_params()); - let mut found_one = false; + let mut result = Self::one(cs, self.get_params()); + let mut base = self.clone(); for i in BitIterator::new(exponent) { - let apply_squaring = Boolean::allocated_constant(cs, found_one); - let result_squared = result.square(cs); - result = Self::conditionally_select(cs, apply_squaring, &result_squared, &result); - if !found_one { - found_one = i; - } - - let result_multiplied = result.mul(cs, self); - let apply_multiplication = Boolean::allocated_constant(cs, i); - result = - Self::conditionally_select(cs, apply_multiplication, &result_multiplied, &result); - - // Normalize the result to stay in field + let mut squared = result.square(cs); + let mut squared_and_multiplied = squared.mul(cs, &mut base); + let shall_multiply = Boolean::allocated_constant(cs, i); + + result = Self::conditionally_select( + cs, + shall_multiply, + &mut squared_and_multiplied, + &mut squared, + ); + NonNativeField::normalize(&mut result, cs); } @@ -242,17 +240,20 @@ where where CS: ConstraintSystem, { - let mut ab = self.c0.mul(cs, &mut self.c1); - let mut c0c1 = self.c0.add(cs, &mut self.c1); + // Karatsuba: + let mut a0 = self.c0.clone(); + let mut a1 = self.c1.clone(); + + let mut v0 = a0.square(cs); + let mut v1 = a1.square(cs); - let mut c0 = self.c1.mul_by_nonresidue(cs); - let mut c0 = c0.add(cs, &mut self.c0); - let mut c0 = c0.mul(cs, &mut c0c1); - let mut c0 = c0.sub(cs, &mut ab); + let mut tmp = v1.mul_by_nonresidue(cs); // c1^2 * w + let c0 = v0.add(cs, &mut tmp); // c0^2 + c1^2 * w - let c1 = ab.double(cs); - let mut ab_residue = ab.mul_by_nonresidue(cs); - let c0 = c0.sub(cs, &mut ab_residue); + let mut a0_plus_a1 = a0.add(cs, &mut a1); + let mut a0_plus_a1_squared = a0_plus_a1.square(cs); + let mut tmp = a0_plus_a1_squared.sub(cs, &mut v0); + let c1 = tmp.sub(cs, &mut v1); // (c0 + c1)^2 - c0^2 - c1^2 <==> 2c0c1 Self::new(c0, c1) } diff --git a/crates/boojum/src/gadgets/tower_extension/fq2.rs b/crates/boojum/src/gadgets/tower_extension/fq2.rs index 24b5955..c9c810a 100644 --- a/crates/boojum/src/gadgets/tower_extension/fq2.rs +++ b/crates/boojum/src/gadgets/tower_extension/fq2.rs @@ -231,18 +231,16 @@ where where CS: ConstraintSystem, { - // Finding 8(a0 + a1*u) + // Finding 9(c0 + c1*u) let mut new = self.double(cs); new = new.double(cs); new = new.double(cs); + new = new.add(cs, self); // c0 <- 9*c0 - c1 - let mut c0 = new.c0.add(cs, &mut self.c0); - let c0 = c0.sub(cs, &mut self.c1); - - // c1 <- c0 + 9*c1 - let mut c1 = new.c1.add(cs, &mut self.c1); - let c1 = c1.add(cs, &mut self.c0); + let c0 = new.c0.sub(cs, &mut self.c1); + // c1 <- 9*c1 + c0 + let c1 = new.c1.add(cs, &mut self.c0); Self::new(c0, c1) } diff --git a/crates/boojum/src/gadgets/tower_extension/fq6.rs b/crates/boojum/src/gadgets/tower_extension/fq6.rs index 5409247..9b47731 100644 --- a/crates/boojum/src/gadgets/tower_extension/fq6.rs +++ b/crates/boojum/src/gadgets/tower_extension/fq6.rs @@ -258,25 +258,6 @@ where Self::new(c0, c1, c2) } - /// Multiplies the element `a=a0+a1*v+a2*v^2` in `Fq6` by the element `b = b1*v` - pub fn mul_by_c1(&mut self, cs: &mut CS, c1: &mut Fq2) -> Self - where - CS: ConstraintSystem, - { - let mut b_b = self.c1.mul(cs, c1); - let mut tmp = self.c1.add(cs, &mut self.c2); - - let mut t1 = c1.mul(cs, &mut tmp); - let mut t1 = t1.sub(cs, &mut b_b); - let t1 = t1.mul_by_nonresidue(cs); - - let mut tmp = self.c0.add(cs, &mut self.c1); - let mut t2 = c1.mul(cs, &mut tmp); - let t2 = t2.sub(cs, &mut b_b); - - Self::new(t1, t2, b_b) - } - /// Multiplies the element `a=a0+a1*v+a2*v^2` in `Fq6` by the element in `NonNativeField` pub fn mul_by_fq(&mut self, cs: &mut CS, c0: &mut NN) -> Self where @@ -303,6 +284,30 @@ where Self::new(t0, t1, t2) } + /// Multiplies the element `a=a0+a1*v+a2*v^2` in `Fq6` by the element `b = b1*v` + pub fn mul_by_c1(&mut self, cs: &mut CS, c1: &mut Fq2) -> Self + where + CS: ConstraintSystem, + { + // Suppose a = a0 + a1*v + a2*v^2. In this case, + // (a0 + a1*v + a2*v^2) * c1 * v = + // a2*c1*\xi + a0*c1*v + a1*c1*v^2 + + let mut a0 = self.c0.clone(); + let mut a1 = self.c1.clone(); + let mut a2 = self.c2.clone(); + + // new_c0 <- a2*c1*\xi + let mut new_c0 = a2.mul(cs, c1); + let new_c0 = new_c0.mul_by_nonresidue(cs); + // new_c1 <- a0*c1 + let new_c1 = a0.mul(cs, c1); + // new_c2 <- a1*c1 + let new_c2 = a1.mul(cs, c1); + + Self::new(new_c0, new_c1, new_c2) + } + /// Multiplies the element `a=a0+a1*v+a2*v^2` in `Fq6` by the element `c2*v^2` pub fn mul_by_c2(&mut self, cs: &mut CS, c2: &mut Fq2) -> Self where @@ -311,22 +316,17 @@ where // Suppose a = a0 + a1*v + a2*v^2. In this case, // (a0 + a1*v + a2*v^2) * c2 * v^2 = // a1*c2*\xi + a2*c2*\xi*v + a0*c2*v^2 - // NOTE: There might be a better way to calculate three coefficients - // without using 3 multiplications and 2 mul_by_nonresidues, similarly to mul_by_c1 - // Setting coefficients let mut a0 = self.c0.clone(); let mut a1 = self.c1.clone(); let mut a2 = self.c2.clone(); - // new_c0 <- a1*c2*\xi - let mut new_c0 = a1.mul(cs, c2); - new_c0 = new_c0.mul_by_nonresidue(cs); + let mut product = c2.mul_by_nonresidue(cs); + // new_c0 <- a1*c2*\xi + let new_c0 = a1.mul(cs, &mut product); // new_c1 <- a2*c2*\xi - let mut new_c1 = a2.mul(cs, c2); - new_c1 = new_c1.mul_by_nonresidue(cs); - + let new_c1 = a2.mul(cs, &mut product); // new_c2 <- a0*c2 let new_c2 = a0.mul(cs, c2); @@ -343,27 +343,34 @@ where where CS: ConstraintSystem, { - let mut a_a = self.c0.mul(cs, c0); - let mut b_b = self.c1.mul(cs, c1); + // (a0+a1v+a2v^2)(b0+b1*v) + // a0b0 +a1b0v + a2b0v^2 + a0b1v + a1b1v^2 + a2b1v^3 + // c0 = a0b0 + a2b1 xi + // c1 = a1b0 + a0b1 => (a1 + a0)(b0 + b1) - a1b1 - a0b0 + // c2 = a2b0 + a1b1 - let mut tmp = self.c1.add(cs, &mut self.c2); - let mut t1 = c1.mul(cs, &mut tmp); - let mut t1 = t1.sub(cs, &mut b_b); - let mut t1 = t1.mul_by_nonresidue(cs); - let t1 = t1.add(cs, &mut a_a); + let mut a0 = self.c0.clone(); + let mut a1 = self.c1.clone(); + let mut a2 = self.c2.clone(); + let mut b0 = c0.clone(); + let mut b1 = c1.clone(); - let mut tmp = self.c0.add(cs, &mut self.c2); - let mut t3 = c0.mul(cs, &mut tmp); - let mut t3 = t3.sub(cs, &mut a_a); - let t3 = t3.add(cs, &mut b_b); + let mut a0b0 = a0.mul(cs, &mut b0); + let mut a2b1 = a2.mul(cs, &mut b1); + let mut c0 = a2b1.mul_by_nonresidue(cs); + c0 = a0b0.add(cs, &mut c0); - let mut t2 = c0.add(cs, c1); - let mut tmp = self.c0.add(cs, &mut self.c1); - let mut t2 = t2.mul(cs, &mut tmp); - let mut t2 = t2.sub(cs, &mut a_a); - let t2 = t2.sub(cs, &mut b_b); + let mut a1b1 = a1.mul(cs, &mut b1); + let mut a1_plus_a0 = a1.add(cs, &mut a0); + let mut b0_plus_b1 = b0.add(cs, &mut b1); + let mut c1 = a1_plus_a0.mul(cs, &mut b0_plus_b1); + c1 = c1.sub(cs, &mut a1b1); + c1 = c1.sub(cs, &mut a0b0); - Self::new(t1, t2, t3) + let mut a2b0 = a2.mul(cs, &mut b0); + let c2 = a2b0.add(cs, &mut a1b1); + + Self::new(c0, c1, c2) } /// Find the inverse element in Fq6 diff --git a/crates/boojum/src/gadgets/u256/mod.rs b/crates/boojum/src/gadgets/u256/mod.rs index e97582d..c0b4785 100644 --- a/crates/boojum/src/gadgets/u256/mod.rs +++ b/crates/boojum/src/gadgets/u256/mod.rs @@ -394,9 +394,8 @@ impl UInt256 { let q = UInt256::allocate(cs, q); let r = UInt256::allocate(cs, r); - let mod_is_zero = Boolean::allocate(cs, m.is_zero()); + let mod_is_zero = modulo.is_zero(cs); let bool_true = Boolean::allocated_constant(cs, true); - let bool_false = Boolean::allocated_constant(cs, false); let (_, m_ge_than_r) = r.overflowing_sub(cs, &modulo); let m_ge_than_r = Boolean::conditionally_select(cs, mod_is_zero, &bool_true, &m_ge_than_r); @@ -408,8 +407,7 @@ impl UInt256 { let rhs = q.widening_mul(cs, &modulo, 8, 8); let r_u512 = r.to_u512(cs); - let (rhs, overflow) = rhs.overflowing_add(cs, &r_u512); - Boolean::enforce_equal(cs, &overflow, &bool_false); + let (rhs, _) = rhs.overflowing_add(cs, &r_u512); let are_equal = UInt512::equals(cs, &lhs, &rhs); Boolean::enforce_equal(cs, &are_equal, &bool_true);