Skip to content

Commit

Permalink
Optimize infer() and prove().
Browse files Browse the repository at this point in the history
This implements differential rule applications as suggested in #1

As a side effect of implementing that optimization I also reduced the cased where LowRule needed to be cloned within the hot path. The reduction in clones made a small difference but differential rule applications had the greated effect.

Before:
```
running 4 tests
test ancestry::infer_          ... bench:   2,339,370 ns/iter (+/- 36,027)
test ancestry::prove_          ... bench:   2,332,249 ns/iter (+/- 61,357)
test recursion_minimal::infer_ ... bench:      35,339 ns/iter (+/- 1,012)
test recursion_minimal::prove_ ... bench:      34,850 ns/iter (+/- 312)
```

After:
```
test ancestry::infer_          ... bench:   1,683,208 ns/iter (+/- 49,885)
test ancestry::prove_          ... bench:   1,792,665 ns/iter (+/- 22,206)
test recursion_minimal::infer_ ... bench:      17,601 ns/iter (+/- 457)
test recursion_minimal::prove_ ... bench:      23,608 ns/iter (+/- 677)
```

We see a ~24% improvment for the 20 node ancestry test case. The percentage improvement becomes larger as node count increases. The previous version was too slow to even include a benchmark with larger node count (I lost patience waiting.). The new version is fast enough to scale to larger node counts so I've added another couple larger benchmarks in this commit: `ancestry::infer_30` and `test ancestry::prove_30`
  • Loading branch information
bddap committed Apr 22, 2021
1 parent 5ac1e27 commit 7bf82a1
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 78 deletions.
24 changes: 19 additions & 5 deletions benches/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ mod ancestry {
}

// Contains intentional leak; don't use outside of tests.
fn facts() -> Vec<[&'static str; 4]> {
let nodes: Vec<&str> = (0..20)
fn facts(numnodes: usize) -> Vec<[&'static str; 4]> {
let nodes: Vec<&str> = (0..numnodes)
.map(|n| Box::leak(format!("node_{}", n).into_boxed_str()) as &str)
.collect();
let facts: Vec<[&str; 4]> = nodes
Expand All @@ -46,16 +46,30 @@ mod ancestry {

#[bench]
fn infer_(b: &mut Bencher) {
let facts = facts();
let facts = facts(20);
let rules = rules();
b.iter(|| infer(&facts, &rules));
}

#[bench]
fn prove_(b: &mut Bencher) {
let facts = facts();
let facts = facts(20);
let rules = rules();
b.iter(|| prove(&facts, &[[PARENT, PARENT, PARENT, PARENT]], &rules));
b.iter(|| prove(&facts, &[[PARENT, PARENT, PARENT, PARENT]], &rules).unwrap_err());
}

#[bench]
fn infer_30(b: &mut Bencher) {
let facts = facts(30);
let rules = rules();
b.iter(|| infer(&facts, &rules));
}

#[bench]
fn prove_30(b: &mut Bencher) {
let facts = facts(30);
let rules = rules();
b.iter(|| prove(&facts, &[[PARENT, PARENT, PARENT, PARENT]], &rules).unwrap_err());
}
}

Expand Down
70 changes: 46 additions & 24 deletions src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,36 +28,58 @@ pub fn infer<Unbound: Ord + Clone, Bound: Ord + Clone>(
/// A version of infer that operates on lowered input an returns output in lowered form.
fn low_infer(premises: &[Quad], rules: &[LowRule]) -> Vec<Quad> {
let mut rs = Reasoner::default();
for prem in premises {
rs.insert(prem.clone());

let mut to_add: BTreeSet<Quad> = premises.iter().cloned().collect();
let initial_len = to_add.len(); // number of premises after dedup
assert!(initial_len <= premises.len());

// apply_related is sufficient except for in the case of unconditional rules.
// in order to avoid calling apply() directly, rules with empty "if" clauses
// lists will get special treatment.
for rule in rules {
if rule.if_all.is_empty() {
for rule_part in &rule.then {
to_add.insert(rule_part.clone().local_to_global(&rule.inst).unwrap());
}
}
}
let initial_len = rs.claims_ref().len(); // number of premises after dedup
debug_assert!(initial_len <= premises.len());
let mut rules: Vec<LowRule> = rules
.iter()
.filter(|r| !r.if_all.is_empty())
.cloned()
.collect();

// subsequent reasoning is done in a loop using apply_related
loop {
let mut to_add = BTreeSet::<Quad>::new();
for rr in rules.iter() {
rs.apply(&mut rr.if_all.clone(), &mut rr.inst.clone(), &mut |inst| {
let ins = inst.as_ref();
for implied in &rr.then {
let new_quad = [
ins[&implied.s.0],
ins[&implied.p.0],
ins[&implied.o.0],
ins[&implied.g.0],
]
.into();
if !rs.contains(&new_quad) {
to_add.insert(new_quad);
}
}
});
}
if to_add.is_empty() {
break;
}
for new in to_add.into_iter() {
rs.insert(new);
let mut adding = BTreeSet::default();
core::mem::swap(&mut to_add, &mut adding);
for new in adding.iter().cloned() {
rs.insert(new.clone());
for LowRule {
ref mut if_all,
then,
ref mut inst,
} in rules.iter_mut()
{
rs.apply_related(new.clone(), if_all, inst, &mut |inst| {
let ins = inst.as_ref();
for implied in then.iter() {
let new_quad = [
ins[&implied.s.0],
ins[&implied.p.0],
ins[&implied.o.0],
ins[&implied.g.0],
]
.into();
if !rs.contains(&new_quad) && !adding.contains(&new_quad) {
to_add.insert(new_quad);
}
}
});
}
}
}

Expand Down
95 changes: 59 additions & 36 deletions src/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::translator::Translator;
use alloc::collections::{BTreeMap, BTreeSet};
use core::fmt::{Debug, Display};

/// Locate a proof of some composite claims given the provied premises and rules.
/// Locate a proof of some composite claims given the provided premises and rules.
///
/// ```
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
Expand Down Expand Up @@ -93,47 +93,70 @@ fn low_prove(
rules: &[LowRule],
) -> Result<Vec<LowRuleApplication>, CantProve> {
let mut rs = Reasoner::default();
for prem in premises {
rs.insert(prem.clone());
}

// statement (Quad) is proved by applying rule (LowRuleApplication)
let mut arguments: BTreeMap<Quad, LowRuleApplication> = BTreeMap::new();
let mut to_add: BTreeSet<Quad> = premises.iter().cloned().collect();

for (rule_index, rule) in rules.iter().enumerate() {
if rule.if_all.is_empty() {
for implied in &rule.then {
let new_quad = implied.clone().local_to_global(&rule.inst).unwrap();
if to_add.insert(new_quad.clone()) {
arguments.insert(
new_quad,
LowRuleApplication {
rule_index,
instantiations: Default::default(),
},
);
}
}
}
}
let mut rules2: Vec<(usize, LowRule)> = rules
.iter()
.cloned()
.enumerate()
.filter(|(_index, rule)| !rule.if_all.is_empty())
.collect();

// reason
loop {
if to_prove.iter().all(|tp| rs.contains(tp)) {
break;
}
let mut to_add = BTreeSet::<Quad>::new();
for (rule_index, rr) in rules.iter().enumerate() {
rs.apply(&mut rr.if_all.clone(), &mut rr.inst.clone(), &mut |inst| {
let ins = inst.as_ref();
for implied in &rr.then {
let new_quad = [
ins[&implied.s.0],
ins[&implied.p.0],
ins[&implied.o.0],
ins[&implied.g.0],
]
.into();
if !rs.contains(&new_quad) {
arguments
.entry(new_quad.clone())
.or_insert_with(|| LowRuleApplication {
rule_index,
instantiations: ins.clone(),
while !to_add.is_empty() && !to_prove.iter().all(|tp| rs.contains(tp)) {
let mut adding_now = BTreeSet::<Quad>::new();
core::mem::swap(&mut adding_now, &mut to_add);
for fact in &adding_now {
rs.insert(fact.clone());
for (
rule_index,
LowRule {
ref mut if_all,
then,
ref mut inst,
},
) in rules2.iter_mut()
{
rs.apply_related(fact.clone(), if_all, inst, &mut |inst| {
let ins = inst.as_ref();
for implied in then.iter() {
let new_quad = [
ins[&implied.s.0],
ins[&implied.p.0],
ins[&implied.o.0],
ins[&implied.g.0],
]
.into();
if !rs.contains(&new_quad) && !adding_now.contains(&new_quad) {
arguments.entry(new_quad.clone()).or_insert_with(|| {
LowRuleApplication {
rule_index: *rule_index,
instantiations: ins.clone(),
}
});
to_add.insert(new_quad);
to_add.insert(new_quad);
}
}
}
});
}
if to_add.is_empty() {
break;
}
for new in to_add.into_iter() {
rs.insert(new);
});
}
}
}

Expand Down
91 changes: 78 additions & 13 deletions src/reasoner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,31 @@ pub struct Quad {
pub g: Grap,
}

impl Quad {
fn zip(self, other: Quad) -> [(usize, usize); 4] {
[
(self.s.0, other.s.0),
(self.p.0, other.p.0),
(self.o.0, other.o.0),
(self.g.0, other.g.0),
]
}

/// Attempt dereference all variable in self.
pub fn local_to_global(self, inst: &Instantiations) -> Option<Quad> {
let inst = inst.as_ref();
Some(
[
*inst.get(&self.s.0)?,
*inst.get(&self.p.0)?,
*inst.get(&self.o.0)?,
*inst.get(&self.g.0)?,
]
.into(),
)
}
}

impl From<[usize; 4]> for Quad {
fn from([s, p, o, g]: [usize; 4]) -> Self {
Self {
Expand Down Expand Up @@ -102,6 +127,34 @@ impl Reasoner {
.all(|wn| wn[0] == wn[1]));
}

/// Granted a statement find in this store all possible valid instantiations of rule that
/// involve the statement. It is expexcted that the statement has already been [Self::insert]ed.
pub fn apply_related(
&self,
quad: Quad,
rule: &mut [Quad],
instantiations: &mut Instantiations,
cb: &mut impl FnMut(&Instantiations),
) {
debug_assert!(self.contains(&quad));
debug_assert!(!rule.is_empty(), "potential confusing so disallowed");
// for each atom of rule, if the atom can match the quad, bind the unbound variables in the
// atom to the corresponding elements of quad, then call apply.
for i in 0..rule.len() {
let (rule_part, rule_rest) = evict(i, rule).unwrap();
if can_match(quad.clone(), rule_part.clone(), instantiations) {
let to_set = rule_part.clone().zip(quad.clone());
for (k, v) in &to_set {
instantiations.write(*k, *v);
}
self.apply(rule_rest, instantiations, cb);
for _ in &to_set {
instantiations.undo().unwrap();
}
}
}
}

/// Find in this store all possible valid instantiations of rule. Report the
/// instantiations through a callback.
/// TODO: This function is recursive, but not tail recursive. Rules that are too long may
Expand All @@ -125,12 +178,7 @@ impl Reasoner {
// in the requirement to the instantiation then recurse.
for index in self.matches(strictest, instantiations) {
let quad = &self.claims[*index];
let to_write = [
(strictest.s.0, quad.s.0),
(strictest.p.0, quad.p.0),
(strictest.o.0, quad.o.0),
(strictest.g.0, quad.g.0),
];
let to_write = strictest.clone().zip(quad.clone());
for (k, v) in &to_write {
debug_assert!(
if let Some(committed_v) = instantiations.as_ref().get(&k) {
Expand Down Expand Up @@ -193,24 +241,41 @@ impl Reasoner {
) -> Option<(&'rule Quad, &'rule mut [Quad])> {
let index_strictest = (0..rule.len())
.min_by_key(|index| self.matches(&rule[*index], instantiations).len())?;
rule.swap(0, index_strictest);
let (strictest, less_strict) = rule.split_first_mut().expect("rule to be non-empty");
Some((strictest, less_strict))
evict(index_strictest, rule)
}

/// Get the deduplicated history of all claims that were inserted into this reasoner.
/// The returned list will be in insertion order.
pub fn claims(self) -> Vec<Quad> {
self.claims
}
}

/// Get the deduplicated history of all claims that were inserted into this reasoner.
/// The returned list will be in insertion order.
pub fn claims_ref(&self) -> &[Quad] {
&self.claims
fn evict<'a, T>(index: usize, unordered_list: &'a mut [T]) -> Option<(&'a T, &'a mut [T])> {
if index >= unordered_list.len() {
None
} else {
unordered_list.swap(0, index);
let (popped, rest) = unordered_list
.split_first_mut()
.expect("list to be non-empty");
Some((popped, rest))
}
}

/// returns whether rule_part could validly be applied to quad assuming the already given
/// instantiations
fn can_match(quad: Quad, rule_part: Quad, instantiations: &Instantiations) -> bool {
let inst = instantiations.as_ref();
rule_part
.zip(quad)
.iter()
.all(|(rp, q)| match inst.get(&rp) {
Some(a) => a == q,
None => true,
})
}

trait Indexed {
fn target(rs: &Reasoner) -> &VecSet<usize>;
fn qcmp(&self, quad: &Quad) -> Ordering;
Expand Down
13 changes: 13 additions & 0 deletions src/rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ pub(crate) struct LowRule {
pub inst: Instantiations, // partially maps the local scope to some global scope
}

// impl LowRule {
// /// List all implications as globaly scoped quads.
// ///
// /// # Panics
// ///
// /// Panics if rule is not sufficiently instantiated.
// pub fn instantiated(&self) -> impl Iterator<Item = Quad> {
// let inst = self.inst.as_ref();
// self.then.iter().map(|local| {
// })
// }
// }

#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub enum Entity<Unbound, Bound> {
Expand Down

0 comments on commit 7bf82a1

Please sign in to comment.