Skip to content

Commit

Permalink
Adjust enum aggregation
Browse files Browse the repository at this point in the history
Simplify the code by reusing some of the existing logic
  • Loading branch information
celinval authored and tautschnig committed Apr 14, 2023
1 parent 6693f04 commit c34d9a5
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 345 deletions.
11 changes: 11 additions & 0 deletions kani-compiler/src/codegen_cprover_gotoc/codegen/place.rs
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,17 @@ impl<'tcx> GotocCtx<'tcx> {
}
}

/// Given a projection, generate an lvalue that represents the given variant index.
pub fn codegen_variant_lvalue(
&mut self,
initial_projection: ProjectedPlace<'tcx>,
variant_idx: VariantIdx,
) -> ProjectedPlace<'tcx> {
debug!(?initial_projection, ?variant_idx, "codegen_variant_lvalue");
let downcast = ProjectionElem::Downcast(None, variant_idx);
self.codegen_projection(Ok(initial_projection), downcast).unwrap()
}

// https://doc.rust-lang.org/nightly/nightly-rustc/rustc_middle/mir/enum.ProjectionElem.html
// ConstantIndex
// [−]
Expand Down
314 changes: 80 additions & 234 deletions kani-compiler/src/codegen_cprover_gotoc/codegen/rvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 OR MIT

use super::typ::pointee_type;
use crate::codegen_cprover_gotoc::codegen::place::{ProjectedPlace, TypeOrVariant};
use crate::codegen_cprover_gotoc::codegen::PropertyClass;
use crate::codegen_cprover_gotoc::utils::{dynamic_fat_ptr, slice_fat_ptr};
use crate::codegen_cprover_gotoc::{GotocCtx, VtableCtx};
Expand All @@ -13,14 +14,13 @@ use cbmc::goto_program::{Expr, Location, Stmt, Symbol, Type};
use cbmc::MachineModel;
use cbmc::{btree_string_map, InternString, InternedString};
use num::bigint::BigInt;
use rustc_index::vec::IndexVec;
use rustc_middle::mir::{AggregateKind, BinOp, CastKind, NullOp, Operand, Place, Rvalue, UnOp};
use rustc_middle::ty::adjustment::PointerCast;
use rustc_middle::ty::layout::LayoutOf;
use rustc_middle::ty::{self, AdtDef, Instance, IntTy, Ty, TyCtxt, UintTy, VtblEntry};
use rustc_target::abi::{FieldsShape, LayoutS, Size, TagEncoding, VariantIdx, Variants};
use rustc_middle::ty::{self, Instance, IntTy, Ty, TyCtxt, UintTy, VtblEntry};
use rustc_target::abi::{FieldsShape, Size, TagEncoding, VariantIdx, Variants};
use std::collections::BTreeMap;
use tracing::{debug, warn};
use tracing::{debug, trace, warn};

impl<'tcx> GotocCtx<'tcx> {
fn codegen_comparison(&mut self, op: &BinOp, e1: &Operand<'tcx>, e2: &Operand<'tcx>) -> Expr {
Expand Down Expand Up @@ -280,230 +280,6 @@ impl<'tcx> GotocCtx<'tcx> {
}
}

/// Create a struct expression for an enum using `Variants::Single` as layout (an enum where
/// only one variant has data).
fn codegen_rvalue_enum_single(
&mut self,
variant_index: &VariantIdx,
operands: &[Operand<'tcx>],
overall_t: Type,
adt: &AdtDef<'_>,
) -> Expr {
let variant = &adt.variants()[*variant_index];
let components = overall_t.lookup_components(&self.symbol_table).unwrap().clone();
Expr::struct_expr_with_nondet_fields(
overall_t,
variant
.fields
.iter()
.zip(operands.iter().zip(components.iter()))
.filter_map(|(f, (o, c))| {
let op_expr = self.codegen_operand(o);
let op_width = op_expr.typ().sizeof_in_bits(&self.symbol_table);
if op_width == 0 {
None
} else {
Some((
InternedString::from(f.name.to_string()),
op_expr.transmute_to(c.typ(), &self.symbol_table),
))
}
})
.collect(),
&self.symbol_table,
)
}

/// Create a struct expression for an enum using `Variants::Multiple` with direct encoding as
/// layout (a tagged enum).
fn codegen_rvalue_enum_direct(
&mut self,
variant_index: &VariantIdx,
operands: &[Operand<'tcx>],
overall_t: Type,
adt: &AdtDef<'_>,
variants: &IndexVec<VariantIdx, LayoutS<VariantIdx>>,
) -> Expr {
let fields = overall_t.lookup_components(&self.symbol_table).unwrap().clone();
assert_eq!(fields.len(), 2, "TagEncoding::Direct encountered for enum with empty variants");
assert_eq!(
fields[0].name().to_string(),
"case",
"Unexpected field in enum/generator. Please report your failing case at https://github.com/model-checking/kani/issues/1465"
);
let case_value = Expr::int_constant(variant_index.index(), fields[0].typ());
assert_eq!(
fields[1].name().to_string(),
"cases",
"Unexpected field in enum/generator. Please report your failing case at https://github.com/model-checking/kani/issues/1465"
);
assert!(matches!(variants[*variant_index].variants, Variants::Single { .. }));
let variant = &adt.variants()[*variant_index];
if variant.fields.is_empty() {
Expr::struct_expr_with_nondet_fields(
overall_t,
btree_string_map![("case", case_value)],
&self.symbol_table,
)
} else {
let target_component = fields[1]
.typ()
.lookup_field(variant.name.to_string(), &self.symbol_table)
.unwrap()
.clone();
let cases_value = Expr::union_expr(
fields[1].typ(),
target_component.name(),
Expr::struct_expr_from_values(
target_component.typ(),
variants[*variant_index]
.fields
.index_by_increasing_offset()
.map(|idx| self.codegen_operand(&operands[idx]))
.collect(),
&self.symbol_table,
),
&self.symbol_table,
);
Expr::struct_expr_from_values(
overall_t,
vec![case_value, cases_value],
&self.symbol_table,
)
}
}

/// Create an initializer for an enum using niche encoding. This is done while having access to
/// the lvalue so that it can be selectively updated when a variant is being used that is
/// smaller than the maximum (also known as untagged) variant.
pub fn codegen_enum_assignment(
&mut self,
lvalue: Expr,
variant_index: &VariantIdx,
operands: &[Operand<'tcx>],
ty: Ty<'tcx>,
location: Location,
) -> Stmt {
let overall_t = self.codegen_ty(ty);
let layout = self.layout_of(ty);
let adt = match &ty.kind() {
ty::Adt(adt, _) => adt,
_ => unreachable!(),
};
match &layout.variants {
Variants::Single { .. } => lvalue.assign(
self.codegen_rvalue_enum_single(variant_index, operands, overall_t, adt),
location,
),
Variants::Multiple { tag_encoding: TagEncoding::Direct, variants, .. } => lvalue
.assign(
self.codegen_rvalue_enum_direct(
variant_index,
operands,
overall_t,
adt,
variants,
),
location,
),
Variants::Multiple {
tag,
tag_encoding:
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start },
variants,
tag_field,
..
} => {
let variant_layout = &variants[*variant_index];
let target_and_member = {
let variant = &adt.variants()[*variant_index];
if variant_layout.size.bytes() == 0 {
None
} else {
let target_component = overall_t
.lookup_field(variant.name.to_string(), &self.symbol_table)
.unwrap()
.clone();
let member_expr = Expr::struct_expr_from_values(
target_component.typ(),
variant_layout
.fields
.index_by_increasing_offset()
.map(|idx| self.codegen_operand(&operands[idx]))
.collect(),
&self.symbol_table,
);
Some((target_component, member_expr))
}
};
if variant_index == untagged_variant {
// from codegen_enum_niche
let non_zst_count =
variants.iter().filter(|layout| layout.size.bytes() > 0).count();
let (target_component, member_expr) = match target_and_member {
Some((target_component, member_expr)) => (target_component, member_expr),
_ => unreachable!(),
};
if non_zst_count > 1 {
lvalue.assign(
Expr::union_expr(
overall_t,
target_component.name(),
member_expr,
&self.symbol_table,
),
location,
)
} else {
lvalue.assign(
Expr::struct_expr_with_nondet_fields(
overall_t,
btree_string_map![(target_component.name(), member_expr)],
&self.symbol_table,
),
location,
)
}
} else {
let niche_value = self.compute_enum_niche_value(
ty,
variant_index,
tag,
niche_variants,
niche_start,
);
let niche_offset = layout.fields.offset(*tag_field);
let lvalue_niche = self.codegen_get_niche(
lvalue.clone(),
niche_offset,
niche_value.typ().clone(),
);
if variant_layout.size.bytes() > 0 {
let (target_component, member_expr) = match target_and_member {
Some((target_component, member_expr)) => {
(target_component, member_expr)
}
_ => unreachable!(),
};
let lvalue_data = lvalue.reinterpret_cast(target_component.typ());
// Assign the tag (niche) _after_ the data as the latter may include
// non-deterministic padding values at the address where the nice is. The
// subsequent niche assignment will then set those to a fixed value.
Stmt::block(
vec![
lvalue_data.assign(member_expr, location),
lvalue_niche.assign(niche_value, location),
],
location,
)
} else {
lvalue_niche.assign(niche_value, location)
}
}
}
}
}

/// Create an initializer for a generator struct.
fn codegen_rvalue_generator(&mut self, operands: &[Operand<'tcx>], ty: Ty<'tcx>) -> Expr {
let layout = self.layout_of(ty);
Expand Down Expand Up @@ -540,13 +316,84 @@ impl<'tcx> GotocCtx<'tcx> {
Expr::union_expr(overall_t, "direct_fields", direct_fields_expr, &self.symbol_table)
}

/// This code will generate an expression that initializes an enumeration.
///
/// It will first create a temporary variant with the same enum type.
/// Initialize the case structure and set its discriminant.
/// Finally, it will return the temporary value.
fn codegen_rvalue_enum_aggregate(
&mut self,
variant_index: VariantIdx,
operands: &[Operand<'tcx>],
res_ty: Ty<'tcx>,
loc: Location,
) -> Expr {
let mut stmts = vec![];
let typ = self.codegen_ty(res_ty);
// 1- Create a temporary value of the enum type.
tracing::debug!(?typ, ?res_ty, "aggregate_enum");
let (temp_var, decl) = self.decl_temp_variable(typ.clone(), None, loc);
stmts.push(decl);
if !operands.is_empty() {
// 2- Initialize the members of the temporary variant.
let initial_projection = ProjectedPlace::try_new(
temp_var.clone(),
TypeOrVariant::Type(res_ty),
None,
None,
self,
)
.unwrap();
let variant_proj = self.codegen_variant_lvalue(initial_projection, variant_index);
let variant_expr = variant_proj.goto_expr.clone();
let layout = self.layout_of(res_ty);
let fields = match &layout.variants {
Variants::Single { index } => {
if *index != variant_index {
// This may occur if all variants except for the one pointed by
// index can never be constructed. Generic code might still try
// to initialize the non-existing invariant.
trace!(?res_ty, ?variant_index, "Unreachable invariant");
return Expr::nondet(typ);
}
&layout.fields
}
Variants::Multiple { variants, .. } => &variants[variant_index].fields,
};

debug!(?variant_expr, ?fields, ?operands, "codegen_aggregate enum");
let init_struct = Expr::struct_expr_from_values(
variant_expr.typ().clone(),
fields
.index_by_increasing_offset()
.map(|idx| {
let op = self.codegen_operand(&operands[idx]);
debug!(?op, ?idx, "codegen_aggregate enum op");
op
})
.collect(),
&self.symbol_table,
);
let assign_case = variant_proj.goto_expr.assign(init_struct, loc);
stmts.push(assign_case);
}
// 3- Set discriminant.
let set_discriminant =
self.codegen_set_discriminant(res_ty, temp_var.clone(), variant_index, loc);
stmts.push(set_discriminant);
// 4- Return temporary variable.
stmts.push(temp_var.as_stmt(loc));
Expr::statement_expression(stmts, typ)
}

fn codegen_rvalue_aggregate(
&mut self,
k: &AggregateKind<'tcx>,
aggregate: &AggregateKind<'tcx>,
operands: &[Operand<'tcx>],
res_ty: Ty<'tcx>,
loc: Location,
) -> Expr {
match *k {
match *aggregate {
AggregateKind::Array(et) => {
if et.is_unit() {
Expr::struct_expr_from_values(
Expand Down Expand Up @@ -597,9 +444,8 @@ impl<'tcx> GotocCtx<'tcx> {
.collect(),
)
}
AggregateKind::Adt(..) if res_ty.is_enum() => {
// codegen_statement handles this case so that we have access to the lvalue
unreachable!()
AggregateKind::Adt(_, variant_index, ..) if res_ty.is_enum() => {
self.codegen_rvalue_enum_aggregate(variant_index, operands, res_ty, loc)
}
AggregateKind::Adt(..) | AggregateKind::Closure(..) | AggregateKind::Tuple => {
let typ = self.codegen_ty(res_ty);
Expand Down Expand Up @@ -701,7 +547,7 @@ impl<'tcx> GotocCtx<'tcx> {
self.codegen_get_discriminant(place, pt, res_ty)
}
Rvalue::Aggregate(ref k, operands) => {
self.codegen_rvalue_aggregate(k, operands, res_ty)
self.codegen_rvalue_aggregate(k, operands, res_ty, loc)
}
Rvalue::ThreadLocalRef(def_id) => {
// Since Kani is single-threaded, we treat a thread local like a static variable:
Expand Down
Loading

0 comments on commit c34d9a5

Please sign in to comment.