Skip to content

Commit

Permalink
Fix internal alignment in frunk heterogeneous lists (#929)
Browse files Browse the repository at this point in the history
* Make some `GuestPointer` methods `const`

Allow them to be used in the `const` context that calculates the size of
a heterogeneous list.

* Add comment explaining how padding is calculated

Link to the Wikipedia source.

* Test `GuestPointer::align_at`

Ensure that the non-trivial padding calculation is correct.

* Fix alignment between heterogeneous list elements

Because heterogeneous lists are recursive types, some special handling
is needed. Alignment between internal elements must be calculated
separately, because otherwise the alignment for the whole list (the
tail) is used, which is not the same as the alignment as the next
element.

The same is true when reading and writing to memory, because the
alignment between each element must be considered, and not the alignment
for all the remaining elements.

* Fix documentation for test type

There's only one place where padding is needed.

* Add roundtrip tests for some `hlist` types

Make sure that they can be stored in memory and read from memory, and
that they can be lowered into its flat layout and lifted back from it.
Ensure that their internal element alignment is properly respected.
  • Loading branch information
jvff authored Aug 4, 2023
1 parent a577933 commit 074c130
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 11 deletions.
15 changes: 11 additions & 4 deletions linera-witty/src/runtime/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,39 @@ use crate::{Layout, WitType};
use frunk::{hlist, hlist_pat, HList};
use std::borrow::Cow;

#[cfg(test)]
#[path = "unit_tests/memory.rs"]
mod tests;

/// An address for a location in a guest WebAssembly module's memory.
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct GuestPointer(pub(crate) u32);

impl GuestPointer {
/// Returns a new address that's the current address advanced to add padding to ensure it's
/// aligned to the `alignment` byte boundary.
pub fn aligned_at(&self, alignment: u32) -> Self {
pub const fn aligned_at(&self, alignment: u32) -> Self {
// The following computation is equivalent to:
// `(alignment - (self.0 % alignment)) % alignment`.
// Source: https://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding
let padding = (-(self.0 as i32) & (alignment as i32 - 1)) as u32;

GuestPointer(self.0 + padding)
}

/// Returns a new address that's the current address advanced to after the size of `T`.
pub fn after<T: WitType>(&self) -> Self {
pub const fn after<T: WitType>(&self) -> Self {
GuestPointer(self.0 + T::SIZE)
}

/// Returns a new address that's the current address advanced to add padding to ensure it's
/// aligned properly for `T`.
pub fn after_padding_for<T: WitType>(&self) -> Self {
pub const fn after_padding_for<T: WitType>(&self) -> Self {
self.aligned_at(<T::Layout as Layout>::ALIGNMENT)
}

/// Returns the address of an element in a contiguous list of properly aligned `T` types.
pub fn index<T: WitType>(&self, index: u32) -> Self {
pub const fn index<T: WitType>(&self, index: u32) -> Self {
let element_size = GuestPointer(T::SIZE).after_padding_for::<T>();

GuestPointer(self.0 + index * element_size.0)
Expand Down
25 changes: 25 additions & 0 deletions linera-witty/src/runtime/unit_tests/memory.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) Zefchain Labs, Inc.
// SPDX-License-Identifier: Apache-2.0

//! Unit tests for guest Wasm module memory manipulation.
use super::GuestPointer;

/// Test aligning memory addresses.
///
/// Check that the resulting address is aligned and that it never advances more than the alignment
/// amount.
#[test]
fn align_guest_pointer() {
for alignment_bits in 0..3 {
let alignment = 1 << alignment_bits;
let alignment_mask = alignment - 1;

for start_offset in 0..32 {
let address = GuestPointer(start_offset).aligned_at(alignment);

assert_eq!(address.0 & alignment_mask, 0);
assert!(address.0 - start_offset < alignment);
}
}
}
77 changes: 71 additions & 6 deletions linera-witty/src/type_traits/implementations/frunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,19 @@ impl WitStore for HNil {
impl<Head, Tail> WitType for HCons<Head, Tail>
where
Head: WitType,
Tail: WitType,
Tail: WitType + SizeCalculation,
Head::Layout: Add<Tail::Layout>,
<Head::Layout as Add<Tail::Layout>>::Output: Layout,
{
const SIZE: u32 = Head::SIZE + Tail::SIZE;
const SIZE: u32 = Self::SIZE_STARTING_AT_BYTE_BOUNDARIES[0];

type Layout = <Head::Layout as Add<Tail::Layout>>::Output;
}

impl<Head, Tail> WitLoad for HCons<Head, Tail>
where
Head: WitLoad,
Tail: WitLoad,
Tail: WitLoad + SizeCalculation,
Head::Layout: Add<Tail::Layout>,
<Head::Layout as Add<Tail::Layout>>::Output: Layout,
<Self::Layout as Layout>::Flat:
Expand All @@ -94,7 +94,12 @@ where
{
Ok(HCons {
head: Head::load(memory, location)?,
tail: Tail::load(memory, location.after::<Head>())?,
tail: Tail::load(
memory,
location
.after::<Head>()
.after_padding_for::<Tail::FirstElement>(),
)?,
})
}

Expand All @@ -118,7 +123,7 @@ where
impl<Head, Tail> WitStore for HCons<Head, Tail>
where
Head: WitStore,
Tail: WitStore,
Tail: WitStore + SizeCalculation,
Head::Layout: Add<Tail::Layout>,
<Head::Layout as Add<Tail::Layout>>::Output: Layout,
<Head::Layout as Layout>::Flat: Add<<Tail::Layout as Layout>::Flat>,
Expand All @@ -136,7 +141,12 @@ where
<Instance::Runtime as Runtime>::Memory: RuntimeMemory<Instance>,
{
self.head.store(memory, location)?;
self.tail.store(memory, location.after::<Head>())?;
self.tail.store(
memory,
location
.after::<Head>()
.after_padding_for::<Tail::FirstElement>(),
)?;

Ok(())
}
Expand All @@ -155,3 +165,58 @@ where
Ok(head_layout + tail_layout)
}
}

/// Helper trait used to calculate the size of a heterogeneous list considering internal alignment.
///
/// Assumes the maximum alignment necessary for any type is 8 bytes, which is the alignment for the
/// largest flat types (`i64` and `f64`).
trait SizeCalculation {
/// The size of the list considering the current size calculation starts at different offsets
/// inside an 8-byte window.
const SIZE_STARTING_AT_BYTE_BOUNDARIES: [u32; 8];

/// The type of the first element of the list, used to determine the current necessary
/// alignment.
type FirstElement: WitType;
}

impl SizeCalculation for HNil {
const SIZE_STARTING_AT_BYTE_BOUNDARIES: [u32; 8] = [0; 8];

type FirstElement = ();
}

/// Unrolls a `for`-like loop so that it runs in a `const` context.
macro_rules! unroll_for {
($binding:ident in [ $($elements:expr),* $(,)? ] $body:tt) => {
$(
let $binding = $elements;
$body
)*
};
}

impl<Head, Tail> SizeCalculation for HCons<Head, Tail>
where
Head: WitType,
Tail: SizeCalculation,
{
const SIZE_STARTING_AT_BYTE_BOUNDARIES: [u32; 8] = {
let mut size_at_boundaries = [0; 8];

unroll_for!(boundary_offset in [0, 1, 2, 3, 4, 5, 6, 7] {
let memory_location = GuestPointer(boundary_offset)
.after_padding_for::<Head>()
.after::<Head>();

let tail_size = Tail::SIZE_STARTING_AT_BYTE_BOUNDARIES[memory_location.0 as usize % 8];

size_at_boundaries[boundary_offset as usize] =
memory_location.0 - boundary_offset + tail_size;
});

size_at_boundaries
};

type FirstElement = Head;
}
2 changes: 2 additions & 0 deletions linera-witty/src/type_traits/implementations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@
mod custom_types;
mod frunk;
mod std;
#[cfg(test)]
mod tests;
106 changes: 106 additions & 0 deletions linera-witty/src/type_traits/implementations/tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright (c) Zefchain Labs, Inc.
// SPDX-License-Identifier: Apache-2.0

//! Unit tests for implementations of the custom traits for existing types.
use crate::{FakeInstance, InstanceWithMemory, Layout, WitLoad, WitStore};
use frunk::hlist;
use std::fmt::Debug;

/// Test roundtrip of a heterogeneous list that doesn't need any internal padding.
#[test]
fn hlist_without_padding() {
let input = hlist![
0x1011_1213_1415_1617_1819_1a1b_1c1d_1e1f_u128,
0x2021_2223_2425_2627_i64,
0x3031_3233_u32,
0x4041_i16,
true,
];

test_memory_roundtrip(
input,
&[
0x1f, 0x1e, 0x1d, 0x1c, 0x1b, 0x1a, 0x19, 0x18, 0x17, 0x16, 0x15, 0x14, 0x13, 0x12,
0x11, 0x10, 0x27, 0x26, 0x25, 0x24, 0x23, 0x22, 0x21, 0x20, 0x33, 0x32, 0x31, 0x30,
0x41, 0x40, 0x01,
],
);
test_flattening_roundtrip(
input,
hlist![
0x1819_1a1b_1c1d_1e1f_i64,
0x1011_1213_1415_1617_i64,
0x2021_2223_2425_2627_i64,
0x3031_3233_i32,
0x0000_4041_i32,
0x0000_0001_i32,
],
);
}

/// Test roundtrip of a heterogeneous list that needs internal padding between some of its elements.
#[test]
fn hlist_with_padding() {
let input = hlist![
true,
0x1011_i16,
0x2021_u16,
0x3031_3233_u32,
0x4041_4243_4445_4647_i64,
];

test_memory_roundtrip(
input,
&[
0x01, 0, 0x11, 0x10, 0x21, 0x20, 0, 0, 0x33, 0x32, 0x31, 0x30, 0, 0, 0, 0, 0x47, 0x46,
0x45, 0x44, 0x43, 0x42, 0x41, 0x40,
],
);
test_flattening_roundtrip(
input,
hlist![
0x0000_0001_i32,
0x0000_1011_i32,
0x0000_2021_i32,
0x3031_3233_i32,
0x4041_4243_4445_4647_i64,
],
);
}

/// Test storing an instance of `T` to memory, checking that the `memory_data` bytes are correctly
/// written, and check that the instance can be loaded from those bytes.
fn test_memory_roundtrip<T>(input: T, memory_data: &[u8])
where
T: Debug + Eq + WitLoad + WitStore,
{
let mut instance = FakeInstance::default();
let mut memory = instance.memory().unwrap();
let length = memory_data.len() as u32;

assert_eq!(length, T::SIZE);

let address = memory.allocate(length).unwrap();

input.store(&mut memory, address).unwrap();

assert_eq!(memory.read(address, length).unwrap(), memory_data);
assert_eq!(T::load(&memory, address).unwrap(), input);
}

/// Test lowering an instance of `T`, checking that the resulting flat layout matches the expected
/// `flat_layout`, and check that the instance can be lifted from that flat layout.
fn test_flattening_roundtrip<T>(input: T, flat_layout: <T::Layout as Layout>::Flat)
where
T: Debug + Eq + WitLoad + WitStore,
<T::Layout as Layout>::Flat: Debug + Eq,
{
let mut instance = FakeInstance::default();
let mut memory = instance.memory().unwrap();

let lowered_layout = input.lower(&mut memory).unwrap();

assert_eq!(lowered_layout, flat_layout);
assert_eq!(T::lift_from(lowered_layout, &memory).unwrap(), input);
}
2 changes: 1 addition & 1 deletion linera-witty/tests/common/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pub struct SimpleWrapper(pub bool);
#[derive(Clone, Copy, Debug, Eq, PartialEq, WitType, WitLoad, WitStore)]
pub struct TupleWithoutPadding(pub u64, pub i32, pub i16);

/// A tuple struct that requires internal padding in its memory layout between all of its fields.
/// A tuple struct that requires internal padding in its memory layout between two of its fields.
#[derive(Clone, Copy, Debug, Eq, PartialEq, WitType, WitLoad, WitStore)]
pub struct TupleWithPadding(pub u16, pub u32, pub i64);

Expand Down

0 comments on commit 074c130

Please sign in to comment.