From 758fb443b0577ef9a76591e6963e702dcf987d5b Mon Sep 17 00:00:00 2001 From: Nadrieril Date: Sat, 18 Nov 2023 03:52:54 +0100 Subject: [PATCH] Track place validity --- .../src/thir/pattern/check_match.rs | 124 ++++++++++++++++-- .../src/thir/pattern/usefulness.rs | 109 ++++++++++++--- 2 files changed, 204 insertions(+), 29 deletions(-) diff --git a/compiler/rustc_mir_build/src/thir/pattern/check_match.rs b/compiler/rustc_mir_build/src/thir/pattern/check_match.rs index ce788a289944..6d7c1f630f3d 100644 --- a/compiler/rustc_mir_build/src/thir/pattern/check_match.rs +++ b/compiler/rustc_mir_build/src/thir/pattern/check_match.rs @@ -42,7 +42,7 @@ pub(crate) fn check_match(tcx: TyCtxt<'_>, def_id: LocalDefId) -> Result<(), Err for param in thir.params.iter() { if let Some(box ref pattern) = param.pat { - visitor.check_binding_is_irrefutable(pattern, "function argument", None); + visitor.check_binding_is_irrefutable(pattern, "function argument", None, None); } } visitor.error @@ -254,10 +254,11 @@ impl<'thir, 'p, 'tcx> MatchVisitor<'thir, 'p, 'tcx> { self.with_lint_level(lint_level, |this| this.visit_land_rhs(&this.thir[value])) } ExprKind::Let { box ref pat, expr } => { + let expr = &self.thir()[expr]; self.with_let_source(LetSource::None, |this| { - this.visit_expr(&this.thir()[expr]); + this.visit_expr(expr); }); - Ok(Some((ex.span, self.is_let_irrefutable(pat)?))) + Ok(Some((ex.span, self.is_let_irrefutable(pat, Some(expr))?))) } _ => { self.with_let_source(LetSource::None, |this| { @@ -287,35 +288,114 @@ impl<'thir, 'p, 'tcx> MatchVisitor<'thir, 'p, 'tcx> { } } + /// Inspects the match scrutinee expression to determine whether the place it evaluates to may + /// hold invalid data. + fn is_known_valid_scrutinee(&self, scrutinee: &Expr<'tcx>) -> bool { + use ExprKind::*; + match &scrutinee.kind { + // Both pointers and references can validly point to a place with invalid data. + Deref { .. } => false, + // Inherit validity of the parent place, unless the parent is an union. + Field { lhs, .. } => { + let lhs = &self.thir()[*lhs]; + match lhs.ty.kind() { + ty::Adt(def, _) if def.is_union() => false, + _ => self.is_known_valid_scrutinee(lhs), + } + } + // Essentially a field access. + Index { lhs, .. } => { + let lhs = &self.thir()[*lhs]; + self.is_known_valid_scrutinee(lhs) + } + + // No-op. + Scope { value, .. } => self.is_known_valid_scrutinee(&self.thir()[*value]), + + // Casts don't cause a load. + NeverToAny { source } + | Cast { source } + | Use { source } + | PointerCoercion { source, .. } + | PlaceTypeAscription { source, .. } + | ValueTypeAscription { source, .. } => { + self.is_known_valid_scrutinee(&self.thir()[*source]) + } + + // These diverge. + Become { .. } | Break { .. } | Continue { .. } | Return { .. } => true, + + // These are statements that evaluate to `()`. + Assign { .. } | AssignOp { .. } | InlineAsm { .. } | Let { .. } => true, + + // These evaluate to a value. + AddressOf { .. } + | Adt { .. } + | Array { .. } + | Binary { .. } + | Block { .. } + | Borrow { .. } + | Box { .. } + | Call { .. } + | Closure { .. } + | ConstBlock { .. } + | ConstParam { .. } + | If { .. } + | Literal { .. } + | LogicalOp { .. } + | Loop { .. } + | Match { .. } + | NamedConst { .. } + | NonHirLiteral { .. } + | OffsetOf { .. } + | Repeat { .. } + | StaticRef { .. } + | ThreadLocalRef { .. } + | Tuple { .. } + | Unary { .. } + | UpvarRef { .. } + | VarRef { .. } + | ZstLiteral { .. } + | Yield { .. } => true, + } + } + fn new_cx( &self, refutability: RefutableFlag, - match_span: Option, + whole_match_span: Option, + scrutinee: Option<&Expr<'tcx>>, scrut_span: Span, ) -> MatchCheckCtxt<'p, 'tcx> { let refutable = match refutability { Irrefutable => false, Refutable => true, }; + // If we don't have a scrutinee we're either a function parameter or a `let x;`. Both cases + // require validity. + let known_valid_scrutinee = + scrutinee.map(|scrut| self.is_known_valid_scrutinee(scrut)).unwrap_or(true); MatchCheckCtxt { tcx: self.tcx, param_env: self.param_env, module: self.tcx.parent_module(self.lint_level).to_def_id(), pattern_arena: self.pattern_arena, match_lint_level: self.lint_level, - match_span, + whole_match_span, scrut_span, refutable, + known_valid_scrutinee, } } #[instrument(level = "trace", skip(self))] fn check_let(&mut self, pat: &Pat<'tcx>, scrutinee: Option, span: Span) { assert!(self.let_source != LetSource::None); + let scrut = scrutinee.map(|id| &self.thir[id]); if let LetSource::PlainLet = self.let_source { - self.check_binding_is_irrefutable(pat, "local binding", Some(span)) + self.check_binding_is_irrefutable(pat, "local binding", scrut, Some(span)) } else { - let Ok(refutability) = self.is_let_irrefutable(pat) else { return }; + let Ok(refutability) = self.is_let_irrefutable(pat, scrut) else { return }; if matches!(refutability, Irrefutable) { report_irrefutable_let_patterns( self.tcx, @@ -336,7 +416,7 @@ impl<'thir, 'p, 'tcx> MatchVisitor<'thir, 'p, 'tcx> { expr_span: Span, ) { let scrut = &self.thir[scrut]; - let cx = self.new_cx(Refutable, Some(expr_span), scrut.span); + let cx = self.new_cx(Refutable, Some(expr_span), Some(scrut), scrut.span); let mut tarms = Vec::with_capacity(arms.len()); for &arm in arms { @@ -377,7 +457,12 @@ impl<'thir, 'p, 'tcx> MatchVisitor<'thir, 'p, 'tcx> { debug_assert_eq!(pat.span.desugaring_kind(), Some(DesugaringKind::ForLoop)); let PatKind::Variant { ref subpatterns, .. } = pat.kind else { bug!() }; let [pat_field] = &subpatterns[..] else { bug!() }; - self.check_binding_is_irrefutable(&pat_field.pattern, "`for` loop binding", None); + self.check_binding_is_irrefutable( + &pat_field.pattern, + "`for` loop binding", + None, + None, + ); } else { self.error = Err(report_non_exhaustive_match( &cx, self.thir, scrut_ty, scrut.span, witnesses, arms, expr_span, @@ -457,16 +542,21 @@ impl<'thir, 'p, 'tcx> MatchVisitor<'thir, 'p, 'tcx> { &mut self, pat: &Pat<'tcx>, refutability: RefutableFlag, + scrut: Option<&Expr<'tcx>>, ) -> Result<(MatchCheckCtxt<'p, 'tcx>, UsefulnessReport<'p, 'tcx>), ErrorGuaranteed> { - let cx = self.new_cx(refutability, None, pat.span); + let cx = self.new_cx(refutability, None, scrut, pat.span); let pat = self.lower_pattern(&cx, pat)?; let arms = [MatchArm { pat, hir_id: self.lint_level, has_guard: false }]; let report = compute_match_usefulness(&cx, &arms, pat.ty()); Ok((cx, report)) } - fn is_let_irrefutable(&mut self, pat: &Pat<'tcx>) -> Result { - let (cx, report) = self.analyze_binding(pat, Refutable)?; + fn is_let_irrefutable( + &mut self, + pat: &Pat<'tcx>, + scrut: Option<&Expr<'tcx>>, + ) -> Result { + let (cx, report) = self.analyze_binding(pat, Refutable, scrut)?; // Report if the pattern is unreachable, which can only occur when the type is uninhabited. // This also reports unreachable sub-patterns. report_arm_reachability(&cx, &report); @@ -476,10 +566,16 @@ impl<'thir, 'p, 'tcx> MatchVisitor<'thir, 'p, 'tcx> { } #[instrument(level = "trace", skip(self))] - fn check_binding_is_irrefutable(&mut self, pat: &Pat<'tcx>, origin: &str, sp: Option) { + fn check_binding_is_irrefutable( + &mut self, + pat: &Pat<'tcx>, + origin: &str, + scrut: Option<&Expr<'tcx>>, + sp: Option, + ) { let pattern_ty = pat.ty; - let Ok((cx, report)) = self.analyze_binding(pat, Irrefutable) else { return }; + let Ok((cx, report)) = self.analyze_binding(pat, Irrefutable, scrut) else { return }; let witnesses = report.non_exhaustiveness_witnesses; if witnesses.is_empty() { // The pattern is irrefutable. diff --git a/compiler/rustc_mir_build/src/thir/pattern/usefulness.rs b/compiler/rustc_mir_build/src/thir/pattern/usefulness.rs index 6c2f633709de..f0fc50547dd1 100644 --- a/compiler/rustc_mir_build/src/thir/pattern/usefulness.rs +++ b/compiler/rustc_mir_build/src/thir/pattern/usefulness.rs @@ -550,6 +550,7 @@ //! I (Nadrieril) prefer to put new tests in `ui/pattern/usefulness` unless there's a specific //! reason not to, for example if they crucially depend on a particular feature like `or_patterns`. +use self::ValidityConstraint::*; use super::deconstruct_pat::{ Constructor, ConstructorSet, DeconstructedPat, IntRange, MaybeInfiniteInt, SplitConstructorSet, WitnessPat, @@ -586,11 +587,14 @@ pub(crate) struct MatchCheckCtxt<'p, 'tcx> { /// Lint level at the match. pub(crate) match_lint_level: HirId, /// The span of the whole match, if applicable. - pub(crate) match_span: Option, + pub(crate) whole_match_span: Option, /// Span of the scrutinee. pub(crate) scrut_span: Span, /// Only produce `NON_EXHAUSTIVE_OMITTED_PATTERNS` lint on refutable patterns. pub(crate) refutable: bool, + /// Whether the data at the scrutinee is known to be valid. This is false if the scrutinee comes + /// from a union field, a pointer deref, or a reference deref (pending opsem decisions). + pub(crate) known_valid_scrutinee: bool, } impl<'a, 'tcx> MatchCheckCtxt<'a, 'tcx> { @@ -619,12 +623,63 @@ pub(super) struct PatCtxt<'a, 'p, 'tcx> { pub(super) is_top_level: bool, } +impl<'a, 'p, 'tcx> PatCtxt<'a, 'p, 'tcx> { + /// A `PatCtxt` when code other than `is_useful` needs one. + fn new_dummy(cx: &'a MatchCheckCtxt<'p, 'tcx>, ty: Ty<'tcx>) -> Self { + PatCtxt { cx, ty, is_top_level: false } + } +} + impl<'a, 'p, 'tcx> fmt::Debug for PatCtxt<'a, 'p, 'tcx> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("PatCtxt").field("ty", &self.ty).finish() } } +/// In the matrix, tracks whether a given place (aka column) is known to contain a valid value or +/// not. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub(super) enum ValidityConstraint { + ValidOnly, + MaybeInvalid, +} + +impl ValidityConstraint { + pub(super) fn from_bool(is_valid_only: bool) -> Self { + if is_valid_only { ValidOnly } else { MaybeInvalid } + } + + /// If the place has validity given by `self` and we read that the value at the place has + /// constructor `ctor`, this computes what we can assume about the validity of the constructor + /// fields. + /// + /// Pending further opsem decisions, the current behavior is: validity is preserved, except + /// under `&` where validity is reset to `MaybeInvalid`. + pub(super) fn specialize<'tcx>( + self, + pcx: &PatCtxt<'_, '_, 'tcx>, + ctor: &Constructor<'tcx>, + ) -> Self { + // We preserve validity except when we go under a reference. + if matches!(ctor, Constructor::Single) && matches!(pcx.ty.kind(), ty::Ref(..)) { + // Validity of `x: &T` does not imply validity of `*x: T`. + MaybeInvalid + } else { + self + } + } +} + +impl fmt::Display for ValidityConstraint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match self { + ValidOnly => "✓", + MaybeInvalid => "?", + }; + write!(f, "{s}") + } +} + /// Represents a pattern-tuple under investigation. #[derive(Clone)] struct PatStack<'p, 'tcx> { @@ -769,10 +824,15 @@ impl<'p, 'tcx> fmt::Debug for MatrixRow<'p, 'tcx> { /// the matrix will correspond to `scrutinee.0.Some.0` and the second column to `scrutinee.1`. #[derive(Clone)] struct Matrix<'p, 'tcx> { + /// Vector of rows. The rows must form a rectangular 2D array. Moreover, all the patterns of + /// each column must have the same type. Each column corresponds to a place within the + /// scrutinee. rows: Vec>, /// Stores an extra fictitious row full of wildcards. Mostly used to keep track of the type of /// each column. This must obey the same invariants as the real rows. wildcard_row: PatStack<'p, 'tcx>, + /// Track for each column/place whether it contains a known valid value. + place_validity: SmallVec<[ValidityConstraint; 2]>, } impl<'p, 'tcx> Matrix<'p, 'tcx> { @@ -794,13 +854,15 @@ impl<'p, 'tcx> Matrix<'p, 'tcx> { cx: &MatchCheckCtxt<'p, 'tcx>, iter: impl Iterator>, scrut_ty: Ty<'tcx>, + scrut_validity: ValidityConstraint, ) -> Self where 'p: 'a, { let wild_pattern = cx.pattern_arena.alloc(DeconstructedPat::wildcard(scrut_ty, DUMMY_SP)); let wildcard_row = PatStack::from_pattern(wild_pattern); - let mut matrix = Matrix { rows: vec![], wildcard_row }; + let mut matrix = + Matrix { rows: vec![], wildcard_row, place_validity: smallvec![scrut_validity] }; for (row_id, arm) in iter.enumerate() { let v = MatrixRow { pats: PatStack::from_pattern(arm.pat), @@ -864,7 +926,12 @@ impl<'p, 'tcx> Matrix<'p, 'tcx> { ctor: &Constructor<'tcx>, ) -> Matrix<'p, 'tcx> { let wildcard_row = self.wildcard_row.pop_head_constructor(pcx, ctor); - let mut matrix = Matrix { rows: vec![], wildcard_row }; + let new_validity = self.place_validity[0].specialize(pcx, ctor); + let new_place_validity = std::iter::repeat(new_validity) + .take(ctor.arity(pcx)) + .chain(self.place_validity[1..].iter().copied()) + .collect(); + let mut matrix = Matrix { rows: vec![], wildcard_row, place_validity: new_place_validity }; for (i, row) in self.rows().enumerate() { if ctor.is_covered_by(pcx, row.head().ctor()) { let new_row = row.pop_head_constructor(pcx, ctor, i); @@ -883,27 +950,38 @@ impl<'p, 'tcx> Matrix<'p, 'tcx> { /// + true + [Second(true)] + /// + false + [_] + /// + _ + [_, _, tail @ ..] + +/// | ✓ | ? | // column validity /// ``` impl<'p, 'tcx> fmt::Debug for Matrix<'p, 'tcx> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "\n")?; - let Matrix { rows, .. } = self; - let pretty_printed_matrix: Vec> = - rows.iter().map(|row| row.iter().map(|pat| format!("{pat:?}")).collect()).collect(); + let mut pretty_printed_matrix: Vec> = self + .rows + .iter() + .map(|row| row.iter().map(|pat| format!("{pat:?}")).collect()) + .collect(); + pretty_printed_matrix + .push(self.place_validity.iter().map(|validity| format!("{validity}")).collect()); - let column_count = rows.iter().map(|row| row.len()).next().unwrap_or(0); - assert!(rows.iter().all(|row| row.len() == column_count)); + let column_count = self.column_count(); + assert!(self.rows.iter().all(|row| row.len() == column_count)); + assert!(self.place_validity.len() == column_count); let column_widths: Vec = (0..column_count) .map(|col| pretty_printed_matrix.iter().map(|row| row[col].len()).max().unwrap_or(0)) .collect(); - for row in pretty_printed_matrix { - write!(f, "+")?; + for (row_i, row) in pretty_printed_matrix.into_iter().enumerate() { + let is_validity_row = row_i == self.rows.len(); + let sep = if is_validity_row { "|" } else { "+" }; + write!(f, "{sep}")?; for (column, pat_str) in row.into_iter().enumerate() { write!(f, " ")?; write!(f, "{:1$}", pat_str, column_widths[column])?; - write!(f, " +")?; + write!(f, " {sep}")?; + } + if is_validity_row { + write!(f, " // column validity")?; } write!(f, "\n")?; } @@ -1293,7 +1371,7 @@ fn collect_nonexhaustive_missing_variants<'p, 'tcx>( let Some(ty) = column.head_ty() else { return Vec::new(); }; - let pcx = &PatCtxt { cx, ty, is_top_level: false }; + let pcx = &PatCtxt::new_dummy(cx, ty); let set = column.analyze_ctors(pcx); if set.present.is_empty() { @@ -1342,7 +1420,7 @@ fn lint_overlapping_range_endpoints<'p, 'tcx>( let Some(ty) = column.head_ty() else { return; }; - let pcx = &PatCtxt { cx, ty, is_top_level: false }; + let pcx = &PatCtxt::new_dummy(cx, ty); let set = column.analyze_ctors(pcx); @@ -1445,7 +1523,8 @@ pub(crate) fn compute_match_usefulness<'p, 'tcx>( arms: &[MatchArm<'p, 'tcx>], scrut_ty: Ty<'tcx>, ) -> UsefulnessReport<'p, 'tcx> { - let mut matrix = Matrix::new(cx, arms.iter(), scrut_ty); + let scrut_validity = ValidityConstraint::from_bool(cx.known_valid_scrutinee); + let mut matrix = Matrix::new(cx, arms.iter(), scrut_ty, scrut_validity); let non_exhaustiveness_witnesses = compute_exhaustiveness_and_usefulness(cx, &mut matrix, true); let non_exhaustiveness_witnesses: Vec<_> = non_exhaustiveness_witnesses.single_column(); @@ -1502,7 +1581,7 @@ pub(crate) fn compute_match_usefulness<'p, 'tcx>( if !matches!(lint_level, rustc_session::lint::Level::Allow) { let decorator = NonExhaustiveOmittedPatternLintOnArm { lint_span: lint_level_source.span(), - suggest_lint_on_match: cx.match_span.map(|span| span.shrink_to_lo()), + suggest_lint_on_match: cx.whole_match_span.map(|span| span.shrink_to_lo()), lint_level: lint_level.as_str(), lint_name: "non_exhaustive_omitted_patterns", };