diff --git a/src/lib.rs b/src/lib.rs index 2b5fc99..9db8664 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,5 @@ #![allow(clippy::type_complexity)] #![cfg_attr(not(feature = "std"), no_std)] -#[cfg(not(feature = "std"))] -extern crate core as std; extern crate alloc; @@ -9,7 +7,7 @@ pub mod modified; pub mod stored; use core::fmt::Debug; -use std::{iter, mem}; +use core::{iter, mem}; use alloc::{boxed::Box, string::String, vec::Vec}; pub use modified::*; @@ -554,7 +552,7 @@ impl, V: AsRef<[u8]>> Transaction { let hash = data_store .get_unvisited_hash(*stored_idx) .copied() - .map_err(|e| e.into())?; + .map_err(|e| format!("Error in `calc_root_hash_node`: {e}"))?; Ok(hash) } } @@ -602,7 +600,9 @@ impl, V: AsRef<[u8]>> Transaction { key_hash: &KeyHash, ) -> Result, String> { loop { - let node = data_store.get_node(stored_idx).map_err(|e| e.into())?; + let node = data_store + .get_node(stored_idx) + .map_err(|e| format!("Error in `get_stored_node`: {e}"))?; match node { // TODO check that the KeyPosition is optimized out. Node::Branch(branch) => match branch.descend(key_hash) { @@ -622,7 +622,10 @@ impl, V: AsRef<[u8]>> Transaction { } } - match data_store.get_node(stored_idx).map_err(|e| e.into())? { + match data_store + .get_node(stored_idx) + .map_err(|e| format!("Error in `get_stored_node`: {e}"))? + { Node::Leaf(leaf) => Ok(Some(&leaf.value)), _ => unreachable!("Prior loop only breaks on a leaf"), } @@ -674,7 +677,9 @@ impl, V: AsRef<[u8]>> Transaction { } } NodeRef::Stored(stored_idx) => { - let new_node = data_store.get_node(*stored_idx).map_err(|e| e.into())?; + let new_node = data_store.get_node(*stored_idx).map_err(|e| { + format!("Error at `{}:{}:{}`: `{e}`", file!(), line!(), column!()) + })?; match new_node { stored::Node::Branch(new_branch) => { *root = NodeRef::ModBranch(Box::new(Branch { @@ -793,7 +798,9 @@ impl, V: AsRef<[u8]>> Transaction { } NodeRef::Stored(stored_idx) => { // TODO this is an artificial load of leaf.value. - let new_node = data_store.get_node(*stored_idx).map_err(|e| e.into())?; + let new_node = data_store + .get_node(*stored_idx) + .map_err(|e| format!("Error in `insert_below_branch`: {e}"))?; match new_node { stored::Node::Branch(new_branch) => { *next = NodeRef::ModBranch(Box::new(Branch { diff --git a/src/stored.rs b/src/stored.rs index badc14a..283001d 100644 --- a/src/stored.rs +++ b/src/stored.rs @@ -1,7 +1,7 @@ pub mod memory_db; pub mod merkle; -use alloc::{fmt::Debug, string::String}; +use alloc::fmt::Debug; use core::{fmt::Display, hash::Hash}; use crate::{Branch, Leaf}; @@ -9,7 +9,7 @@ use crate::{Branch, Leaf}; pub type Idx = u32; pub trait Store { - type Error: Into + Debug; + type Error: Display; /// Must return a hash of a node that has not been visited. /// May return a hash of a node that has already been visited. diff --git a/src/stored/merkle.rs b/src/stored/merkle.rs index 7de6012..9d5adcc 100644 --- a/src/stored/merkle.rs +++ b/src/stored/merkle.rs @@ -71,18 +71,28 @@ impl> Snapshot { // I dislike using an explicit mutable stack. // I have an idea for abusing async for high performance segmented stacks fn calc_root_hash_inner(&self, node: Idx) -> Result { - match self.get_node(node) { - Ok(Node::Branch(branch)) => { - let left = self.calc_root_hash_inner(branch.left)?; - let right = self.calc_root_hash_inner(branch.right)?; + let idx = node as usize; + let leaf_offset = self.branches.len(); + let unvisited_offset = leaf_offset + self.leaves.len(); - Ok(branch.hash_branch(&left, &right)) - } - Ok(Node::Leaf(leaf)) => Ok(leaf.hash_leaf()), - Err(_) => self - .get_unvisited_hash(node) - .copied() - .map_err(|_| format!("Invalid snapshot: node {} not found", node)), + if let Some(branch) = self.branches.get(idx) { + let left = self.calc_root_hash_inner(branch.left)?; + let right = self.calc_root_hash_inner(branch.right)?; + + Ok(branch.hash_branch(&left, &right)) + } else if let Some(leaf) = self.leaves.get(idx - leaf_offset) { + Ok(leaf.hash_leaf()) + } else if let Some(hash) = self.unvisited_nodes.get(idx - unvisited_offset) { + Ok(*hash) + } else { + Err(format!( + "Invalid snapshot: node {} not found\n\ + Snapshot has {} branches, {} leaves, and {} unvisited nodes", + idx, + self.branches.len(), + self.leaves.len(), + self.unvisited_nodes.len(), + )) } } } @@ -91,9 +101,7 @@ impl> Store for Snapshot { type Error = Error; fn get_unvisited_hash(&self, idx: Idx) -> Result<&NodeHash> { - let idx = idx as usize - self.branches.len() - self.leaves.len(); - - self.unvisited_nodes.get(idx).ok_or_else(|| { + let error = || { format!( "Invalid snapshot: no unvisited node at index {}\n\ Snapshot has {} branches, {} leaves, and {} unvisited nodes", @@ -102,7 +110,15 @@ impl> Store for Snapshot { self.leaves.len(), self.unvisited_nodes.len(), ) - }) + }; + + let idx = idx as usize; + if idx < self.branches.len() + self.leaves.len() { + return Err(error()); + } + let idx = idx - self.branches.len() - self.leaves.len(); + + self.unvisited_nodes.get(idx).ok_or_else(error) } fn get_node(&self, idx: Idx) -> Result, &Leaf>> { @@ -127,16 +143,16 @@ impl> Store for Snapshot { } } -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct SnapshotBuilder<'a, Db, V> { pub db: Db, bump: &'a Bump, /// The root of the trie is always at index 0 - nodes: RefCell>>, + nodes: RefCell, &'a Leaf>>)>>, } -type NodeHashMaybeNode<'a, V> = (NodeHash, Option, &'a Leaf>>); +type NodeHashMaybeNode<'a, V> = (&'a NodeHash, Option, &'a Leaf>>); impl<'a, Db: DatabaseGet, V: Clone> Store for SnapshotBuilder<'a, Db, V> { type Error = Error; @@ -147,7 +163,7 @@ impl<'a, Db: DatabaseGet, V: Clone> Store for SnapshotBuilder<'a, Db, V> { self.nodes .borrow() .get(hash_idx) - .map(|(hash, _)| hash) + .map(|(hash, _)| *hash) .ok_or_else(|| { format!( "Invalid snapshot: no unvisited node at index {}\n\ @@ -160,18 +176,14 @@ impl<'a, Db: DatabaseGet, V: Clone> Store for SnapshotBuilder<'a, Db, V> { fn get_node(&self, hash_idx: Idx) -> Result, &Leaf>, Self::Error> { let hash_idx = hash_idx as usize; + let mut nodes = self.nodes.borrow_mut(); - let Some((hash, o_node)) = self - .nodes - .borrow() - .get(hash_idx) - .map(|(hash, o_node)| (hash, *o_node)) - else { + let Some((hash, o_node)) = nodes.get(hash_idx).map(|(hash, o_node)| (hash, *o_node)) else { return Err(format!( "Invalid snapshot: no node at index {}\n\ SnapshotBuilder has {} nodes", hash_idx, - self.nodes.borrow().len() + nodes.len() )); }; @@ -179,18 +191,39 @@ impl<'a, Db: DatabaseGet, V: Clone> Store for SnapshotBuilder<'a, Db, V> { return Ok(node); } - let next_idx = self.nodes.borrow().len() as Idx; - let (node, left, right) = Self::get_from_db(self.bump, &self.db, hash, next_idx)?; + let node = self + .db + .get(hash) + .map_err(|e| format!("Error getting {hash} from database: `{e}`"))?; - let add_unvisited = |hash: Option| { - if let Some(hash) = hash { - self.nodes.borrow_mut().push(self.bump.alloc((hash, None))) + let node = match node { + Node::Branch(Branch { + mask, + left, + right, + prior_word, + prefix, + }) => { + let idx = nodes.len() as Idx; + + let left = self.bump.alloc(left); + let right = self.bump.alloc(right); + + nodes.push((&*left, None)); + nodes.push((&*right, None)); + + Node::Branch(&*self.bump.alloc(Branch { + mask, + left: idx, + right: idx + 1, + prior_word, + prefix, + })) } + Node::Leaf(leaf) => Node::Leaf(&*self.bump.alloc(leaf)), }; - add_unvisited(left); - add_unvisited(right); - + nodes[hash_idx].1 = Some(node); Ok(node) } } @@ -212,9 +245,8 @@ impl<'a, Db, V> SnapshotBuilder<'a, Db, V> { } pub fn with_root_hash(self, root_hash: NodeHash) -> Self { - self.nodes - .borrow_mut() - .push(self.bump.alloc((root_hash, None))); + let root_hash = self.bump.alloc(root_hash); + self.nodes.borrow_mut().push((&*root_hash, None)); self } @@ -249,54 +281,10 @@ impl<'a, Db, V> SnapshotBuilder<'a, Db, V> { state.build() } } - - #[inline(always)] - fn get_from_db( - bump: &'a Bump, - db: &Db, - hash: &NodeHash, - next_idx: Idx, - ) -> Result< - ( - Node<&'a Branch, &'a Leaf>, - Option, - Option, - ), - String, - > - where - Db: DatabaseGet, - { - let node = db - .get(hash) - .map_err(|e| format!("Error getting {hash} from database: `{e}`"))?; - - Ok(match node { - Node::Branch(Branch { - mask, - left, - right, - prior_word, - prefix, - }) => ( - Node::Branch(&*bump.alloc(Branch { - mask, - left: next_idx, - right: next_idx + 1, - prior_word, - prefix, - })), - Some(left), - Some(right), - ), - - Node::Leaf(leaf) => (Node::Leaf(&*bump.alloc(leaf)), None, None), - }) - } } struct SnapshotBuilderFold<'v, 'a, V> { - nodes: &'v [&'a NodeHashMaybeNode<'a, V>], + nodes: &'v [NodeHashMaybeNode<'a, V>], /// The count of branches that will be in the snapshot branch_count: u32, /// The count of leaves that will be in the snapshot @@ -309,7 +297,7 @@ struct SnapshotBuilderFold<'v, 'a, V> { } impl<'v, 'a, V> SnapshotBuilderFold<'v, 'a, V> { - fn new(nodes: &'v [&'a NodeHashMaybeNode<'_, V>]) -> Self { + fn new(nodes: &'v [NodeHashMaybeNode<'a, V>]) -> Self { let mut branch_count = 0; let mut leaf_count = 0; let mut unvisited_count = 0; diff --git a/tests/build_store_modify.rs b/tests/build_store_modify.rs index f2ee259..d0ad7dc 100644 --- a/tests/build_store_modify.rs +++ b/tests/build_store_modify.rs @@ -66,9 +66,9 @@ fn end_to_end_example(maps: Vec>) { SnapshotBuilder::<_, [u8; 8]>::empty(db, &bump).with_trie_root_hash(prior_root_hash), ); - // for (k, v) in merged_map.iter() { - // let v = v.to_be_bytes(); - // let ret_v = txn.get(k).unwrap().unwrap(); - // assert_eq!(v, *ret_v); - // } + for (k, v) in merged_map.iter() { + let v = v.to_le_bytes(); + let ret_v = txn.get(k).unwrap().unwrap(); + assert_eq!(v, *ret_v); + } }