diff --git a/kani-compiler/src/codegen_cprover_gotoc/codegen/statement.rs b/kani-compiler/src/codegen_cprover_gotoc/codegen/statement.rs index dbae4073451b..d952dcfc50c8 100644 --- a/kani-compiler/src/codegen_cprover_gotoc/codegen/statement.rs +++ b/kani-compiler/src/codegen_cprover_gotoc/codegen/statement.rs @@ -16,7 +16,7 @@ use rustc_middle::ty; use rustc_middle::ty::layout::LayoutOf; use rustc_middle::ty::{Instance, InstanceDef, Ty}; use rustc_span::Span; -use rustc_target::abi::{FieldsShape, Primitive, TagEncoding, Variants}; +use rustc_target::abi::{FieldsShape, TagEncoding, Variants}; use tracing::{debug, debug_span, trace}; impl<'tcx> GotocCtx<'tcx> { @@ -88,28 +88,23 @@ impl<'tcx> GotocCtx<'tcx> { } TagEncoding::Niche { untagged_variant, niche_variants, niche_start } => { if untagged_variant != variant_index { - let offset = match &layout.fields { - FieldsShape::Arbitrary { offsets, .. } => offsets[0], - _ => unreachable!("niche encoding must have arbitrary fields"), - }; - let discr_ty = self.codegen_enum_discr_typ(pt); - let discr_ty = self.codegen_ty(discr_ty); - let niche_value = - variant_index.as_u32() - niche_variants.start().as_u32(); - let niche_value = (niche_value as u128).wrapping_add(*niche_start); - let value = if niche_value == 0 - && matches!(tag.primitive(), Primitive::Pointer(_)) - { - discr_ty.null() - } else { - Expr::int_constant(niche_value, discr_ty.clone()) - }; + let value = self.compute_enum_niche_value( + pt, + variant_index, + tag, + niche_variants, + niche_start, + ); let place = unwrap_or_return_codegen_unimplemented_stmt!( self, self.codegen_place(place) ) .goto_expr; - self.codegen_get_niche(place, offset, discr_ty) + let offset = match &layout.fields { + FieldsShape::Arbitrary { offsets, .. } => offsets[0], + _ => unreachable!("niche encoding must have arbitrary fields"), + }; + self.codegen_get_niche(place, offset, value.typ().clone()) .assign(value, location) } else { Stmt::skip(location) diff --git a/kani-compiler/src/codegen_cprover_gotoc/codegen/typ.rs b/kani-compiler/src/codegen_cprover_gotoc/codegen/typ.rs index 278f246161eb..1aa26df57cb2 100644 --- a/kani-compiler/src/codegen_cprover_gotoc/codegen/typ.rs +++ b/kani-compiler/src/codegen_cprover_gotoc/codegen/typ.rs @@ -20,11 +20,12 @@ use rustc_middle::ty::{ use rustc_middle::ty::{List, TypeFoldable}; use rustc_span::def_id::DefId; use rustc_target::abi::{ - Abi::Vector, FieldsShape, Integer, LayoutS, Primitive, Size, TagEncoding, TyAndLayout, + Abi::Vector, FieldsShape, Integer, LayoutS, Primitive, Scalar, Size, TagEncoding, TyAndLayout, VariantIdx, Variants, }; use rustc_target::spec::abi::Abi; use std::iter; +use std::ops::RangeInclusive; use tracing::{debug, trace, warn}; use ty::layout::HasParamEnv; @@ -1572,6 +1573,30 @@ impl<'tcx> GotocCtx<'tcx> { } } + /// Compute the discriminant expression for an enum that uses niche optimization. + /// + /// We follow the logic of the SSA and Cranelift back-ends in doing the computation: + /// https://github.com/rust-lang/rust/blob/master/compiler/rustc_codegen_ssa/src/mir/place.rs#L455 + /// https://github.com/rust-lang/rust/blob/d37e2f74afd131cda7b08520d37426bfbb622b5c/compiler/rustc_codegen_cranelift/src/discriminant.rs#L52 + pub fn compute_enum_niche_value( + &mut self, + enum_ty: Ty<'tcx>, + variant_index: &VariantIdx, + tag: &Scalar, + niche_variants: &RangeInclusive, + niche_start: &u128, + ) -> Expr { + let discr_ty = self.codegen_enum_discr_typ(enum_ty); + let discr_ty = self.codegen_ty(discr_ty); + let niche_value = variant_index.as_u32() - niche_variants.start().as_u32(); + let niche_value = (niche_value as u128).wrapping_add(*niche_start); + if niche_value == 0 && matches!(tag.primitive(), Primitive::Pointer(_)) { + discr_ty.null() + } else { + Expr::int_constant(niche_value, discr_ty.clone()) + } + } + pub(crate) fn variant_min_offset( &self, variants: &IndexVec>,