Skip to content

Commit

Permalink
Refactor safe body generation
Browse files Browse the repository at this point in the history
  • Loading branch information
adpaco-aws committed Jul 26, 2024
1 parent 5078b07 commit b05de12
Showing 1 changed file with 34 additions and 32 deletions.
66 changes: 34 additions & 32 deletions library/kani_macros/src/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,27 +386,9 @@ pub fn expand_derive_invariant(item: proc_macro::TokenStream) -> proc_macro::Tok
let derive_item = parse_macro_input!(item as DeriveInput);
let item_name = &derive_item.ident;

let has_item_safety_constraint =
derive_item.attrs.iter().any(|attr| attr.path().is_ident("safety_constraint"));
let has_field_safety_constraints = has_field_safety_constraints(&item_name, &derive_item.data);

if has_item_safety_constraint && has_field_safety_constraints {
abort!(Span::call_site(), "Cannot derive `Invariant` for `{}`", item_name;
note = item_name.span() =>
"`#[safety_constaint(...)]` cannot be used in struct AND its fields"
)
}

let safe_body = safe_body(&item_name, &derive_item);
let field_refs = field_refs(&item_name, &derive_item.data);

let safe_body_from_attrs = if has_item_safety_constraint {
safe_body_from_struct_attr(&item_name, &derive_item)
} else {
safe_body_from_fields_attr(&item_name, &derive_item.data)
};

let safe_body_default = safe_body_default(&item_name, &derive_item.data);

// Add a bound `T: Invariant` to every type parameter T.
let generics = add_trait_bound_invariant(derive_item.generics);
// Generate an expression to sum up the heap size of each field.
Expand All @@ -418,13 +400,42 @@ pub fn expand_derive_invariant(item: proc_macro::TokenStream) -> proc_macro::Tok
fn is_safe(&self) -> bool {
let obj = self;
#field_refs
#safe_body_default && #safe_body_from_attrs
#safe_body
}
}
};
proc_macro::TokenStream::from(expanded)
}

fn safe_body(item_name: &Ident, derive_input: &DeriveInput) -> TokenStream {
let has_item_safety_constraint =
derive_input.attrs.iter().any(|attr| attr.path().is_ident("safety_constraint"));
let has_field_safety_constraints = has_field_safety_constraints(&item_name, &derive_input.data);

if has_item_safety_constraint && has_field_safety_constraints {
abort!(Span::call_site(), "Cannot derive `Invariant` for `{}`", item_name;
note = item_name.span() =>
"`#[safety_constaint(...)]` cannot be used in struct AND its fields"
)
}

let safe_body_from_attrs_opt: Option<TokenStream> = if has_item_safety_constraint {
Some(safe_body_from_struct_attr(&item_name, &derive_input))
} else if has_field_safety_constraints {
Some(safe_body_from_fields_attr(&item_name, &derive_input.data))
} else {
None
};

let safe_body_default = safe_body_default(&item_name, &derive_input.data);

if let Some(safe_body_from_attrs) = safe_body_from_attrs_opt {
quote! { #safe_body_default && #safe_body_from_attrs }
} else {
safe_body_default
}
}

fn has_field_safety_constraints(ident: &Ident, data: &Data) -> bool {
match data {
Data::Struct(struct_data) => has_field_safety_constraints_inner(ident, &struct_data.fields),
Expand Down Expand Up @@ -480,21 +491,12 @@ fn safe_body_from_fields_attr(ident: &Ident, data: &Data) -> TokenStream {
fn struct_invariant_conjunction(ident: &Ident, fields: &Fields) -> TokenStream {
match fields {
// Expands to the expression
// `true && <safety_cond1> && <safety_cond2> && ..`
// where `safety_condN` is
// - `self.fieldN.is_safe() && <cond>` if a condition `<cond>` was
// specified through the `#[safety_constraint(<cond>)]` helper attribute, or
// - `self.fieldN.is_safe()` otherwise
//
// Therefore, if `#[safety_constraint(<cond>)]` isn't specified for any field, this expands to
// `true && self.field1.is_safe() && self.field2.is_safe() && ..`
// `<safety_cond1> && <safety_cond2> && ..`
// where `<safety_condN>` is the safety condition specified for the N-th field.
Fields::Named(ref fields) => {
let safety_conds: Vec<TokenStream> =
fields.named.iter().filter_map(|field| parse_safety_expr(ident, field)).collect();
// An initial value is required for empty structs
safety_conds.iter().fold(quote! { true }, |acc, cond| {
quote! { #acc && #cond }
})
quote! { #(#safety_conds)&&* }
}
Fields::Unnamed(_) => {
quote! {
Expand Down

0 comments on commit b05de12

Please sign in to comment.