Skip to content

Commit

Permalink
now tests passes
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenfeizhang committed Dec 9, 2024
1 parent f40f589 commit f1fb7a5
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 35 deletions.
2 changes: 2 additions & 0 deletions expander_compiler/src/frontend/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ pub trait BasicAPI<C: Config> {
binary_op!(xor);
binary_op!(or);
binary_op!(and);

fn display(&self, _x: impl ToVariableOrValue<C::CircuitField>) {}
fn div(
&mut self,
x: impl ToVariableOrValue<C::CircuitField>,
Expand Down
6 changes: 6 additions & 0 deletions expander_compiler/src/frontend/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ pub struct DebugBuilder<C: Config> {
}

impl<C: Config> BasicAPI<C> for DebugBuilder<C> {
fn display(&self, x: impl ToVariableOrValue<<C as Config>::CircuitField>) {
let x = self.convert_to_value(x);
println!("x: {:?}", x);
}

fn add(
&mut self,
x: impl ToVariableOrValue<C::CircuitField>,
Expand All @@ -36,6 +41,7 @@ impl<C: Config> BasicAPI<C> for DebugBuilder<C> {
) -> Variable {
let x = self.convert_to_value(x);
let y = self.convert_to_value(y);
// println!("sub x: {:?}, y: {:?}", x, y);
self.return_as_variable(x - y)
}
fn mul(
Expand Down
10 changes: 10 additions & 0 deletions expander_compiler/src/frontend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,22 @@ pub mod extra {
let mut inputs = Vec::new();
let mut public_inputs = Vec::new();
assignment.dump_into(&mut inputs, &mut public_inputs);

println!("input len: {}", inputs.len());
println!("public input len: {}", public_inputs.len());

let (mut root_builder, input_variables, public_input_variables) =
DebugBuilder::<C>::new(inputs, public_inputs);
let mut circuit = circuit.clone();
let mut vars_ptr = input_variables.as_slice();
let mut public_vars_ptr = public_input_variables.as_slice();
circuit.load_from(&mut vars_ptr, &mut public_vars_ptr);

for (i, v) in input_variables.iter().enumerate() {
println!("{}: {:?} {}", i, root_builder.value_of(v), v.id);
}
println!("num vars: {:?}", circuit.num_vars());

circuit.define(&mut root_builder);
}
}
Expand Down
22 changes: 19 additions & 3 deletions rsa_circuit/src/tests/u2048_mul_no_mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use expander_compiler::{
declare_circuit,
frontend::{BN254Config, Define, Variable, API},
};
use extra::{debug_eval, DebugAPI};
use halo2curves::bn256::Fr;
use num_bigint::BigUint;
use num_traits::Num;
Expand All @@ -19,14 +20,26 @@ declare_circuit!(MulNoModCircuit {
result: [Variable; 2 * N_LIMBS],
});

impl Define<BN254Config> for MulNoModCircuit<Variable> {
fn define(&self, builder: &mut API<BN254Config>) {
impl GenericDefine<BN254Config> for MulNoModCircuit<Variable> {
fn define<Builder: RootAPI<BN254Config>>(&self, builder: &mut Builder) {
let x = U2048Variable::from_raw(self.x);
let y = U2048Variable::from_raw(self.y);
let two_to_120 = builder.constant(BN_TWO_TO_120);

let res = U2048Variable::mul_without_mod_reduction(&x, &y, &two_to_120, builder);

for i in 0..2 * N_LIMBS {
// println!("{:?}",
// <Builder as DebugAPI<BN254Config>>::value_of(&builder, res[i]));
println!("{i}");
builder.display(res[i]);
builder.display(self.result[i]);
}

for i in 0..2 * N_LIMBS {
// println!("{:?}",
// <Builder as DebugAPI<BN254Config>>::value_of(&builder, res[i]));

builder.assert_is_equal(res[i], self.result[i]);
}
}
Expand Down Expand Up @@ -60,7 +73,7 @@ impl MulNoModCircuit<Fr> {
}
#[test]
fn test_mul_without_mod() {
let compile_result = compile(&MulNoModCircuit::default()).unwrap();
let compile_result = compile_generic(&MulNoModCircuit::default()).unwrap();

{
// Test case 1: Simple multiplication with no carries
Expand Down Expand Up @@ -291,12 +304,15 @@ fn test_mul_without_mod() {
}

let assignment = MulNoModCircuit::<Fr>::create_circuit(x, x, result);

let witness = compile_result
.witness_solver
.solve_witness(&assignment)
.unwrap();
let output = compile_result.layered_circuit.run(&witness);

debug_eval(&MulNoModCircuit::default(), &assignment);

println!("x");
for i in 0..N_LIMBS {
println!("{} {:0x?}", i, x[i]);
Expand Down
22 changes: 12 additions & 10 deletions rsa_circuit/src/u120.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
use expander_compiler::frontend::{extra::UnconstrainedAPI, BN254Config, BasicAPI, Variable, API};
use expander_compiler::frontend::{
extra::UnconstrainedAPI, BN254Config, BasicAPI, RootAPI, Variable, API,
};

#[inline]
// TODO:
// Assert the variable x is 120 bits, via LogUp
pub fn range_proof_u120(_x: &Variable, _builder: &mut API<BN254Config>) {}
pub fn range_proof_u120<Builder: RootAPI<BN254Config>>(_x: &Variable, _builder: &mut Builder) {}

// Accumulate up to 2^120 variables
pub fn accumulate_u120(
pub fn accumulate_u120<Builder: RootAPI<BN254Config>>(
x: &[Variable],
two_to_120: &Variable,
builder: &mut API<BN254Config>,
builder: &mut Builder,
) -> (Variable, Variable) {
assert!(x.len() > 1, "length is {}", x.len());

Expand Down Expand Up @@ -42,12 +44,12 @@ pub fn accumulate_u120(
// Does not ensure:
// - x, y are 120 bits
// - carry_in is 1 bit
pub(crate) fn add_u120(
pub(crate) fn add_u120<Builder: RootAPI<BN254Config>>(
x: &Variable,
y: &Variable,
carry_in: &Variable,
two_to_120: &Variable,
builder: &mut API<BN254Config>,
builder: &mut Builder,
) -> (Variable, Variable) {
let x_plus_y = builder.add(x, y);
let sum = builder.add(x_plus_y, carry_in);
Expand All @@ -73,12 +75,12 @@ pub(crate) fn add_u120(
// Does not ensure:
// - x, y are 120 bits
// - carry_in is 120 bit
pub(crate) fn mul_u120(
pub(crate) fn mul_u120<Builder: RootAPI<BN254Config>>(
x: &Variable,
y: &Variable,
carry_in: &Variable,
two_to_120: &Variable,
builder: &mut API<BN254Config>,
builder: &mut Builder,
) -> (Variable, Variable) {
let x_mul_y = builder.mul(x, y);
let left = builder.add(x_mul_y, carry_in);
Expand All @@ -96,10 +98,10 @@ pub(crate) fn mul_u120(

// check if x < y
// assumption: x, y are 120 bits
pub(crate) fn is_less_than_u120(
pub(crate) fn is_less_than_u120<Builder: RootAPI<BN254Config>>(
x: &Variable,
y: &Variable,
builder: &mut API<BN254Config>,
builder: &mut Builder,
) -> Variable {
let diff = builder.sub(x, y);
let byte_decomp = crate::util::unconstrained_byte_decomposition(&diff, builder);
Expand Down
56 changes: 38 additions & 18 deletions rsa_circuit/src/u2048.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use expander_compiler::frontend::{extra::UnconstrainedAPI, BN254Config, BasicAPI, Variable, API};
use expander_compiler::frontend::{
extra::UnconstrainedAPI, BN254Config, BasicAPI, RootAPI, Variable, API,
};

use crate::{
constants::N_LIMBS,
Expand All @@ -18,10 +20,10 @@ impl U2048Variable {

#[inline]
// generate a bool variable for the comparison of two U2048 variables
pub fn unconstrained_greater_eq(
pub fn unconstrained_greater_eq<Builder: RootAPI<BN254Config>>(
&self,
other: &Self,
builder: &mut API<BN254Config>,
builder: &mut Builder,
) -> Variable {
// Start from most significant limb (N_LIMBS-1) and work down
let mut gt_flags = Vec::with_capacity(N_LIMBS);
Expand Down Expand Up @@ -52,7 +54,11 @@ impl U2048Variable {
}

#[inline]
pub fn assert_is_less_than(&self, other: &Self, builder: &mut API<BN254Config>) -> Variable {
pub fn assert_is_less_than<Builder: RootAPI<BN254Config>>(
&self,
other: &Self,
builder: &mut Builder,
) -> Variable {
let mut result = builder.constant(0);
let mut all_eq_so_far = builder.constant(1);

Expand Down Expand Up @@ -94,7 +100,11 @@ impl U2048Variable {

// Helper function to check if one U2048 is greater than or equal to another
#[inline]
pub fn assert_is_greater_eq(&self, other: &Self, builder: &mut API<BN254Config>) -> Variable {
pub fn assert_is_greater_eq<Builder: RootAPI<BN254Config>>(
&self,
other: &Self,
builder: &mut Builder,
) -> Variable {
let less = other.assert_is_less_than(self, builder);
let eq = self.assert_is_equal(other, builder);

Expand All @@ -109,7 +119,11 @@ impl U2048Variable {

// Helper function to check equality
#[inline]
pub fn assert_is_equal(&self, other: &Self, builder: &mut API<BN254Config>) -> Variable {
pub fn assert_is_equal<Builder: RootAPI<BN254Config>>(
&self,
other: &Self,
builder: &mut Builder,
) -> Variable {
let mut is_equal = builder.constant(1);

for i in 0..N_LIMBS {
Expand All @@ -126,14 +140,14 @@ impl U2048Variable {
#[inline]
// add two U2048 variables with mod reductions
// a + b = result + carry * modulus
pub fn assert_add(
pub fn assert_add<Builder: RootAPI<BN254Config>>(
x: &U2048Variable,
y: &U2048Variable,
result: &U2048Variable,
carry: &Variable,
modulus: &U2048Variable,
two_to_120: &Variable,
builder: &mut API<BN254Config>,
builder: &mut Builder,
) {
// First compute raw sum x + y with carries between limbs
let mut sum = vec![];
Expand Down Expand Up @@ -184,14 +198,14 @@ impl U2048Variable {
#[inline]
// assert multiplication of two U2048 variables
// x * y = result + carry * modulus
pub fn assert_mul(
pub fn assert_mul<Builder: RootAPI<BN254Config>>(
x: &U2048Variable,
y: &U2048Variable,
result: &U2048Variable,
carry: &U2048Variable,
modulus: &U2048Variable,
two_to_120: &Variable,
builder: &mut API<BN254Config>,
builder: &mut Builder,
) {
let zero = builder.constant(0);
// x * y
Expand Down Expand Up @@ -226,15 +240,15 @@ impl U2048Variable {
}

#[inline]
pub fn mul_without_mod_reduction(
pub fn mul_without_mod_reduction<Builder: RootAPI<BN254Config>>(
x: &U2048Variable,
y: &U2048Variable,
two_to_120: &Variable,
builder: &mut API<BN254Config>,
builder: &mut Builder,
) -> Vec<Variable> {
let zero = builder.constant(0);
let mut local_res = vec![zero; 2 * N_LIMBS];
let mut addition_carries = vec![zero; 2 * N_LIMBS];
let mut addition_carries = vec![zero; 2 * N_LIMBS + 1];

for i in 0..N_LIMBS {
for j in 0..N_LIMBS {
Expand All @@ -252,9 +266,10 @@ impl U2048Variable {
two_to_120,
builder,
);

local_res[target_position] = sum;
addition_carries[target_position] =
builder.add(addition_carries[target_position], new_carry);
addition_carries[target_position+1] =
builder.add(addition_carries[target_position+1], new_carry);

// update mul_carry to result[target+1]
let (sum, new_carry) = add_u120(
Expand All @@ -265,13 +280,18 @@ impl U2048Variable {
builder,
);
local_res[target_position + 1] = sum;
addition_carries[target_position + 1] =
builder.add(addition_carries[target_position + 1], new_carry);
addition_carries[target_position + 2] =
builder.add(addition_carries[target_position + 2], new_carry);
}
}
for i in 0..2 * N_LIMBS {
println!("{i}");
builder.display(local_res[i]);
builder.display(addition_carries[i]);
}

// integrate carries into result
let mut cur_carry = addition_carries[0];
let mut cur_carry = builder.constant(0);
for i in 0..2 * N_LIMBS {
(local_res[i], cur_carry) = add_u120(
&local_res[i],
Expand Down
8 changes: 4 additions & 4 deletions rsa_circuit/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use expander_compiler::frontend::{BN254Config, Variable, API};
use extra::UnconstrainedAPI;
use halo2curves::bn256::Fr;

pub(crate) fn unconstrained_byte_decomposition(
pub(crate) fn unconstrained_byte_decomposition<Builder: RootAPI<BN254Config>>(
x: &Variable,
builder: &mut API<BN254Config>,
builder: &mut Builder,
) -> Vec<Variable> {
let mut res = vec![];
let mut x = x.clone();
Expand All @@ -21,10 +21,10 @@ pub(crate) fn unconstrained_byte_decomposition(

// assert bit decomposition
// the constant_scalars are 2^8, 2^16, ... 2^248
pub fn byte_decomposition(
pub fn byte_decomposition<Builder: RootAPI<BN254Config>>(
x: &Variable,
constant_scalars: &[Variable],
builder: &mut API<BN254Config>,
builder: &mut Builder,
) -> Vec<Variable> {
let bytes = unconstrained_byte_decomposition(x, builder);
// todo: constraint each byte to be less than 256 via logup
Expand Down

0 comments on commit f1fb7a5

Please sign in to comment.