From cf0a57175a9628c8b7cd139ef42d3f8f35d1a3dd Mon Sep 17 00:00:00 2001 From: siq1 Date: Tue, 29 Oct 2024 09:21:17 +0700 Subject: [PATCH] wip --- expander_compiler/src/circuit/layered/opt.rs | 4 +- .../src/circuit/layered/serde.rs | 5 +- expander_compiler/src/compile/mod.rs | 76 ++++-- expander_compiler/src/frontend/builder.rs | 6 + expander_compiler/src/frontend/mod.rs | 10 +- expander_compiler/src/layering/compile.rs | 17 ++ .../src/layering/layer_layout.rs | 2 +- expander_compiler/src/layering/mod.rs | 10 +- expander_compiler/src/layering/tests.rs | 4 +- expander_compiler/src/lib.rs | 1 + expander_compiler/src/zkcuda/kernel.rs | 230 ++++++++++++++++++ expander_compiler/src/zkcuda/mod.rs | 2 + expander_compiler/src/zkcuda/traits.rs | 15 ++ 13 files changed, 348 insertions(+), 34 deletions(-) create mode 100644 expander_compiler/src/zkcuda/kernel.rs create mode 100644 expander_compiler/src/zkcuda/mod.rs create mode 100644 expander_compiler/src/zkcuda/traits.rs diff --git a/expander_compiler/src/circuit/layered/opt.rs b/expander_compiler/src/circuit/layered/opt.rs index 9bc232b..2262ffb 100644 --- a/expander_compiler/src/circuit/layered/opt.rs +++ b/expander_compiler/src/circuit/layered/opt.rs @@ -682,7 +682,7 @@ impl Circuit { mod tests { use crate::circuit::layered; use crate::field::FieldArith; - use crate::layering::compile; + use crate::layering::{compile, CompileOptions}; use crate::{ circuit::{ config::{Config, GF2Config as C}, @@ -711,7 +711,7 @@ mod tests { } }, } - let (lc, _) = compile(&root); + let (lc, _) = compile(&root, CompileOptions { is_zkcuda: false }); assert_eq!(lc.validate(), Ok(())); Some(lc) } diff --git a/expander_compiler/src/circuit/layered/serde.rs b/expander_compiler/src/circuit/layered/serde.rs index 99776ce..b0baf93 100644 --- a/expander_compiler/src/circuit/layered/serde.rs +++ b/expander_compiler/src/circuit/layered/serde.rs @@ -214,7 +214,10 @@ mod tests { config.seed = i + 10000; let root = RootCircuit::::random(&config); assert_eq!(root.validate(), Ok(())); - let (circuit, _) = crate::layering::compile(&root); + let (circuit, _) = crate::layering::compile( + &root, + crate::layering::CompileOptions { is_zkcuda: false }, + ); assert_eq!(circuit.validate(), Ok(())); let mut buf = Vec::new(); circuit.serialize_into(&mut buf).unwrap(); diff --git a/expander_compiler/src/compile/mod.rs b/expander_compiler/src/compile/mod.rs index a3fa6a0..29a52e1 100644 --- a/expander_compiler/src/compile/mod.rs +++ b/expander_compiler/src/compile/mod.rs @@ -47,9 +47,9 @@ fn print_stat(stat_name: &str, stat: usize, is_last: bool) { } } -pub fn compile( +pub fn compile_step_1( r_source: &ir::source::RootCircuit, -) -> Result<(ir::hint_normalized::RootCircuit, layered::Circuit), Error> { +) -> Result<(ir::hint_normalized::RootCircuit, InputMapping), Error> { r_source.validate()?; let mut src_im = InputMapping::new_identity(r_source.input_size()); @@ -78,19 +78,12 @@ pub fn compile( r_hint_normalized_opt .validate() .map_err(|e| e.prepend("hint normalized ir circuit invalid"))?; - let ho_stats = r_hint_normalized_opt.get_stats(); - print_info("built hint normalized ir"); - print_stat("numInputs", ho_stats.num_inputs, false); - print_stat("numConstraints", ho_stats.num_constraints, false); - print_stat("numInsns", ho_stats.num_insns, false); - print_stat("numVars", ho_stats.num_variables, false); - print_stat("numTerms", ho_stats.num_terms, true); - - let (r_hint_less, mut r_hint_exported) = r_hint_normalized_opt.remove_and_export_hints(); - r_hint_exported - .validate() - .map_err(|e| e.prepend("hint exported circuit invalid"))?; + Ok((r_hint_normalized_opt, src_im)) +} +pub fn compile_step_2( + r_hint_less: ir::hint_less::RootCircuit, +) -> Result<(ir::dest::RootCircuit, InputMapping), Error> { let mut hl_im = InputMapping::new_identity(r_hint_less.input_size()); let r_hint_less_opt = optimize_until_fixed_point(&r_hint_less, &mut hl_im, |r| { @@ -151,8 +144,12 @@ pub fn compile( r_dest_opt .validate_circuit_has_inputs() .map_err(|e| e.prepend("dest ir circuit invalid"))?; + Ok((r_dest_opt, hl_im)) +} - let (mut lc, dest_im) = layering::compile(&r_dest_opt); +pub fn compile_step_3( + mut lc: layered::Circuit, +) -> Result, Error> { lc.validate() .map_err(|e| e.prepend("layered circuit invalid"))?; @@ -172,6 +169,47 @@ pub fn compile( lc.validate() .map_err(|e| e.prepend("layered circuit invalid1"))?; lc.sort_everything(); // for deterministic output + Ok(lc) +} + +pub fn compile_step_4( + r_hint_exported: ir::hint_normalized::RootCircuit, + src_im: &mut InputMapping, +) -> Result, Error> { + r_hint_exported + .validate() + .map_err(|e| e.prepend("final hint exported circuit invalid"))?; + let r_hint_exported_opt = optimize_until_fixed_point(&r_hint_exported, src_im, |r| { + let (r, im) = r.remove_unreachable(); + (r, im) + }); + Ok(r_hint_exported_opt) +} + +pub fn compile( + r_source: &ir::source::RootCircuit, +) -> Result<(ir::hint_normalized::RootCircuit, layered::Circuit), Error> { + let (r_hint_normalized_opt, mut src_im) = compile_step_1(r_source)?; + + let ho_stats = r_hint_normalized_opt.get_stats(); + print_info("built hint normalized ir"); + print_stat("numInputs", ho_stats.num_inputs, false); + print_stat("numConstraints", ho_stats.num_constraints, false); + print_stat("numInsns", ho_stats.num_insns, false); + print_stat("numVars", ho_stats.num_variables, false); + print_stat("numTerms", ho_stats.num_terms, true); + + let (r_hint_less, mut r_hint_exported) = r_hint_normalized_opt.remove_and_export_hints(); + r_hint_exported + .validate() + .map_err(|e| e.prepend("hint exported circuit invalid"))?; + + let (r_dest_opt, mut hl_im) = compile_step_2(r_hint_less)?; + + let (lc, dest_im) = + layering::compile(&r_dest_opt, layering::CompileOptions { is_zkcuda: false }); + + let lc = compile_step_3(lc)?; let lc_stats = lc.get_stats(); print_info("built layered circuit"); @@ -193,14 +231,8 @@ pub fn compile( .iter() .map(|&x| x.max(1)) .collect(); - r_hint_exported - .validate() - .map_err(|e| e.prepend("final hint exported circuit invalid"))?; - let mut r_hint_exported_opt = optimize_until_fixed_point(&r_hint_exported, &mut src_im, |r| { - let (r, im) = r.remove_unreachable(); - (r, im) - }); + let mut r_hint_exported_opt = compile_step_4(r_hint_exported, &mut src_im)?; r_hint_exported_opt.add_back_removed_inputs(&src_im); r_hint_exported_opt .validate() diff --git a/expander_compiler/src/frontend/builder.rs b/expander_compiler/src/frontend/builder.rs index ca42c27..d7fc1ec 100644 --- a/expander_compiler/src/frontend/builder.rs +++ b/expander_compiler/src/frontend/builder.rs @@ -65,6 +65,12 @@ impl ToVariableOrValue for &Variable { } } +impl Variable { + pub fn id(&self) -> usize { + self.id + } +} + impl Builder { pub fn new(num_inputs: usize) -> (Self, Vec) { ( diff --git a/expander_compiler/src/frontend/mod.rs b/expander_compiler/src/frontend/mod.rs index 1b087b3..f9673cf 100644 --- a/expander_compiler/src/frontend/mod.rs +++ b/expander_compiler/src/frontend/mod.rs @@ -2,11 +2,11 @@ use builder::RootBuilder; use crate::circuit::{ir, layered}; -mod api; -mod builder; -mod circuit; -mod variables; -mod witness; +pub mod api; +pub mod builder; +pub mod circuit; +pub mod variables; +pub mod witness; pub use circuit::declare_circuit; pub type API = builder::RootBuilder; diff --git a/expander_compiler/src/layering/compile.rs b/expander_compiler/src/layering/compile.rs index 06080b7..c775294 100644 --- a/expander_compiler/src/layering/compile.rs +++ b/expander_compiler/src/layering/compile.rs @@ -9,7 +9,9 @@ use crate::circuit::{ }; use crate::utils::pool::Pool; +use super::layer_layout::merge_layouts; use super::layer_layout::{LayerLayout, LayerLayoutContext, LayerReq}; +use super::CompileOptions; pub struct CompileContext<'a, C: Config> { // the root circuit @@ -38,6 +40,8 @@ pub struct CompileContext<'a, C: Config> { pub input_order: Vec, pub root_has_constraints: bool, + + pub opts: CompileOptions, } pub struct IrContext<'a, C: Config> { @@ -112,6 +116,19 @@ impl<'a, C: Config> CompileContext<'a, C> { })); } self.layout_ids = layout_ids; + if self.opts.is_zkcuda { + let layout_vec = + merge_layouts(vec![], (0..self.circuits[&0].lcs[0].vars.len()).collect()); + let id = self.layer_layout_pool.add(&LayerLayout { + circuit_id: 0, + layer: 0, + size: layout_vec.len(), + inner: super::layer_layout::LayerLayoutInner::Dense { + placement: layout_vec, + }, + }); + self.layout_ids[0] = id; + } // 5. generate wires let mut layers = Vec::with_capacity(self.circuits[&0].output_layer); diff --git a/expander_compiler/src/layering/layer_layout.rs b/expander_compiler/src/layering/layer_layout.rs index bd29f2a..ba3f073 100644 --- a/expander_compiler/src/layering/layer_layout.rs +++ b/expander_compiler/src/layering/layer_layout.rs @@ -343,7 +343,7 @@ impl<'a, C: Config> CompileContext<'a, C> { } } -fn merge_layouts(s: Vec>, additional: Vec) -> Vec { +pub fn merge_layouts(s: Vec>, additional: Vec) -> Vec { // currently it's a simple greedy algorithm // sort groups by size, and then place them one by one // since their size are always 2^n, the result is aligned diff --git a/expander_compiler/src/layering/mod.rs b/expander_compiler/src/layering/mod.rs index c9f4cf7..ccf2344 100644 --- a/expander_compiler/src/layering/mod.rs +++ b/expander_compiler/src/layering/mod.rs @@ -14,7 +14,14 @@ mod wire; #[cfg(test)] mod tests; -pub fn compile(rc: &ir::dest::RootCircuit) -> (layered::Circuit, InputMapping) { +pub struct CompileOptions { + pub is_zkcuda: bool, +} + +pub fn compile( + rc: &ir::dest::RootCircuit, + opts: CompileOptions, +) -> (layered::Circuit, InputMapping) { let mut ctx = compile::CompileContext { rc, circuits: HashMap::new(), @@ -27,6 +34,7 @@ pub fn compile(rc: &ir::dest::RootCircuit) -> (layered::Circuit layers: Vec::new(), input_order: Vec::new(), root_has_constraints: false, + opts, }; ctx.compile(); let l0_size = ctx.compiled_circuits[ctx.layers[0]].num_inputs; diff --git a/expander_compiler/src/layering/tests.rs b/expander_compiler/src/layering/tests.rs index faeb6e7..db5ee96 100644 --- a/expander_compiler/src/layering/tests.rs +++ b/expander_compiler/src/layering/tests.rs @@ -7,7 +7,7 @@ use crate::circuit::{ use crate::field::FieldArith; -use super::compile; +use super::{compile, CompileOptions}; pub fn test_input( rc: &IrRootCircuit, @@ -29,7 +29,7 @@ pub fn compile_and_random_test( n_tests: usize, ) -> (layered::Circuit, InputMapping) { assert!(rc.validate().is_ok()); - let (lc, input_mapping) = compile(rc); + let (lc, input_mapping) = compile(rc, CompileOptions { is_zkcuda: false }); //print!("{}", lc); assert_eq!(lc.validate(), Ok(())); assert_eq!(rc.input_size(), input_mapping.cur_size()); diff --git a/expander_compiler/src/lib.rs b/expander_compiler/src/lib.rs index bfe42bf..e4fd48e 100644 --- a/expander_compiler/src/lib.rs +++ b/expander_compiler/src/lib.rs @@ -6,3 +6,4 @@ pub mod frontend; pub mod hints; pub mod layering; pub mod utils; +pub mod zkcuda; diff --git a/expander_compiler/src/zkcuda/kernel.rs b/expander_compiler/src/zkcuda/kernel.rs new file mode 100644 index 0000000..a5c250c --- /dev/null +++ b/expander_compiler/src/zkcuda/kernel.rs @@ -0,0 +1,230 @@ +use crate::circuit::{ + config::Config, + input_mapping::{InputMapping, EMPTY}, + ir::{self, expr}, + layered::Circuit as LayeredCircuit, +}; +use crate::field::FieldArith; +use crate::frontend::*; + +pub struct Kernel { + pub witness_solver: ir::hint_normalized::RootCircuit, + pub layered_circuit: LayeredCircuit, + pub io: Vec, + pub hint_input: Option, +} + +pub struct IOVecOffset { + pub len: usize, + pub input_offset: Option, + pub output_offset: Option, + pub witness_solver_input_offset: Option, + pub witness_solver_output_offset: Option, +} + +pub struct IOVecSpec { + pub len: usize, + pub is_input: bool, + pub is_output: bool, +} + +fn dup_inputs(api: &mut API, inputs: &Vec) -> Vec { + use extra::UnconstrainedAPI; + let mut res = vec![]; + for x in inputs { + res.push(api.unconstrained_identity(x)); + } + res +} + +pub fn compile_with_spec(f: F, io_specs: &[IOVecSpec]) -> Result, Error> +where + C: Config, + F: Fn(&mut API, &mut Vec>), +{ + let total_inputs = io_specs + .iter() + .map(|spec| spec.len * (spec.is_input as usize + spec.is_output as usize)) + .sum(); + let (mut root_builder, input_variables, _) = API::::new(total_inputs, 0); + let mut io_vars = vec![]; + let mut expected_outputs = vec![]; + let mut inputs_offsets = vec![]; + let mut expected_outputs_offsets = vec![]; + let mut global_input_offset = 0; + for spec in io_specs { + let mut cur_inputs = vec![]; + if spec.is_input { + for i in 0..spec.len { + cur_inputs.push(input_variables[global_input_offset + i]); + } + inputs_offsets.push(global_input_offset); + global_input_offset += spec.len; + } else { + for _ in 0..spec.len { + cur_inputs.push(root_builder.constant(0)); + } + inputs_offsets.push(0); + } + io_vars.push(cur_inputs); + } + let n_in = global_input_offset; + for spec in io_specs { + if spec.is_output { + let mut cur_outputs = vec![]; + for i in 0..spec.len { + cur_outputs.push(input_variables[global_input_offset + i]); + } + expected_outputs.push(cur_outputs); + expected_outputs_offsets.push(global_input_offset); + global_input_offset += spec.len; + } else { + expected_outputs.push(vec![]); + expected_outputs_offsets.push(0); + } + } + let mut io_off = vec![]; + for i in 0..io_specs.len() { + io_off.push(IOVecOffset { + len: io_specs[i].len, + input_offset: if io_specs[i].is_input { + Some(inputs_offsets[i]) + } else { + None + }, + output_offset: if io_specs[i].is_output { + Some(expected_outputs_offsets[i]) + } else { + None + }, + witness_solver_input_offset: if io_specs[i].is_input { + Some(inputs_offsets[i]) + } else { + None + }, + witness_solver_output_offset: if io_specs[i].is_output { + Some(expected_outputs_offsets[i]) + } else { + None + }, + }); + } + f(&mut root_builder, &mut io_vars); + let mut output_offsets = vec![]; + let mut global_output_offset = 0; + let mut output_vars = vec![]; + for (i, spec) in io_specs.iter().enumerate() { + if spec.is_output { + for (x, y) in io_vars[i].iter().zip(expected_outputs[i].iter()) { + root_builder.assert_is_equal(x, y); + output_vars.push(*x); + } + output_offsets.push(global_output_offset); + global_output_offset += spec.len; + } else { + output_offsets.push(0); + } + } + let dup_out = root_builder.memorized_simple_call(dup_inputs, &output_vars); + let output_vars_ids: Vec = dup_out.iter().map(|x| x.id()).collect(); + + // prevent optimization + let mut r_source = root_builder.build(); + let c0 = r_source.circuits.get_mut(&0).unwrap(); + for i in 1..=total_inputs { + c0.outputs.push(i); + } + c0.outputs.extend_from_slice(&output_vars_ids); + // compile step 1 + let (r_hint_normalized_opt, src_im) = crate::compile::compile_step_1(&r_source)?; + for (i, x) in src_im.mapping().iter().enumerate() { + assert_eq!(i, *x); + } + // export hints + let (mut r_hint_less, mut r_hint_exported) = r_hint_normalized_opt.remove_and_export_hints(); + // remove additional hints, move them to user outputs + let rl_c0 = r_hint_less.circuits.get_mut(&0).unwrap(); + let re_c0 = r_hint_exported.circuits.get_mut(&0).unwrap(); + let n_out = output_vars_ids.len(); + let off1 = re_c0.outputs.len() - n_out; + let off2 = n_in; + for i in 0..n_out { + re_c0.outputs.swap(off1 + i, off2 + i); + } + rl_c0.num_inputs -= n_out; + let mut add_insns = vec![]; + for i in 0..n_out { + add_insns.push(ir::hint_less::Instruction::LinComb(expr::LinComb { + terms: vec![expr::LinCombTerm { + var: n_in + i + 1, + coef: C::CircuitField::one(), + }], + constant: C::CircuitField::zero(), + })); + } + add_insns.extend_from_slice(&rl_c0.instructions); + rl_c0.instructions = add_insns; + assert_eq!(rl_c0.outputs.len(), n_in + n_out * 2); + rl_c0.outputs.truncate(n_in + n_out); + let num_inputs_with_hint = rl_c0.num_inputs; + // compile step 2 + let (mut r_dest_opt, hl_im) = crate::compile::compile_step_2(r_hint_less)?; + for (i, x) in hl_im.mapping().iter().enumerate() { + assert_eq!(i, *x); + } + // remove outputs that used for prevent optimization + let rd_c0 = r_dest_opt.circuits.get_mut(&0).unwrap(); + assert_eq!(rd_c0.outputs.len(), n_in + n_out); + rd_c0.outputs = vec![]; + // compile step 3 + let (lc, dest_im) = crate::layering::compile( + &r_dest_opt, + crate::layering::CompileOptions { is_zkcuda: true }, + ); + for (i, x) in dest_im.mapping().iter().enumerate() { + if i < num_inputs_with_hint { + assert_eq!(i, *x); + } else { + assert_eq!(*x, EMPTY); + } + } + let lc = crate::compile::compile_step_3(lc)?; + // compile step 4 + let mut tmp_im = InputMapping::new_identity(r_hint_exported.input_size()); + let mut r_hint_exported_opt = crate::compile::compile_step_4(r_hint_exported, &mut tmp_im)?; + for (i, x) in tmp_im.mapping().iter().enumerate() { + assert_eq!(i, *x); + } + let re_c0 = r_hint_exported_opt.circuits.get_mut(&0).unwrap(); + re_c0.outputs.truncate(off1); + let hint_size = re_c0.outputs.len() - n_in - n_out; + let hint_io = if hint_size > 0 { + Some(IOVecOffset { + len: hint_size, + input_offset: Some(n_in + n_out), + output_offset: None, + witness_solver_input_offset: None, + witness_solver_output_offset: Some(n_in + n_out), + }) + } else { + None + }; + + Ok(Kernel { + witness_solver: r_hint_exported_opt, + layered_circuit: lc, + io: io_off, + hint_input: hint_io, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn example_kernel_1(api: &mut API, a: &mut Vec>) { + let x = a[1][1]; + a[0][0] = x; + a[1][2] = api.add(x, 1); + } +} diff --git a/expander_compiler/src/zkcuda/mod.rs b/expander_compiler/src/zkcuda/mod.rs new file mode 100644 index 0000000..55c0966 --- /dev/null +++ b/expander_compiler/src/zkcuda/mod.rs @@ -0,0 +1,2 @@ +pub mod kernel; +pub mod traits; diff --git a/expander_compiler/src/zkcuda/traits.rs b/expander_compiler/src/zkcuda/traits.rs new file mode 100644 index 0000000..5e14435 --- /dev/null +++ b/expander_compiler/src/zkcuda/traits.rs @@ -0,0 +1,15 @@ +use crate::circuit::config::Config; + +use super::kernel::Kernel; + +pub trait Commitment {} + +pub trait Proof {} + +pub trait ProvingSystem { + type Proof: Proof; + type Commitment: Commitment; + fn commit(vals: &[C::CircuitField]) -> Self::Commitment; + fn prove(kernel: &Kernel, commitments: &[Self::Commitment]) -> Self::Proof; + fn verify(kernel: &Kernel, proof: Self::Proof, commitments: &[Self::Commitment]) -> bool; +}