Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix internal alignment in frunk heterogeneous lists #929

Merged
merged 6 commits into from
Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions linera-witty/src/runtime/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,39 @@ use crate::{Layout, WitType};
use frunk::{hlist, hlist_pat, HList};
use std::borrow::Cow;

#[cfg(test)]
#[path = "unit_tests/memory.rs"]
mod tests;

/// An address for a location in a guest WebAssembly module's memory.
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct GuestPointer(pub(crate) u32);

impl GuestPointer {
/// Returns a new address that's the current address advanced to add padding to ensure it's
/// aligned to the `alignment` byte boundary.
pub fn aligned_at(&self, alignment: u32) -> Self {
pub const fn aligned_at(&self, alignment: u32) -> Self {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yay for const functions!

// The following computation is equivalent to:
// `(alignment - (self.0 % alignment)) % alignment`.
// Source: https://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding
let padding = (-(self.0 as i32) & (alignment as i32 - 1)) as u32;
Copy link
Contributor

@ma2bd ma2bd Aug 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe comment something like

// Compute `-self.0` modulo `alignment`

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, and linked to Wikipedia where I first found the formula 😆


GuestPointer(self.0 + padding)
}

/// Returns a new address that's the current address advanced to after the size of `T`.
pub fn after<T: WitType>(&self) -> Self {
pub const fn after<T: WitType>(&self) -> Self {
GuestPointer(self.0 + T::SIZE)
}

/// Returns a new address that's the current address advanced to add padding to ensure it's
/// aligned properly for `T`.
pub fn after_padding_for<T: WitType>(&self) -> Self {
pub const fn after_padding_for<T: WitType>(&self) -> Self {
self.aligned_at(<T::Layout as Layout>::ALIGNMENT)
}

/// Returns the address of an element in a contiguous list of properly aligned `T` types.
pub fn index<T: WitType>(&self, index: u32) -> Self {
pub const fn index<T: WitType>(&self, index: u32) -> Self {
let element_size = GuestPointer(T::SIZE).after_padding_for::<T>();

GuestPointer(self.0 + index * element_size.0)
Expand Down
25 changes: 25 additions & 0 deletions linera-witty/src/runtime/unit_tests/memory.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) Zefchain Labs, Inc.
// SPDX-License-Identifier: Apache-2.0

//! Unit tests for guest Wasm module memory manipulation.

use super::GuestPointer;

/// Test aligning memory addresses.
///
/// Check that the resulting address is aligned and that it never advances more than the alignment
/// amount.
#[test]
fn align_guest_pointer() {
for alignment_bits in 0..3 {
let alignment = 1 << alignment_bits;
let alignment_mask = alignment - 1;

for start_offset in 0..32 {
let address = GuestPointer(start_offset).aligned_at(alignment);

assert_eq!(address.0 & alignment_mask, 0);
assert!(address.0 - start_offset < alignment);
}
}
}
77 changes: 71 additions & 6 deletions linera-witty/src/type_traits/implementations/frunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,19 @@ impl WitStore for HNil {
impl<Head, Tail> WitType for HCons<Head, Tail>
where
Head: WitType,
Tail: WitType,
Tail: WitType + SizeCalculation,
Head::Layout: Add<Tail::Layout>,
<Head::Layout as Add<Tail::Layout>>::Output: Layout,
{
const SIZE: u32 = Head::SIZE + Tail::SIZE;
const SIZE: u32 = Self::SIZE_STARTING_AT_BYTE_BOUNDARIES[0];

type Layout = <Head::Layout as Add<Tail::Layout>>::Output;
}

impl<Head, Tail> WitLoad for HCons<Head, Tail>
where
Head: WitLoad,
Tail: WitLoad,
Tail: WitLoad + SizeCalculation,
Head::Layout: Add<Tail::Layout>,
<Head::Layout as Add<Tail::Layout>>::Output: Layout,
<Self::Layout as Layout>::Flat:
Expand All @@ -94,7 +94,12 @@ where
{
Ok(HCons {
head: Head::load(memory, location)?,
tail: Tail::load(memory, location.after::<Head>())?,
tail: Tail::load(
memory,
location
.after::<Head>()
.after_padding_for::<Tail::FirstElement>(),
)?,
})
}

Expand All @@ -118,7 +123,7 @@ where
impl<Head, Tail> WitStore for HCons<Head, Tail>
where
Head: WitStore,
Tail: WitStore,
Tail: WitStore + SizeCalculation,
Head::Layout: Add<Tail::Layout>,
<Head::Layout as Add<Tail::Layout>>::Output: Layout,
<Head::Layout as Layout>::Flat: Add<<Tail::Layout as Layout>::Flat>,
Expand All @@ -136,7 +141,12 @@ where
<Instance::Runtime as Runtime>::Memory: RuntimeMemory<Instance>,
{
self.head.store(memory, location)?;
self.tail.store(memory, location.after::<Head>())?;
self.tail.store(
memory,
location
.after::<Head>()
.after_padding_for::<Tail::FirstElement>(),
)?;

Ok(())
}
Expand All @@ -155,3 +165,58 @@ where
Ok(head_layout + tail_layout)
}
}

/// Helper trait used to calculate the size of a heterogeneous list considering internal alignment.
///
/// Assumes the maximum alignment necessary for any type is 8 bytes, which is the alignment for the
/// largest flat types (`i64` and `f64`).
trait SizeCalculation {
/// The size of the list considering the current size calculation starts at different offsets
/// inside an 8-byte window.
const SIZE_STARTING_AT_BYTE_BOUNDARIES: [u32; 8];

/// The type of the first element of the list, used to determine the current necessary
/// alignment.
type FirstElement: WitType;
}

impl SizeCalculation for HNil {
const SIZE_STARTING_AT_BYTE_BOUNDARIES: [u32; 8] = [0; 8];

type FirstElement = ();
}

/// Unrolls a `for`-like loop so that it runs in a `const` context.
macro_rules! unroll_for {
($binding:ident in [ $($elements:expr),* $(,)? ] $body:tt) => {
$(
let $binding = $elements;
$body
)*
};
}

impl<Head, Tail> SizeCalculation for HCons<Head, Tail>
where
Head: WitType,
Tail: SizeCalculation,
{
const SIZE_STARTING_AT_BYTE_BOUNDARIES: [u32; 8] = {
let mut size_at_boundaries = [0; 8];

unroll_for!(boundary_offset in [0, 1, 2, 3, 4, 5, 6, 7] {
let memory_location = GuestPointer(boundary_offset)
.after_padding_for::<Head>()
.after::<Head>();

let tail_size = Tail::SIZE_STARTING_AT_BYTE_BOUNDARIES[memory_location.0 as usize % 8];

size_at_boundaries[boundary_offset as usize] =
memory_location.0 - boundary_offset + tail_size;
});

size_at_boundaries
};

type FirstElement = Head;
}
2 changes: 2 additions & 0 deletions linera-witty/src/type_traits/implementations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@
mod custom_types;
mod frunk;
mod std;
#[cfg(test)]
mod tests;
106 changes: 106 additions & 0 deletions linera-witty/src/type_traits/implementations/tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright (c) Zefchain Labs, Inc.
// SPDX-License-Identifier: Apache-2.0

//! Unit tests for implementations of the custom traits for existing types.

use crate::{FakeInstance, InstanceWithMemory, Layout, WitLoad, WitStore};
use frunk::hlist;
use std::fmt::Debug;

/// Test roundtrip of a heterogeneous list that doesn't need any internal padding.
#[test]
fn hlist_without_padding() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe assert the SIZE of the WitTypes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. Done inside the test_memory_roundtrip function.

let input = hlist![
0x1011_1213_1415_1617_1819_1a1b_1c1d_1e1f_u128,
0x2021_2223_2425_2627_i64,
0x3031_3233_u32,
0x4041_i16,
true,
];

test_memory_roundtrip(
input,
&[
0x1f, 0x1e, 0x1d, 0x1c, 0x1b, 0x1a, 0x19, 0x18, 0x17, 0x16, 0x15, 0x14, 0x13, 0x12,
0x11, 0x10, 0x27, 0x26, 0x25, 0x24, 0x23, 0x22, 0x21, 0x20, 0x33, 0x32, 0x31, 0x30,
0x41, 0x40, 0x01,
],
);
test_flattening_roundtrip(
input,
hlist![
0x1819_1a1b_1c1d_1e1f_i64,
0x1011_1213_1415_1617_i64,
0x2021_2223_2425_2627_i64,
0x3031_3233_i32,
0x0000_4041_i32,
0x0000_0001_i32,
],
);
}

/// Test roundtrip of a heterogeneous list that needs internal padding between some of its elements.
#[test]
fn hlist_with_padding() {
let input = hlist![
true,
0x1011_i16,
0x2021_u16,
0x3031_3233_u32,
0x4041_4243_4445_4647_i64,
];

test_memory_roundtrip(
input,
&[
0x01, 0, 0x11, 0x10, 0x21, 0x20, 0, 0, 0x33, 0x32, 0x31, 0x30, 0, 0, 0, 0, 0x47, 0x46,
0x45, 0x44, 0x43, 0x42, 0x41, 0x40,
],
);
test_flattening_roundtrip(
input,
hlist![
0x0000_0001_i32,
0x0000_1011_i32,
0x0000_2021_i32,
0x3031_3233_i32,
0x4041_4243_4445_4647_i64,
],
);
}

/// Test storing an instance of `T` to memory, checking that the `memory_data` bytes are correctly
/// written, and check that the instance can be loaded from those bytes.
fn test_memory_roundtrip<T>(input: T, memory_data: &[u8])
where
T: Debug + Eq + WitLoad + WitStore,
{
let mut instance = FakeInstance::default();
let mut memory = instance.memory().unwrap();
let length = memory_data.len() as u32;

assert_eq!(length, T::SIZE);

let address = memory.allocate(length).unwrap();

input.store(&mut memory, address).unwrap();

assert_eq!(memory.read(address, length).unwrap(), memory_data);
assert_eq!(T::load(&memory, address).unwrap(), input);
}

/// Test lowering an instance of `T`, checking that the resulting flat layout matches the expected
/// `flat_layout`, and check that the instance can be lifted from that flat layout.
fn test_flattening_roundtrip<T>(input: T, flat_layout: <T::Layout as Layout>::Flat)
where
T: Debug + Eq + WitLoad + WitStore,
<T::Layout as Layout>::Flat: Debug + Eq,
{
let mut instance = FakeInstance::default();
let mut memory = instance.memory().unwrap();

let lowered_layout = input.lower(&mut memory).unwrap();

assert_eq!(lowered_layout, flat_layout);
assert_eq!(T::lift_from(lowered_layout, &memory).unwrap(), input);
}
2 changes: 1 addition & 1 deletion linera-witty/tests/common/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pub struct SimpleWrapper(pub bool);
#[derive(Clone, Copy, Debug, Eq, PartialEq, WitType, WitLoad, WitStore)]
pub struct TupleWithoutPadding(pub u64, pub i32, pub i16);

/// A tuple struct that requires internal padding in its memory layout between all of its fields.
/// A tuple struct that requires internal padding in its memory layout between two of its fields.
#[derive(Clone, Copy, Debug, Eq, PartialEq, WitType, WitLoad, WitStore)]
pub struct TupleWithPadding(pub u16, pub u32, pub i64);

Expand Down