From 4b370d65d0733aa2b80bd5f0efeda17342136238 Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Wed, 4 Dec 2024 11:46:15 +0000 Subject: [PATCH] Update the current_witness_index after optimisation --- .../acvm/src/compiler/transformers/mod.rs | 269 +++++++++++++++++- 1 file changed, 264 insertions(+), 5 deletions(-) diff --git a/acvm-repo/acvm/src/compiler/transformers/mod.rs b/acvm-repo/acvm/src/compiler/transformers/mod.rs index c9ce4ac7895..c33d06d3c00 100644 --- a/acvm-repo/acvm/src/compiler/transformers/mod.rs +++ b/acvm-repo/acvm/src/compiler/transformers/mod.rs @@ -1,5 +1,12 @@ +use std::collections::BTreeSet; + use acir::{ - circuit::{brillig::BrilligOutputs, Circuit, ExpressionWidth, Opcode}, + circuit::{ + self, + brillig::{BrilligInputs, BrilligOutputs}, + opcodes::{BlackBoxFuncCall, FunctionInput, MemOp}, + Circuit, ExpressionWidth, Opcode, + }, native_types::{Expression, Witness}, AcirField, }; @@ -79,8 +86,6 @@ pub(super) fn transform_internal( &mut next_witness_index, ); - // Update next_witness counter - next_witness_index += (intermediate_variables.len() - len) as u32; let mut new_opcodes = Vec::new(); for (g, (norm, w)) in intermediate_variables.iter().skip(len) { // de-normalize @@ -160,13 +165,267 @@ pub(super) fn transform_internal( let mut merge_optimizer = MergeExpressionsOptimizer::new(); let (opcodes, new_acir_opcode_positions) = merge_optimizer.eliminate_intermediate_variable(&acir, new_acir_opcode_positions); - // n.b. we do not update current_witness_index after the eliminate_intermediate_variable pass, the real index could be less. - let acir = Circuit { + + // n.b. if we do not update current_witness_index after the eliminate_intermediate_variable pass, the real index could be less. + let mut acir = Circuit { current_witness_index, expression_width, opcodes, // The optimizer does not add new public inputs ..acir }; + + // After the elimination of intermediate variables the `current_witness_index` is potentially higher than it needs to be, + // which would cause gaps if we ran the optimization a second time, making it look like new variables were added. + // Here we figure out what is the final state of witnesses by visiting each opcode. + let witnesses = WitnessCollector::collect_from_circuit(&acir); + if let Some(max_witness) = witnesses.last() { + acir.current_witness_index = max_witness.0; + } + (acir, new_acir_opcode_positions) } + +/// Collect all witnesses in a circuit. +#[derive(Default, Clone, Debug)] +struct WitnessCollector { + witnesses: BTreeSet, +} + +impl WitnessCollector { + /// Collect all witnesses in a circuit. + fn collect_from_circuit(circuit: &Circuit) -> BTreeSet { + let mut collector = Self::default(); + collector.extend_from_circuit(circuit); + collector.witnesses + } + + fn add(&mut self, witness: Witness) { + self.witnesses.insert(witness); + } + + fn add_many(&mut self, witnesses: &[Witness]) { + self.witnesses.extend(witnesses); + } + + /// Add all witnesses from the circuit. + fn extend_from_circuit(&mut self, circuit: &Circuit) { + self.witnesses.extend(&circuit.private_parameters); + self.witnesses.extend(&circuit.public_parameters.0); + self.witnesses.extend(&circuit.return_values.0); + for opcode in &circuit.opcodes { + self.extend_from_opcode(opcode) + } + } + + /// Add witnesses from the opcode. + fn extend_from_opcode(&mut self, opcode: &Opcode) { + match opcode { + Opcode::AssertZero(expr) => { + self.extend_from_expr(expr); + } + Opcode::BlackBoxFuncCall(call) => self.extend_from_blackbox(call), + Opcode::MemoryOp { block_id: _, op, predicate } => { + let MemOp { operation, index, value } = op; + self.extend_from_expr(operation); + self.extend_from_expr(index); + self.extend_from_expr(value); + if let Some(pred) = predicate { + self.extend_from_expr(pred); + } + } + Opcode::MemoryInit { block_id: _, init, block_type: _ } => { + for w in init { + self.add(*w); + } + } + // We keep the display for a BrilligCall and circuit Call separate as they + // are distinct in their functionality and we should maintain this separation for debugging. + Opcode::BrilligCall { id: _, inputs, outputs, predicate } => { + if let Some(pred) = predicate { + self.extend_from_expr(pred); + } + self.extend_from_brillig_inputs(inputs); + self.extend_from_brillig_outputs(outputs); + } + Opcode::Call { id: _, inputs, outputs, predicate } => { + if let Some(pred) = predicate { + self.extend_from_expr(pred); + } + self.add_many(&inputs); + self.add_many(&outputs); + } + } + } + + fn extend_from_expr(&mut self, expr: &Expression) { + for i in &expr.mul_terms { + self.add(i.1); + self.add(i.2); + } + for i in &expr.linear_combinations { + self.add(i.1); + } + } + + fn extend_from_brillig_inputs(&mut self, inputs: &[BrilligInputs]) { + for input in inputs { + match input { + BrilligInputs::Single(expr) => { + self.extend_from_expr(expr); + } + BrilligInputs::Array(exprs) => { + for expr in exprs { + self.extend_from_expr(expr); + } + } + BrilligInputs::MemoryArray(_) => {} + } + } + } + + fn extend_from_brillig_outputs(&mut self, outputs: &[BrilligOutputs]) { + for output in outputs { + match output { + BrilligOutputs::Simple(w) => { + self.add(*w); + } + BrilligOutputs::Array(ws) => self.add_many(ws), + } + } + } + + fn extend_from_blackbox(&mut self, call: &BlackBoxFuncCall) { + match call { + BlackBoxFuncCall::AES128Encrypt { inputs, iv, key, outputs } => { + self.extend_from_function_inputs(inputs.as_slice()); + self.extend_from_function_inputs(iv.as_slice()); + self.extend_from_function_inputs(key.as_slice()); + self.add_many(outputs); + } + BlackBoxFuncCall::AND { lhs, rhs, output } => { + self.extend_from_function_input(lhs); + self.extend_from_function_input(rhs); + self.add(*output); + } + BlackBoxFuncCall::XOR { lhs, rhs, output } => { + self.extend_from_function_input(lhs); + self.extend_from_function_input(rhs); + self.add(*output); + } + BlackBoxFuncCall::RANGE { input } => { + self.extend_from_function_input(input); + } + BlackBoxFuncCall::Blake2s { inputs, outputs } => { + self.extend_from_function_inputs(inputs.as_slice()); + self.add_many(outputs.as_slice()); + } + BlackBoxFuncCall::Blake3 { inputs, outputs } => { + self.extend_from_function_inputs(inputs.as_slice()); + self.add_many(outputs.as_slice()); + } + BlackBoxFuncCall::SchnorrVerify { + public_key_x, + public_key_y, + signature, + message, + output, + } => { + self.extend_from_function_input(public_key_x); + self.extend_from_function_input(public_key_y); + self.extend_from_function_inputs(signature.as_slice()); + self.extend_from_function_inputs(message.as_slice()); + self.add(*output); + } + BlackBoxFuncCall::EcdsaSecp256k1 { + public_key_x, + public_key_y, + signature, + hashed_message, + output, + } => { + self.extend_from_function_inputs(public_key_x.as_slice()); + self.extend_from_function_inputs(public_key_y.as_slice()); + self.extend_from_function_inputs(signature.as_slice()); + self.extend_from_function_inputs(hashed_message.as_slice()); + self.add(*output); + } + BlackBoxFuncCall::EcdsaSecp256r1 { + public_key_x, + public_key_y, + signature, + hashed_message, + output, + } => { + self.extend_from_function_inputs(public_key_x.as_slice()); + self.extend_from_function_inputs(public_key_y.as_slice()); + self.extend_from_function_inputs(signature.as_slice()); + self.extend_from_function_inputs(hashed_message.as_slice()); + self.add(*output); + } + BlackBoxFuncCall::MultiScalarMul { points, scalars, outputs } => { + self.extend_from_function_inputs(points.as_slice()); + self.extend_from_function_inputs(scalars.as_slice()); + let (x, y, i) = outputs; + self.add(*x); + self.add(*y); + self.add(*i); + } + BlackBoxFuncCall::EmbeddedCurveAdd { input1, input2, outputs } => { + self.extend_from_function_inputs(input1.as_slice()); + self.extend_from_function_inputs(input2.as_slice()); + let (x, y, i) = outputs; + self.add(*x); + self.add(*y); + self.add(*i); + } + BlackBoxFuncCall::Keccakf1600 { inputs, outputs } => { + self.extend_from_function_inputs(inputs.as_slice()); + self.add_many(outputs.as_slice()); + } + BlackBoxFuncCall::RecursiveAggregation { + verification_key, + proof, + public_inputs, + key_hash, + proof_type: _, + } => { + self.extend_from_function_inputs(verification_key.as_slice()); + self.extend_from_function_inputs(proof.as_slice()); + self.extend_from_function_inputs(public_inputs.as_slice()); + self.extend_from_function_input(key_hash); + } + BlackBoxFuncCall::BigIntAdd { .. } + | BlackBoxFuncCall::BigIntSub { .. } + | BlackBoxFuncCall::BigIntMul { .. } + | BlackBoxFuncCall::BigIntDiv { .. } => {} + BlackBoxFuncCall::BigIntFromLeBytes { inputs, modulus: _, output: _ } => { + self.extend_from_function_inputs(inputs.as_slice()); + } + BlackBoxFuncCall::BigIntToLeBytes { input: _, outputs } => { + self.add_many(outputs.as_slice()); + } + BlackBoxFuncCall::Poseidon2Permutation { inputs, outputs, len: _ } => { + self.extend_from_function_inputs(inputs.as_slice()); + self.add_many(outputs.as_slice()); + } + BlackBoxFuncCall::Sha256Compression { inputs, hash_values, outputs } => { + self.extend_from_function_inputs(inputs.as_slice()); + self.extend_from_function_inputs(hash_values.as_slice()); + self.add_many(outputs.as_slice()); + } + } + } + + fn extend_from_function_input(&mut self, input: &FunctionInput) { + if let circuit::opcodes::ConstantOrWitnessEnum::Witness(witness) = input.input() { + self.add(witness); + } + } + + fn extend_from_function_inputs(&mut self, inputs: &[FunctionInput]) { + for input in inputs { + self.extend_from_function_input(input); + } + } +}