From 65766a9e0770fbe1249e47babdcc9d9d3e13a133 Mon Sep 17 00:00:00 2001 From: Andrew Dirksen Date: Thu, 22 Apr 2021 17:52:32 -0700 Subject: [PATCH] MapStack Optimization impl the optimization described in https://github.com/docknetwork/rify/issues/9 Results are good. ``` test ancestry::infer_ ... bench: 1,201,454 ns/iter (+/- 20,579) test ancestry::infer_30 ... bench: 3,758,380 ns/iter (+/- 76,535) test ancestry::prove_ ... bench: 1,342,813 ns/iter (+/- 31,824) test ancestry::prove_30 ... bench: 4,210,072 ns/iter (+/- 46,268) test recursion_minimal::infer_ ... bench: 12,014 ns/iter (+/- 232) test recursion_minimal::prove_ ... bench: 16,378 ns/iter (+/- 357) ``` --- src/common.rs | 10 +++---- src/infer.rs | 11 ++------ src/mapstack.rs | 47 +++++++++++++++++++-------------- src/prove.rs | 40 +++++++++++++--------------- src/reasoner.rs | 69 +++++++++++++++++++------------------------------ src/rule.rs | 19 ++------------ 6 files changed, 81 insertions(+), 115 deletions(-) diff --git a/src/common.rs b/src/common.rs index 6ea7e23..7dc9f2b 100644 --- a/src/common.rs +++ b/src/common.rs @@ -33,9 +33,9 @@ where /// [crate::prove::prove] function does not immediately convert LowRuleApplication's to /// [RuleApplication]'s. Rather, it converts only the LowRuleApplication's it is going to return /// to the caller. -pub struct LowRuleApplication { +pub(crate) struct LowRuleApplication { pub rule_index: usize, - pub instantiations: BTreeMap, + pub instantiations: Vec>, } impl LowRuleApplication { @@ -65,7 +65,7 @@ impl LowRuleApplication { for unbound_human in original_rule.cononical_unbound() { let unbound_local: usize = uhul[unbound_human]; - let bound_global: usize = self.instantiations[&unbound_local]; + let bound_global: usize = self.instantiations.get(unbound_local).unwrap().unwrap(); let bound_human: &Bound = trans.back(bound_global).unwrap(); instantiations.push(bound_human.clone()); } @@ -81,7 +81,7 @@ impl LowRuleApplication { /// This is a helper function. It translates all four element of a quad, calling /// [Translator::forward] for each element. If any element has no translation /// this function will return `None`. -pub fn forward(translator: &Translator, key: &[T; 4]) -> Option { +pub(crate) fn forward(translator: &Translator, key: &[T; 4]) -> Option { let [s, p, o, g] = key; Some( [ @@ -97,7 +97,7 @@ pub fn forward(translator: &Translator, key: &[T; 4]) -> Option /// Reverse of [forward]. /// [forward] translates each element from `T` to `usize`. This function translates each element /// from `usize` to `T`. If any element has no translation this function will return `None`. -pub fn back(translator: &Translator, key: Quad) -> Option<[&T; 4]> { +pub(crate) fn back(translator: &Translator, key: Quad) -> Option<[&T; 4]> { let Quad { s, p, o, g } = key; Some([ translator.back(s.0)?, diff --git a/src/infer.rs b/src/infer.rs index e4bc0ed..b17c410 100644 --- a/src/infer.rs +++ b/src/infer.rs @@ -65,15 +65,8 @@ fn low_infer(premises: &[Quad], rules: &[LowRule]) -> Vec { } in rules.iter_mut() { rs.apply_related(new.clone(), if_all, inst, &mut |inst| { - let ins = inst.as_ref(); - for implied in then.iter() { - let new_quad = [ - ins[&implied.s.0], - ins[&implied.p.0], - ins[&implied.o.0], - ins[&implied.g.0], - ] - .into(); + for implied in then.iter().cloned() { + let new_quad = implied.local_to_global(inst).unwrap(); if !rs.contains(&new_quad) && !adding.contains(&new_quad) { to_add.insert(new_quad); } diff --git a/src/mapstack.rs b/src/mapstack.rs index a17ab67..08d9970 100644 --- a/src/mapstack.rs +++ b/src/mapstack.rs @@ -1,51 +1,58 @@ -use alloc::collections::BTreeMap; use core::fmt; use core::fmt::Debug; use core::iter::FromIterator; /// A mapping that keeps a history of writes. Writes to the map effect "pushes" to a stack. Those /// "pushes" can be undone with a "pop". -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct MapStack { - current: BTreeMap, - history: Vec<(K, Option)>, +/// +/// Beware large keys when using this data structure. The memory used byt this structure scales +/// with the size of the largest key! +#[derive(Clone, Debug, PartialEq, Eq, Ord, PartialOrd, Default)] +pub(crate) struct MapStack { + current: Vec>, + history: Vec<(usize, Option)>, } -impl MapStack { +impl MapStack { pub fn new() -> Self { Self { - current: BTreeMap::new(), + current: Vec::new(), history: Vec::new(), } } } -impl MapStack { +impl MapStack { /// Set a value appending to write history. - pub fn write(&mut self, key: K, val: V) { - let old_val = self.current.insert(key.clone(), val); - self.history.push((key, old_val)); + pub fn write(&mut self, key: usize, val: V) { + debug_assert!(key < 30 ^ 2, "Thats a pretty large key. Are you sure?"); + if self.current.len() <= key { + self.current.resize_with(key + 1, || None); + } + let mut val = Some(val); + core::mem::swap(&mut self.current[key], &mut val); + self.history.push((key, val)); } /// Undo most recent write or return error if there is no history to undo. pub fn undo(&mut self) -> Result<(), NoMoreHistory> { let (key, old_val) = self.history.pop().ok_or(NoMoreHistory)?; - match old_val { - Some(val) => self.current.insert(key, val), - None => self.current.remove(&key), - }; + self.current[key] = old_val; Ok(()) } -} -impl AsRef> for MapStack { - fn as_ref(&self) -> &BTreeMap { + /// Get the current value at key. + pub fn get(&self, key: usize) -> Option<&V> { + self.current.get(key).and_then(|o| o.as_ref()) + } + + pub fn inner(&self) -> &Vec> { &self.current } } -impl FromIterator<(K, V)> for MapStack { - fn from_iter>(kvs: T) -> Self { +impl FromIterator<(usize, V)> for MapStack { + fn from_iter>(kvs: T) -> Self { let mut ret = Self::new(); for (k, v) in kvs.into_iter() { ret.write(k, v); diff --git a/src/prove.rs b/src/prove.rs index cf56a8d..145d5db 100644 --- a/src/prove.rs +++ b/src/prove.rs @@ -136,20 +136,13 @@ fn low_prove( ) in rules2.iter_mut() { rs.apply_related(fact.clone(), if_all, inst, &mut |inst| { - let ins = inst.as_ref(); - for implied in then.iter() { - let new_quad = [ - ins[&implied.s.0], - ins[&implied.p.0], - ins[&implied.o.0], - ins[&implied.g.0], - ] - .into(); + for implied in then.iter().cloned() { + let new_quad = implied.local_to_global(inst).unwrap(); if !rs.contains(&new_quad) && !adding_now.contains(&new_quad) { arguments.entry(new_quad.clone()).or_insert_with(|| { LowRuleApplication { rule_index: *rule_index, - instantiations: ins.clone(), + instantiations: inst.inner().clone(), } }); to_add.insert(new_quad); @@ -183,8 +176,11 @@ fn recall_proof( outp: &mut Vec, ) { let to_global_scope = |rra: &LowRuleApplication, locally_scoped: usize| -> usize { - let concrete = rules[rra.rule_index].inst.as_ref().get(&locally_scoped); - let found = rra.instantiations.get(&locally_scoped); + let concrete = rules[rra.rule_index].inst.get(locally_scoped); + let found = rra + .instantiations + .get(locally_scoped) + .and_then(|o| o.as_ref()); if let (Some(c), Some(f)) = (concrete, found) { debug_assert_eq!(c, f); } @@ -237,7 +233,7 @@ impl std::error::Error for CantProve {} /// An element of a deductive proof. Proofs can be transmitted and later validatated as long as the /// validator assumes the same rule list as the prover. /// -/// Unbound variables are bound to the values in `instanitations`. They are bound in order of +/// Unbound variables are bound to the values in `instantiations`. They are bound in order of /// initial appearance. /// /// Given the rule: @@ -311,31 +307,31 @@ impl RuleApplication { /// Panics /// /// panics if an unbound entity is not registered in map -/// panics if the canonical index of unbound (according to map) is too large to index instanitations +/// panics if the canonical index of unbound (according to map) is too large to index instantiations fn bind_claim( [s, p, o, g]: [Entity; 4], map: &BTreeMap<&Unbound, usize>, - instanitations: &[Bound], + instantiations: &[Bound], ) -> [Bound; 4] { [ - bind_entity(s, map, instanitations), - bind_entity(p, map, instanitations), - bind_entity(o, map, instanitations), - bind_entity(g, map, instanitations), + bind_entity(s, map, instantiations), + bind_entity(p, map, instantiations), + bind_entity(o, map, instantiations), + bind_entity(g, map, instantiations), ] } /// Panics /// /// panics if an unbound entity is not registered in map -/// panics if the canonical index of unbound (according to map) is too large to index instanitations +/// panics if the canonical index of unbound (according to map) is too large to index instantiations fn bind_entity( e: Entity, map: &BTreeMap<&Unbound, usize>, - instanitations: &[Bound], + instantiations: &[Bound], ) -> Bound { match e { - Entity::Unbound(a) => instanitations[map[&a]].clone(), + Entity::Unbound(a) => instantiations[map[&a]].clone(), Entity::Bound(e) => e, } } diff --git a/src/reasoner.rs b/src/reasoner.rs index 6557328..bcd5c70 100644 --- a/src/reasoner.rs +++ b/src/reasoner.rs @@ -3,7 +3,7 @@ use crate::vecset::VecSet; use core::cmp::Ordering; #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct Quad { +pub(crate) struct Quad { pub s: Subj, pub p: Prop, pub o: Obje, @@ -22,13 +22,12 @@ impl Quad { /// Attempt dereference all variable in self. pub fn local_to_global(self, inst: &Instantiations) -> Option { - let inst = inst.as_ref(); Some( [ - *inst.get(&self.s.0)?, - *inst.get(&self.p.0)?, - *inst.get(&self.o.0)?, - *inst.get(&self.g.0)?, + *inst.get(self.s.0)?, + *inst.get(self.p.0)?, + *inst.get(self.o.0)?, + *inst.get(self.g.0)?, ] .into(), ) @@ -82,10 +81,10 @@ type O = (Obje,); type G = (Grap,); /// Bindings of slots within the context of a rule. -pub type Instantiations = MapStack; +pub(crate) type Instantiations = MapStack; #[derive(Default)] -pub struct Reasoner { +pub(crate) struct Reasoner { claims: Vec, spog: VecSet, posg: VecSet, @@ -137,7 +136,7 @@ impl Reasoner { cb: &mut impl FnMut(&Instantiations), ) { debug_assert!(self.contains(&quad)); - debug_assert!(!rule.is_empty(), "potential confusing so disallowed"); + debug_assert!(!rule.is_empty(), "potentialy confusing so disallowed"); // for each atom of rule, if the atom can match the quad, bind the unbound variables in the // atom to the corresponding elements of quad, then call apply. for i in 0..rule.len() { @@ -181,7 +180,7 @@ impl Reasoner { let to_write = strictest.clone().zip(quad.clone()); for (k, v) in &to_write { debug_assert!( - if let Some(committed_v) = instantiations.as_ref().get(&k) { + if let Some(committed_v) = instantiations.get(*k) { committed_v == v } else { true @@ -200,13 +199,12 @@ impl Reasoner { /// Return a slice representing all possible matches to the pattern provided. /// pattern is in a local scope. instantiations is a partial translation of that /// local scope to the global scope represented by self.claims - fn matches(&self, pattern: &Quad, instantiations: &Instantiations) -> &[usize] { - let inst = instantiations.as_ref(); + fn matches(&self, pattern: &Quad, inst: &Instantiations) -> &[usize] { let pattern: (Option, Option, Option, Option) = ( - inst.get(&pattern.s.0).cloned().map(Subj), - inst.get(&pattern.p.0).cloned().map(Prop), - inst.get(&pattern.o.0).cloned().map(Obje), - inst.get(&pattern.g.0).cloned().map(Grap), + inst.get(pattern.s.0).cloned().map(Subj), + inst.get(pattern.p.0).cloned().map(Prop), + inst.get(pattern.o.0).cloned().map(Obje), + inst.get(pattern.g.0).cloned().map(Grap), ); match pattern { (Some(s), Some(p), Some(o), Some(g)) => (s, p, o, g).search(self), @@ -251,7 +249,7 @@ impl Reasoner { } } -fn evict<'a, T>(index: usize, unordered_list: &'a mut [T]) -> Option<(&'a T, &'a mut [T])> { +fn evict(index: usize, unordered_list: &mut [T]) -> Option<(&T, &mut [T])> { if index >= unordered_list.len() { None } else { @@ -265,12 +263,11 @@ fn evict<'a, T>(index: usize, unordered_list: &'a mut [T]) -> Option<(&'a T, &'a /// returns whether rule_part could validly be applied to quad assuming the already given /// instantiations -fn can_match(quad: Quad, rule_part: Quad, instantiations: &Instantiations) -> bool { - let inst = instantiations.as_ref(); +fn can_match(quad: Quad, rule_part: Quad, inst: &Instantiations) -> bool { rule_part .zip(quad) .iter() - .all(|(rp, q)| match inst.get(&rp) { + .all(|(rp, q)| match inst.get(*rp) { Some(a) => a == q, None => true, }) @@ -352,7 +349,7 @@ mod tests { LowRule, Rule, }; use crate::translator::Translator; - use alloc::collections::{BTreeMap, BTreeSet}; + use alloc::collections::BTreeSet; #[test] fn ancestry_raw() { @@ -408,7 +405,7 @@ mod tests { // This test only does one round of reasoning, no forward chaining. We will need a forward // chaining test eventually. - let mut results = Vec::>::new(); + let mut results = Vec::>>::new(); for rule in rules { let Rule { mut if_all, @@ -416,7 +413,7 @@ mod tests { mut inst, } = rule.clone(); ts.apply(&mut if_all, &mut inst, &mut |inst| { - results.push(inst.as_ref().clone()) + results.push(inst.inner().clone()) }); } @@ -425,20 +422,15 @@ mod tests { // The second rule, (?a ancestor ?b ?g) and (?b ancestor ?c ?g) -> (?a ancestor ?c ?g), // should not activate because results from application of first rule have not been added // to the rdf store so there are there are are not yet any ancestry relations present. - let mut expected_intantiations: Vec> = nodes + let mut expected_intantiations: Vec>> = nodes .iter() .zip(nodes.iter().cycle().skip(1)) .map(|(a, b)| { - [ - (0, *a), - (1, parent.0), - (2, *b), - (3, ancestor.0), - (4, default_graph.0), - ] - .iter() - .cloned() - .collect() + [*a, parent.0, *b, ancestor.0, default_graph.0] + .iter() + .cloned() + .map(Some) + .collect() }) .collect(); results.sort(); @@ -517,14 +509,7 @@ mod tests { let then = &rule.then; ts.apply(&mut if_all, &mut inst, &mut |inst| { for implied in then { - let inst = inst.as_ref(); - let new: Quad = [ - inst[&implied.s.0], - inst[&implied.p.0], - inst[&implied.o.0], - inst[&implied.g.0], - ] - .into(); + let new: Quad = implied.clone().local_to_global(inst).unwrap(); if !ts.contains(&new) { to_add.insert(new); } diff --git a/src/rule.rs b/src/rule.rs index ead7f4a..605e566 100644 --- a/src/rule.rs +++ b/src/rule.rs @@ -27,19 +27,6 @@ pub(crate) struct LowRule { pub inst: Instantiations, // partially maps the local scope to some global scope } -// impl LowRule { -// /// List all implications as globaly scoped quads. -// /// -// /// # Panics -// /// -// /// Panics if rule is not sufficiently instantiated. -// pub fn instantiated(&self) -> impl Iterator { -// let inst = self.inst.as_ref(); -// self.then.iter().map(|local| { -// }) -// } -// } - #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Debug)] #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub enum Entity { @@ -298,8 +285,7 @@ mod test { .map(|local_name| { re_rulea .inst - .as_ref() - .get(&local_name) + .get(*local_name) .map(|global_name| trans.back(*global_name).unwrap().clone()) }) .collect(); @@ -353,8 +339,7 @@ mod test { .map(|local_name| { re_ruleb .inst - .as_ref() - .get(&local_name) + .get(*local_name) .map(|global_name| trans.back(*global_name).unwrap().clone()) }) .collect();