Skip to content

Commit

Permalink
Opt: memory: linear for [group] const values (#207)
Browse files Browse the repository at this point in the history
For memories with constant values that have sorts which are linear
groups, there is a way to optimize linear-scan memory-checking.

This patch implements that optimization.
  • Loading branch information
foxier25 committed Aug 19, 2024
1 parent 30fd17c commit 7fcad48
Show file tree
Hide file tree
Showing 14 changed files with 189 additions and 27 deletions.
23 changes: 23 additions & 0 deletions examples/ZoKrates/pf/const_linear_lookup.zok
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
struct T {
field v
field w
field x
field y
field z
}

const T[9] TABLE = [
T { v: 1, w: 12, x: 13, y: 14, z: 15 },
T { v: 2, w: 22, x: 23, y: 24, z: 25 },
T { v: 3, w: 32, x: 33, y: 34, z: 35 },
T { v: 4, w: 42, x: 43, y: 44, z: 45 },
T { v: 5, w: 52, x: 53, y: 54, z: 55 },
T { v: 6, w: 62, x: 63, y: 64, z: 65 },
T { v: 7, w: 72, x: 73, y: 74, z: 75 },
T { v: 8, w: 82, x: 83, y: 84, z: 85 },
T { v: 9, w: 92, x: 93, y: 94, z: 95 }
]

def main(field i) -> field:
T t = TABLE[i]
return t.v + t.w + t.x + t.y + t.z
5 changes: 5 additions & 0 deletions examples/circ.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,11 @@ fn main() {
"Final R1cs rounds: {}",
prover_data.precompute.stage_sizes().count() - 1
);
println!(
"Final Witext steps: {}, arguments: {}",
prover_data.precompute.num_steps(),
prover_data.precompute.num_step_args()
);
match action {
ProofAction::Count => (),
#[cfg(feature = "bellman")]
Expand Down
2 changes: 2 additions & 0 deletions examples/opa_bench.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#![allow(clippy::mutable_key_type)]

use circ::cfg::clap::{self, Parser};
use circ::ir::term::*;
use circ::target::aby::assignment::ilp;
Expand Down
1 change: 1 addition & 0 deletions scripts/zokrates_test.zsh
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ function pf_test_isolate {
}

r1cs_test_count ./examples/ZoKrates/pf/mm4_cond.zok 120
r1cs_test_count ./examples/ZoKrates/pf/const_linear_lookup.zok 20
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsAdd.zok
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsOnCurve.zok
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsOrderCheck.zok
Expand Down
2 changes: 1 addition & 1 deletion src/circify/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ pub trait Embeddable {
/// * `name`: the name
/// * `visibility`: who knows it
/// * `precompute`: an optional term for pre-computing the values of this input. If a party
/// knows the inputs to the precomputation, they can use the precomputation.
/// knows the inputs to the precomputation, they can use the precomputation.
fn declare_input(
&self,
ctx: &mut CirCtx,
Expand Down
4 changes: 0 additions & 4 deletions src/ir/opt/chall.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@
//!
//! Each challenge term c that depends on t is replaced with a variable v.
//! Let t' denote a rewritten term.
//!
//! Rules:
//! * round(v) >=
//! round(v
use log::{debug, trace};

use std::cell::RefCell;
Expand Down
46 changes: 41 additions & 5 deletions src/ir/opt/mem/lin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,47 @@ impl RewritePass for Linearizer {
.unwrap_or_else(|| a.val.default_term()),
)
} else {
let mut fields = (0..a.size).map(|idx| term![Op::Field(idx); tup.clone()]);
let first = fields.next().unwrap();
Some(a.key.elems_iter().take(a.size).skip(1).zip(fields).fold(first, |acc, (idx_c, field)| {
term![Op::Ite; term![Op::Eq; idx.clone(), idx_c], field, acc]
}))
let value_sort = check(tup).as_tuple()[0].clone();
if value_sort.is_group() {
// if values are a group
// then emit v0 + ite(idx == i1, v1 - v0, 0) + ... it(idx = iN, vN - v0, 0)
// where +, -, 0 are defined by the group.
//
// we do this because if the values are constant, then the above sum is
// linear, which is very nice for most backends.
let mut fields =
(0..a.size).map(|idx| term![Op::Field(idx); tup.clone()]);
let first = fields.next().unwrap();
let zero = value_sort.group_identity();
Some(
value_sort.group_add_nary(
std::iter::once(first.clone())
.chain(
a.key
.elems_iter()
.take(a.size)
.skip(1)
.zip(fields)
.map(|(idx_c, field)| {
term![Op::Ite;
term![Op::Eq; idx.clone(), idx_c],
value_sort.group_sub(field, first.clone()),
zero.clone()
]
}),
)
.collect(),
),
)
} else {
// otherwise, ite(idx == iN, vN, ... ite(idx == i1, v1, v0) ... )
let mut fields =
(0..a.size).map(|idx| term![Op::Field(idx); tup.clone()]);
let first = fields.next().unwrap();
Some(a.key.elems_iter().take(a.size).skip(1).zip(fields).fold(first, |acc, (idx_c, field)| {
term![Op::Ite; term![Op::Eq; idx.clone(), idx_c], field, acc]
}))
}
}
} else {
unreachable!()
Expand Down
11 changes: 9 additions & 2 deletions src/ir/opt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ pub enum Opt {
pub fn opt<I: IntoIterator<Item = Opt>>(mut cs: Computations, optimizations: I) -> Computations {
for c in cs.comps.values() {
trace!("Before all opts: {}", text::serialize_computation(c));
info!("Before all opts: {} terms", c.stats().main.n_terms);
info!(
"Before all opts: {} terms",
c.stats().main.n_terms + c.stats().prec.n_terms
);
debug!("Before all opts: {:#?}", c.stats());
debug!("Before all opts: {:#?}", c.detailed_stats());
}
Expand Down Expand Up @@ -167,7 +170,11 @@ pub fn opt<I: IntoIterator<Item = Opt>>(mut cs: Computations, optimizations: I)
fits_in_bits_ip::fits_in_bits_ip(c);
}
}
info!("After {:?}: {} terms", i, c.stats().main.n_terms);
info!(
"After {:?}: {} terms",
i,
c.stats().main.n_terms + c.stats().prec.n_terms
);
debug!("After {:?}: {:#?}", i, c.stats());
trace!("After {:?}: {}", i, text::serialize_computation(c));
#[cfg(debug_assertions)]
Expand Down
87 changes: 87 additions & 0 deletions src/ir/term/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,93 @@ impl Sort {
pub fn is_scalar(&self) -> bool {
!matches!(self, Sort::Tuple(..) | Sort::Array(..) | Sort::Map(..))
}

/// Is this sort a group?
pub fn is_group(&self) -> bool {
match self {
Sort::BitVector(_) | Sort::Int | Sort::Field(_) | Sort::Bool => true,
Sort::F32 | Sort::F64 | Sort::Array(_) | Sort::Map(_) => false,
Sort::Tuple(fields) => fields.iter().all(|f| f.is_group()),
}
}

/// The (n-ary) group operation for these terms.
pub fn group_add_nary(&self, ts: Vec<Term>) -> Term {
debug_assert!(ts.iter().all(|t| &check(t) == self));
match self {
Sort::BitVector(_) => term(BV_ADD, ts),
Sort::Bool => term(XOR, ts),
Sort::Field(_) => term(PF_ADD, ts),
Sort::Int => term(INT_ADD, ts),
Sort::Tuple(sorts) => term(
Op::Tuple,
sorts
.iter()
.enumerate()
.map(|(i, sort)| {
sort.group_add_nary(
ts.iter()
.map(|t| term(Op::Field(i), vec![t.clone()]))
.collect(),
)
})
.collect(),
),
_ => panic!("Not a group: {}", self),
}
}

/// Group inverse
pub fn group_neg(&self, t: Term) -> Term {
debug_assert_eq!(&check(&t), self);
match self {
Sort::BitVector(_) => term(BV_NEG, vec![t]),
Sort::Bool => term(NOT, vec![t]),
Sort::Field(_) => term(PF_NEG, vec![t]),
Sort::Int => term(
INT_MUL,
vec![leaf_term(Op::new_const(Value::Int(Integer::from(-1i8)))), t],
),
Sort::Tuple(sorts) => term(
Op::Tuple,
sorts
.iter()
.enumerate()
.map(|(i, sort)| sort.group_neg(term(Op::Field(i), vec![t.clone()])))
.collect(),
),
_ => panic!("Not a group: {}", self),
}
}

/// Group identity
pub fn group_identity(&self) -> Term {
match self {
Sort::BitVector(n_bits) => bv_lit(0, *n_bits),
Sort::Bool => bool_lit(false),
Sort::Field(f) => pf_lit(f.new_v(0)),
Sort::Int => leaf_term(Op::new_const(Value::Int(Integer::from(0i8)))),
Sort::Tuple(sorts) => term(
Op::Tuple,
sorts.iter().map(|sort| sort.group_identity()).collect(),
),
_ => panic!("Not a group: {}", self),
}
}

/// Group operation
pub fn group_add(&self, s: Term, t: Term) -> Term {
debug_assert_eq!(&check(&s), self);
debug_assert_eq!(&check(&t), self);
self.group_add_nary(vec![s, t])
}

/// Group elimination
pub fn group_sub(&self, s: Term, t: Term) -> Term {
debug_assert_eq!(&check(&s), self);
debug_assert_eq!(&check(&t), self);
self.group_add(s, self.group_neg(t))
}
}

mod hc {
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#![warn(missing_docs)]
#![deny(warnings)]
#![allow(rustdoc::private_intra_doc_links)]
#![allow(clippy::mutable_key_type)]

#[macro_use]
pub mod ir;
Expand Down
2 changes: 0 additions & 2 deletions src/target/aby/trans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -904,8 +904,6 @@ pub fn to_aby(cs: Computations, path: &Path, lang: &str, cm: &str, ss: &str) {
panic!("Unsupported sharing scheme: {}", ss);
}
};
#[cfg(feature = "bench")]
println!("LOG: Assignment {}: {:?}", name, now.elapsed());
s_map.insert(name.to_string(), assignments);
}

Expand Down
2 changes: 1 addition & 1 deletion src/target/r1cs/opt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ fn constantly_true((a, b, c): &(Lc, Lc, Lc)) -> bool {
/// ## Parameters
///
/// * `lc_size_thresh`: the maximum size LC (number of non-constant monomials) that will be used
/// for propagation. `None` means no size limit.
/// for propagation. `None` means no size limit.
pub fn reduce_linearities(r1cs: R1cs, cfg: &CircCfg) -> R1cs {
let mut r = LinReducer::new(r1cs, cfg.r1cs.lc_elim_thresh).run();
r.update_stats();
Expand Down
10 changes: 10 additions & 0 deletions src/target/r1cs/wit_comp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,16 @@ impl StagedWitComp {
pub fn num_stage_inputs(&self, n: usize) -> usize {
self.stages[n].inputs.len()
}

/// Number of steps
pub fn num_steps(&self) -> usize {
self.steps.len()
}

/// Number of step arguments
pub fn num_step_args(&self) -> usize {
self.step_args.len()
}
}

/// Evaluator interface
Expand Down
20 changes: 8 additions & 12 deletions src/target/smt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,9 +365,7 @@ pub fn check_sat(t: &Term) -> bool {
let mut solver = make_solver((), false, false);
for c in PostOrderIter::new(t.clone()) {
if let Op::Var(v) = &c.op() {
solver
.declare_const(&SmtSymDisp(&*v.name), &v.sort)
.unwrap();
solver.declare_const(SmtSymDisp(&*v.name), &v.sort).unwrap();
}
}
assert!(check(t) == Sort::Bool);
Expand All @@ -380,9 +378,7 @@ fn get_model_solver(t: &Term, inc: bool) -> rsmt2::Solver<Parser> {
//solver.path_tee("solver_com").unwrap();
for c in PostOrderIter::new(t.clone()) {
if let Op::Var(v) = &c.op() {
solver
.declare_const(&SmtSymDisp(&*v.name), &v.sort)
.unwrap();
solver.declare_const(SmtSymDisp(&*v.name), &v.sort).unwrap();
}
}
assert!(check(t) == Sort::Bool);
Expand Down Expand Up @@ -590,13 +586,13 @@ mod test {
let mut solver = make_solver((), false, false);
for (v, val) in vs {
let s = val.sort();
solver.declare_const(&SmtSymDisp(&v), &s).unwrap();
solver.declare_const(SmtSymDisp(&v), &s).unwrap();
solver
.assert(&term![Op::Eq; var(v.to_string(), s), const_(val.clone())])
.assert(term![Op::Eq; var(v.to_string(), s), const_(val.clone())])
.unwrap();
}
let val = eval(&t, vs);
solver.assert(&term![Op::Eq; t, const_(val)]).unwrap();
solver.assert(term![Op::Eq; t, const_(val)]).unwrap();
solver.check_sat().unwrap()
}

Expand All @@ -605,14 +601,14 @@ mod test {
let mut solver = make_solver((), false, false);
for (v, val) in vs {
let s = val.sort();
solver.declare_const(&SmtSymDisp(&v), &s).unwrap();
solver.declare_const(SmtSymDisp(&v), &s).unwrap();
solver
.assert(&term![Op::Eq; var(v.to_string(), s), const_(val.clone())])
.assert(term![Op::Eq; var(v.to_string(), s), const_(val.clone())])
.unwrap();
}
let val = eval(&t, vs);
solver
.assert(&term![Op::Not; term![Op::Eq; t, const_(val)]])
.assert(term![Op::Not; term![Op::Eq; t, const_(val)]])
.unwrap();
solver.check_sat().unwrap()
}
Expand Down

0 comments on commit 7fcad48

Please sign in to comment.