diff --git a/chalk-integration/src/lowering/program_lowerer.rs b/chalk-integration/src/lowering/program_lowerer.rs index 1f39cd9bf27..0a7a192d2c3 100644 --- a/chalk-integration/src/lowering/program_lowerer.rs +++ b/chalk-integration/src/lowering/program_lowerer.rs @@ -1,7 +1,8 @@ use chalk_ir::cast::Cast; use chalk_ir::{ self, AdtId, AssocTypeId, BoundVar, ClosureId, DebruijnIndex, FnDefId, ForeignDefId, - GeneratorId, ImplId, OpaqueTyId, TraitId, TyVariableKind, VariableKinds, + GeneratorId, ImplId, OpaqueTyId, ProgramClauseData, ProgramClauseImplication, TraitId, + TyVariableKind, VariableKinds, }; use chalk_parse::ast::*; use chalk_solve::rust_ir::{ @@ -160,6 +161,10 @@ impl ProgramLowerer { let mut hidden_opaque_types = BTreeMap::new(); let mut custom_clauses = Vec::new(); + // We can't check which impls are closure overrides until we've accumulated well-known + // traits + let mut possible_closure_overrides = Vec::new(); + for (item, &raw_id) in program.items.iter().zip(raw_ids) { let empty_env = Env { adt_ids: &self.adt_ids, @@ -345,6 +350,13 @@ impl ProgramLowerer { }), ); } + if let Some(GenericArg::Ty(Ty::Id { name } | Ty::Apply { name, .. })) = + impl_defn.trait_ref.args.get(0) + { + if self.closure_ids.contains_key(&name.str) { + possible_closure_overrides.push(impl_datum); + } + } } Item::Clause(ref clause) => { custom_clauses.extend(clause.lower(&empty_env)?); @@ -467,6 +479,33 @@ impl ProgramLowerer { } } + for imp in possible_closure_overrides { + if [ + chalk_solve::rust_ir::WellKnownTrait::FnOnce, + chalk_solve::rust_ir::WellKnownTrait::FnMut, + chalk_solve::rust_ir::WellKnownTrait::Fn, + ] + .iter() + .filter_map(|t| well_known_traits.get(t)) + .any(|id| id == &imp.binders.skip_binders().trait_ref.trait_id) + { + custom_clauses.push( + ProgramClauseData(chalk_ir::Binders::new( + imp.binders.binders.clone(), + ProgramClauseImplication { + consequence: chalk_ir::DomainGoal::LocalImplAllowed( + imp.binders.skip_binders().trait_ref.clone(), + ), + conditions: chalk_ir::Goals::empty(&ChalkIr), + constraints: chalk_ir::Constraints::empty(&ChalkIr), + priority: chalk_ir::ClausePriority::Low, + }, + )) + .intern(&ChalkIr), + ); + } + } + Ok(LoweredProgram { adt_ids: self.adt_ids, fn_def_ids: self.fn_def_ids, diff --git a/chalk-solve/src/clauses/builtin_traits/fn_family.rs b/chalk-solve/src/clauses/builtin_traits/fn_family.rs index 4645b139125..26029deb358 100644 --- a/chalk-solve/src/clauses/builtin_traits/fn_family.rs +++ b/chalk-solve/src/clauses/builtin_traits/fn_family.rs @@ -3,8 +3,8 @@ use crate::rust_ir::{ClosureKind, FnDefInputsAndOutputDatum, WellKnownTrait}; use crate::{Interner, RustIrDatabase, TraitRef}; use chalk_ir::cast::Cast; use chalk_ir::{ - AliasTy, Binders, Floundered, Normalize, ProjectionTy, Safety, Substitution, TraitId, Ty, - TyKind, + AliasTy, Binders, DomainGoal, Floundered, Normalize, ProjectionTy, Safety, Substitution, + TraitId, Ty, TyKind, }; fn push_clauses( @@ -110,6 +110,21 @@ pub fn add_fn_trait_program_clauses( Ok(()) } TyKind::Closure(closure_id, substitution) => { + for custom in db.custom_clauses() { + if let DomainGoal::LocalImplAllowed(tr) = + &custom.data(interner).0.skip_binders().consequence + { + if tr.trait_id == db.well_known_trait_id(WellKnownTrait::FnOnce).unwrap() { + if let TyKind::Closure(cl_id, _) = + tr.self_type_parameter(interner).data(interner).kind + { + if cl_id == *closure_id { + return Ok(()); + } + } + } + } + } let closure_kind = db.closure_kind(*closure_id, &substitution); let trait_matches = match (well_known, closure_kind) { (WellKnownTrait::Fn, ClosureKind::Fn) => true, diff --git a/chalk-solve/src/wf.rs b/chalk-solve/src/wf.rs index 3a610bdde8f..17e29aa0e52 100644 --- a/chalk-solve/src/wf.rs +++ b/chalk-solve/src/wf.rs @@ -421,11 +421,27 @@ where ) } WellKnownTrait::Clone | WellKnownTrait::Unpin => true, + // Manual implementations are only allowed during testing + WellKnownTrait::Fn | WellKnownTrait::FnOnce | WellKnownTrait::FnMut => { + let interner = self.db.interner(); + self.db.custom_clauses().iter().any(|custom| { + if let DomainGoal::LocalImplAllowed(tr) = + &custom.data(interner).0.skip_binders().consequence + { + tr.trait_id == self.db.well_known_trait_id(well_known).unwrap() + && tr.self_type_parameter(interner) + == impl_datum + .binders + .skip_binders() + .trait_ref + .self_type_parameter(interner) + } else { + false + } + }) + } // You can't add a manual implementation for the following traits: - WellKnownTrait::Fn - | WellKnownTrait::FnOnce - | WellKnownTrait::FnMut - | WellKnownTrait::Unsize + WellKnownTrait::Unsize | WellKnownTrait::Sized | WellKnownTrait::DiscriminantKind | WellKnownTrait::Generator => false, diff --git a/tests/test/closures.rs b/tests/test/closures.rs index 37fe647fb07..b67229b95ed 100644 --- a/tests/test/closures.rs +++ b/tests/test/closures.rs @@ -280,3 +280,120 @@ fn closure_implements_fn_traits() { } } } + +#[test] +fn can_override_closure_traits() { + test! { + program { + #[lang(fn_once)] + trait FnOnce { + type Output; + } + + #[lang(fn_mut)] + trait FnMut where Self: FnOnce { } + + #[lang(fn)] + trait Fn where Self: FnMut { } + + closure foo(&self,) {} + closure bar(&self,) {} + closure baz(&mut self,) {} + + impl FnOnce<()> for foo { + type Output = i32; + } + impl FnMut<()> for foo {} + impl Fn<()> for foo {} + + impl FnOnce for bar { + type Output = T; + } + impl FnMut for bar {} + + impl FnOnce for baz { + type Output = i32; + } + } + // All 3 traits implemented + goal { + foo: Fn<()> + } yields { + "Unique" + } + goal { + foo: FnMut<()> + } yields { + "Unique" + } + goal { + foo: FnOnce<()> + } yields { + "Unique" + } + goal { + Normalize(>::Output -> i32) + } yields { + "Unique" + } + // and the impl is not wider than expected + goal { + foo: FnOnce + } yields { + "No possible solution" + } + + // Do not implement `Fn` on `&self` closure if there's an override + goal { + bar: Fn<()> + } yields { + "No possible solution" + } + goal { + bar: FnMut<()> + } yields { + "Unique" + } + goal { + bar: FnOnce<()> + } yields { + "Unique" + } + // also the generic impl still does something reasonable + goal { + exists { + Normalize(>::Output -> T) + } + } yields { + "Unique; for { substitution [?0 := ^0.0], lifetime constraints [] }" + } + goal { + exists { + Normalize(>::Output -> ()) + } + } yields { + "Unique; substitution [?0 := 0], lifetime constraints []" + } + + // Do not implement `Fn` or 'FnMut' on `&mut self` closure if there's an override + goal { + exists { + baz: Fn + } + } yields { + "No possible solution" + } + goal { + exists { + baz: FnMut + } + } yields { + "No possible solution" + } + goal { + baz<()>: FnOnce<()> + } yields { + "Unique" + } + } +}