diff --git a/firewood/src/merkle/stream.rs b/firewood/src/merkle/stream.rs index a3442e8bc..39e01df2b 100644 --- a/firewood/src/merkle/stream.rs +++ b/firewood/src/merkle/stream.rs @@ -88,13 +88,40 @@ impl<'a, S: ShaleStore + Send + Sync, T> Stream for MerkleKeyValueStream<' .get_node(*merkle_root) .map_err(|e| api::Error::InternalError(Box::new(e)))?; - // `get_node_and_parents_by_key` will traverse the trie, nibble by nibble, pushing the - // node-ref of each node it visited into a the "parents" vec. - // if it reaches a node that matches the last nibble AND that node has a value - // it will return with (Some(node), parents), otherwise `(None, parents)` - let (found_node, mut visited_node_path) = merkle - .get_node_and_parents_by_key(root_node, &key) - .map_err(|e| api::Error::InternalError(Box::new(e)))?; + // traverse the trie along each nibble until we find a node with a value + // TODO: merkle.iter_by_key(key) will simplify this entire code-block. + let (found_node, mut visited_node_path) = { + let mut visited_node_path = vec![]; + + let found_node = merkle + .get_node_by_key_with_callbacks( + root_node, + &key, + |node_addr, i| visited_node_path.push((node_addr, i)), + |_, _| {}, + ) + .map_err(|e| api::Error::InternalError(Box::new(e)))?; + + let mut visited_node_path = visited_node_path + .into_iter() + .map(|(node, pos)| merkle.get_node(node).map(|node| (node, pos))) + .collect::, _>>() + .map_err(|e| api::Error::InternalError(Box::new(e)))?; + + let last_visited_node_not_branch = visited_node_path + .last() + .map(|(node, _)| { + matches!(node.inner(), NodeType::Leaf(_) | NodeType::Extension(_)) + }) + .unwrap_or_default(); + + // we only want branch in the visited node-path to start + if last_visited_node_not_branch { + visited_node_path.pop(); + } + + (found_node, visited_node_path) + }; if let Some(found_node) = found_node { let value = match found_node.inner() { @@ -122,6 +149,11 @@ impl<'a, S: ShaleStore + Send + Sync, T> Stream for MerkleKeyValueStream<' } IteratorState::Iterating { visited_node_path } => { + visited_node_path + .last() + .as_ref() + .map(|(last, pos)| (last.inner(), pos)); + let next = find_next_result(merkle, visited_node_path) .map_err(|e| api::Error::InternalError(Box::new(e))) .transpose(); @@ -676,6 +708,9 @@ mod tests { check_stream_is_done(stream).await; } + // TODO: proper test for key between siblings 0x00 and 0xff + // TODO: start key greater than all keys + async fn check_stream_is_done(mut stream: S) where S: FusedStream + Unpin,