Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start to move the prefix out of the Branch, replacing prior word and prefix vec #5

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/stored.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ pub trait Store<V> {
hash_idx: Idx,
) -> Result<NodeHash, Self::Error>;

fn get_node(&self, hash_idx: Idx) -> Result<Node<&Branch<Idx>, &Leaf<V>>, Self::Error>;
fn get_node<'s>(
&'s self,
hash_idx: Idx,
) -> Result<Node<&'s Branch<'s, Idx>, &'s Leaf<V>>, Self::Error>;
}

impl<V, S: Store<V>> Store<V> for &S {
Expand Down
6 changes: 4 additions & 2 deletions src/stored/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use bumpalo::Bump;
use ouroboros::self_referencing;

use crate::{
transaction::nodes::{NodeRef, TrieRoot},
transaction::nodes::{NodeRef, PrefixesBuffer, TrieRoot},
Branch, Leaf, PortableHash, PortableHasher, TrieError,
};

Expand All @@ -20,12 +20,14 @@ type Result<T, E = TrieError> = core::result::Result<T, E>;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Snapshot<V> {
/// The last branch is the root of the trie if it exists.
branches: Box<[Branch<Idx>]>,
branches: Box<[Branch<'static, Idx>]>,
/// A Snapshot containing only
leaves: Box<[Leaf<V>]>,

// we only store the hashes of the nodes that have not been visited.
unvisited_nodes: Box<[NodeHash]>,

prefixies_buffer: PrefixesBuffer,
}

impl<V: PortableHash> Snapshot<V> {
Expand Down
7 changes: 5 additions & 2 deletions src/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ use crate::{
};

use self::nodes::{
Branch, KeyPosition, KeyPositionAdjacent, Leaf, Node, NodeRef, StoredLeafRef, TrieRoot,
Branch, KeyPosition, KeyPositionAdjacent, Leaf, Node, NodeRef, PrefixesBuffer, StoredLeafRef,
TrieRoot,
};

pub struct Transaction<S, V> {
Expand All @@ -42,7 +43,6 @@ impl<Db: DatabaseSet<V>, V: Clone + PortableHash> Transaction<SnapshotBuilder<Db
left,
right,
mask: branch.mask,
prior_word: branch.prior_word,
prefix: branch.prefix.clone(),
};

Expand Down Expand Up @@ -395,6 +395,7 @@ impl<S: Store<V>, V> Transaction<S, V> {
mask: new_branch.mask,
prior_word: new_branch.prior_word,
prefix: new_branch.prefix.clone(),
prefix,
}));

continue;
Expand Down Expand Up @@ -547,6 +548,7 @@ impl<Db, V: PortableHash + Clone> Transaction<SnapshotBuilder<Db, V>, V> {
Transaction {
current_root: builder.trie_root(),
data_store: builder,
prefixes_buffer: builder.prefixes_buffer,
}
}
}
Expand Down Expand Up @@ -580,6 +582,7 @@ impl<V: PortableHash + Clone> Transaction<Snapshot<V>, V> {
Ok(Transaction {
current_root: snapshot.trie_root()?,
data_store: snapshot,
prefixes_buffer: snapshot.prefixes_buffer,
})
}
}
Expand Down
127 changes: 89 additions & 38 deletions src/transaction/nodes.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use alloc::boxed::Box;
use core::{fmt, iter, mem};
use core::{fmt, iter, mem, slice, usize};

use crate::{hash::PortableHasher, stored, KeyHash, NodeHash, PortableHash, PortableUpdate};

Expand Down Expand Up @@ -74,20 +74,20 @@ pub enum Node<B, L> {
/// When executing against a `SnapshotBuilder`, it's a reference to a `NodeHash`,
/// which can in turn be used to retrieve the `Node`.
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum NodeRef<V> {
ModBranch(Box<Branch<Self>>),
pub enum NodeRef<'s, V> {
ModBranch(Box<Branch<'s, Self>>),
ModLeaf(Box<Leaf<V>>),
Stored(stored::Idx),
}

impl<V> NodeRef<V> {
impl<V> NodeRef<'_, V> {
#[inline(always)]
pub fn temp_null_stored() -> Self {
NodeRef::Stored(u32::MAX)
}
}

impl<V> fmt::Debug for NodeRef<V> {
impl<V> fmt::Debug for NodeRef<'_, V> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ModBranch(b) => f.debug_tuple("ModBranch").field(b).finish(),
Expand All @@ -97,7 +97,7 @@ impl<V> fmt::Debug for NodeRef<V> {
}
}

impl<V> From<Box<Branch<NodeRef<V>>>> for NodeRef<V> {
impl<'s, V> From<Box<Branch<'s, NodeRef<'s, V>>>> for NodeRef<'s, V> {
#[inline]
fn from(branch: Box<Branch<NodeRef<V>>>) -> Self {
NodeRef::ModBranch(branch)
Expand Down Expand Up @@ -263,27 +263,87 @@ mod tests {
}
}

pub struct PrefixBufferRef {
/// This value will be 0 if the branch occurs in the first word of the hash key.
/// The value is the prior word if the branches parent's word index no more than 1 less.
/// If the parent's word index is more than 1 word prior,
/// we must store the multiword prefix outside of the branch, so the value is the index of the prefix.
/// The length of the prefix is the difference between the parent's word index and the branch's word index.
prior_word_or_prefix_idx: u32,
}

impl PrefixBufferRef {
pub fn get_prefix<'s, 'txn: 's>(
&'s self,
prefixies: &'txn PrefixesBuffer,
word_idx: usize,
parent_word_idx: usize,
) -> Option<&'s [u32]> {
if word_idx - parent_word_idx <= 1 {
if cfg!(debug_assertions) && word_idx == 0 {
debug_assert_eq!(self.prior_word_or_prefix_idx, 0);
}

Some(slice::from_ref(&self.prior_word_or_prefix_idx))
} else {
let prefix_len = word_idx - parent_word_idx;
debug_assert!(prefix_len > 1);
prefixies.get_prefix(
self.prior_word_or_prefix_idx as usize,
word_idx - parent_word_idx,
)
}
}
}

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct PrefixesBuffer {
buffer: Vec<u32>,
}

impl PrefixesBuffer {
pub fn new() -> Self {
Self { buffer: Vec::new() }
}

pub fn push_prefix(&mut self, prefix: &[u32]) -> u32 {
let idx = self.buffer.len() as u32;
self.buffer.extend_from_slice(prefix);
idx
}

pub fn get_prefix(&self, idx: usize, len: usize) -> Option<&[u32]> {
let end = idx + len;
self.buffer.get(idx..end)
}
}

#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum PrefixCow<'a> {
StartOfKey,
PriorWord(u32),
Segment(&'a [u32]),
SegmentOwned(Box<[u32]>),
}

/// A branch node in the trie.
/// `NR` is the type of the node references.
/// `PR` is the type of reference to the prefix.
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct Branch<NR> {
pub struct Branch<'s, NR> {
pub left: NR,
pub right: NR,
pub mask: BranchMask,
/// The word at the `(bit_idx / 32) - 1`.
/// Common to both children.
/// Will be 0 if this node is the root.
pub prior_word: u32,
/// The the segment of the hash key from the parent branch to `prior_word`.
/// Will be empty if the parent_branch.mask.bit_idx / 32 == self.mask.bit_idx / 32.
pub prefix: Box<[u32]>,
pub prefix: PrefixCow<'s>,
}

impl<NR> fmt::Debug for Branch<NR> {
impl<NR> fmt::Debug for Branch<'_, NR> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Branch")
.field("mask", &self.mask)
.field("prior_word", &self.prior_word)
.field("prefix", &self.prefix)
.finish()
}
Expand All @@ -310,7 +370,7 @@ pub enum KeyPositionAdjacent {
PrefixVec(usize),
}

impl<NR> Branch<NR> {
impl<NR> Branch<NR, PrefixCow<'a>> {
/// Returns the position of the key relative to the branch.
#[inline(always)]
pub fn key_position(&self, key_hash: &KeyHash) -> KeyPosition {
Expand Down Expand Up @@ -352,6 +412,7 @@ impl<NR> Branch<NR> {
/// Hash a branch node with known child hashes.
///
/// Caller must ensure that the hasher is reset before calling this function.
///
#[inline]
pub fn hash_branch<H: PortableHasher<32>>(
&self,
Expand All @@ -361,6 +422,7 @@ impl<NR> Branch<NR> {
) -> NodeHash {
hasher.portable_update(left);
hasher.portable_update(right);
// Security: it's important to hash the metadata to avoid a potential trie corruption attack.
hasher.portable_update(self.mask.bit_idx.to_le_bytes());
hasher.portable_update(self.mask.left_prefix.to_le_bytes());
hasher.portable_update(self.prior_word.to_le_bytes());
Expand All @@ -373,19 +435,7 @@ impl<NR> Branch<NR> {
}
}

impl<V> Branch<NodeRef<V>> {
pub(crate) fn from_stored(branch: &Branch<stored::Idx>) -> Branch<NodeRef<V>> {
Branch {
left: NodeRef::Stored(branch.left),
right: NodeRef::Stored(branch.right),
mask: branch.mask,
prior_word: branch.prior_word,
// TODO remove the clone
// Maybe use a AsRef<[u32]> instead of Box<[u32]>
prefix: branch.prefix.clone(),
}
}

impl<V, PR> Branch<NodeRef<V>, PR> {
/// A wrapper around `new_at_branch_ret` which returns nothing.
/// This exists to aid compiler inlining.
///
Expand Down Expand Up @@ -540,17 +590,19 @@ impl<V> Branch<NodeRef<V>> {

debug_assert!(new_leaf.key_hash.0[..word_idx] == old_leaf.as_ref().key_hash.0[..word_idx]);

let prior_word_idx = word_idx.saturating_sub(1);
let prefix = new_leaf.key_hash.0[prefix_start_idx..prior_word_idx].into();
let prior_word = if word_idx == 0 {
0
} else {
let prefix = if word_idx == 0 {
PrefixCow::StartOfKey
} else if prefix_start_idx == word_idx {
let prior_word_idx = word_idx - 1;
debug_assert_eq!(
new_leaf.key_hash.0[prior_word_idx],
old_leaf.as_ref().key_hash.0[prior_word_idx]
);

new_leaf.key_hash.0[prior_word_idx]
PrefixCow::PriorWord(new_leaf.key_hash.0[prior_word_idx])
} else if prefix_start_idx == word_idx - 1 {
PrefixCow::PriorWord(new_leaf.key_hash.0[word_idx - 1])
} else {
PrefixCow::Segment(&new_leaf.key_hash.0[prefix_start_idx..word_idx])
};

let mask = BranchMask::new(word_idx as u32, a, b);
Expand All @@ -577,7 +629,6 @@ impl<V> Branch<NodeRef<V>> {
left,
right,
mask,
prior_word,
prefix,
}),
// TODO use an enum
Expand Down
Loading