Skip to content

Commit

Permalink
chore: fix most serialization failures
Browse files Browse the repository at this point in the history
  • Loading branch information
cfcosta committed Jul 30, 2024
1 parent ddd78fa commit 0bf2ca5
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 59 deletions.
12 changes: 6 additions & 6 deletions core/src/contracts/fusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,19 @@ pub fn fusion(context: &mut Context) -> Result<()> {
Hash::digest(&mut context.hasher, &ib)?,
);

let output = BlindedNote {
let note = BlindedNote {
asset_id: ia.asset_id,
amount: total,
secret: Hash::combine3(&mut context.hasher, OUTPUT_SEP, a, b)?,
};
context.write_stdout(&output);
context.write_stdout(&note);

let fusion = Output {
let output = Output {
a,
b,
c: Hash::digest(&mut context.hasher, &output)?,
c: Hash::digest(&mut context.hasher, &note)?,
};
context.write_journal(&fusion);
context.write_journal(&output);

Ok(())
}
Expand All @@ -77,7 +77,7 @@ mod tests {
a: a.clone(),
b: b.clone(),
};
input.to_slice(&mut context.stdin);
context.write_stdin(&input);

fusion(&mut context)?;
let result = BlindedNote::from_slice(&context.stdout)?;
Expand Down
5 changes: 5 additions & 0 deletions core/src/contracts/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ impl<const STDIN: usize, const STDOUT: usize, const JOURNAL: usize>
r.read()
}

pub fn write_stdin<T: SerializeBytes>(&mut self, value: &T) {
let mut w = Writer::new(&mut self.stdin);
w.write(value);
}

pub fn write_stdout<T: SerializeBytes>(&mut self, value: &T) {
let mut w = Writer::new(&mut self.stdout);
w.write(value);
Expand Down
6 changes: 3 additions & 3 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ impl<'a> Reader<'a> {
}

pub fn read<T: SerializeBytes>(&mut self) -> Result<T> {
assert!(self.offset + T::SIZE <= self.data.len() - 1);
assert!(self.offset + T::SIZE <= self.data.len());

let result = T::from_slice(&self.data[self.offset..T::SIZE])?;
let result = T::from_slice(&self.data[self.offset..self.offset + T::SIZE])?;
self.offset += T::SIZE;

Ok(result)
Expand All @@ -53,7 +53,7 @@ impl<'a> Writer<'a> {
}

pub fn write<T: SerializeBytes>(&mut self, value: &T) {
assert!(self.offset + T::SIZE <= self.data.len() - 1);
assert!(self.offset + T::SIZE <= self.data.len());

value.to_slice(&mut self.data[self.offset..T::SIZE]);
self.offset += T::SIZE;
Expand Down
2 changes: 2 additions & 0 deletions core/src/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ impl SerializeBytes for u64 {

#[inline]
fn to_slice(&self, out: &mut [u8]) {
assert_eq!(out.len(), Self::SIZE);
out[..Self::SIZE].copy_from_slice(&self.to_le_bytes())
}

#[inline]
fn from_slice(input: &[u8]) -> Result<Self> {
assert_eq!(input.len(), Self::SIZE);
Ok(Self::from_le_bytes(input[..Self::SIZE].try_into()?))
}
}
Expand Down
19 changes: 13 additions & 6 deletions core/src/types/hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ impl TryFrom<&[u8]> for Hash {
fn try_from(value: &[u8]) -> core::result::Result<Self, Self::Error> {
assert_eq!(value.len(), 32);

let bytes: [u8; 32] = unsafe { *(value.as_ptr() as *const [u8; 32]) };

Ok(Self(bytes))
Ok(Self(value.try_into()?))
}
}

Expand All @@ -99,15 +97,12 @@ impl SerializeBytes for Hash {

#[inline]
fn to_slice(&self, out: &mut [u8]) {
assert!(out.len() >= 32);

out.copy_from_slice(&self.0)
}

#[inline]
fn from_slice(input: &[u8]) -> Result<Self> {
assert!(input.len() >= 32);

input.try_into()
}
}
Expand All @@ -126,8 +121,20 @@ mod tests {
use proptest::prelude::*;
use test_strategy::proptest;

use crate::SerializeBytes;

use super::Hash;

#[proptest]
fn test_serialize_bytes(input: Hash) {
let mut buf = [0u8; 32];

input.to_slice(&mut buf);
let output = Hash::from_slice(&buf).unwrap();

prop_assert_eq!(input, output);
}

#[proptest]
fn test_try_from(input: [u8; 32]) {
let input_ref: &[u8] = &input;
Expand Down
24 changes: 13 additions & 11 deletions core/tests/serialize.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
macro_rules! generate_serialize_roundtrip_tests {
($($type:ty),+) => {
($([$type:ty, $val: expr]),+) => {
$(
paste::paste! {
#[cfg(feature = "std")]
Expand All @@ -8,11 +8,13 @@ macro_rules! generate_serialize_roundtrip_tests {
use mugraph_core::SerializeBytes;
use proptest::prelude::*;

let mut buffer = vec![0u8; <$type as SerializeBytes>::SIZE];
let size = <$type as SerializeBytes>::SIZE;
let mut buffer = vec![0u8; size];
value.to_slice(&mut buffer);

let deserialized = <$type as SerializeBytes>::from_slice(&buffer).unwrap();
prop_assert_eq!(value, deserialized);
prop_assert_eq!(size, $val);
}
}
)+
Expand All @@ -28,13 +30,13 @@ type FusionInput = mugraph_core::contracts::fusion::Input;
type FusionOutput = mugraph_core::contracts::fusion::Output;

generate_serialize_roundtrip_tests!(
u64,
Hash,
Signature,
FissionInput,
FissionOutput,
FusionInput,
FusionOutput,
Note,
BlindedNote
[u64, 8],
[Hash, 32],
[Signature, 64],
[FissionInput, 144],
[FissionOutput, 96],
[FusionInput, 208],
[FusionOutput, 96],
[Note, 104],
[BlindedNote, 72]
);
77 changes: 44 additions & 33 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,44 +7,63 @@ pub fn derive_serialize_bytes(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;

let (to_slice_impl, from_slice_impl, size_calc) = match &input.data {
let (impl_block, size_calc) = match &input.data {
Data::Struct(data_struct) => match &data_struct.fields {
Fields::Named(fields) => {
let to_slice_fields = fields.named.iter().map(|f| {
let name = &f.ident;
quote! {
w.write(&self.#name);
}
});
let field_data: Vec<_> = fields
.named
.iter()
.map(|f| {
let name = &f.ident;
let ty = &f.ty;
(name, ty)
})
.collect();

let mut offset = quote! { 0usize };
let to_slice_impl = field_data.iter().map(|(name, ty)| {
let size = quote! { <#ty as SerializeBytes>::SIZE };
let current = offset.clone();
offset = quote! { #offset + #size };

let from_slice_fields = fields.named.iter().map(|f| {
let name = &f.ident;
quote! {
#name: r.read()?,
<#ty as SerializeBytes>::to_slice(&self.#name, &mut out[#current..#offset]);
}
});

let size_fields = fields.named.iter().map(|f| {
let ty = &f.ty;
let mut offset = quote! { 0usize };
let from_slice_impl = field_data.iter().map(|(name, ty)| {
let size = quote! { <#ty as SerializeBytes>::SIZE };
let current = offset.clone();
offset = quote! { #offset + #size };

quote! {
#ty::SIZE +
#name: <#ty as SerializeBytes>::from_slice(&input[#current..#offset])?,
}
});

let size_calc = field_data
.iter()
.map(|(_, ty)| {
quote! { <#ty as SerializeBytes>::SIZE }
})
.fold(quote!(0), |acc, size| {
quote! { #acc + #size }
});

(
quote! {
let mut w = Writer::new(out);
#(#to_slice_fields)*
},
quote! {
let mut r = Reader::new(input);
Ok(Self {
#(#from_slice_fields)*
})
},
quote! {
#(#size_fields)* 0
fn to_slice(&self, out: &mut [u8]) {
#(#to_slice_impl)*
}

fn from_slice(input: &[u8]) -> Result<Self> {
Ok(Self {
#(#from_slice_impl)*
})
}
},
size_calc,
)
}
_ => panic!("SerializeBytes can only be derived for structs with named fields"),
Expand All @@ -56,15 +75,7 @@ pub fn derive_serialize_bytes(input: TokenStream) -> TokenStream {
impl SerializeBytes for #name {
const SIZE: usize = #size_calc;

#[inline]
fn to_slice(&self, out: &mut [u8]) {
#to_slice_impl
}

#[inline]
fn from_slice(input: &[u8]) -> Result<Self> {
#from_slice_impl
}
#impl_block
}
};

Expand Down

0 comments on commit 0bf2ca5

Please sign in to comment.