Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
siq1 committed Oct 30, 2024
1 parent 5e81ce3 commit 600c9fd
Show file tree
Hide file tree
Showing 6 changed files with 348 additions and 70 deletions.
13 changes: 13 additions & 0 deletions expander_compiler/src/circuit/ir/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,16 @@ impl<Irc: IrConfig> RootCircuit<Irc> {
Ok((res, cond))
}
}

impl<Irc: IrConfig> Hash for RootCircuit<Irc> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.num_public_inputs.hash(state);
self.expected_num_output_zeroes.hash(state);
let mut keys = self.circuits.keys().collect::<Vec<_>>();
keys.sort();
for k in keys.iter() {
k.hash(state);
self.circuits[k].hash(state);
}
}
}
6 changes: 3 additions & 3 deletions expander_compiler/src/circuit/layered/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,15 @@ pub struct GateCustom<C: Config> {
pub coef: Coef<C>,
}

#[derive(Debug, Clone, PartialOrd, Ord, PartialEq, Eq)]
#[derive(Debug, Clone, Hash, PartialOrd, Ord, PartialEq, Eq)]
pub struct Allocation {
pub input_offset: usize,
pub output_offset: usize,
}

pub type ChildSpec = (usize, Vec<Allocation>);

#[derive(Default, Debug, Clone, PartialOrd, Ord, PartialEq, Eq)]
#[derive(Default, Debug, Hash, Clone, PartialOrd, Ord, PartialEq, Eq)]
pub struct Segment<C: Config> {
pub num_inputs: usize,
pub num_outputs: usize,
Expand All @@ -164,7 +164,7 @@ pub struct Segment<C: Config> {
pub gate_customs: Vec<GateCustom<C>>,
}

#[derive(Debug, Clone, PartialOrd, Ord, PartialEq, Eq)]
#[derive(Debug, Clone, Hash, PartialOrd, Ord, PartialEq, Eq)]
pub struct Circuit<C: Config> {
pub num_public_inputs: usize,
pub num_actual_outputs: usize,
Expand Down
197 changes: 158 additions & 39 deletions expander_compiler/src/zkcuda/context.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::circuit::config::Config;
use crate::field::FieldArith;
use crate::{circuit::config::Config, utils::pool::Pool};

use super::{kernel::Kernel, proving_system::ProvingSystem};

Expand All @@ -10,24 +10,37 @@ pub struct DeviceMemory<C: Config, P: ProvingSystem<C>> {

pub type DeviceMemoryHandle = usize;

pub struct WrappedProof<C: Config, P: ProvingSystem<C>> {
pub proof: P::Proof,
pub kernel_id: usize,
pub commitment_indices: Vec<usize>,
pub parallel_count: usize,
pub is_broadcast: Vec<bool>,
}

pub struct Context<C: Config, P: ProvingSystem<C>> {
pub kernels: Pool<Kernel<C>>,
pub device_memories: Vec<DeviceMemory<C, P>>,
pub proofs: Vec<P::Proof>,
pub proofs: Vec<WrappedProof<C, P>>,
}

pub struct CombinedProof<C: Config, P: ProvingSystem<C>> {
pub kernels: Vec<Kernel<C>>,
pub commitments: Vec<P::Commitment>,
pub proofs: Vec<P::Proof>,
pub proofs: Vec<WrappedProof<C, P>>,
}

impl<C: Config, P: ProvingSystem<C>> Context<C, P> {
pub fn new() -> Self {
Self {
impl<C: Config, P: ProvingSystem<C>> Default for Context<C, P> {
fn default() -> Self {
Context {
kernels: Pool::new(),
device_memories: vec![],
proofs: vec![],
}
}
}

impl<C: Config, P: ProvingSystem<C>> Context<C, P> {
pub fn copy_to_device(&mut self, host_memory: &[C::CircuitField]) -> DeviceMemoryHandle {
self.device_memories.push(DeviceMemory {
values: host_memory.to_vec(),
Expand All @@ -40,65 +53,171 @@ impl<C: Config, P: ProvingSystem<C>> Context<C, P> {
self.device_memories[device_memory_handle].values.clone()
}

pub fn call_kernel(&mut self, kernel: &Kernel<C>, ios: &mut [Option<DeviceMemoryHandle>]) {
let mut ws_inputs = vec![C::CircuitField::zero(); kernel.witness_solver.input_size()];
pub fn call_kernel(
&mut self,
kernel: &Kernel<C>,
ios: &mut [Option<DeviceMemoryHandle>],
parallel_count: usize,
is_broadcast: &[bool],
) {
if kernel.witness_solver_io.len() != ios.len() {
panic!("Invalid number of inputs/outputs");
}
if kernel.witness_solver_io.len() != is_broadcast.len() {
panic!("Invalid number of is_broadcast");
}
for i in 0..ios.len() {
if is_broadcast[i] {
assert!(kernel.witness_solver_io[i].output_offset.is_none());
assert_eq!(
self.device_memories[ios[i].unwrap()].values.len(),
kernel.witness_solver_io[i].len
);
} else if kernel.witness_solver_io[i].input_offset.is_some() {
assert_eq!(
self.device_memories[ios[i].unwrap()].values.len(),
kernel.witness_solver_io[i].len * parallel_count
);
}
}

let kernel_id = self.kernels.add(kernel);

let mut handles = vec![];
for (input, ws_input) in ios.iter().zip(kernel.witness_solver_io.iter()) {
let mut lc_is_broadcast = vec![];
for ((input, ws_input), ib) in ios
.iter()
.zip(kernel.witness_solver_io.iter())
.zip(is_broadcast)
{
assert_eq!(input.is_some(), ws_input.input_offset.is_some());
if input.is_none() {
continue;
if input.is_some() {
handles.push(input.unwrap());
lc_is_broadcast.push(*ib);
}
handles.push(input.unwrap());
let device_memory = &self.device_memories[input.unwrap()];
assert_eq!(device_memory.values.len(), ws_input.len);
let offset = ws_input.input_offset.unwrap();
for (i, x) in device_memory.values.iter().enumerate() {
ws_inputs[offset + i] = *x;
}

let mut output_vecs = vec![vec![]; ios.len()];
let mut hint_output_vec = vec![];

for parallel_i in 0..parallel_count {
let mut ws_inputs = vec![C::CircuitField::zero(); kernel.witness_solver.input_size()];
for (i, (input, ws_input)) in
ios.iter().zip(kernel.witness_solver_io.iter()).enumerate()
{
if input.is_none() {
continue;
}
let device_memory = &self.device_memories[input.unwrap()];
let offset = ws_input.input_offset.unwrap();
if is_broadcast[i] {
for (i, x) in device_memory.values.iter().enumerate() {
ws_inputs[offset + i] = *x;
}
} else {
for (i, x) in device_memory
.values
.iter()
.skip(parallel_i * ws_input.len)
.take(ws_input.len)
.enumerate()
{
ws_inputs[offset + i] = *x;
}
}
}
let ws_outputs = kernel
.witness_solver
.eval_with_public_inputs(ws_inputs, &[])
.unwrap(); // TODO: handle error
for (i, ws_input) in kernel.witness_solver_io.iter().enumerate() {
if ws_input.output_offset.is_none() {
continue;
}
let offset = ws_input.output_offset.unwrap();
let values = &ws_outputs[offset..offset + ws_input.len];
output_vecs[i].extend_from_slice(values);
}
if let Some(hint_io) = &kernel.witness_solver_hint_input {
let values = &ws_outputs
[hint_io.output_offset.unwrap()..hint_io.output_offset.unwrap() + hint_io.len];
hint_output_vec.extend_from_slice(values);
}
}
let ws_outputs = kernel
.witness_solver
.eval_with_public_inputs(ws_inputs, &[])
.unwrap(); // TODO: handle error2
for (output, ws_input) in ios.iter_mut().zip(kernel.witness_solver_io.iter()) {

for ((output, ws_input), ov) in ios
.iter_mut()
.zip(kernel.witness_solver_io.iter())
.zip(output_vecs.into_iter())
{
if ws_input.output_offset.is_none() {
*output = None;
continue;
}
let offset = ws_input.output_offset.unwrap();
let values = ws_outputs[offset..offset + ws_input.len].to_vec();
let commitment = P::commit(&values);
let device_memory = DeviceMemory { values, commitment };
let commitment = P::commit(&ov);
let device_memory = DeviceMemory {
values: ov,
commitment,
};
self.device_memories.push(device_memory);
handles.push(self.device_memories.len() - 1);
*output = Some(self.device_memories.len() - 1);
lc_is_broadcast.push(false);
}
if let Some(hint_io) = &kernel.witness_solver_hint_input {
let values = ws_outputs
[hint_io.output_offset.unwrap()..hint_io.output_offset.unwrap() + hint_io.len]
.to_vec();
let commitment = P::commit(&values);
let device_memory = DeviceMemory { values, commitment };
if kernel.witness_solver_hint_input.is_some() {
let commitment = P::commit(&hint_output_vec);
let device_memory = DeviceMemory {
values: hint_output_vec,
commitment,
};
self.device_memories.push(device_memory);
handles.push(self.device_memories.len() - 1);
lc_is_broadcast.push(false);
}
let commitment_refs: Vec<_> = handles
.iter()
.map(|&x| &self.device_memories[x].commitment)
.collect();
let proof = P::prove(kernel, &commitment_refs);
self.proofs.push(proof);
// TODO: encode commitment id in proof
let proof = P::prove(kernel, &commitment_refs, parallel_count, &lc_is_broadcast);
self.proofs.push(WrappedProof {
proof,
kernel_id,
commitment_indices: handles,
parallel_count,
is_broadcast: lc_is_broadcast,
});
}

pub fn get_proof(&self) -> CombinedProof<C, P> {
pub fn to_proof(self) -> CombinedProof<C, P> {
CombinedProof {
kernels: self.kernels.vec().clone(),
commitments: self
.device_memories
.iter()
.map(|x| x.commitment.clone())
.into_iter()
.map(|x| x.commitment)
.collect(),
proofs: self.proofs.clone(),
proofs: self.proofs,
}
}
}

impl<C: Config, P: ProvingSystem<C>> CombinedProof<C, P> {
pub fn verify(&self) -> bool {
for proof in self.proofs.iter() {
if !P::verify(
&self.kernels[proof.kernel_id],
&proof.proof,
&proof
.commitment_indices
.iter()
.map(|&x| &self.commitments[x])
.collect::<Vec<_>>(),
proof.parallel_count,
&proof.is_broadcast,
) {
return false;
}
}
true
}
}
32 changes: 30 additions & 2 deletions expander_compiler/src/zkcuda/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::circuit::{
use crate::field::FieldArith;
use crate::frontend::*;

#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct Kernel<C: Config> {
pub witness_solver: ir::hint_normalized::RootCircuit<C>,
pub layered_circuit: LayeredCircuit<C>,
Expand All @@ -15,12 +16,14 @@ pub struct Kernel<C: Config> {
pub layered_circuit_input: Vec<LayeredCircuitInputVec>,
}

#[derive(Default, Debug, Clone, Hash, PartialEq, Eq)]
pub struct WitnessSolverIOVec {
pub len: usize,
pub input_offset: Option<usize>,
pub output_offset: Option<usize>,
}

#[derive(Default, Debug, Clone, Hash, PartialEq, Eq)]
pub struct LayeredCircuitInputVec {
pub len: usize,
pub offset: usize,
Expand Down Expand Up @@ -177,8 +180,7 @@ where
}
// remove outputs that used for prevent optimization
let rd_c0 = r_dest_opt.circuits.get_mut(&0).unwrap();
assert_eq!(rd_c0.outputs.len(), n_in + n_out * 2);
rd_c0.outputs = vec![];
rd_c0.outputs.truncate(rd_c0.outputs.len() - n_in - n_out);
// compile step 3
let (lc, dest_im) = crate::layering::compile(
&r_dest_opt,
Expand Down Expand Up @@ -233,4 +235,30 @@ mod tests {
a[0][0] = x;
a[1][2] = api.add(x, 1);
}

#[test]
fn test_1() {
let kernel: Kernel<M31Config> = compile_with_spec(
example_kernel_1,
&[
IOVecSpec {
len: 1,
is_input: true,
is_output: true,
},
IOVecSpec {
len: 3,
is_input: true,
is_output: true,
},
],
)
.unwrap();
println!(
"{} {} {}",
kernel.layered_circuit.num_public_inputs,
kernel.layered_circuit.num_actual_outputs,
kernel.layered_circuit.expected_num_output_zeroes
);
}
}
Loading

0 comments on commit 600c9fd

Please sign in to comment.