Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pilopt: equal-constrained witness columns removal #2086

Closed
wants to merge 57 commits into from
Closed
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
c9606ed
Opt: equal-constrained witness columns removal
gzanitti Nov 13, 2024
567e7f1
avoid removing columns when next is true
gzanitti Nov 13, 2024
f1a1f6e
some reparse::* failing
gzanitti Nov 13, 2024
761dc48
remove removed
gzanitti Nov 14, 2024
0df29b4
code simplification
gzanitti Nov 14, 2024
e0f7b18
N::Y = N::Y not being caught
gzanitti Nov 14, 2024
b4e6bd5
Merge remote-tracking branch 'upstream/main' into equal_constrained_opt
gzanitti Nov 14, 2024
86f9f2c
Merge remote-tracking branch 'upstream/main' into equal_constrained_opt
gzanitti Nov 14, 2024
fde0d6b
removed A = A case from identities
gzanitti Nov 14, 2024
684b979
preserve order
gzanitti Nov 14, 2024
523a0d8
remove after simplification
gzanitti Nov 14, 2024
6efa22d
always left
gzanitti Nov 14, 2024
99cb450
Merge remote-tracking branch 'upstream/main' into equal_constrained_opt
gzanitti Nov 15, 2024
b02522a
Code improvements & test fix
gzanitti Nov 15, 2024
278ed35
remove_unreferenced_definitions before remove_trivial_identities to l…
gzanitti Nov 15, 2024
b1c56e4
- set_hint
gzanitti Nov 15, 2024
ede7992
checking and updating defs
gzanitti Nov 19, 2024
3c378a3
Docs, msgs, unused imports, etc
gzanitti Nov 19, 2024
a3dd4be
Merge remote-tracking branch 'upstream/main' into equal_constrained_opt
gzanitti Nov 19, 2024
d8cbf75
set_hint?
gzanitti Nov 19, 2024
19aca28
avoid get values from *_const in RISCV when pilopt removed them
gzanitti Nov 20, 2024
a4f6a4e
contains_key
gzanitti Nov 20, 2024
906a75b
fix output
gzanitti Nov 20, 2024
af1f679
change set_hint() test to avoid opt
gzanitti Nov 20, 2024
be6f87b
minor fix
gzanitti Nov 20, 2024
a8f0864
empty spaces & minor details
gzanitti Nov 20, 2024
83a8b88
namespace & try_get_col
gzanitti Nov 20, 2024
2622366
avoid double call
gzanitti Nov 20, 2024
239d573
more tests
gzanitti Nov 21, 2024
dfc9424
avoid array elem
gzanitti Nov 21, 2024
dde999a
names
gzanitti Nov 21, 2024
b4f485e
Function simplification, new doc, etc
gzanitti Nov 25, 2024
c53022c
Minor issue after merge main fixed
gzanitti Nov 29, 2024
b9cbcbe
only update definitions
gzanitti Nov 29, 2024
d7e6875
Merge remote-tracking branch 'upstream/main' into equal_constrained_opt
gzanitti Nov 29, 2024
39dba03
Merge remote-tracking branch 'upstream/main' into equal_constrained_opt
gzanitti Nov 29, 2024
32e74ef
back
gzanitti Nov 29, 2024
ba335e8
missed println
gzanitti Nov 29, 2024
92b21bf
Merge remote-tracking branch 'upstream/main' into equal_constrained_opt
gzanitti Dec 2, 2024
a86bd85
test estark
gzanitti Dec 2, 2024
5d070be
lint
gzanitti Dec 2, 2024
ecde331
removed from lookups
gzanitti Dec 2, 2024
8019fc4
cols in lookup & better code
gzanitti Dec 2, 2024
a213672
Merge remote-tracking branch 'upstream/main' into equal_constrained_opt
gzanitti Dec 2, 2024
0b3e774
toopt, removed
gzanitti Dec 2, 2024
79d3d80
build
gzanitti Dec 2, 2024
3dfeec2
Merge remote-tracking branch 'upstream/main' into equal_constrained_opt
gzanitti Dec 3, 2024
e8629ca
Merge remote-tracking branch 'upstream/main' into equal_constrained_opt
gzanitti Dec 4, 2024
deeca6b
Code improvements. test still failing
gzanitti Dec 4, 2024
df5b7be
fix program columns to witness columns
gzanitti Dec 5, 2024
c859b60
Update pilopt/src/lib.rs
gzanitti Dec 10, 2024
42c08b2
Update pilopt/src/lib.rs
gzanitti Dec 10, 2024
934f668
minor changes
gzanitti Dec 10, 2024
1fd819f
main + more improvements
gzanitti Dec 10, 2024
7dd3a12
transitive substitutions
gzanitti Dec 10, 2024
c01dca9
equal_constrained_transitive
gzanitti Dec 10, 2024
926eedd
Merge remote-tracking branch 'upstream/main' into equal_constrained_opt
gzanitti Dec 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 135 additions & 23 deletions pilopt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub fn optimize<T: FieldElement>(mut pil_file: Analyzed<T>) -> Analyzed<T> {
extract_constant_lookups(&mut pil_file);
remove_constant_witness_columns(&mut pil_file);
simplify_identities(&mut pil_file);
remove_equal_constrained_witness_columns(&mut pil_file);
remove_trivial_identities(&mut pil_file);
remove_duplicate_identities(&mut pil_file);
remove_unreferenced_definitions(&mut pil_file);
Expand Down Expand Up @@ -85,7 +86,7 @@ fn remove_unreferenced_definitions<T: FieldElement>(pil_file: &mut Analyzed<T>)
Box::new(value.iter().flat_map(|v| {
v.all_children().flat_map(|e| {
if let AlgebraicExpression::Reference(AlgebraicReference { poly_id, .. }) = e {
Some(poly_id_to_definition_name[poly_id].into())
Some(poly_id_to_definition_name[poly_id].0.into())
} else {
None
}
Expand Down Expand Up @@ -120,31 +121,47 @@ fn remove_unreferenced_definitions<T: FieldElement>(pil_file: &mut Analyzed<T>)
pil_file.remove_trait_impls(&impls_to_remove);
}

/// Builds a lookup-table that can be used to turn array elements
/// (in form of their poly ids) into the names of the arrays.
/// Builds a lookup-table that can be used to turn all symbols
/// (including array elements) in the form of their poly ids, into the names of the symbols.
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
/// The boolean flag indicates whether the symbol belongs to an array or not
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
fn build_poly_id_to_definition_name_lookup(
pil_file: &Analyzed<impl FieldElement>,
) -> BTreeMap<PolyID, &String> {
let mut poly_id_to_definition_name = BTreeMap::new();
) -> BTreeMap<PolyID, (&String, bool)> {
let mut poly_id_to_info = BTreeMap::new();

for (name, (symbol, _)) in &pil_file.definitions {
if matches!(symbol.kind, SymbolKind::Poly(_)) {
symbol.array_elements().for_each(|(_, id)| {
poly_id_to_definition_name.insert(id, name);
});
if symbol.is_array() {
symbol.array_elements().for_each(|(_, id)| {
poly_id_to_info.insert(id, (name, true));
});
} else {
symbol.array_elements().for_each(|(_, id)| {
poly_id_to_info.insert(id, (name, false));
});
}
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
}
}

for (name, (symbol, _)) in &pil_file.intermediate_columns {
symbol.array_elements().for_each(|(_, id)| {
poly_id_to_definition_name.insert(id, name);
});
if symbol.is_array() {
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
symbol.array_elements().for_each(|(_, id)| {
poly_id_to_info.insert(id, (name, true));
});
} else {
symbol.array_elements().for_each(|(_, id)| {
poly_id_to_info.insert(id, (name, false));
});
}
}
poly_id_to_definition_name

poly_id_to_info
}

/// Collect all names that are referenced in identities and public declarations.
fn collect_required_symbols<'a, T: FieldElement>(
pil_file: &'a Analyzed<T>,
poly_id_to_definition_name: &BTreeMap<PolyID, &'a String>,
poly_id_to_definition_name: &BTreeMap<PolyID, (&'a String, bool)>,
) -> HashSet<SymbolReference<'a>> {
let mut required_names: HashSet<SymbolReference<'a>> = Default::default();
required_names.extend(
Expand All @@ -163,7 +180,7 @@ fn collect_required_symbols<'a, T: FieldElement>(
for id in &pil_file.identities {
id.pre_visit_expressions(&mut |e: &AlgebraicExpression<T>| {
if let AlgebraicExpression::Reference(AlgebraicReference { poly_id, .. }) = e {
required_names.insert(poly_id_to_definition_name[poly_id].into());
required_names.insert(poly_id_to_definition_name[poly_id].0.into());
}
});
}
Expand Down Expand Up @@ -246,7 +263,6 @@ fn deduplicate_fixed_columns<T: FieldElement>(pil_file: &mut Analyzed<T>) {
.unzip();

// substitute all occurences in expressions.

pil_file.post_visit_expressions_in_identities_mut(&mut |e| {
if let AlgebraicExpression::Reference(r) = e {
if let Some((new_name, new_id)) = replacement_by_id.get(&r.poly_id) {
Expand Down Expand Up @@ -531,23 +547,44 @@ fn constrained_to_constant<T: FieldElement>(
None
}

/// Removes identities that evaluate to zero and lookups with empty columns.
/// Removes identities that evaluate to zero (including constraints of the form "X = X") and lookups with empty columns.
fn remove_trivial_identities<T: FieldElement>(pil_file: &mut Analyzed<T>) {
let to_remove = pil_file
.identities
.iter()
.enumerate()
.filter_map(|(index, identity)| match identity {
Identity::Polynomial(PolynomialIdentity { expression, .. }) => {
if let AlgebraicExpression::Number(n) = expression {
Identity::Polynomial(PolynomialIdentity { expression, .. }) => match expression {
AlgebraicExpression::Number(n) => {
if *n == 0.into() {
return Some(index);
Some(index)
} else {
// Otherwise the constraint is not satisfiable,
// but better to get the error elsewhere.
None
}
// Otherwise the constraint is not satisfiable,
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
// but better to get the error elsewhere.
}
None
}
AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation {
left,
op: AlgebraicBinaryOperator::Sub,
right,
}) => {
if let (
AlgebraicExpression::Reference(left),
AlgebraicExpression::Reference(right),
) = (left.as_ref(), right.as_ref())
{
if left.is_witness() && right.is_witness() && left == right {
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
Some(index)
} else {
None
}
} else {
None
}
}
_ => None,
},
Identity::Lookup(LookupIdentity { left, right, .. })
| Identity::Permutation(PermutationIdentity { left, right, .. })
| Identity::PhantomLookup(PhantomLookupIdentity { left, right, .. })
Expand Down Expand Up @@ -671,3 +708,78 @@ fn remove_duplicate_identities<T: FieldElement>(pil_file: &mut Analyzed<T>) {
.collect();
pil_file.remove_identities(&to_remove);
}

/// Identifies witness columns that are directly constrained to be equal to other witness columns
/// through polynomial identities of the form "x = y" and returns a tuple ((name, id), (name, id))
/// for each pair of identified columns
fn equal_constrained<T: FieldElement>(
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
expression: &AlgebraicExpression<T>,
poly_id_to_array_elem: &BTreeMap<PolyID, (&String, bool)>,
) -> Option<((String, PolyID), (String, PolyID))> {
match expression {
AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation {
left,
op: AlgebraicBinaryOperator::Sub,
right,
}) => match (left.as_ref(), right.as_ref()) {
(AlgebraicExpression::Reference(l), AlgebraicExpression::Reference(r)) => {
let is_valid = |x: &AlgebraicReference| {
x.is_witness()
&& !x.next
&& poly_id_to_array_elem
.get(&x.poly_id)
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
.map_or(false, |&(_, b)| !b)
};

if is_valid(l) && is_valid(r) {
Some(if l.poly_id > r.poly_id {
((l.name.clone(), l.poly_id), (r.name.clone(), r.poly_id))
} else {
((r.name.clone(), r.poly_id), (l.name.clone(), l.poly_id))
})
} else {
None
}
}
_ => None,
},
_ => None,
}
}

fn remove_equal_constrained_witness_columns<T: FieldElement>(pil_file: &mut Analyzed<T>) {
let poly_id_to_array_elem = build_poly_id_to_definition_name_lookup(pil_file);
let substitutions: Vec<_> = pil_file
.identities
.iter()
.filter_map(|id| {
if let Identity::Polynomial(PolynomialIdentity { expression, .. }) = id {
equal_constrained(expression, &poly_id_to_array_elem)
} else {
None
}
})
.collect();

gzanitti marked this conversation as resolved.
Show resolved Hide resolved
let (subs_by_id, subs_by_name): (HashMap<_, _>, HashMap<_, _>) = substitutions
.iter()
.map(|((name, id), to_keep)| ((id, to_keep), (name, to_keep)))
.unzip();

pil_file.post_visit_expressions_in_identities_mut(&mut |e: &mut AlgebraicExpression<_>| {
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
if let AlgebraicExpression::Reference(ref mut reference) = e {
if let Some((replacement_name, replacement_id)) = subs_by_id.get(&reference.poly_id) {
reference.poly_id = *replacement_id;
reference.name = replacement_name.clone();
}
}
});

pil_file.post_visit_expressions_mut(&mut |e: &mut Expression| {
if let Expression::Reference(_, Reference::Poly(reference)) = e {
if let Some((replacement_name, _)) = subs_by_name.get(&reference.name) {
reference.name = replacement_name.clone();
}
}
});
}
56 changes: 53 additions & 3 deletions pilopt/tests/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@ fn replace_fixed() {
"#;
let expectation = r#"namespace N(65536);
col witness X;
col witness Y;
query |i| {
let _: expr = 1_expr;
};
N::X = N::Y;
N::Y = 7 * N::X;
N::X = 7 * N::X;
"#;
let optimized = optimize(analyze_string::<GoldilocksField>(input).unwrap()).to_string();
assert_eq!(optimized, expectation);
Expand Down Expand Up @@ -375,3 +373,55 @@ fn handle_array_references_in_prover_functions() {
let optimized = optimize(analyze_string::<GoldilocksField>(input).unwrap()).to_string();
assert_eq!(optimized, expectation);
}

#[test]
fn equal_constrained_array_elements_empty() {
let input = r#"namespace N(65536);
col witness w[20];
w[4] = w[7];
"#;
let expectation = r#"namespace N(65536);
col witness w[20];
N::w[4] = N::w[7];
"#;
let optimized = optimize(analyze_string::<GoldilocksField>(input).unwrap()).to_string();
assert_eq!(optimized, expectation);
}

#[test]
fn equal_constrained_array_elements_query() {
let input = r#"namespace N(65536);
col witness w[20];
w[4] = w[7];
query |i| {
let _ = w[4] + w[7] - w[5];
};
"#;
let expectation = r#"namespace N(65536);
col witness w[20];
N::w[4] = N::w[7];
query |i| {
let _: expr = N::w[4_int] + N::w[7_int] - N::w[5_int];
};
"#;
let optimized = optimize(analyze_string::<GoldilocksField>(input).unwrap()).to_string();
assert_eq!(optimized, expectation);
}

#[test]
fn equal_constrained_array_elements() {
let input = r#"namespace N(65536);
col witness w[20];
w[4] = w[7];
w[3] = w[5];
w[7] + w[1] = 5;
"#;
let expectation = r#"namespace N(65536);
col witness w[20];
N::w[4] = N::w[7];
N::w[3] = N::w[5];
N::w[7] + N::w[1] = 5;
"#;
let optimized = optimize(analyze_string::<GoldilocksField>(input).unwrap()).to_string();
assert_eq!(optimized, expectation);
}
Loading