diff --git a/pilopt/src/lib.rs b/pilopt/src/lib.rs index 5db5a81aaf..3775dfc073 100644 --- a/pilopt/src/lib.rs +++ b/pilopt/src/lib.rs @@ -29,6 +29,7 @@ pub fn optimize(mut pil_file: Analyzed) -> Analyzed { extract_constant_lookups(&mut pil_file); remove_constant_witness_columns(&mut pil_file); simplify_identities(&mut pil_file); + remove_equal_constrained_witness_columns(&mut pil_file); remove_trivial_identities(&mut pil_file); remove_duplicate_identities(&mut pil_file); remove_unreferenced_definitions(&mut pil_file); @@ -85,7 +86,7 @@ fn remove_unreferenced_definitions(pil_file: &mut Analyzed) Box::new(value.iter().flat_map(|v| { v.all_children().flat_map(|e| { if let AlgebraicExpression::Reference(AlgebraicReference { poly_id, .. }) = e { - Some(poly_id_to_definition_name[poly_id].into()) + Some(poly_id_to_definition_name[poly_id].0.into()) } else { None } @@ -120,33 +121,43 @@ fn remove_unreferenced_definitions(pil_file: &mut Analyzed) pil_file.remove_trait_impls(&impls_to_remove); } -/// Builds a lookup-table that can be used to turn array elements -/// (in form of their poly ids) into the names of the arrays. +/// Builds a lookup-table that can be used to turn all poly ids into the names of the symbols that define them. +/// For array elements, this contains the array name and the index of the element in the array. fn build_poly_id_to_definition_name_lookup( pil_file: &Analyzed, -) -> BTreeMap { +) -> BTreeMap)> { let mut poly_id_to_definition_name = BTreeMap::new(); #[allow(clippy::iter_over_hash_type)] for (name, (symbol, _)) in &pil_file.definitions { if matches!(symbol.kind, SymbolKind::Poly(_)) { - symbol.array_elements().for_each(|(_, id)| { - poly_id_to_definition_name.insert(id, name); - }); + symbol + .array_elements() + .enumerate() + .for_each(|(idx, (_, id))| { + let array_pos = if symbol.is_array() { Some(idx) } else { None }; + poly_id_to_definition_name.insert(id, (name, array_pos)); + }); } } + #[allow(clippy::iter_over_hash_type)] for (name, (symbol, _)) in &pil_file.intermediate_columns { - symbol.array_elements().for_each(|(_, id)| { - poly_id_to_definition_name.insert(id, name); - }); + symbol + .array_elements() + .enumerate() + .for_each(|(idx, (_, id))| { + let array_pos = if symbol.is_array() { Some(idx) } else { None }; + poly_id_to_definition_name.insert(id, (name, array_pos)); + }); } + poly_id_to_definition_name } /// Collect all names that are referenced in identities and public declarations. fn collect_required_symbols<'a, T: FieldElement>( pil_file: &'a Analyzed, - poly_id_to_definition_name: &BTreeMap, + poly_id_to_definition_name: &BTreeMap)>, ) -> HashSet> { let mut required_names: HashSet> = Default::default(); required_names.extend( @@ -165,7 +176,7 @@ fn collect_required_symbols<'a, T: FieldElement>( for id in &pil_file.identities { id.pre_visit_expressions(&mut |e: &AlgebraicExpression| { if let AlgebraicExpression::Reference(AlgebraicReference { poly_id, .. }) = e { - required_names.insert(poly_id_to_definition_name[poly_id].into()); + required_names.insert(poly_id_to_definition_name[poly_id].0.into()); } }); } @@ -248,7 +259,6 @@ fn deduplicate_fixed_columns(pil_file: &mut Analyzed) { .unzip(); // substitute all occurences in expressions. - pil_file.post_visit_expressions_in_identities_mut(&mut |e| { if let AlgebraicExpression::Reference(r) = e { if let Some((new_name, new_id)) = replacement_by_id.get(&r.poly_id) { @@ -533,23 +543,44 @@ fn constrained_to_constant( None } -/// Removes identities that evaluate to zero and lookups with empty columns. +/// Removes identities that evaluate to zero (including constraints of the form "X = X") and lookups with empty columns. fn remove_trivial_identities(pil_file: &mut Analyzed) { let to_remove = pil_file .identities .iter() .enumerate() .filter_map(|(index, identity)| match identity { - Identity::Polynomial(PolynomialIdentity { expression, .. }) => { - if let AlgebraicExpression::Number(n) = expression { + Identity::Polynomial(PolynomialIdentity { expression, .. }) => match expression { + AlgebraicExpression::Number(n) => { if *n == 0.into() { - return Some(index); + Some(index) + } else { + // Otherwise the constraint is not satisfiable, + // but better to get the error elsewhere. + None } - // Otherwise the constraint is not satisfiable, - // but better to get the error elsewhere. } - None - } + AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { + left, + op: AlgebraicBinaryOperator::Sub, + right, + }) => { + if let ( + AlgebraicExpression::Reference(left), + AlgebraicExpression::Reference(right), + ) = (left.as_ref(), right.as_ref()) + { + if !left.next && !right.next && left == right { + Some(index) + } else { + None + } + } else { + None + } + } + _ => None, + }, Identity::Lookup(LookupIdentity { left, right, .. }) | Identity::Permutation(PermutationIdentity { left, right, .. }) | Identity::PhantomLookup(PhantomLookupIdentity { left, right, .. }) @@ -688,3 +719,103 @@ fn remove_duplicate_identities(pil_file: &mut Analyzed) { .collect(); pil_file.remove_identities(&to_remove); } + +/// Identifies witness columns that are directly constrained to be equal to other witness columns +/// through polynomial identities of the form "x = y" and returns a tuple ((name, id), (name, id)) +/// for each pair of identified columns +fn equal_constrained( + expression: &AlgebraicExpression, + poly_id_to_array_elem: &BTreeMap)>, +) -> Option<((String, PolyID), (String, PolyID))> { + match expression { + AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { + left, + op: AlgebraicBinaryOperator::Sub, + right, + }) => match (left.as_ref(), right.as_ref()) { + (AlgebraicExpression::Reference(l), AlgebraicExpression::Reference(r)) => { + let is_valid = |x: &AlgebraicReference| { + x.is_witness() + && !x.next + && poly_id_to_array_elem.get(&x.poly_id).unwrap().1.is_none() + }; + + if is_valid(l) && is_valid(r) { + Some(if l.poly_id > r.poly_id { + ((l.name.clone(), l.poly_id), (r.name.clone(), r.poly_id)) + } else { + ((r.name.clone(), r.poly_id), (l.name.clone(), l.poly_id)) + }) + } else { + None + } + } + _ => None, + }, + _ => None, + } +} + +fn remove_equal_constrained_witness_columns(pil_file: &mut Analyzed) { + let poly_id_to_array_elem = build_poly_id_to_definition_name_lookup(pil_file); + let substitutions: Vec<_> = pil_file + .identities + .iter() + .filter_map(|id| { + if let Identity::Polynomial(PolynomialIdentity { expression, .. }) = id { + equal_constrained(expression, &poly_id_to_array_elem) + } else { + None + } + }) + .collect(); + + let substitutions = resolve_transitive_substitutions(substitutions); + + let (subs_by_id, subs_by_name): (HashMap<_, _>, HashMap<_, _>) = substitutions + .iter() + .map(|((name, id), to_keep)| ((id, to_keep), (name, to_keep))) + .unzip(); + + pil_file.post_visit_expressions_in_identities_mut(&mut |e: &mut AlgebraicExpression<_>| { + if let AlgebraicExpression::Reference(ref mut reference) = e { + if let Some((replacement_name, replacement_id)) = subs_by_id.get(&reference.poly_id) { + reference.poly_id = *replacement_id; + reference.name = replacement_name.clone(); + } + } + }); + + pil_file.post_visit_expressions_mut(&mut |e: &mut Expression| { + if let Expression::Reference(_, Reference::Poly(reference)) = e { + if let Some((replacement_name, _)) = subs_by_name.get(&reference.name) { + reference.name = replacement_name.clone(); + } + } + }); +} + +fn resolve_transitive_substitutions( + subs: Vec<((String, PolyID), (String, PolyID))>, +) -> Vec<((String, PolyID), (String, PolyID))> { + let mut result = subs.clone(); + let mut changed = true; + + while changed { + changed = false; + for i in 0..result.len() { + let (_, target1) = &result[i].1; + if let Some(j) = result + .iter() + .position(|((_, source2), _)| source2 == target1) + { + let ((name1, source1), _) = &result[i]; + let (_, (name3, target2)) = &result[j]; + result[i] = ((name1.clone(), *source1), (name3.clone(), *target2)); + changed = true; + } + } + } + + result +} diff --git a/pilopt/tests/optimizer.rs b/pilopt/tests/optimizer.rs index 65ba80f4c4..731aeb5746 100644 --- a/pilopt/tests/optimizer.rs +++ b/pilopt/tests/optimizer.rs @@ -19,12 +19,10 @@ fn replace_fixed() { "#; let expectation = r#"namespace N(65536); col witness X; - col witness Y; query |i| { let _: expr = 1_expr; }; - N::X = N::Y; - N::Y = 7 * N::X; + N::X = 7 * N::X; "#; let optimized = optimize(analyze_string::(input).unwrap()).to_string(); assert_eq!(optimized, expectation); @@ -102,11 +100,7 @@ fn intermediate() { col intermediate = x; intermediate = intermediate; "#; - let expectation = r#"namespace N(65536); - col witness x; - col intermediate = N::x; - N::intermediate = N::intermediate; -"#; + let expectation = r#""#; let optimized = optimize(analyze_string::(input).unwrap()).to_string(); assert_eq!(optimized, expectation); } @@ -375,3 +369,73 @@ fn handle_array_references_in_prover_functions() { let optimized = optimize(analyze_string::(input).unwrap()).to_string(); assert_eq!(optimized, expectation); } + +#[test] +fn equal_constrained_array_elements_empty() { + let input = r#"namespace N(65536); + col witness w[20]; + w[4] = w[7]; + "#; + let expectation = r#"namespace N(65536); + col witness w[20]; + N::w[4] = N::w[7]; +"#; + let optimized = optimize(analyze_string::(input).unwrap()).to_string(); + assert_eq!(optimized, expectation); +} + +#[test] +fn equal_constrained_array_elements_query() { + let input = r#"namespace N(65536); + col witness w[20]; + w[4] = w[7]; + query |i| { + let _ = w[4] + w[7] - w[5]; + }; + "#; + let expectation = r#"namespace N(65536); + col witness w[20]; + N::w[4] = N::w[7]; + query |i| { + let _: expr = N::w[4_int] + N::w[7_int] - N::w[5_int]; + }; +"#; + let optimized = optimize(analyze_string::(input).unwrap()).to_string(); + assert_eq!(optimized, expectation); +} + +#[test] +fn equal_constrained_array_elements() { + let input = r#"namespace N(65536); + col witness w[20]; + w[4] = w[7]; + w[3] = w[5]; + w[7] + w[1] = 5; + "#; + let expectation = r#"namespace N(65536); + col witness w[20]; + N::w[4] = N::w[7]; + N::w[3] = N::w[5]; + N::w[7] + N::w[1] = 5; +"#; + let optimized = optimize(analyze_string::(input).unwrap()).to_string(); + assert_eq!(optimized, expectation); +} + +#[test] +fn equal_constrained_transitive() { + let input = r#"namespace N(65536); + col witness a; + col witness b; + col witness c; + a = b; + b = c; + a + b + c = 5; + "#; + let expectation = r#"namespace N(65536); + col witness a; + N::a + N::a + N::a = 5; +"#; + let optimized = optimize(analyze_string::(input).unwrap()).to_string(); + assert_eq!(optimized, expectation); +} diff --git a/pipeline/tests/asm.rs b/pipeline/tests/asm.rs index 99ecf47326..c9db9b3d3d 100644 --- a/pipeline/tests/asm.rs +++ b/pipeline/tests/asm.rs @@ -366,12 +366,6 @@ fn full_pil_constant() { regular_test_all_fields(f, Default::default()); } -#[test] -fn intermediate() { - let f = "asm/intermediate.asm"; - regular_test_all_fields(f, Default::default()); -} - #[test] fn intermediate_nested() { let f = "asm/intermediate_nested.asm"; diff --git a/riscv-executor/src/lib.rs b/riscv-executor/src/lib.rs index 3a887d2f0d..69a2a94032 100644 --- a/riscv-executor/src/lib.rs +++ b/riscv-executor/src/lib.rs @@ -20,7 +20,7 @@ use builder::TraceBuilder; use itertools::Itertools; use powdr_ast::{ - analyzed::Analyzed, + analyzed::{AlgebraicExpression, Analyzed, Identity, LookupIdentity}, asm_analysis::{AnalysisASMFile, CallableSymbol, FunctionStatement, LabelStatement, Machine}, parsed::{ asm::{parse_absolute_path, AssignmentRegister, DebugDirective}, @@ -659,6 +659,18 @@ mod builder { } } + pub fn try_get_col(&self, name: &str) -> Option> { + if let ExecMode::Trace = self.mode { + self.trace + .cols + .get(name) + .and_then(|col| col.last()) + .copied() + } else { + None + } + } + pub fn push_row(&mut self) { if let ExecMode::Trace = self.mode { self.trace @@ -1079,10 +1091,21 @@ impl<'a, 'b, F: FieldElement> Executor<'a, 'b, F> { self.proc.backup_reg_mem(); - set_col!(X, get_col!(X_const)); - set_col!(Y, get_col!(Y_const)); - set_col!(Z, get_col!(Z_const)); - set_col!(W, get_col!(W_const)); + if let Some(x_const) = self.proc.try_get_col("main::X_const") { + set_col!(X, x_const); + } + + if let Some(y_const) = self.proc.try_get_col("main::Y_const") { + set_col!(Y, y_const); + } + + if let Some(z_const) = self.proc.try_get_col("main::Z_const") { + set_col!(Z, z_const); + } + + if let Some(w_const) = self.proc.try_get_col("main::W_const") { + set_col!(W, w_const); + } self.proc .set_col(&format!("main::instr_{name}"), Elem::from_u32_as_fe(1)); @@ -2371,24 +2394,37 @@ fn execute_inner( }) .unwrap_or_default(); - // program columns to witness columns - let program_cols: HashMap<_, _> = if let Some(fixed) = &fixed { - fixed - .iter() - .filter_map(|(name, _col)| { - if !name.starts_with("main__rom::p_") { - return None; - } - let wit_name = format!("main::{}", name.strip_prefix("main__rom::p_").unwrap()); - if !witness_cols.contains(&wit_name) { - return None; - } - Some((name.clone(), wit_name)) - }) - .collect() - } else { - Default::default() - }; + //program columns to witness columns + let program_cols: HashMap<_, _> = opt_pil + .map(|pil| { + pil.identities + .iter() + .flat_map(|id| match id { + Identity::Lookup(LookupIdentity { left, right, .. }) => left + .expressions + .iter() + .zip(right.expressions.iter()) + .filter_map(|(l, r)| match (l, r) { + ( + AlgebraicExpression::Reference(l), + AlgebraicExpression::Reference(r), + ) => { + if r.name.starts_with("main__rom::p_") + && witness_cols.contains(&l.name) + { + Some((r.name.clone(), l.name.clone())) + } else { + None + } + } + _ => None, + }) + .collect::>(), + _ => vec![], + }) + .collect() + }) + .unwrap_or_default(); let proc = match TraceBuilder::<'_, F>::new( main_machine, diff --git a/test_data/asm/intermediate.asm b/test_data/asm/intermediate.asm deleted file mode 100644 index 4584503218..0000000000 --- a/test_data/asm/intermediate.asm +++ /dev/null @@ -1,11 +0,0 @@ -machine Intermediate with - latch: latch, - operation_id: operation_id, - degree: 8 -{ - col fixed latch = [1]*; - col fixed operation_id = [0]*; - col witness x; - col intermediate = x; - intermediate = intermediate; -} diff --git a/test_data/asm/set_hint.asm b/test_data/asm/set_hint.asm index cd25283d73..f52ee78412 100644 --- a/test_data/asm/set_hint.asm +++ b/test_data/asm/set_hint.asm @@ -7,5 +7,5 @@ let new_col_with_hint: -> expr = constr || { machine Main with degree: 4 { let x; let w = new_col_with_hint(); - x = w; + x = w + 1; } \ No newline at end of file