diff --git a/Cargo.toml b/Cargo.toml index 6297a23..00342b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,8 @@ [workspace] members = [ "polynomial", - "univariate-polynomial-iop-zerotest" + "univariate-polynomial-iop-zerotest", + "zk-plonky2-permutation-circuit" ] resolver = "2" diff --git a/zk-plonky2-permutation-circuit/Cargo.toml b/zk-plonky2-permutation-circuit/Cargo.toml new file mode 100644 index 0000000..3f30d95 --- /dev/null +++ b/zk-plonky2-permutation-circuit/Cargo.toml @@ -0,0 +1,10 @@ +[package] +edition = "2021" +name = "zk-plonky2-permutation-circuit" +version = "0.1.0" + +[dependencies] +anyhow = "1.0.81" +plonky2 = "0.2.0" +plonky2_field = "0.2.0" +array_tool = "1.0.3" diff --git a/zk-plonky2-permutation-circuit/src/custom_gates/mod.rs b/zk-plonky2-permutation-circuit/src/custom_gates/mod.rs new file mode 100644 index 0000000..5a2a8f4 --- /dev/null +++ b/zk-plonky2-permutation-circuit/src/custom_gates/mod.rs @@ -0,0 +1 @@ +pub mod switch; diff --git a/zk-plonky2-permutation-circuit/src/custom_gates/switch.rs b/zk-plonky2-permutation-circuit/src/custom_gates/switch.rs new file mode 100644 index 0000000..08a538f --- /dev/null +++ b/zk-plonky2-permutation-circuit/src/custom_gates/switch.rs @@ -0,0 +1,266 @@ +use array_tool::vec::Union; +use std::marker::PhantomData; + +use plonky2::{ + gates::gate::Gate, + hash::hash_types::RichField, + iop::{generator::{GeneratedValues, WitnessGenerator, WitnessGeneratorRef}, target::Target, wire::Wire, witness::{Witness, WitnessWrite, PartitionWitness}}, + plonk::circuit_data::CommonCircuitData, + util::serialization::{Buffer, IoResult, Read, Write}, +}; +use plonky2_field::types::Field; +use plonky2_field::extension::Extendable; + +/// A gate for conditionally swapping input values based on a boolean. +#[derive(Copy, Clone, Debug)] +pub struct SwitchGate, const D: usize> { + _phantom: PhantomData, +} + +impl, const D: usize> SwitchGate { + pub fn new() -> Self { + Self{_phantom: PhantomData} + } + + pub fn wire_switch_bool() -> usize { + 0 + } + + pub fn wire_input_1() -> usize { + 1 + } + + pub fn wire_input_2() -> usize { + 2 + } + + pub fn wire_output_1() -> usize { + 3 + } + + pub fn wire_output_2() -> usize { + 4 + } +} + +impl, const D: usize> Gate for SwitchGate { + fn id(&self) -> String { + format!("{self:?}") + } + + fn serialize(&self, dst: &mut Vec, common_data: &CommonCircuitData) -> IoResult<()> { + dst.write_bool(false) // TODO: remove + } + + fn deserialize(src: &mut Buffer, common_data: &CommonCircuitData) -> IoResult + where + Self: Sized, + { + Ok(Self{_phantom: PhantomData}) + } + + fn eval_unfiltered( + &self, + vars: plonky2::plonk::vars::EvaluationVars, + ) -> Vec<>::Extension> { + let mut constraints = Vec::with_capacity(4); + let switch_bool = vars.local_wires[Self::wire_switch_bool()]; + let not_switch = F::Extension::ONE - switch_bool; + + let input_1 = vars.local_wires[Self::wire_input_1()]; + let input_2 = vars.local_wires[Self::wire_input_2()]; + let output_1 = vars.local_wires[Self::wire_output_1()]; + let output_2 = vars.local_wires[Self::wire_output_2()]; + + constraints.push(not_switch * (output_1 - input_1)); + constraints.push(not_switch * (output_2 - input_2)); + + constraints.push(switch_bool * (output_2 - input_1)); + constraints.push(switch_bool * (output_1 - input_2)); + constraints + } + + fn eval_unfiltered_circuit( + &self, + builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + vars: plonky2::plonk::vars::EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(4); + + let one = builder.one_extension(); + + let switch_bool = vars.local_wires[Self::wire_switch_bool()]; + let not_switch = builder.sub_extension(one, switch_bool); + + let input_1 = vars.local_wires[Self::wire_input_1()]; + let input_2 = vars.local_wires[Self::wire_input_2()]; + let output_1 = vars.local_wires[Self::wire_output_1()]; + let output_2 = vars.local_wires[Self::wire_output_2()]; + + constraints + .push(builder.mul_extension(not_switch, builder.sub_extension(input_1, output_1))); + constraints + .push(builder.mul_extension(not_switch, builder.sub_extension(input_2, output_2))); + + constraints + .push(builder.mul_extension(switch_bool, builder.sub_extension(input_1, output_2))); + constraints + .push(builder.mul_extension(switch_bool, builder.sub_extension(input_2, output_1))); + + constraints + } + + fn generators(&self, row: usize, local_constants: &[F]) -> Vec> { + // unimplemented!() + let g = Box::new(SwitchGenerator:: { + row, + gate: *self, + }); + vec![g] + } + + fn num_wires(&self) -> usize { + 5 + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 2 + } + + fn num_constraints(&self) -> usize { + 4 + } +} + +#[derive(Debug)] +struct SwitchGenerator, const D: usize> { + row: usize, + gate: SwitchGate, +} + +impl, const D: usize> SwitchGenerator { + /// List of wire targets for inputs and outputs + fn dependencies_inputs_outputs(&self) -> Vec { + let local_target = |column| Target::wire(self.row, column); + + let mut deps = Vec::new(); + + deps.push(local_target(SwitchGate::wire_first_input())); + deps.push(local_target(SwitchGate::wire_second_input())); + deps.push(local_target(SwitchGate::wire_first_output())); + deps.push(local_target(SwitchGate::wire_second_output())); + + deps + } + + /// List of wire targets for inputs and switching boolean + fn dependencies_switch_inputs(&self) -> Vec { + let local_target = |column| Target::wire(self.row, column); + + let mut deps = Vec::new(); + + deps.push(local_target(SwitchGate::wire_first_input())); + deps.push(local_target(SwitchGate::wire_second_input())); + deps.push(local_target(SwitchGate::wire_switch_bool())); + + deps + } + + /// Run when all input and output wires are present + fn set_switch_wire(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let get_local_wire = |column| { + witness.get_wire(Wire { + row: self.row, + column, + }) + }; + let switch_bool_wire = Wire { + row: self.row, + column: SwitchGate::wire_switch_bool(), + }; + + let mut input_1 = get_local_wire(SwitchGate::wire_input_1()); + let mut input_2 = get_local_wire(SwitchGate::wire_input_2()); + let mut output_1 = get_local_wire(SwitchGate::wire_output_1()); + let mut output_2 = get_local_wire(SwitchGate::wire_output_2()); + + if input_1 == output_1 && input_2 == output_2 { + out_buffer.set_wire(switch_bool_wire, F::ZERO); + } else if input_1 == output_2 && input_2 == output_1 { + out_buffer.set_wire(switch_bool_wire, F::ONE); + } else { + panic!("No permutation from given inputs to given outputs"); + } + } + + /// Run when only inputs and switching boolean is available + fn set_output_wires(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let get_local_wire = |column| { + witness.get_wire(Wire { + row: self.row, + column, + }) + }; + let switch_bool_wire = Wire { + row: self.row, + column: SwitchGate::wire_switch_bool(), + }; + + let mut input_1 = get_local_wire(SwitchGate::wire_input_1()); + let mut input_2 = get_local_wire(SwitchGate::wire_input_2()); + let mut output_1 = get_local_wire(SwitchGate::wire_output_1()); + let mut output_2 = get_local_wire(SwitchGate::wire_output_2()); + + let (expected_output_1, expected_output_2) = if switch_bool_wire == F::ZERO { + (input_1, input_2) + } else if switch_bool_wire == F::ONE { + (input_2, input_1) + } else { + panic!("Invalid switch bool value"); + }; + + out_buffer.set_wire(output_1, expected_output_1); + out_buffer.set_wire(output_2, expected_output_2); + } +} + +impl, const D: usize> WitnessGenerator for SwitchGenerator { + fn id(&self) -> String { + format!("{self:?}") + } + + fn serialize(&self, dst: &mut Vec, common_data: &CommonCircuitData) -> IoResult<()> { + dst.write_usize(self.row) + } + + fn deserialize(src: &mut Buffer, common_data: &CommonCircuitData) -> IoResult + where + Self: Sized { + Ok(Self{row: src.read_bool().unwrap(), gate: SwitchGate::new()}) + } + /// Register the different columns to watch + fn watch_list(&self) -> Vec { + self.dependencies_inputs_outputs() + .union(self.dependencies_switch_inputs()) + } + + /// Figure out which columns change and set the remaining + /// Can work in two modes: + /// 1. If input and switch wires are pre-populated + /// 2. If input and output wires are pre-populated + fn run(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) -> bool { + if witness.contains_all(&self.dependencies_switch_inputs()) { + self.set_output_wires(witness, out_buffer); + true + } else if witness.contains_all(&self.dependencies_inputs_outputs()) { + self.set_switch_wire(witness, out_buffer); + true + } else { + false + } + } +} diff --git a/zk-plonky2-permutation-circuit/src/lib.rs b/zk-plonky2-permutation-circuit/src/lib.rs new file mode 100644 index 0000000..bacdca9 --- /dev/null +++ b/zk-plonky2-permutation-circuit/src/lib.rs @@ -0,0 +1,68 @@ +mod custom_gates; + +use custom_gates::switch::SwitchGate; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::target::Target; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2_field::extension::Extendable; + +// Inspired by https://github.com/0xPolygonZero/plonky2-waksman/blob/main/src/permutation.rs + +/// Assert that two set of targets are permutation of each other +pub fn assert_permutation_circuit, const D: usize>( + builder: &mut CircuitBuilder, + a: Vec, + b: Vec, +) { + assert_eq!( + a.len(), + b.len(), + "Permutation must have same number of inputs and outputs" + ); + + match a.len() { + // Two empty lists are permutations of one another, trivially. + 0 => (), + // Two singleton lists are permutations of one another as long as their items are equal. + 1 => { + builder.connect(a[0], b[0]); + } + 2 => assert_permutation_2x2_circuit(builder, a[0], a[1], b[0], b[1]), + // For larger lists, we recursively use two smaller permutation networks. + _ => unimplemented!(), // assert_permutation_helper_circuit(builder, a, b), + } +} + +/// Assert that [a1, a2] is a permutation of [b1, b2]. +fn assert_permutation_2x2_circuit, const D: usize>( + builder: &mut CircuitBuilder, + a1: Target, + a2: Target, + b1: Target, + b2: Target, +) { + let (_switch, out_1, out_2) = create_switch_circuit(builder, a1, a2); + // Add constraints + builder.connect(b1, out_1); + builder.connect(b2, out_2); +} + +/// Given two input wire chunks, add a new switch to the circuit (by adding one copy to a switch +/// gate). Returns the wire for the switch boolean, and the two output wire chunks. +fn create_switch_circuit, const D: usize>( + builder: &mut CircuitBuilder, + a1: Target, + a2: Target, +) -> (Target, Target, Target) { + let gate = SwitchGate::new(); + let (row, _next_copy) = builder.find_slot(gate, &vec![], &[]); + + builder.connect(a1, Target::wire(row, SwitchGate::wire_input_1())); + builder.connect(a2, Target::wire(row, SwitchGate::wire_input_2())); + + ( + Target::wire(row, SwitchGate::wire_switch_bool()), + Target::wire(row, SwitchGate::wire_output_1()), + Target::wire(row, SwitchGate::wire_output_2()), + ) +}