Skip to content

Commit

Permalink
Factor out niche value computation
Browse files Browse the repository at this point in the history
There will be future re-use of this code, so move it to a public
function.
  • Loading branch information
tautschnig committed Apr 13, 2023
1 parent 868b13b commit 4e59e23
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 19 deletions.
31 changes: 13 additions & 18 deletions kani-compiler/src/codegen_cprover_gotoc/codegen/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use rustc_middle::ty;
use rustc_middle::ty::layout::LayoutOf;
use rustc_middle::ty::{Instance, InstanceDef, Ty};
use rustc_span::Span;
use rustc_target::abi::{FieldsShape, Primitive, TagEncoding, Variants};
use rustc_target::abi::{FieldsShape, TagEncoding, Variants};
use tracing::{debug, debug_span, trace};

impl<'tcx> GotocCtx<'tcx> {
Expand Down Expand Up @@ -88,28 +88,23 @@ impl<'tcx> GotocCtx<'tcx> {
}
TagEncoding::Niche { untagged_variant, niche_variants, niche_start } => {
if untagged_variant != variant_index {
let offset = match &layout.fields {
FieldsShape::Arbitrary { offsets, .. } => offsets[0],
_ => unreachable!("niche encoding must have arbitrary fields"),
};
let discr_ty = self.codegen_enum_discr_typ(pt);
let discr_ty = self.codegen_ty(discr_ty);
let niche_value =
variant_index.as_u32() - niche_variants.start().as_u32();
let niche_value = (niche_value as u128).wrapping_add(*niche_start);
let value = if niche_value == 0
&& matches!(tag.primitive(), Primitive::Pointer(_))
{
discr_ty.null()
} else {
Expr::int_constant(niche_value, discr_ty.clone())
};
let value = self.compute_enum_niche_value(
pt,
variant_index,
tag,
niche_variants,
niche_start,
);
let place = unwrap_or_return_codegen_unimplemented_stmt!(
self,
self.codegen_place(place)
)
.goto_expr;
self.codegen_get_niche(place, offset, discr_ty)
let offset = match &layout.fields {
FieldsShape::Arbitrary { offsets, .. } => offsets[0],
_ => unreachable!("niche encoding must have arbitrary fields"),
};
self.codegen_get_niche(place, offset, value.typ().clone())
.assign(value, location)
} else {
Stmt::skip(location)
Expand Down
27 changes: 26 additions & 1 deletion kani-compiler/src/codegen_cprover_gotoc/codegen/typ.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ use rustc_middle::ty::{
use rustc_middle::ty::{List, TypeFoldable};
use rustc_span::def_id::DefId;
use rustc_target::abi::{
Abi::Vector, FieldsShape, Integer, LayoutS, Primitive, Size, TagEncoding, TyAndLayout,
Abi::Vector, FieldsShape, Integer, LayoutS, Primitive, Scalar, Size, TagEncoding, TyAndLayout,
VariantIdx, Variants,
};
use rustc_target::spec::abi::Abi;
use std::iter;
use std::ops::RangeInclusive;
use tracing::{debug, trace, warn};
use ty::layout::HasParamEnv;

Expand Down Expand Up @@ -1572,6 +1573,30 @@ impl<'tcx> GotocCtx<'tcx> {
}
}

/// Compute the discriminant expression for an enum that uses niche optimization.
///
/// We follow the logic of the SSA and Cranelift back-ends in doing the computation:
/// https://github.com/rust-lang/rust/blob/master/compiler/rustc_codegen_ssa/src/mir/place.rs#L455
/// https://github.com/rust-lang/rust/blob/d37e2f74afd131cda7b08520d37426bfbb622b5c/compiler/rustc_codegen_cranelift/src/discriminant.rs#L52
pub fn compute_enum_niche_value(
&mut self,
enum_ty: Ty<'tcx>,
variant_index: &VariantIdx,
tag: &Scalar,
niche_variants: &RangeInclusive<VariantIdx>,
niche_start: &u128,
) -> Expr {
let discr_ty = self.codegen_enum_discr_typ(enum_ty);
let discr_ty = self.codegen_ty(discr_ty);
let niche_value = variant_index.as_u32() - niche_variants.start().as_u32();
let niche_value = (niche_value as u128).wrapping_add(*niche_start);
if niche_value == 0 && matches!(tag.primitive(), Primitive::Pointer(_)) {
discr_ty.null()
} else {
Expr::int_constant(niche_value, discr_ty)
}
}

pub(crate) fn variant_min_offset(
&self,
variants: &IndexVec<VariantIdx, LayoutS<VariantIdx>>,
Expand Down

0 comments on commit 4e59e23

Please sign in to comment.