diff --git a/src/allocator.rs b/src/allocator.rs index e5dd0771..0927d4d5 100644 --- a/src/allocator.rs +++ b/src/allocator.rs @@ -2,6 +2,7 @@ use crate::err_utils::err; use crate::number::{node_from_number, number_from_u8, Number}; use crate::reduction::EvalErr; use chia_bls::{G1Element, G2Element}; +use std::cell::RefCell; const MAX_NUM_ATOMS: usize = 62500000; const MAX_NUM_PAIRS: usize = 62500000; @@ -12,36 +13,42 @@ const NODE_PTR_IDX_MASK: u32 = (1 << NODE_PTR_IDX_BITS) - 1; pub struct NodePtr(u32); enum ObjectType { + // The low bits form an index into the pair_vec Pair, + // The low bits form an index into the atom_vec Bytes, + // The low bits are the atom itself (unsigned integer, 26 bits) + SmallAtom, } // The top 6 bits of the NodePtr indicate what type of object it is impl NodePtr { - pub const NIL: Self = Self::new(ObjectType::Bytes, 0); + pub const NIL: Self = Self::new(ObjectType::SmallAtom, 0); const fn new(t: ObjectType, idx: usize) -> Self { debug_assert!(idx <= NODE_PTR_IDX_MASK as usize); NodePtr(((t as u32) << NODE_PTR_IDX_BITS) | (idx as u32)) } - fn node_type(&self) -> (ObjectType, usize) { + fn node_type(&self) -> (ObjectType, u32) { ( match self.0 >> NODE_PTR_IDX_BITS { 0 => ObjectType::Pair, 1 => ObjectType::Bytes, + 2 => ObjectType::SmallAtom, _ => { panic!("unknown NodePtr type"); } }, - (self.0 & NODE_PTR_IDX_MASK) as usize, + (self.0 & NODE_PTR_IDX_MASK), ) } pub(crate) fn as_index(&self) -> usize { match self.node_type() { - (ObjectType::Pair, idx) => idx * 2, - (ObjectType::Bytes, idx) => idx * 2 + 1, + (ObjectType::Pair, idx) => (idx as usize) * 3, + (ObjectType::Bytes, idx) => (idx as usize) * 3 + 1, + (ObjectType::SmallAtom, idx) => (idx as usize) * 3 + 2, } } } @@ -83,6 +90,13 @@ pub struct Checkpoint { u8s: usize, pairs: usize, atoms: usize, + small_atoms: usize, +} + +pub enum NodeVisitor<'a> { + Buffer(&'a [u8]), + U32(u32), + Pair(NodePtr, NodePtr), } #[derive(Debug)] @@ -100,8 +114,18 @@ pub struct Allocator { // on. atom_vec: Vec, + // index into temp_buf array + temp_idx: RefCell, + + // temporary buffers for storing SmallAtoms in to return from atom() + temp_vec: Vec>, + // the atom_vec may not grow past this heap_limit: usize, + + // the number of small atoms we've allocated. We keep track of these to ensure the limit on the + // number of atoms is identical to what it was before the small-atom optimization + small_atoms: usize, } impl Default for Allocator { @@ -110,6 +134,40 @@ impl Default for Allocator { } } +pub fn canonical_positive_integer(v: &[u8]) -> bool { + if v.is_empty() { + // empty buffer is 0/nil + true + } else if (v.len() == 1 && v[0] == 0) + // a 1-byte buffer of 0 is not the canonical representation of 0 + || (v[0] & 0x80) != 0 + // if the top bit is set, it's a negative number (i.e. not positive) + || (v[0] == 0 && (v[1] & 0x80) == 0) + { + // if the top byte is a 0 but the top bit of the next byte is not set, that's a redundant + // leading zero. i.e. not canonical representation + false + } else { + true + } +} + +pub fn len_for_value(val: u32) -> usize { + if val == 0 { + 0 + } else if val < 0x80 { + 1 + } else if val < 0x8000 { + 2 + } else if val < 0x800000 { + 3 + } else if val < 0x80000000 { + 4 + } else { + 5 + } +} + impl Allocator { pub fn new() -> Self { Self::new_limited(u32::MAX as usize) @@ -119,20 +177,26 @@ impl Allocator { // we have a maximum of 4 GiB heap, because pointers are 32 bit unsigned assert!(heap_limit <= u32::MAX as usize); + let mut temp_vec = Vec::>::with_capacity(64); + for _ in 0..16 { + temp_vec.push(RefCell::default()); + } + let mut r = Self { u8_vec: Vec::new(), pair_vec: Vec::new(), atom_vec: Vec::new(), - heap_limit, + temp_idx: RefCell::new(0), + temp_vec, + // subtract 1 to compensate for the one() we used to allocate unconfitionally + heap_limit: heap_limit - 1, + // initialize this to 2 to behave as if we had allocated atoms for + // nil() and one(), like we used to + small_atoms: 2, }; r.u8_vec.reserve(1024 * 1024); r.atom_vec.reserve(256); r.pair_vec.reserve(256); - r.u8_vec.push(1_u8); - // Preallocated empty list - r.atom_vec.push(AtomBuf { start: 0, end: 0 }); - // Preallocated 1 - r.atom_vec.push(AtomBuf { start: 0, end: 1 }); r } @@ -144,6 +208,7 @@ impl Allocator { u8s: self.u8_vec.len(), pairs: self.pair_vec.len(), atoms: self.atom_vec.len(), + small_atoms: self.small_atoms, } } @@ -158,6 +223,7 @@ impl Allocator { self.u8_vec.truncate(cp.u8s); self.pair_vec.truncate(cp.pairs); self.atom_vec.truncate(cp.atoms); + self.small_atoms = cp.small_atoms; } pub fn new_atom(&mut self, v: &[u8]) -> Result { @@ -166,16 +232,37 @@ impl Allocator { return err(self.nil(), "out of memory"); } let idx = self.atom_vec.len(); - if idx == MAX_NUM_ATOMS { - return err(self.nil(), "too many atoms"); + self.check_atom_limit()?; + if v.len() <= 3 && canonical_positive_integer(v) { + let mut ret: u32 = 0; + for b in v { + ret <<= 8; + ret |= *b as u32; + } + self.small_atoms += 1; + Ok(NodePtr::new(ObjectType::SmallAtom, ret as usize)) + } else { + self.u8_vec.extend_from_slice(v); + let end = self.u8_vec.len() as u32; + self.atom_vec.push(AtomBuf { start, end }); + Ok(NodePtr::new(ObjectType::Bytes, idx)) } - self.u8_vec.extend_from_slice(v); - let end = self.u8_vec.len() as u32; - self.atom_vec.push(AtomBuf { start, end }); - Ok(NodePtr::new(ObjectType::Bytes, idx)) + } + + pub fn new_small_number(&mut self, v: u32) -> Result { + debug_assert!(v <= NODE_PTR_IDX_MASK); + self.check_atom_limit()?; + self.small_atoms += 1; + Ok(NodePtr::new(ObjectType::SmallAtom, v as usize)) } pub fn new_number(&mut self, v: Number) -> Result { + use num_traits::ToPrimitive; + if let Some(val) = v.to_u32() { + if val <= NODE_PTR_IDX_MASK { + return self.new_small_number(val); + } + } node_from_number(self, &v) } @@ -197,35 +284,65 @@ impl Allocator { } pub fn new_substr(&mut self, node: NodePtr, start: u32, end: u32) -> Result { - if self.atom_vec.len() == MAX_NUM_ATOMS { - return err(self.nil(), "too many atoms"); - } - let (ObjectType::Bytes, idx) = node.node_type() else { - return err(node, "(internal error) substr expected atom, got pair"); - }; - let atom = self.atom_vec[idx]; - let atom_len = atom.end - atom.start; - if start > atom_len { - return err(node, "substr start out of bounds"); - } - if end > atom_len { - return err(node, "substr end out of bounds"); + self.check_atom_limit()?; + + fn bounds_check(node: NodePtr, start: u32, end: u32, len: u32) -> Result<(), EvalErr> { + if start > len { + return err(node, "substr start out of bounds"); + } + if end > len { + return err(node, "substr end out of bounds"); + } + if end < start { + return err(node, "substr invalid bounds"); + } + Ok(()) } - if end < start { - return err(node, "substr invalid bounds"); + + match node.node_type() { + (ObjectType::Pair, _) => err(node, "(internal error) substr expected atom, got pair"), + (ObjectType::Bytes, idx) => { + let atom = self.atom_vec[idx as usize]; + let atom_len = atom.end - atom.start; + bounds_check(node, start, end, atom_len)?; + let idx = self.atom_vec.len(); + self.atom_vec.push(AtomBuf { + start: atom.start + start, + end: atom.start + end, + }); + Ok(NodePtr::new(ObjectType::Bytes, idx)) + } + (ObjectType::SmallAtom, val) => { + let len = len_for_value(val) as u32; + bounds_check(node, start, end, len)?; + let buf: [u8; 4] = val.to_be_bytes(); + let buf = &buf[4 - len as usize..]; + let substr = &buf[start as usize..end as usize]; + if !canonical_positive_integer(substr) { + let start = self.u8_vec.len(); + let end = start + substr.len(); + self.u8_vec.extend_from_slice(substr); + let idx = self.atom_vec.len(); + self.atom_vec.push(AtomBuf { + start: start as u32, + end: end as u32, + }); + Ok(NodePtr::new(ObjectType::Bytes, idx)) + } else { + let mut new_val: u32 = 0; + for i in substr { + new_val <<= 8; + new_val |= *i as u32; + } + self.small_atoms += 1; + Ok(NodePtr::new(ObjectType::SmallAtom, new_val as usize)) + } + } } - let idx = self.atom_vec.len(); - self.atom_vec.push(AtomBuf { - start: atom.start + start, - end: atom.start + end, - }); - Ok(NodePtr::new(ObjectType::Bytes, idx)) } pub fn new_concat(&mut self, new_size: usize, nodes: &[NodePtr]) -> Result { - if self.atom_vec.len() == MAX_NUM_ATOMS { - return err(self.nil(), "too many atoms"); - } + self.check_atom_limit()?; let start = self.u8_vec.len(); if self.heap_limit - start < new_size { return err(self.nil(), "out of memory"); @@ -234,19 +351,29 @@ impl Allocator { let mut counter: usize = 0; for node in nodes { - let (ObjectType::Bytes, idx) = node.node_type() else { - self.u8_vec.truncate(start); - return err(*node, "(internal error) concat expected atom, got pair"); - }; - - let term = self.atom_vec[idx]; - if counter + term.len() > new_size { - self.u8_vec.truncate(start); - return err(*node, "(internal error) concat passed invalid new_size"); + match node.node_type() { + (ObjectType::Pair, _) => { + self.u8_vec.truncate(start); + return err(*node, "(internal error) concat expected atom, got pair"); + } + (ObjectType::Bytes, idx) => { + let term = self.atom_vec[idx as usize]; + if counter + term.len() > new_size { + self.u8_vec.truncate(start); + return err(*node, "(internal error) concat passed invalid new_size"); + } + self.u8_vec + .extend_from_within(term.start as usize..term.end as usize); + counter += term.len(); + } + (ObjectType::SmallAtom, val) => { + let len = len_for_value(val) as u32; + let buf: [u8; 4] = val.to_be_bytes(); + let buf = &buf[4 - len as usize..]; + self.u8_vec.extend_from_slice(buf); + counter += len as usize; + } } - self.u8_vec - .extend_from_within(term.start as usize..term.end as usize); - counter += term.len(); } if counter != new_size { self.u8_vec.truncate(start); @@ -265,15 +392,67 @@ impl Allocator { } pub fn atom_eq(&self, lhs: NodePtr, rhs: NodePtr) -> bool { - self.atom(lhs) == self.atom(rhs) + match (lhs.node_type(), rhs.node_type()) { + ((ObjectType::Pair, _), _) | (_, (ObjectType::Pair, _)) => { + panic!("atom_eq() called on pair"); + } + ((ObjectType::Bytes, lhs), (ObjectType::Bytes, rhs)) => { + let lhs = self.atom_vec[lhs as usize]; + let rhs = self.atom_vec[rhs as usize]; + self.u8_vec[lhs.start as usize..lhs.end as usize] + == self.u8_vec[rhs.start as usize..rhs.end as usize] + } + ((ObjectType::SmallAtom, lhs), (ObjectType::SmallAtom, rhs)) => lhs == rhs, + ((ObjectType::SmallAtom, val), (ObjectType::Bytes, idx)) + | ((ObjectType::Bytes, idx), (ObjectType::SmallAtom, val)) => { + let atom = self.atom_vec[idx as usize]; + let len = len_for_value(val) as u32; + if (atom.end - atom.start) != len { + return false; + } + if val == 0 { + return true; + } + + if self.u8_vec[atom.start as usize] & 0x80 != 0 { + // SmallAtom only represents positive values + // if the byte buffer is negative, they can't match + return false; + } + + // since we know the value of atom is small, we can turn it into a u32 and compare + // against val + let mut atom_val: u32 = 0; + for i in atom.start..atom.end { + atom_val <<= 8; + atom_val |= self.u8_vec[i as usize] as u32; + } + val == atom_val + } + } } pub fn atom(&self, node: NodePtr) -> &[u8] { match node.node_type() { (ObjectType::Bytes, idx) => { - let atom = self.atom_vec[idx]; + let atom = self.atom_vec[idx as usize]; &self.u8_vec[atom.start as usize..atom.end as usize] } + (ObjectType::SmallAtom, val) => { + let len = len_for_value(val); + let mut idx = self.temp_idx.borrow_mut(); + *self.temp_vec[*idx].borrow_mut() = val.to_be_bytes(); + let ret = unsafe { + self.temp_vec[*idx] + .try_borrow_unguarded() + .expect("(internal error) temporary buffer problem in Allocator::atom()") + }; + *idx += 1; + if *idx == self.temp_vec.len() { + *idx = 0; + } + &ret[4 - len..] + } _ => { panic!("expected atom, got pair"); } @@ -283,52 +462,101 @@ impl Allocator { pub fn atom_len(&self, node: NodePtr) -> usize { match node.node_type() { (ObjectType::Bytes, idx) => { - let atom = self.atom_vec[idx]; + let atom = self.atom_vec[idx as usize]; (atom.end - atom.start) as usize } + (ObjectType::SmallAtom, val) => len_for_value(val), _ => { panic!("expected atom, got pair"); } } } + pub fn small_number(&self, node: NodePtr) -> Option { + match node.node_type() { + (ObjectType::SmallAtom, val) => Some(val), + _ => None, + } + } + pub fn number(&self, node: NodePtr) -> Number { - number_from_u8(self.atom(node)) + match node.node_type() { + (ObjectType::Bytes, idx) => { + let atom = self.atom_vec[idx as usize]; + number_from_u8(&self.u8_vec[atom.start as usize..atom.end as usize]) + } + (ObjectType::SmallAtom, val) => Number::from(val), + _ => { + panic!("number() calld on pair"); + } + } } pub fn g1(&self, node: NodePtr) -> Result { - let blob = match self.sexp(node) { - SExp::Atom => self.atom(node), - _ => { + let idx = match node.node_type() { + (ObjectType::Bytes, idx) => idx, + (ObjectType::SmallAtom, _) => { + return err(node, "atom is not G1 size, 48 bytes"); + } + (ObjectType::Pair, _) => { return err(node, "pair found, expected G1 point"); } }; - let array: [u8; 48] = blob + let atom = self.atom_vec[idx as usize]; + if atom.end - atom.start != 48 { + return err(node, "atom is not G1 size, 48 bytes"); + } + + let array: &[u8; 48] = &self.u8_vec[atom.start as usize..atom.end as usize] .try_into() - .map_err(|_| EvalErr(node, "atom is not G1 size, 48 bytes".to_string()))?; - G1Element::from_bytes(&array) + .expect("atom size is not 48 bytes"); + G1Element::from_bytes(array) .map_err(|_| EvalErr(node, "atom is not a G1 point".to_string())) } pub fn g2(&self, node: NodePtr) -> Result { - let blob = match self.sexp(node) { - SExp::Atom => self.atom(node), - _ => { + let idx = match node.node_type() { + (ObjectType::Bytes, idx) => idx, + (ObjectType::SmallAtom, _) => { + return err(node, "atom is not G2 size, 96 bytes"); + } + (ObjectType::Pair, _) => { return err(node, "pair found, expected G2 point"); } }; - let array = blob + let atom = self.atom_vec[idx as usize]; + if atom.end - atom.start != 96 { + return err(node, "atom is not G2 size, 96 bytes"); + } + + let array: &[u8; 96] = &self.u8_vec[atom.start as usize..atom.end as usize] .try_into() - .map_err(|_| EvalErr(node, "atom is not G2 size, 96 bytes".to_string()))?; - G2Element::from_bytes(&array) + .expect("atom size is not 96 bytes"); + + G2Element::from_bytes(array) .map_err(|_| EvalErr(node, "atom is not a G2 point".to_string())) } + pub fn node<'a>(&'a self, node: NodePtr) -> NodeVisitor<'a> { + match node.node_type() { + (ObjectType::Bytes, idx) => { + let atom = self.atom_vec[idx as usize]; + let buf = &self.u8_vec[atom.start as usize..atom.end as usize]; + NodeVisitor::<'a>::Buffer(buf) + } + (ObjectType::SmallAtom, val) => NodeVisitor::U32(val), + (ObjectType::Pair, idx) => { + let pair = self.pair_vec[idx as usize]; + NodeVisitor::Pair(pair.first, pair.rest) + } + } + } + pub fn sexp(&self, node: NodePtr) -> SExp { match node.node_type() { - (ObjectType::Bytes, _) => SExp::Atom, + (ObjectType::Bytes, _) | (ObjectType::SmallAtom, _) => SExp::Atom, (ObjectType::Pair, idx) => { - let pair = self.pair_vec[idx]; + let pair = self.pair_vec[idx as usize]; SExp::Pair(pair.first, pair.rest) } } @@ -347,11 +575,20 @@ impl Allocator { } pub fn nil(&self) -> NodePtr { - NodePtr::new(ObjectType::Bytes, 0) + NodePtr::new(ObjectType::SmallAtom, 0) } pub fn one(&self) -> NodePtr { - NodePtr::new(ObjectType::Bytes, 1) + NodePtr::new(ObjectType::SmallAtom, 1) + } + + #[inline] + fn check_atom_limit(&self) -> Result<(), EvalErr> { + if self.atom_vec.len() + self.small_atoms == MAX_NUM_ATOMS { + err(self.nil(), "too many atoms") + } else { + Ok(()) + } } #[cfg(feature = "counters")] @@ -359,6 +596,11 @@ impl Allocator { self.atom_vec.len() } + #[cfg(feature = "counters")] + pub fn small_atom_count(&self) -> usize { + self.small_atoms + } + #[cfg(feature = "counters")] pub fn pair_count(&self) -> usize { self.pair_vec.len() @@ -373,14 +615,14 @@ impl Allocator { #[test] fn test_node_as_index() { assert_eq!(NodePtr::new(ObjectType::Pair, 0).as_index(), 0); - assert_eq!(NodePtr::new(ObjectType::Pair, 1).as_index(), 2); - assert_eq!(NodePtr::new(ObjectType::Pair, 2).as_index(), 4); - assert_eq!(NodePtr::new(ObjectType::Pair, 3).as_index(), 6); + assert_eq!(NodePtr::new(ObjectType::Pair, 1).as_index(), 3); + assert_eq!(NodePtr::new(ObjectType::Pair, 2).as_index(), 6); + assert_eq!(NodePtr::new(ObjectType::Pair, 3).as_index(), 9); assert_eq!(NodePtr::new(ObjectType::Bytes, 0).as_index(), 1); - assert_eq!(NodePtr::new(ObjectType::Bytes, 1).as_index(), 3); - assert_eq!(NodePtr::new(ObjectType::Bytes, 2).as_index(), 5); - assert_eq!(NodePtr::new(ObjectType::Bytes, 3).as_index(), 7); - assert_eq!(NodePtr::new(ObjectType::Bytes, 4).as_index(), 9); + assert_eq!(NodePtr::new(ObjectType::Bytes, 1).as_index(), 4); + assert_eq!(NodePtr::new(ObjectType::Bytes, 2).as_index(), 7); + assert_eq!(NodePtr::new(ObjectType::Bytes, 3).as_index(), 10); + assert_eq!(NodePtr::new(ObjectType::Bytes, 4).as_index(), 13); } #[test] @@ -396,36 +638,49 @@ fn test_atom_eq_1() { }; let a3 = a.new_substr(a2, 0, 1).unwrap(); let a4 = a.new_number(1.into()).unwrap(); + let a5 = a.new_small_number(1).unwrap(); assert!(a.atom_eq(a0, a0)); assert!(a.atom_eq(a0, a1)); assert!(a.atom_eq(a0, a2)); assert!(a.atom_eq(a0, a3)); assert!(a.atom_eq(a0, a4)); + assert!(a.atom_eq(a0, a5)); assert!(a.atom_eq(a1, a0)); assert!(a.atom_eq(a1, a1)); assert!(a.atom_eq(a1, a2)); assert!(a.atom_eq(a1, a3)); assert!(a.atom_eq(a1, a4)); + assert!(a.atom_eq(a1, a5)); assert!(a.atom_eq(a2, a0)); assert!(a.atom_eq(a2, a1)); assert!(a.atom_eq(a2, a2)); assert!(a.atom_eq(a2, a3)); assert!(a.atom_eq(a2, a4)); + assert!(a.atom_eq(a2, a5)); assert!(a.atom_eq(a3, a0)); assert!(a.atom_eq(a3, a1)); assert!(a.atom_eq(a3, a2)); assert!(a.atom_eq(a3, a3)); assert!(a.atom_eq(a3, a4)); + assert!(a.atom_eq(a3, a5)); assert!(a.atom_eq(a4, a0)); assert!(a.atom_eq(a4, a1)); assert!(a.atom_eq(a4, a2)); assert!(a.atom_eq(a4, a3)); assert!(a.atom_eq(a4, a4)); + assert!(a.atom_eq(a4, a5)); + + assert!(a.atom_eq(a5, a0)); + assert!(a.atom_eq(a5, a1)); + assert!(a.atom_eq(a5, a2)); + assert!(a.atom_eq(a5, a3)); + assert!(a.atom_eq(a5, a4)); + assert!(a.atom_eq(a5, a5)); } #[test] @@ -468,9 +723,9 @@ fn test_atom_eq() { let a0 = a.nil(); let a1 = a.one(); let a2 = a.new_atom(&[1]).unwrap(); - let a3 = a.new_atom(&[0x5, 0x39]).unwrap(); - let a4 = a.new_number(1.into()).unwrap(); - let a5 = a.new_number(1337.into()).unwrap(); + let a3 = a.new_atom(&[0xfa, 0xc7]).unwrap(); + let a4 = a.new_small_number(1).unwrap(); + let a5 = a.new_number((-1337).into()).unwrap(); assert!(a.atom_eq(a0, a0)); assert!(!a.atom_eq(a0, a1)); @@ -559,6 +814,30 @@ fn test_node_ptr_overflow() { NodePtr::new(ObjectType::Bytes, NODE_PTR_IDX_MASK + 1); } +#[cfg(dbg)] +#[test] +#[should_panic] +fn test_invalid_small_number() { + let mut a = Allocator::new(); + a.new_small_number(NODE_PTR_IDX_MASK + 1); +} + +#[cfg(test)] +#[rstest] +#[case(0, 0)] +#[case(1, 1)] +#[case(0x7f, 1)] +#[case(0x80, 2)] +#[case(0x7fff, 2)] +#[case(0x7fffff, 3)] +#[case(0x800000, 4)] +#[case(0x7fffffff, 4)] +#[case(0x80000000, 5)] +#[case(0xffffffff, 5)] +fn test_len_for_value(#[case] val: u32, #[case] len: usize) { + assert_eq!(len_for_value(val), len); +} + #[test] fn test_nil() { let a = Allocator::new(); @@ -614,7 +893,21 @@ fn test_allocate_atom_limit() { let _ = a.new_atom(b"foo").unwrap(); } assert_eq!(a.new_atom(b"foobar").unwrap_err().1, "too many atoms"); - assert_eq!(a.u8_vec.len(), (MAX_NUM_ATOMS - 2) * 3 + 1); + assert_eq!(a.u8_vec.len(), 0); + assert_eq!(a.small_atoms, MAX_NUM_ATOMS); +} + +#[test] +fn test_allocate_small_number_limit() { + let mut a = Allocator::new(); + + for _ in 0..MAX_NUM_ATOMS - 2 { + // exhaust the number of atoms allowed to be allocated + let _ = a.new_atom(b"foo").unwrap(); + } + assert_eq!(a.new_small_number(3).unwrap_err().1, "too many atoms"); + assert_eq!(a.u8_vec.len(), 0); + assert_eq!(a.small_atoms, MAX_NUM_ATOMS); } #[test] @@ -627,7 +920,8 @@ fn test_allocate_substr_limit() { } let atom = a.new_atom(b"foo").unwrap(); assert_eq!(a.new_substr(atom, 1, 2).unwrap_err().1, "too many atoms"); - assert_eq!(a.u8_vec.len(), (MAX_NUM_ATOMS - 2) * 3 + 1); + assert_eq!(a.u8_vec.len(), 0); + assert_eq!(a.small_atoms, MAX_NUM_ATOMS); } #[test] @@ -640,7 +934,8 @@ fn test_allocate_concat_limit() { } let atom = a.new_atom(b"foo").unwrap(); assert_eq!(a.new_concat(3, &[atom]).unwrap_err().1, "too many atoms"); - assert_eq!(a.u8_vec.len(), (MAX_NUM_ATOMS - 2) * 3 + 1); + assert_eq!(a.u8_vec.len(), 0); + assert_eq!(a.small_atoms, MAX_NUM_ATOMS); } #[test] @@ -694,6 +989,41 @@ fn test_substr() { ); } +#[test] +fn test_substr_small_number() { + let mut a = Allocator::new(); + let atom = a.new_atom(b"a\x80").unwrap(); + assert!(a.small_number(atom).is_some()); + + let sub = a.new_substr(atom, 0, 1).unwrap(); + assert_eq!(a.atom(sub), b"a"); + assert!(a.small_number(sub).is_some()); + let sub = a.new_substr(atom, 1, 2).unwrap(); + assert_eq!(a.atom(sub), b"\x80"); + assert!(a.small_number(sub).is_none()); + let sub = a.new_substr(atom, 1, 1).unwrap(); + assert_eq!(a.atom(sub), b""); + let sub = a.new_substr(atom, 0, 0).unwrap(); + assert_eq!(a.atom(sub), b""); + + assert_eq!( + a.new_substr(atom, 1, 0).unwrap_err().1, + "substr invalid bounds" + ); + assert_eq!( + a.new_substr(atom, 3, 3).unwrap_err().1, + "substr start out of bounds" + ); + assert_eq!( + a.new_substr(atom, 0, 3).unwrap_err().1, + "substr end out of bounds" + ); + assert_eq!( + a.new_substr(atom, u32::MAX, 2).unwrap_err().1, + "substr start out of bounds" + ); +} + #[test] fn test_concat() { let mut a = Allocator::new(); @@ -788,7 +1118,7 @@ fn test_sexp() { #[test] fn test_concat_limit() { - let mut a = Allocator::new_limited(6 + 3); + let mut a = Allocator::new_limited(6); let atom1 = a.new_atom(b"f").unwrap(); let atom2 = a.new_atom(b"o").unwrap(); let atom3 = a.new_atom(b"o").unwrap(); @@ -1328,3 +1658,100 @@ fn test_number_roundtrip(#[case] value: Number) { let atom = a.new_number(value.clone()).expect("new_number()"); assert_eq!(a.number(atom), value); } + +#[cfg(test)] +#[rstest] +#[case(0)] +#[case(1)] +#[case(0x7f)] +#[case(0x80)] +#[case(0xff)] +#[case(0x100)] +#[case(0x7fff)] +#[case(0x8000)] +#[case(0xffff)] +#[case(0x10000)] +#[case(0x7ffff)] +#[case(0x80000)] +#[case(0xfffff)] +#[case(0x100000)] +#[case(0x7fffff)] +#[case(0x800000)] +#[case(0xffffff)] +#[case(0x1000000)] +#[case(0x3ffffff)] +fn test_small_number_roundtrip(#[case] value: u32) { + let mut a = Allocator::new(); + let atom = a.new_small_number(value).expect("new_small_number()"); + assert_eq!(a.small_number(atom).expect("small_number()"), value); +} + +#[cfg(test)] +#[rstest] +#[case(0.into(), true)] +#[case(1.into(), true)] +#[case(0x3ffffff.into(), true)] +#[case(0x4000000.into(), false)] +#[case(0x7f.into(), true)] +#[case(0x80.into(), true)] +#[case(0xff.into(), true)] +#[case(0x100.into(), true)] +#[case(0x7fff.into(), true)] +#[case(0x8000.into(), true)] +#[case(0xffff.into(), true)] +#[case(0x10000.into(), true)] +#[case(0x7ffff.into(), true)] +#[case(0x80000.into(), true)] +#[case(0xfffff.into(), true)] +#[case(0x100000.into(), true)] +#[case(0x7ffffff.into(), false)] +#[case(0x8000000.into(), false)] +#[case(0xfffffff.into(), false)] +#[case(0x10000000.into(), false)] +#[case(0x7ffffffff_u64.into(), false)] +#[case(0x8000000000_u64.into(), false )] +#[case(0xffffffffff_u64.into(), false)] +#[case(0x10000000000_u64.into(), false)] +#[case((-1).into(), false)] +#[case((-0x7f).into(), false)] +#[case((-0x80).into(), false)] +#[case((-0x10000000000_i64).into(), false)] +fn test_auto_small_number(#[case] value: Number, #[case] expect_small: bool) { + let mut a = Allocator::new(); + let atom = a.new_number(value.clone()).expect("new_number()"); + assert_eq!(a.small_number(atom).is_some(), expect_small); + if let Some(v) = a.small_number(atom) { + use num_traits::ToPrimitive; + assert_eq!(v, value.to_u32().unwrap()); + } + assert_eq!(a.number(atom), value); +} + +#[cfg(test)] +#[rstest] +// redundant leading zeros are not canoncial +#[case(&[0x00], false)] +#[case(&[0x00, 0x7f], false)] +// negative numbers cannot be small ints +#[case(&[0x80], false)] +#[case(&[0xff], false)] +#[case(&[0xff, 0xff], false)] +#[case(&[0x80, 0xff, 0xff], false)] +// we use a simple heuristic, for atoms. if we have more than 3 bytes, we assume +// it's not small. Even though it would have fit in 26 bits +#[case(&[0x1, 0xff, 0xff, 0xff], false)] +// small positive intergers can be small +#[case(&[0x01], true)] +#[case(&[0x00, 0xff], true)] +#[case(&[0x7f, 0xff], true)] +#[case(&[0x7f, 0xff, 0xff], true)] +fn test_auto_small_number_from_buf(#[case] buf: &[u8], #[case] expect_small: bool) { + let mut a = Allocator::new(); + let atom = a.new_atom(buf).expect("new_atom()"); + assert_eq!(a.small_number(atom).is_some(), expect_small); + if let Some(v) = a.small_number(atom) { + use num_traits::ToPrimitive; + assert_eq!(v, a.number(atom).to_u32().expect("to_u32()")); + } + assert_eq!(buf, a.atom(atom)); +} diff --git a/src/chia_dialect.rs b/src/chia_dialect.rs index 140d306d..1c973d6b 100644 --- a/src/chia_dialect.rs +++ b/src/chia_dialect.rs @@ -70,8 +70,8 @@ impl Dialect for ChiaDialect { max_cost: Cost, extension: OperatorSet, ) -> Response { - let b = allocator.atom(o); - if b.len() == 4 { + let op_len = allocator.atom_len(o); + if op_len == 4 { // these are unkown operators with assigned cost // the formula is: // +---+---+---+------------+ @@ -83,6 +83,7 @@ impl Dialect for ChiaDialect { // (3 bytes) + 2 bits // cost_function + let b = allocator.atom(o); let opcode = u32::from_be_bytes(b.try_into().unwrap()); // the secp operators have a fixed cost of 1850000 and 1300000, @@ -97,10 +98,13 @@ impl Dialect for ChiaDialect { }; return f(allocator, argument_list, max_cost); } - if b.len() != 1 { + if op_len != 1 { return unknown_operator(allocator, o, argument_list, self.flags, max_cost); } - let f = match b[0] { + let Some(op) = allocator.small_number(o) else { + return unknown_operator(allocator, o, argument_list, self.flags, max_cost); + }; + let f = match op { // 1 = quote // 2 = apply 3 => op_if, @@ -146,7 +150,7 @@ impl Dialect for ChiaDialect { _ => { if extension == OperatorSet::BLS || (self.flags & ENABLE_BLS_OPS_OUTSIDE_GUARD) != 0 { - match b[0] { + match op { 48 => op_coinid, 49 => op_bls_g1_subtract, 50 => op_bls_g1_multiply, @@ -179,16 +183,14 @@ impl Dialect for ChiaDialect { f(allocator, argument_list, max_cost) } - fn quote_kw(&self) -> &[u8] { - &[1] + fn quote_kw(&self) -> u32 { + 1 } - - fn apply_kw(&self) -> &[u8] { - &[2] + fn apply_kw(&self) -> u32 { + 2 } - - fn softfork_kw(&self) -> &[u8] { - &[36] + fn softfork_kw(&self) -> u32 { + 36 } // interpret the extension argument passed to the softfork operator, and diff --git a/src/dialect.rs b/src/dialect.rs index ef14af95..b6103ceb 100644 --- a/src/dialect.rs +++ b/src/dialect.rs @@ -10,9 +10,9 @@ pub enum OperatorSet { } pub trait Dialect { - fn quote_kw(&self) -> &[u8]; - fn apply_kw(&self) -> &[u8]; - fn softfork_kw(&self) -> &[u8]; + fn quote_kw(&self) -> u32; + fn apply_kw(&self) -> u32; + fn softfork_kw(&self) -> u32; fn softfork_extension(&self, ext: u32) -> OperatorSet; fn op( &self, diff --git a/src/more_ops.rs b/src/more_ops.rs index 0c395d54..54b39d64 100644 --- a/src/more_ops.rs +++ b/src/more_ops.rs @@ -4,7 +4,7 @@ use std::ops::BitAndAssign; use std::ops::BitOrAssign; use std::ops::BitXorAssign; -use crate::allocator::{Allocator, NodePtr, SExp}; +use crate::allocator::{len_for_value, Allocator, NodePtr, NodeVisitor, SExp}; use crate::cost::{check_cost, Cost}; use crate::err_utils::err; use crate::number::Number; @@ -365,9 +365,21 @@ pub fn op_add(a: &mut Allocator, mut input: NodePtr, max_cost: Cost) -> Response cost + (byte_count as Cost * ARITH_COST_PER_BYTE), max_cost, )?; - let (v, len) = int_atom(a, arg, "+")?; - byte_count += len; - total += v; + + match a.node(arg) { + NodeVisitor::Buffer(buf) => { + use crate::number::number_from_u8; + total += number_from_u8(buf); + byte_count += buf.len(); + } + NodeVisitor::U32(val) => { + total += val; + byte_count += len_for_value(val); + } + NodeVisitor::Pair(_, _) => { + return err(arg, "+ requires int args"); + } + } } let total = a.new_number(total)?; cost += byte_count as Cost * ARITH_COST_PER_BYTE; @@ -383,12 +395,25 @@ pub fn op_subtract(a: &mut Allocator, mut input: NodePtr, max_cost: Cost) -> Res input = rest; cost += ARITH_COST_PER_ARG; check_cost(a, cost + byte_count as Cost * ARITH_COST_PER_BYTE, max_cost)?; - let (v, len) = int_atom(a, arg, "-")?; - byte_count += len; if is_first { - total += v; + let (v, len) = int_atom(a, arg, "-")?; + byte_count = len; + total = v; } else { - total -= v; + match a.node(arg) { + NodeVisitor::Buffer(buf) => { + use crate::number::number_from_u8; + total -= number_from_u8(buf); + byte_count += buf.len(); + } + NodeVisitor::U32(val) => { + total -= val; + byte_count += len_for_value(val); + } + NodeVisitor::Pair(_, _) => { + return err(arg, "- requires int args"); + } + } }; is_first = false; } @@ -411,14 +436,24 @@ pub fn op_multiply(a: &mut Allocator, mut input: NodePtr, max_cost: Cost) -> Res continue; } - let (v0, l1) = int_atom(a, arg, "*")?; + let l1 = match a.node(arg) { + NodeVisitor::Buffer(buf) => { + use crate::number::number_from_u8; + total *= number_from_u8(buf); + buf.len() + } + NodeVisitor::U32(val) => { + total *= val; + len_for_value(val) + } + NodeVisitor::Pair(_, _) => { + return err(arg, "* requires int args"); + } + }; - total *= v0; cost += MUL_COST_PER_OP; - cost += (l0 + l1) as Cost * MUL_LINEAR_COST_PER_BYTE; cost += (l0 * l1) as Cost / MUL_SQUARE_COST_PER_BYTE_DIVIDER; - l0 = limbs_for_int(&total); } let total = a.new_number(total)?; @@ -490,10 +525,20 @@ pub fn op_mod(a: &mut Allocator, input: NodePtr, _max_cost: Cost) -> Response { pub fn op_gr(a: &mut Allocator, input: NodePtr, _max_cost: Cost) -> Response { let [v0, v1] = get_args::<2>(a, input, ">")?; - let (v0, v0_len) = int_atom(a, v0, ">")?; - let (v1, v1_len) = int_atom(a, v1, ">")?; - let cost = GR_BASE_COST + (v0_len + v1_len) as Cost * GR_COST_PER_BYTE; - Ok(Reduction(cost, if v0 > v1 { a.one() } else { a.nil() })) + + match (a.small_number(v0), a.small_number(v1)) { + (Some(lhs), Some(rhs)) => { + let cost = + GR_BASE_COST + (len_for_value(lhs) + len_for_value(rhs)) as Cost * GR_COST_PER_BYTE; + Ok(Reduction(cost, if lhs > rhs { a.one() } else { a.nil() })) + } + _ => { + let (v0, v0_len) = int_atom(a, v0, ">")?; + let (v1, v1_len) = int_atom(a, v1, ">")?; + let cost = GR_BASE_COST + (v0_len + v1_len) as Cost * GR_COST_PER_BYTE; + Ok(Reduction(cost, if v0 > v1 { a.one() } else { a.nil() })) + } + } } pub fn op_gr_bytes(a: &mut Allocator, input: NodePtr, _max_cost: Cost) -> Response { diff --git a/src/op_utils.rs b/src/op_utils.rs index b04a6d65..78a9403c 100644 --- a/src/op_utils.rs +++ b/src/op_utils.rs @@ -1,4 +1,4 @@ -use crate::allocator::{Allocator, NodePtr, SExp}; +use crate::allocator::{Allocator, NodePtr, NodeVisitor, SExp}; use crate::cost::Cost; use crate::err_utils::err; use crate::number::Number; @@ -279,37 +279,36 @@ pub fn uint_atom( args: NodePtr, op_name: &str, ) -> Result { - let bytes = match a.sexp(args) { - SExp::Atom => a.atom(args), - _ => { - return err(args, &format!("{op_name} requires int arg")); + match a.node(args) { + NodeVisitor::Buffer(bytes) => { + if bytes.is_empty() { + return Ok(0); + } + + if (bytes[0] & 0x80) != 0 { + return err(args, &format!("{op_name} requires positive int arg")); + } + + // strip leading zeros + let mut buf: &[u8] = bytes; + while !buf.is_empty() && buf[0] == 0 { + buf = &buf[1..]; + } + + if buf.len() > SIZE { + return err(args, &format!("{op_name} requires u{} arg", SIZE * 8)); + } + + let mut ret = 0; + for b in buf { + ret <<= 8; + ret |= *b as u64; + } + Ok(ret) } - }; - - if bytes.is_empty() { - return Ok(0); - } - - if (bytes[0] & 0x80) != 0 { - return err(args, &format!("{op_name} requires positive int arg")); - } - - // strip leading zeros - let mut buf: &[u8] = bytes; - while !buf.is_empty() && buf[0] == 0 { - buf = &buf[1..]; - } - - if buf.len() > SIZE { - return err(args, &format!("{op_name} requires u{} arg", SIZE * 8)); - } - - let mut ret = 0; - for b in buf { - ret <<= 8; - ret |= *b as u64; + NodeVisitor::U32(val) => Ok(val as u64), + NodeVisitor::Pair(_, _) => err(args, &format!("{op_name} requires int arg")), } - Ok(ret) } #[cfg(test)] @@ -532,18 +531,16 @@ fn test_u64_from_bytes() { } pub fn i32_atom(a: &Allocator, args: NodePtr, op_name: &str) -> Result { - let buf = match a.sexp(args) { - SExp::Atom => a.atom(args), - _ => { - return err(args, &format!("{op_name} requires int32 args")); - } - }; - match i32_from_u8(buf) { - Some(v) => Ok(v), - _ => err( - args, - &format!("{op_name} requires int32 args (with no leading zeros)"), - ), + match a.node(args) { + NodeVisitor::Buffer(buf) => match i32_from_u8(buf) { + Some(v) => Ok(v), + _ => err( + args, + &format!("{op_name} requires int32 args (with no leading zeros)"), + ), + }, + NodeVisitor::U32(val) => Ok(val as i32), + NodeVisitor::Pair(_, _) => err(args, &format!("{op_name} requires int32 args")), } } diff --git a/src/run_program.rs b/src/run_program.rs index 4354c973..0a3fad2b 100644 --- a/src/run_program.rs +++ b/src/run_program.rs @@ -1,9 +1,9 @@ -use super::traverse_path::traverse_path; -use crate::allocator::{Allocator, Checkpoint, NodePtr, SExp}; +use super::traverse_path::{traverse_path, traverse_path_fast}; +use crate::allocator::{Allocator, Checkpoint, NodePtr, NodeVisitor, SExp}; use crate::cost::Cost; use crate::dialect::{Dialect, OperatorSet}; use crate::err_utils::err; -use crate::op_utils::{atom, first, get_args, uint_atom}; +use crate::op_utils::{first, get_args, uint_atom}; use crate::reduction::{EvalErr, Reduction, Response}; // lowered from 46 @@ -44,6 +44,7 @@ pub struct Counters { pub env_stack_usage: usize, pub op_stack_usage: usize, pub atom_count: u32, + pub small_atom_count: u32, pub pair_count: u32, pub heap_size: u32, } @@ -56,6 +57,7 @@ impl Counters { env_stack_usage: 0, op_stack_usage: 0, atom_count: 0, + small_atom_count: 0, pair_count: 0, heap_size: 0, } @@ -228,9 +230,8 @@ impl<'a, D: Dialect> RunProgramContext<'a, D> { operand_list: NodePtr, env: NodePtr, ) -> Result { - let op_atom = self.allocator.atom(operator_node); // special case check for quote - if op_atom == self.dialect.quote_kw() { + if self.allocator.small_number(operator_node) == Some(self.dialect.quote_kw()) { self.push(operand_list)?; Ok(QUOTE_COST) } else { @@ -257,7 +258,7 @@ impl<'a, D: Dialect> RunProgramContext<'a, D> { operands = rest; } // ensure a correct nil terminator - if !self.allocator.atom(operands).is_empty() { + if self.allocator.atom_len(operands) != 0 { err(operand_list, "bad operand list") } else { self.push(self.allocator.nil())?; @@ -278,7 +279,13 @@ impl<'a, D: Dialect> RunProgramContext<'a, D> { // put a bunch of ops on op_stack let SExp::Pair(op_node, op_list) = self.allocator.sexp(program) else { // the program is just a bitfield path through the env tree - let r: Reduction = traverse_path(self.allocator, self.allocator.atom(program), env)?; + let r = match self.allocator.node(program) { + NodeVisitor::Buffer(buf) => traverse_path(self.allocator, buf, env)?, + NodeVisitor::U32(val) => traverse_path_fast(self.allocator, val, env)?, + NodeVisitor::Pair(_, _) => { + panic!("expected atom, got pair"); + } + }; self.push(r.1)?; return Ok(r.0); }; @@ -339,14 +346,15 @@ impl<'a, D: Dialect> RunProgramContext<'a, D> { fn apply_op(&mut self, current_cost: Cost, max_cost: Cost) -> Result { let operand_list = self.pop()?; let operator = self.pop()?; - let op_atom = atom(self.allocator, operator, "(internal error) apply")?; if self.env_stack.pop().is_none() { return err(operator, "runtime error: env stack empty"); } - if op_atom == self.dialect.apply_kw() { + let op_atom = self.allocator.small_number(operator); + + if op_atom == Some(self.dialect.apply_kw()) { let [new_operator, env] = get_args::<2>(self.allocator, operand_list, "apply")?; self.eval_pair(new_operator, env).map(|c| c + APPLY_COST) - } else if op_atom == self.dialect.softfork_kw() { + } else if op_atom == Some(self.dialect.softfork_kw()) { let expected_cost = uint_atom::<8>( self.allocator, first(self.allocator, operand_list)?, @@ -531,6 +539,7 @@ pub fn run_program_with_counters<'a, D: Dialect>( let mut rpc = RunProgramContext::new(allocator, dialect); let ret = rpc.run_program(program, env, max_cost); rpc.counters.atom_count = rpc.allocator.atom_count() as u32; + rpc.counters.small_atom_count = rpc.allocator.small_atom_count() as u32; rpc.counters.pair_count = rpc.allocator.pair_count() as u32; rpc.counters.heap_size = rpc.allocator.heap_size() as u32; (rpc.counters, ret) @@ -1331,9 +1340,10 @@ fn test_counters() { assert_eq!(counters.val_stack_usage, 3015); assert_eq!(counters.env_stack_usage, 1005); assert_eq!(counters.op_stack_usage, 3014); - assert_eq!(counters.atom_count, 2040); + assert_eq!(counters.atom_count, 998); + assert_eq!(counters.small_atom_count, 1042); assert_eq!(counters.pair_count, 22077); - assert_eq!(counters.heap_size, 771884); + assert_eq!(counters.heap_size, 769963); assert_eq!(result.unwrap().0, cost); } diff --git a/src/runtime_dialect.rs b/src/runtime_dialect.rs index 373e974d..20b00eea 100644 --- a/src/runtime_dialect.rs +++ b/src/runtime_dialect.rs @@ -55,16 +55,14 @@ impl Dialect for RuntimeDialect { } } - fn quote_kw(&self) -> &[u8] { - &self.quote_kw + fn quote_kw(&self) -> u32 { + self.quote_kw[0] as u32 } - - fn apply_kw(&self) -> &[u8] { - &self.apply_kw + fn apply_kw(&self) -> u32 { + self.apply_kw[0] as u32 } - - fn softfork_kw(&self) -> &[u8] { - &self.softfork_kw + fn softfork_kw(&self) -> u32 { + self.softfork_kw[0] as u32 } fn softfork_extension(&self, _ext: u32) -> OperatorSet { diff --git a/src/serde/parse_atom.rs b/src/serde/parse_atom.rs index 11ad8c2d..77b2ca56 100644 --- a/src/serde/parse_atom.rs +++ b/src/serde/parse_atom.rs @@ -1,6 +1,6 @@ use std::io::{Cursor, Read, Result, Seek, SeekFrom}; -use crate::allocator::{Allocator, NodePtr}; +use crate::allocator::{canonical_positive_integer, Allocator, NodePtr}; use super::errors::{bad_encoding, internal_error}; @@ -85,7 +85,16 @@ pub fn parse_atom( Ok(allocator.nil()) } else { let blob = parse_atom_ptr(f, first_byte)?; - Ok(allocator.new_atom(blob)?) + if blob.len() <= 3 && canonical_positive_integer(blob) { + let mut val: u32 = 0; + for i in blob { + val <<= 8; + val |= *i as u32; + } + Ok(allocator.new_small_number(val)?) + } else { + Ok(allocator.new_atom(blob)?) + } } } diff --git a/src/serde/ser.rs b/src/serde/ser.rs index 4c86454d..314bcaf3 100644 --- a/src/serde/ser.rs +++ b/src/serde/ser.rs @@ -4,7 +4,7 @@ use std::io::ErrorKind; use std::io::Write; use super::write_atom::write_atom; -use crate::allocator::{Allocator, NodePtr, SExp}; +use crate::allocator::{len_for_value, Allocator, NodePtr, NodeVisitor}; const CONS_BOX_MARKER: u8 = 0xff; @@ -41,12 +41,14 @@ impl Write for LimitedWriter { pub fn node_to_stream(a: &Allocator, node: NodePtr, f: &mut W) -> io::Result<()> { let mut values: Vec = vec![node]; while let Some(v) = values.pop() { - let n = a.sexp(v); - match n { - SExp::Atom => { - write_atom(f, a.atom(v))?; + match a.node(v) { + NodeVisitor::Buffer(buf) => write_atom(f, buf)?, + NodeVisitor::U32(val) => { + let buf = val.to_be_bytes(); + let len = len_for_value(val); + write_atom(f, &buf[4 - len..])? } - SExp::Pair(left, right) => { + NodeVisitor::Pair(left, right) => { f.write_all(&[CONS_BOX_MARKER])?; values.push(right); values.push(left); diff --git a/src/traverse_path.rs b/src/traverse_path.rs index 27127397..fa47b0dc 100644 --- a/src/traverse_path.rs +++ b/src/traverse_path.rs @@ -72,6 +72,42 @@ pub fn traverse_path(allocator: &Allocator, node_index: &[u8], args: NodePtr) -> Ok(Reduction(cost, arg_list)) } +// The cost calculation for this version of traverse_path assumes the node_index has the canonical +// integer representation (which is true for SmallAtom in the allocator). If there are any +// redundant leading zeros, the slow path must be used +pub fn traverse_path_fast(allocator: &Allocator, mut node_index: u32, args: NodePtr) -> Response { + if node_index == 0 { + return Ok(Reduction( + TRAVERSE_BASE_COST + TRAVERSE_COST_PER_BIT, + allocator.nil(), + )); + } + + let mut arg_list: NodePtr = args; + + let mut cost: Cost = TRAVERSE_BASE_COST + TRAVERSE_COST_PER_BIT; + let mut num_bits = 0; + while node_index != 1 { + let SExp::Pair(left, right) = allocator.sexp(arg_list) else { + return Err(EvalErr(arg_list, "path into atom".into())); + }; + + let is_bit_set: bool = (node_index & 0x01) != 0; + arg_list = if is_bit_set { right } else { left }; + node_index >>= 1; + num_bits += 1 + } + + cost += num_bits * TRAVERSE_COST_PER_BIT; + // since positive numbers sometimes need a leading zero, e.g. 0x80, 0x8000 etc. We also + // need to add the cost of that leading zero byte + if num_bits == 7 || num_bits == 15 || num_bits == 23 || num_bits == 31 { + cost += TRAVERSE_COST_PER_ZERO_BYTE; + } + + Ok(Reduction(cost, arg_list)) +} + #[test] fn test_msb_mask() { assert_eq!(msb_mask(0x0), 0x0); @@ -166,3 +202,61 @@ fn test_traverse_path() { EvalErr(n2, "path into atom".to_string()) ); } + +#[test] +fn test_traverse_path_fast_fast() { + use crate::allocator::Allocator; + + let mut a = Allocator::new(); + let nul = a.nil(); + let n1 = a.new_atom(&[0, 1, 2]).unwrap(); + let n2 = a.new_atom(&[4, 5, 6]).unwrap(); + + assert_eq!(traverse_path_fast(&a, 0, n1).unwrap(), Reduction(44, nul)); + assert_eq!(traverse_path_fast(&a, 0b1, n1).unwrap(), Reduction(44, n1)); + assert_eq!(traverse_path_fast(&a, 0b1, n2).unwrap(), Reduction(44, n2)); + + let n3 = a.new_pair(n1, n2).unwrap(); + assert_eq!(traverse_path_fast(&a, 0b1, n3).unwrap(), Reduction(44, n3)); + assert_eq!(traverse_path_fast(&a, 0b10, n3).unwrap(), Reduction(48, n1)); + assert_eq!(traverse_path_fast(&a, 0b11, n3).unwrap(), Reduction(48, n2)); + assert_eq!(traverse_path_fast(&a, 0b11, n3).unwrap(), Reduction(48, n2)); + + let list = a.new_pair(n1, nul).unwrap(); + let list = a.new_pair(n2, list).unwrap(); + + assert_eq!( + traverse_path_fast(&a, 0b10, list).unwrap(), + Reduction(48, n2) + ); + assert_eq!( + traverse_path_fast(&a, 0b101, list).unwrap(), + Reduction(52, n1) + ); + assert_eq!( + traverse_path_fast(&a, 0b111, list).unwrap(), + Reduction(52, nul) + ); + + // errors + assert_eq!( + traverse_path_fast(&a, 0b1011, list).unwrap_err(), + EvalErr(nul, "path into atom".to_string()) + ); + assert_eq!( + traverse_path_fast(&a, 0b1101, list).unwrap_err(), + EvalErr(n1, "path into atom".to_string()) + ); + assert_eq!( + traverse_path_fast(&a, 0b1001, list).unwrap_err(), + EvalErr(n1, "path into atom".to_string()) + ); + assert_eq!( + traverse_path_fast(&a, 0b1010, list).unwrap_err(), + EvalErr(n2, "path into atom".to_string()) + ); + assert_eq!( + traverse_path_fast(&a, 0b1110, list).unwrap_err(), + EvalErr(n2, "path into atom".to_string()) + ); +}