Skip to content

Commit

Permalink
Implement AggregateKind::{Adt,Closure,Generator}
Browse files Browse the repository at this point in the history
  • Loading branch information
tautschnig committed Apr 13, 2023
1 parent 4dacf96 commit 88cb51b
Show file tree
Hide file tree
Showing 3 changed files with 351 additions and 10 deletions.
307 changes: 301 additions & 6 deletions kani-compiler/src/codegen_cprover_gotoc/codegen/rvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ use cbmc::goto_program::{Expr, Location, Stmt, Symbol, Type};
use cbmc::MachineModel;
use cbmc::{btree_string_map, InternString, InternedString};
use num::bigint::BigInt;
use rustc_index::vec::IndexVec;
use rustc_middle::mir::{AggregateKind, BinOp, CastKind, NullOp, Operand, Place, Rvalue, UnOp};
use rustc_middle::ty::adjustment::PointerCast;
use rustc_middle::ty::layout::LayoutOf;
use rustc_middle::ty::{self, Instance, IntTy, Ty, TyCtxt, UintTy, VtblEntry};
use rustc_target::abi::{FieldsShape, Size, TagEncoding, Variants};
use rustc_middle::ty::{self, AdtDef, Instance, IntTy, Ty, TyCtxt, UintTy, VtblEntry};
use rustc_target::abi::{FieldsShape, LayoutS, Size, TagEncoding, VariantIdx, Variants};
use std::collections::BTreeMap;
use tracing::{debug, warn};

Expand Down Expand Up @@ -279,6 +280,266 @@ impl<'tcx> GotocCtx<'tcx> {
}
}

/// Create a struct expression for an enum using `Variants::Single` as layout (an enum where
/// only one variant has data).
fn codegen_rvalue_enum_single(
&mut self,
variant_index: &VariantIdx,
operands: &[Operand<'tcx>],
overall_t: Type,
adt: &AdtDef<'_>,
) -> Expr {
let variant = &adt.variants()[*variant_index];
let components = overall_t.lookup_components(&self.symbol_table).unwrap().clone();
Expr::struct_expr_with_nondet_fields(
overall_t,
variant
.fields
.iter()
.zip(operands.iter().zip(components.iter()))
.filter_map(|(f, (o, c))| {
let op_expr = self.codegen_operand(o);
let op_width = op_expr.typ().sizeof_in_bits(&self.symbol_table);
if op_width == 0 {
None
} else {
Some((
InternedString::from(f.name.to_string()),
op_expr.transmute_to(c.typ(), &self.symbol_table),
))
}
})
.collect(),
&self.symbol_table,
)
}

/// Create a struct expression for an enum using `Variants::Multiple` with direct encoding as
/// layout (a tagged enum).
fn codegen_rvalue_enum_direct(
&mut self,
variant_index: &VariantIdx,
operands: &[Operand<'tcx>],
overall_t: Type,
adt: &AdtDef<'_>,
variants: &IndexVec<VariantIdx, LayoutS<VariantIdx>>,
) -> Expr {
let fields = overall_t.lookup_components(&self.symbol_table).unwrap().clone();
assert_eq!(fields.len(), 2, "TagEncoding::Direct encountered for enum with empty variants");
assert_eq!(
fields[0].name().to_string(),
"case",
"Unexpected field in enum/generator. Please report your failing case at https://github.com/model-checking/kani/issues/1465"
);
let case_value = Expr::int_constant(variant_index.index(), fields[0].typ());
assert_eq!(
fields[1].name().to_string(),
"cases",
"Unexpected field in enum/generator. Please report your failing case at https://github.com/model-checking/kani/issues/1465"
);
assert!(matches!(variants[*variant_index].variants, Variants::Single { .. }));
let variant = &adt.variants()[*variant_index];
if variant.fields.is_empty() {
Expr::struct_expr_with_nondet_fields(
overall_t,
btree_string_map![("case", case_value)],
&self.symbol_table,
)
} else {
let target_component = fields[1]
.typ()
.lookup_field(variant.name.to_string(), &self.symbol_table)
.unwrap()
.clone();
let cases_value = Expr::union_expr(
fields[1].typ(),
target_component.name(),
Expr::struct_expr_from_values(
target_component.typ(),
variants[*variant_index]
.fields
.index_by_increasing_offset()
.map(|idx| self.codegen_operand(&operands[idx]))
.collect(),
&self.symbol_table,
),
&self.symbol_table,
);
Expr::struct_expr_from_values(
overall_t,
vec![case_value, cases_value],
&self.symbol_table,
)
}
}

/// Create an initializer for an enum using niche encoding. This is done while having access to
/// the lvalue so that it can be selectively updated when a variant is being used that is
/// smaller than the maximum (also known as untagged) variant.
pub fn codegen_enum_assignment(
&mut self,
lvalue: Expr,
variant_index: &VariantIdx,
operands: &[Operand<'tcx>],
ty: Ty<'tcx>,
location: Location,
) -> Stmt {
let overall_t = self.codegen_ty(ty);
let layout = self.layout_of(ty);
let adt = match &ty.kind() {
ty::Adt(adt, _) => adt,
_ => unreachable!(),
};
match &layout.variants {
Variants::Single { .. } => lvalue.assign(
self.codegen_rvalue_enum_single(variant_index, operands, overall_t, adt),
location,
),
Variants::Multiple { tag_encoding: TagEncoding::Direct, variants, .. } => lvalue
.assign(
self.codegen_rvalue_enum_direct(
variant_index,
operands,
overall_t,
adt,
variants,
),
location,
),
Variants::Multiple {
tag,
tag_encoding:
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start },
variants,
tag_field,
..
} => {
let variant_layout = &variants[*variant_index];
let target_and_member = {
let variant = &adt.variants()[*variant_index];
if variant_layout.size.bytes() == 0 {
None
} else {
let target_component = overall_t
.lookup_field(variant.name.to_string(), &self.symbol_table)
.unwrap()
.clone();
let member_expr = Expr::struct_expr_from_values(
target_component.typ(),
variant_layout
.fields
.index_by_increasing_offset()
.map(|idx| self.codegen_operand(&operands[idx]))
.collect(),
&self.symbol_table,
);
Some((target_component, member_expr))
}
};
if variant_index == untagged_variant {
// from codegen_enum_niche
let non_zst_count =
variants.iter().filter(|layout| layout.size.bytes() > 0).count();
let (target_component, member_expr) = match target_and_member {
Some((target_component, member_expr)) => (target_component, member_expr),
_ => unreachable!(),
};
if non_zst_count > 1 {
lvalue.assign(
Expr::union_expr(
overall_t,
target_component.name(),
member_expr,
&self.symbol_table,
),
location,
)
} else {
lvalue.assign(
Expr::struct_expr_with_nondet_fields(
overall_t,
btree_string_map![(target_component.name(), member_expr)],
&self.symbol_table,
),
location,
)
}
} else {
let niche_value = self.compute_enum_niche_value(
ty,
variant_index,
tag,
niche_variants,
niche_start,
);
let niche_offset = layout.fields.offset(*tag_field);
let lvalue_niche = self.codegen_get_niche(
lvalue.clone(),
niche_offset,
niche_value.typ().clone(),
);
if variant_layout.size.bytes() > 0 {
let (target_component, member_expr) = match target_and_member {
Some((target_component, member_expr)) => {
(target_component, member_expr)
}
_ => unreachable!(),
};
let lvalue_data = lvalue.reinterpret_cast(target_component.typ());
// Assign the tag (niche) _after_ the data as the latter may include
// non-deterministic padding values at the address where the nice is. The
// subsequent niche assignment will then set those to a fixed value.
Stmt::block(
vec![
lvalue_data.assign(member_expr, location),
lvalue_niche.assign(niche_value, location),
],
location,
)
} else {
lvalue_niche.assign(niche_value, location)
}
}
}
}
}

/// Create an initializer for a generator struct.
fn codegen_rvalue_generator(&mut self, operands: &[Operand<'tcx>], ty: Ty<'tcx>) -> Expr {
let layout = self.layout_of(ty);
let discriminant_field = match &layout.variants {
Variants::Multiple { tag_encoding: TagEncoding::Direct, tag_field, .. } => tag_field,
_ => unreachable!("Generators have more than one variant and use direct encoding"),
};
let overall_t = self.codegen_ty(ty);
let direct_fields = overall_t.lookup_field("direct_fields", &self.symbol_table).unwrap();
let mut operands_iter = operands.iter();
let direct_fields_expr = Expr::struct_expr_with_nondet_fields(
direct_fields.typ(),
layout
.fields
.index_by_increasing_offset()
.map(|idx| {
let field_ty = layout.field(self, idx).ty;
if idx == *discriminant_field {
(
InternedString::from("case"),
Expr::int_constant(0, self.codegen_ty(field_ty)),
)
} else {
(
self.generator_field_name(idx),
self.codegen_operand(operands_iter.next().unwrap()),
)
}
})
.collect(),
&self.symbol_table,
);
assert!(operands_iter.next().is_none());
Expr::union_expr(overall_t, "direct_fields", direct_fields_expr, &self.symbol_table)
}

fn codegen_rvalue_aggregate(
&mut self,
k: &AggregateKind<'tcx>,
Expand All @@ -304,7 +565,43 @@ impl<'tcx> GotocCtx<'tcx> {
)
}
}
AggregateKind::Tuple => {
AggregateKind::Adt(_, _, _, _, Some(active_field_index)) => {
assert!(res_ty.is_union());
assert_eq!(operands.len(), 1);
let typ = self.codegen_ty(res_ty);
let components = typ.lookup_components(&self.symbol_table).unwrap();
Expr::union_expr(
typ,
components[active_field_index].name(),
self.codegen_operand(&operands[0]),
&self.symbol_table,
)
}
AggregateKind::Adt(_, _, _, _, _) if res_ty.is_simd() => {
let typ = self.codegen_ty(res_ty);
let layout = self.layout_of(res_ty);
let vector_element_type = typ.base_type().unwrap().clone();
Expr::vector_expr(
typ,
layout
.fields
.index_by_increasing_offset()
.map(|idx| {
let cgo = self.codegen_operand(&operands[idx]);
if *cgo.typ() == vector_element_type {
cgo
} else {
cgo.transmute_to(vector_element_type.clone(), &self.symbol_table)
}
})
.collect(),
)
}
AggregateKind::Adt(..) if res_ty.is_enum() => {
// codegen_statement handles this case so that we have access to the lvalue
unreachable!()
}
AggregateKind::Adt(..) | AggregateKind::Closure(..) | AggregateKind::Tuple => {
let typ = self.codegen_ty(res_ty);
let layout = self.layout_of(res_ty);
Expr::struct_expr_from_values(
Expand All @@ -317,9 +614,7 @@ impl<'tcx> GotocCtx<'tcx> {
&self.symbol_table,
)
}
AggregateKind::Adt(_, _, _, _, _) => unimplemented!(),
AggregateKind::Closure(_, _) => unimplemented!(),
AggregateKind::Generator(_, _, _) => unimplemented!(),
AggregateKind::Generator(_, _, _) => self.codegen_rvalue_generator(&operands, res_ty),
}
}

Expand Down
31 changes: 27 additions & 4 deletions kani-compiler/src/codegen_cprover_gotoc/codegen/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use cbmc::goto_program::{Expr, Location, Stmt, Type};
use rustc_hir::def_id::DefId;
use rustc_middle::mir;
use rustc_middle::mir::{
AssertKind, BasicBlock, NonDivergingIntrinsic, Operand, Place, Statement, StatementKind,
SwitchTargets, Terminator, TerminatorKind,
AggregateKind, AssertKind, BasicBlock, NonDivergingIntrinsic, Operand, Place, Rvalue,
Statement, StatementKind, SwitchTargets, Terminator, TerminatorKind,
};
use rustc_middle::ty;
use rustc_middle::ty::layout::LayoutOf;
Expand Down Expand Up @@ -48,9 +48,32 @@ impl<'tcx> GotocCtx<'tcx> {
.goto_expr
.assign(self.codegen_rvalue(r, location).cast_to(Type::c_bool()), location)
} else {
unwrap_or_return_codegen_unimplemented_stmt!(self, self.codegen_place(l))
match r {
Rvalue::Aggregate(ref k, operands) if rty.is_enum() => {
if let AggregateKind::Adt(_, variant_index, _, _, _) = **k {
let lvalue_expr = unwrap_or_return_codegen_unimplemented_stmt!(
self,
self.codegen_place(l)
)
.goto_expr;
self.codegen_enum_assignment(
lvalue_expr,
&variant_index,
&operands,
rty,
location,
)
} else {
unreachable!()
}
}
_ => unwrap_or_return_codegen_unimplemented_stmt!(
self,
self.codegen_place(l)
)
.goto_expr
.assign(self.codegen_rvalue(r, location), location)
.assign(self.codegen_rvalue(r, location), location),
}
}
}
StatementKind::Deinit(place) => self.codegen_deinit(place, location),
Expand Down
Loading

0 comments on commit 88cb51b

Please sign in to comment.