Skip to content

Commit

Permalink
fix: post-review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
NikitaMasych committed Dec 11, 2024
1 parent da49f64 commit 6af6df5
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 110 deletions.
8 changes: 2 additions & 6 deletions crates/boojum/src/gadgets/curves/sw_projective/extended.rs
Original file line number Diff line number Diff line change
Expand Up @@ -548,13 +548,9 @@ where
let one_nn = NN::allocated_constant(cs, T::one(), &params);
let mut safe_z = NN::conditionally_select(cs, is_point_at_infty, &one_nn, &self.z);
let mut safe_z_squared = safe_z.square(cs);
safe_z_squared.normalize(cs);
let mut safe_z_cubed = safe_z.mul(cs, &mut safe_z_squared);
safe_z_cubed.normalize(cs);
let mut x_for_safe_z = self.x.div_unchecked(cs, &mut safe_z_squared);
x_for_safe_z.normalize(cs);
let mut y_for_safe_z = self.y.div_unchecked(cs, &mut safe_z_cubed);
y_for_safe_z.normalize(cs);
let x_for_safe_z = self.x.div_unchecked(cs, &mut safe_z_squared);
let y_for_safe_z = self.y.div_unchecked(cs, &mut safe_z_cubed);

let (default_x, default_y) = default.into_xy_unchecked();

Expand Down
37 changes: 13 additions & 24 deletions crates/boojum/src/gadgets/tower_extension/algebraic_torus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,21 +257,11 @@ where
let mut sum = self.encoding.clone().add(cs, &mut other.encoding);
let lhs = encoding_new.clone().mul(cs, &mut sum);

// rhs = {(g + g') == 0} ? zero : (g * g' + \gamma)
// rhs = g * g' + \gamma
let mut gamma = Fq6::gamma(cs, params);
let mut rhs = self.encoding.clone().mul(cs, &mut other.encoding);
let rhs = rhs.add(cs, &mut gamma);

let zero = Fq6::zero(cs, params);
let is_zero_sum = sum.is_zero(cs);

let rhs = <Fq6<F, T, NonNativeFieldOverU16<F, T, N>, P::Ex6>>::conditionally_select(
cs,
is_zero_sum,
&zero,
&rhs,
);

// Enforce equality
Fq6::enforce_equal(cs, &lhs, &rhs);

Expand Down Expand Up @@ -320,20 +310,19 @@ where
CS: ConstraintSystem<F>,
{
let mut result = Self::one(cs, self.get_params());
let mut found_one = false;

for bit in BitIterator::new(exponent) {
let apply_squaring = Boolean::allocated_constant(cs, found_one);
let result_squared = result.square(cs);
result = Self::conditionally_select(cs, apply_squaring, &result_squared, &result);
if !found_one {
found_one = bit;
}
let mut base = self.clone();

for i in BitIterator::new(exponent) {
let mut squared = result.square(cs);
let mut squared_and_multiplied = squared.mul(cs, &mut base);
let shall_multiply = Boolean::allocated_constant(cs, i);

let result_multiplied = result.mul(cs, self);
let apply_multiplication = Boolean::allocated_constant(cs, bit);
result =
Self::conditionally_select(cs, apply_multiplication, &result_multiplied, &result);
result = Self::conditionally_select(
cs,
shall_multiply,
&mut squared_and_multiplied,
&mut squared,
);

result.normalize(cs);
}
Expand Down
49 changes: 25 additions & 24 deletions crates/boojum/src/gadgets/tower_extension/fq12.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,23 +80,21 @@ where
where
CS: ConstraintSystem<F>,
{
let mut result = Self::one(cs, self.c0.c0.get_params());
let mut found_one = false;
let mut result = Self::one(cs, self.get_params());
let mut base = self.clone();

for i in BitIterator::new(exponent) {
let apply_squaring = Boolean::allocated_constant(cs, found_one);
let result_squared = result.square(cs);
result = Self::conditionally_select(cs, apply_squaring, &result_squared, &result);
if !found_one {
found_one = i;
}

let result_multiplied = result.mul(cs, self);
let apply_multiplication = Boolean::allocated_constant(cs, i);
result =
Self::conditionally_select(cs, apply_multiplication, &result_multiplied, &result);

// Normalize the result to stay in field
let mut squared = result.square(cs);
let mut squared_and_multiplied = squared.mul(cs, &mut base);
let shall_multiply = Boolean::allocated_constant(cs, i);

result = Self::conditionally_select(
cs,
shall_multiply,
&mut squared_and_multiplied,
&mut squared,
);

NonNativeField::normalize(&mut result, cs);
}

Expand Down Expand Up @@ -242,17 +240,20 @@ where
where
CS: ConstraintSystem<F>,
{
let mut ab = self.c0.mul(cs, &mut self.c1);
let mut c0c1 = self.c0.add(cs, &mut self.c1);
// Karatsuba:
let mut a0 = self.c0.clone();
let mut a1 = self.c1.clone();

let mut v0 = a0.square(cs);
let mut v1 = a1.square(cs);

let mut c0 = self.c1.mul_by_nonresidue(cs);
let mut c0 = c0.add(cs, &mut self.c0);
let mut c0 = c0.mul(cs, &mut c0c1);
let mut c0 = c0.sub(cs, &mut ab);
let mut tmp = v1.mul_by_nonresidue(cs); // c1^2 * w
let c0 = v0.add(cs, &mut tmp); // c0^2 + c1^2 * w

let c1 = ab.double(cs);
let mut ab_residue = ab.mul_by_nonresidue(cs);
let c0 = c0.sub(cs, &mut ab_residue);
let mut a0_plus_a1 = a0.add(cs, &mut a1);
let mut a0_plus_a1_squared = a0_plus_a1.square(cs);
let mut tmp = a0_plus_a1_squared.sub(cs, &mut v0);
let c1 = tmp.sub(cs, &mut v1); // (c0 + c1)^2 - c0^2 - c1^2 <==> 2c0c1

Self::new(c0, c1)
}
Expand Down
12 changes: 5 additions & 7 deletions crates/boojum/src/gadgets/tower_extension/fq2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,18 +231,16 @@ where
where
CS: ConstraintSystem<F>,
{
// Finding 8(a0 + a1*u)
// Finding 9(c0 + c1*u)
let mut new = self.double(cs);
new = new.double(cs);
new = new.double(cs);
new = new.add(cs, self);

// c0 <- 9*c0 - c1
let mut c0 = new.c0.add(cs, &mut self.c0);
let c0 = c0.sub(cs, &mut self.c1);

// c1 <- c0 + 9*c1
let mut c1 = new.c1.add(cs, &mut self.c1);
let c1 = c1.add(cs, &mut self.c0);
let c0 = new.c0.sub(cs, &mut self.c1);
// c1 <- 9*c1 + c0
let c1 = new.c1.add(cs, &mut self.c0);

Self::new(c0, c1)
}
Expand Down
97 changes: 52 additions & 45 deletions crates/boojum/src/gadgets/tower_extension/fq6.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,25 +258,6 @@ where
Self::new(c0, c1, c2)
}

/// Multiplies the element `a=a0+a1*v+a2*v^2` in `Fq6` by the element `b = b1*v`
pub fn mul_by_c1<CS>(&mut self, cs: &mut CS, c1: &mut Fq2<F, T, NN, P::Ex2>) -> Self
where
CS: ConstraintSystem<F>,
{
let mut b_b = self.c1.mul(cs, c1);
let mut tmp = self.c1.add(cs, &mut self.c2);

let mut t1 = c1.mul(cs, &mut tmp);
let mut t1 = t1.sub(cs, &mut b_b);
let t1 = t1.mul_by_nonresidue(cs);

let mut tmp = self.c0.add(cs, &mut self.c1);
let mut t2 = c1.mul(cs, &mut tmp);
let t2 = t2.sub(cs, &mut b_b);

Self::new(t1, t2, b_b)
}

/// Multiplies the element `a=a0+a1*v+a2*v^2` in `Fq6` by the element in `NonNativeField`
pub fn mul_by_fq<CS>(&mut self, cs: &mut CS, c0: &mut NN) -> Self
where
Expand All @@ -303,6 +284,30 @@ where
Self::new(t0, t1, t2)
}

/// Multiplies the element `a=a0+a1*v+a2*v^2` in `Fq6` by the element `b = b1*v`
pub fn mul_by_c1<CS>(&mut self, cs: &mut CS, c1: &mut Fq2<F, T, NN, P::Ex2>) -> Self
where
CS: ConstraintSystem<F>,
{
// Suppose a = a0 + a1*v + a2*v^2. In this case,
// (a0 + a1*v + a2*v^2) * c1 * v =
// a2*c1*\xi + a0*c1*v + a1*c1*v^2

let mut a0 = self.c0.clone();
let mut a1 = self.c1.clone();
let mut a2 = self.c2.clone();

// new_c0 <- a2*c1*\xi
let mut new_c0 = a2.mul(cs, c1);
let new_c0 = new_c0.mul_by_nonresidue(cs);
// new_c1 <- a0*c1
let new_c1 = a0.mul(cs, c1);
// new_c2 <- a1*c1
let new_c2 = a1.mul(cs, c1);

Self::new(new_c0, new_c1, new_c2)
}

/// Multiplies the element `a=a0+a1*v+a2*v^2` in `Fq6` by the element `c2*v^2`
pub fn mul_by_c2<CS>(&mut self, cs: &mut CS, c2: &mut Fq2<F, T, NN, P::Ex2>) -> Self
where
Expand All @@ -311,22 +316,17 @@ where
// Suppose a = a0 + a1*v + a2*v^2. In this case,
// (a0 + a1*v + a2*v^2) * c2 * v^2 =
// a1*c2*\xi + a2*c2*\xi*v + a0*c2*v^2
// NOTE: There might be a better way to calculate three coefficients
// without using 3 multiplications and 2 mul_by_nonresidues, similarly to mul_by_c1

// Setting coefficients
let mut a0 = self.c0.clone();
let mut a1 = self.c1.clone();
let mut a2 = self.c2.clone();

// new_c0 <- a1*c2*\xi
let mut new_c0 = a1.mul(cs, c2);
new_c0 = new_c0.mul_by_nonresidue(cs);
let mut product = c2.mul_by_nonresidue(cs);

// new_c0 <- a1*c2*\xi
let new_c0 = a1.mul(cs, &mut product);
// new_c1 <- a2*c2*\xi
let mut new_c1 = a2.mul(cs, c2);
new_c1 = new_c1.mul_by_nonresidue(cs);

let new_c1 = a2.mul(cs, &mut product);
// new_c2 <- a0*c2
let new_c2 = a0.mul(cs, c2);

Expand All @@ -343,27 +343,34 @@ where
where
CS: ConstraintSystem<F>,
{
let mut a_a = self.c0.mul(cs, c0);
let mut b_b = self.c1.mul(cs, c1);
// (a0+a1v+a2v^2)(b0+b1*v)
// a0b0 +a1b0v + a2b0v^2 + a0b1v + a1b1v^2 + a2b1v^3
// c0 = a0b0 + a2b1 xi
// c1 = a1b0 + a0b1 => (a1 + a0)(b0 + b1) - a1b1 - a0b0
// c2 = a2b0 + a1b1

let mut tmp = self.c1.add(cs, &mut self.c2);
let mut t1 = c1.mul(cs, &mut tmp);
let mut t1 = t1.sub(cs, &mut b_b);
let mut t1 = t1.mul_by_nonresidue(cs);
let t1 = t1.add(cs, &mut a_a);
let mut a0 = self.c0.clone();
let mut a1 = self.c1.clone();
let mut a2 = self.c2.clone();
let mut b0 = c0.clone();
let mut b1 = c1.clone();

let mut tmp = self.c0.add(cs, &mut self.c2);
let mut t3 = c0.mul(cs, &mut tmp);
let mut t3 = t3.sub(cs, &mut a_a);
let t3 = t3.add(cs, &mut b_b);
let mut a0b0 = a0.mul(cs, &mut b0);
let mut a2b1 = a2.mul(cs, &mut b1);
let mut c0 = a2b1.mul_by_nonresidue(cs);
c0 = a0b0.add(cs, &mut c0);

let mut t2 = c0.add(cs, c1);
let mut tmp = self.c0.add(cs, &mut self.c1);
let mut t2 = t2.mul(cs, &mut tmp);
let mut t2 = t2.sub(cs, &mut a_a);
let t2 = t2.sub(cs, &mut b_b);
let mut a1b1 = a1.mul(cs, &mut b1);
let mut a1_plus_a0 = a1.add(cs, &mut a0);
let mut b0_plus_b1 = b0.add(cs, &mut b1);
let mut c1 = a1_plus_a0.mul(cs, &mut b0_plus_b1);
c1 = c1.sub(cs, &mut a1b1);
c1 = c1.sub(cs, &mut a0b0);

Self::new(t1, t2, t3)
let mut a2b0 = a2.mul(cs, &mut b0);
let c2 = a2b0.add(cs, &mut a1b1);

Self::new(c0, c1, c2)
}

/// Find the inverse element in Fq6
Expand Down
6 changes: 2 additions & 4 deletions crates/boojum/src/gadgets/u256/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,8 @@ impl<F: SmallField> UInt256<F> {
let q = UInt256::allocate(cs, q);
let r = UInt256::allocate(cs, r);

let mod_is_zero = Boolean::allocate(cs, m.is_zero());
let mod_is_zero = modulo.is_zero(cs);
let bool_true = Boolean::allocated_constant(cs, true);
let bool_false = Boolean::allocated_constant(cs, false);

let (_, m_ge_than_r) = r.overflowing_sub(cs, &modulo);
let m_ge_than_r = Boolean::conditionally_select(cs, mod_is_zero, &bool_true, &m_ge_than_r);
Expand All @@ -408,8 +407,7 @@ impl<F: SmallField> UInt256<F> {

let rhs = q.widening_mul(cs, &modulo, 8, 8);
let r_u512 = r.to_u512(cs);
let (rhs, overflow) = rhs.overflowing_add(cs, &r_u512);
Boolean::enforce_equal(cs, &overflow, &bool_false);
let (rhs, _) = rhs.overflowing_add(cs, &r_u512);

let are_equal = UInt512::equals(cs, &lhs, &rhs);
Boolean::enforce_equal(cs, &are_equal, &bool_true);
Expand Down

0 comments on commit 6af6df5

Please sign in to comment.