Skip to content

Commit

Permalink
Fix soundness issue of TransparentWrapper derive macro. (#173)
Browse files Browse the repository at this point in the history
Uses the compiler to check that all non-wrapped fields are actually 1-ZSTs,
and uses Zeroable to check that all non-wrapped fields are "conjurable".

Additionally, relaxes the bound of `PhantomData<T: Zeroable>: Zeroable` to all `T: ?Sized`.
  • Loading branch information
zachs18 authored Feb 17, 2023
1 parent d1655f5 commit 1039388
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 19 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ Cargo.lock
**/*.rs.bk

/derive/target/
/derive/.vscode/
43 changes: 32 additions & 11 deletions derive/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,7 @@ impl Derivable for CheckedBitPattern {

Ok(assert_fields_are_maybe_pod)
}
Data::Enum(_) => Ok(quote!()), /* nothing needed, already guaranteed
* OK by NoUninit */
Data::Enum(_) => Ok(quote!()), /* nothing needed, already guaranteed OK by NoUninit */
Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */
}
}
Expand Down Expand Up @@ -273,21 +272,43 @@ impl Derivable for TransparentWrapper {
}

fn asserts(input: &DeriveInput) -> Result<TokenStream> {
let (impl_generics, _ty_generics, where_clause) =
input.generics.split_for_impl();
let fields = get_struct_fields(input)?;
let wrapped_type = match Self::get_wrapper_type(&input.attrs, &fields) {
Some(wrapped_type) => wrapped_type.to_string(),
None => unreachable!(), /* other code will already reject this derive */
};
let mut wrapped_fields = fields
.iter()
.filter(|field| field.ty.to_token_stream().to_string() == wrapped_type);
if let None = wrapped_fields.next() {
bail!("TransparentWrapper must have one field of the wrapped type");
};
if let Some(_) = wrapped_fields.next() {
bail!("TransparentWrapper can only have one field of the wrapped type")
let mut wrapped_field_ty = None;
let mut nonwrapped_field_tys = vec![];
for field in fields.iter() {
let field_ty = &field.ty;
if field_ty.to_token_stream().to_string() == wrapped_type {
if wrapped_field_ty.is_some() {
bail!(
"TransparentWrapper can only have one field of the wrapped type"
);
}
wrapped_field_ty = Some(field_ty);
} else {
nonwrapped_field_tys.push(field_ty);
}
}
if let Some(wrapped_field_ty) = wrapped_field_ty {
Ok(quote!(
const _: () = {
#[repr(transparent)]
struct AssertWrappedIsWrapped #impl_generics((u8, ::core::marker::PhantomData<#wrapped_field_ty>), #(#nonwrapped_field_tys),*) #where_clause;
fn assert_zeroable<Z: ::bytemuck::Zeroable>() {}
fn check #impl_generics () #where_clause {
#(
assert_zeroable::<#nonwrapped_field_tys>();
)*
}
};
))
} else {
Ok(quote!())
bail!("TransparentWrapper must have one field of the wrapped type")
}
}

Expand Down
34 changes: 32 additions & 2 deletions derive/tests/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use bytemuck::{
AnyBitPattern, CheckedBitPattern, Contiguous, NoUninit, Pod,
TransparentWrapper, Zeroable,
};
use std::marker::PhantomData;
use std::marker::{PhantomData, PhantomPinned};

#[derive(Copy, Clone, Pod, Zeroable)]
#[repr(C)]
Expand Down Expand Up @@ -64,6 +64,14 @@ struct TransparentWithZeroSized<T> {
b: PhantomData<T>,
}

struct MyZst<T>(PhantomData<T>, [u8; 0], PhantomPinned);
unsafe impl<T> Zeroable for MyZst<T> {}

#[derive(TransparentWrapper)]
#[repr(transparent)]
#[transparent(u16)]
struct TransparentTupleWithCustomZeroSized<T>(u16, MyZst<T>);

#[repr(u8)]
#[derive(Clone, Copy, Contiguous)]
enum ContiguousWithValues {
Expand Down Expand Up @@ -169,6 +177,21 @@ struct AnyBitPatternTest {
#[repr(transparent)]
struct NewtypeWrapperTest<T>(T);

/// ```compile_fail
/// use bytemuck::TransparentWrapper;
///
/// struct NonTransparentSafeZST;
///
/// #[derive(TransparentWrapper)]
/// #[repr(transparent)]
/// struct Wrapper<T>(T, NonTransparentSafeZST);
/// ```
#[derive(
Debug, Copy, Clone, PartialEq, Eq, Pod, Zeroable, TransparentWrapper,
)]
#[repr(transparent)]
struct TransarentWrapperZstTest<T>(T);

#[test]
fn fails_cast_contiguous() {
let can_cast = CheckedBitPatternEnumWithValues::is_valid_bit_pattern(&5);
Expand Down Expand Up @@ -207,7 +230,14 @@ fn fails_cast_bytelit() {
fn passes_cast_bytelit() {
let res =
bytemuck::checked::cast_slice::<u8, CheckedBitPatternEnumByteLit>(b"CAB");
assert_eq!(res, [CheckedBitPatternEnumByteLit::C, CheckedBitPatternEnumByteLit::A, CheckedBitPatternEnumByteLit::B]);
assert_eq!(
res,
[
CheckedBitPatternEnumByteLit::C,
CheckedBitPatternEnumByteLit::A,
CheckedBitPatternEnumByteLit::B
]
);
}

#[test]
Expand Down
41 changes: 40 additions & 1 deletion src/transparent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ use super::*;
/// the only non-ZST field.
///
/// 2. Any fields *other* than the `Inner` field must be trivially constructable
/// ZSTs, for example `PhantomData`, `PhantomPinned`, etc.
/// ZSTs, for example `PhantomData`, `PhantomPinned`, etc. (When deriving
/// `TransparentWrapper` on a type with ZST fields, the ZST fields must be
/// [`Zeroable`]).
///
/// 3. The `Wrapper` may not impose additional alignment requirements over
/// `Inner`.
Expand Down Expand Up @@ -84,6 +86,43 @@ use super::*;
/// let mut buf = [1, 2, 3u8];
/// let sm = Slice::wrap_mut(&mut buf);
/// ```
///
/// ## Deriving
///
/// When deriving, the non-wrapped fields must uphold all the normal requirements,
/// and must also be `Zeroable`.
///
#[cfg_attr(feature = "derive", doc = "```")]
#[cfg_attr(
not(feature = "derive"),
doc = "```ignore
// This example requires the `derive` feature."
)]
/// use bytemuck::TransparentWrapper;
/// use std::marker::PhantomData;
///
/// #[derive(TransparentWrapper)]
/// #[repr(transparent)]
/// #[transparent(usize)]
/// struct Wrapper<T: ?Sized>(usize, PhantomData<T>); // PhantomData<T> implements Zeroable for all T
/// ```
///
/// Here, an error will occur, because `MyZst` does not implement `Zeroable`.
///
#[cfg_attr(feature = "derive", doc = "```compile_fail")]
#[cfg_attr(
not(feature = "derive"),
doc = "```ignore
// This example requires the `derive` feature."
)]
/// use bytemuck::TransparentWrapper;
/// struct MyZst;
///
/// #[derive(TransparentWrapper)]
/// #[repr(transparent)]
/// #[transparent(usize)]
/// struct Wrapper(usize, MyZst); // MyZst does not implement Zeroable
/// ```
pub unsafe trait TransparentWrapper<Inner: ?Sized> {
/// Convert the inner type into the wrapper type.
#[inline]
Expand Down
2 changes: 1 addition & 1 deletion src/zeroable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ unsafe impl<T> Zeroable for *const [T] {}
unsafe impl Zeroable for *mut str {}
unsafe impl Zeroable for *const str {}

unsafe impl<T: Zeroable> Zeroable for PhantomData<T> {}
unsafe impl<T: ?Sized> Zeroable for PhantomData<T> {}
unsafe impl Zeroable for PhantomPinned {}
unsafe impl<T: Zeroable> Zeroable for ManuallyDrop<T> {}
unsafe impl<T: Zeroable> Zeroable for core::cell::UnsafeCell<T> {}
Expand Down
34 changes: 30 additions & 4 deletions tests/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#![allow(dead_code)]

use bytemuck::{ByteEq, ByteHash, Pod, TransparentWrapper, Zeroable};
use std::marker::PhantomData;

#[derive(Copy, Clone, Pod, Zeroable, ByteEq, ByteHash)]
#[repr(C)]
Expand All @@ -26,7 +27,7 @@ struct TransparentWithZeroSized {

#[derive(TransparentWrapper)]
#[repr(transparent)]
struct TransparentWithGeneric<T> {
struct TransparentWithGeneric<T: ?Sized> {
a: T,
}

Expand All @@ -39,13 +40,38 @@ fn test_generic<T>(x: T) -> TransparentWithGeneric<T> {
#[derive(TransparentWrapper)]
#[repr(transparent)]
#[transparent(T)]
struct TransparentWithGenericAndZeroSized<T> {
a: T,
b: ()
struct TransparentWithGenericAndZeroSized<T: ?Sized> {
a: (),
b: T,
}

/// Ensuring that no additional bounds are emitted.
/// See https://github.com/Lokathor/bytemuck/issues/145
fn test_generic_with_zst<T>(x: T) -> TransparentWithGenericAndZeroSized<T> {
TransparentWithGenericAndZeroSized::wrap(x)
}

#[derive(TransparentWrapper)]
#[repr(transparent)]
struct TransparentUnsized {
a: dyn std::fmt::Debug,
}

type DynDebug = dyn std::fmt::Debug;

#[derive(TransparentWrapper)]
#[repr(transparent)]
#[transparent(DynDebug)]
struct TransparentUnsizedWithZeroSized {
a: (),
b: DynDebug,
}

#[derive(TransparentWrapper)]
#[repr(transparent)]
#[transparent(DynDebug)]
struct TransparentUnsizedWithGenericZeroSizeds<T: ?Sized, U: ?Sized> {
a: PhantomData<T>,
b: PhantomData<U>,
c: DynDebug,
}

0 comments on commit 1039388

Please sign in to comment.