diff --git a/core/src/typecheck/mod.rs b/core/src/typecheck/mod.rs index a003386d68..7291847254 100644 --- a/core/src/typecheck/mod.rs +++ b/core/src/typecheck/mod.rs @@ -81,6 +81,7 @@ pub mod reporting; #[macro_use] pub mod mk_uniftype; pub mod eq; +pub mod subtyping; pub mod unif; use eq::{SimpleTermEnvironment, TermEnvironment}; @@ -90,6 +91,8 @@ use operation::{get_bop_type, get_nop_type, get_uop_type}; use pattern::{PatternTypeData, PatternTypes}; use unif::*; +use self::subtyping::SubsumedBy; + /// The max depth parameter used to limit the work performed when inferring the type of the stdlib. const INFER_RECORD_MAX_DEPTH: u8 = 4; @@ -2166,8 +2169,9 @@ fn check( | Term::Annotated(..) => { let inferred = infer(state, ctxt.clone(), visitor, rt)?; - // We call to `subsumption` to perform the switch from infer mode to checking mode. - subsumption(state, ctxt, inferred, ty) + // We apply the subsumption rule when switching from infer mode to checking mode. + inferred + .subsumed_by(ty, state, ctxt) .map_err(|err| err.into_typecheck_err(state, rt.pos)) } Term::Enum(id) => { @@ -2363,102 +2367,6 @@ fn check( } } -/// Change from inference mode to checking mode, and apply a potential subsumption rule. -/// -/// Currently, there is record/dictionary subtyping, if we are not in this case we fallback to perform -/// polymorphic type instantiation with unification variable on the left (on the inferred type), -/// and then simply performs unification (put differently, the subtyping relation when it is not -/// a record/dictionary subtyping is the equality -/// relation). -/// -/// The type instantiation corresponds to the zero-ary case of application in the current -/// specification (which is based on [A Quick Look at Impredicativity][quick-look], although we -/// currently don't support impredicative polymorphism). -/// -/// In the future, this function might implement a other non-trivial subsumption rule. -/// -/// [quick-look]: https://www.microsoft.com/en-us/research/uploads/prod/2020/01/quick-look-icfp20-fixed.pdf -pub fn subsumption( - state: &mut State, - mut ctxt: Context, - inferred: UnifType, - checked: UnifType, -) -> Result<(), UnifError> { - let inferred_inst = instantiate_foralls(state, &mut ctxt, inferred, ForallInst::UnifVar); - let checked = checked.into_root(state.table); - match (inferred_inst, checked) { - ( - UnifType::Concrete { - typ: TypeF::Record(rrows), - .. - }, - UnifType::Concrete { - typ: - TypeF::Dict { - type_fields, - flavour, - }, - var_levels_data, - }, - ) => { - for row in rrows.iter() { - match row { - GenericUnifRecordRowsIteratorItem::Row(a) => { - subsumption(state, ctxt.clone(), a.typ.clone(), *type_fields.clone())? - } - GenericUnifRecordRowsIteratorItem::TailUnifVar { id, .. } => - // We don't need to perform any variable level checks when unifying a free - // unification variable with a ground type - // We close the tail because there is no garanty that - // { a : Number, b : Number, _ : a?} <= { _ : Number} - { - state - .table - .assign_rrows(id, UnifRecordRows::concrete(RecordRowsF::Empty)) - } - GenericUnifRecordRowsIteratorItem::TailConstant(id) => { - let checked = UnifType::Concrete { - typ: TypeF::Dict { - type_fields: type_fields.clone(), - flavour, - }, - var_levels_data, - }; - Err(UnifError::WithConst { - var_kind: VarKindDiscriminant::RecordRows, - expected_const_id: id, - inferred: checked, - })? - } - _ => (), - } - } - Ok(()) - } - ( - UnifType::Concrete { - typ: TypeF::Array(a), - .. - }, - UnifType::Concrete { - typ: TypeF::Array(b), - .. - }, - ) - | ( - UnifType::Concrete { - typ: TypeF::Dict { type_fields: a, .. }, - .. - }, - UnifType::Concrete { - typ: TypeF::Dict { type_fields: b, .. }, - .. - }, - ) => subsumption(state, ctxt.clone(), *a, *b), - (inferred_inst, checked) => checked.unify(inferred_inst, state, &ctxt), - } -} - fn check_field( state: &mut State, ctxt: Context, @@ -2489,7 +2397,9 @@ fn check_field( field.value.as_ref(), )?; - subsumption(state, ctxt, inferred, ty).map_err(|err| err.into_typecheck_err(state, pos)) + inferred + .subsumed_by(ty, state, ctxt) + .map_err(|err| err.into_typecheck_err(state, pos)) } } @@ -2573,10 +2483,6 @@ fn infer_with_annot( // An empty value is a record field without definition. We don't check anything, and infer // its type to be either the first annotation defined if any, or `Dyn` otherwise. // We can only hit this case for record fields. - // - // TODO: we might have something to do with the visitor to clear the current metadata. It - // looks like it may be unduly attached to the next field definition, which is not - // critical, but still a bug. _ => { let inferred = annot .first() diff --git a/core/src/typecheck/subtyping.rs b/core/src/typecheck/subtyping.rs new file mode 100644 index 0000000000..bbcaa86669 --- /dev/null +++ b/core/src/typecheck/subtyping.rs @@ -0,0 +1,254 @@ +//! Type subsumption (subtyping) +//! +//! Subtyping is a relation between types that allows a value of one type to be used at a place +//! where another type is expected, because the value's actual type is subsumed by the expected +//! type. +//! +//! The subsumption rule is applied when from inference mode to checking mode, as customary in +//! bidirectional type checking. +//! +//! Currently, there is one core subtyping axiom: +//! +//! - Record / Dictionary : `{a1 : T1,...,an : Tn} <: {_ : U}` if for every n `Tn <: U` +//! +//! The subtyping relation is extended to a congruence on other type constructors in the obvious +//! way: +//! +//! - `Array T <: Array U` if `T <: U` +//! - `{_ : T} <: {_ : U}` if `T <: U` +//! - `{a1 : T1,...,an : Tn} <: {b1 : U1,...,bn : Un}` if for every n `Tn <: Un` +//! +//! In all other cases, we fallback to unification (although we instantiate polymorphic types as +//! needed before). That is, we try to apply reflexivity: `T <: U` if `T = U`. +//! +//! The type instantiation corresponds to the zero-ary case of application in the current +//! specification (which is based on [A Quick Look at Impredicativity][quick-look], although we +//! currently don't support impredicative polymorphism). +//! +//! [quick-look]: https://www.microsoft.com/en-us/research/uploads/prod/2020/01/quick-look-icfp20-fixed.pdf +use super::*; + +pub(super) trait SubsumedBy { + type Error; + + /// Checks if `self` is subsumed by `t2`, that is if `self <: t2`. Returns an error otherwise. + fn subsumed_by(self, t2: Self, state: &mut State, ctxt: Context) -> Result<(), Self::Error>; +} + +impl SubsumedBy for UnifType { + type Error = UnifError; + + fn subsumed_by( + self, + t2: Self, + state: &mut State, + mut ctxt: Context, + ) -> Result<(), Self::Error> { + let inferred = instantiate_foralls(state, &mut ctxt, self, ForallInst::UnifVar); + let checked = t2.into_root(state.table); + + match (inferred, checked) { + // {a1 : T1,...,an : Tn} <: {_ : U} if for every n `Tn <: U` + ( + UnifType::Concrete { + typ: TypeF::Record(rrows), + .. + }, + UnifType::Concrete { + typ: + TypeF::Dict { + type_fields, + flavour, + }, + var_levels_data, + }, + ) => { + for row in rrows.iter() { + match row { + GenericUnifRecordRowsIteratorItem::Row(a) => { + a.typ + .clone() + .subsumed_by(*type_fields.clone(), state, ctxt.clone())? + } + GenericUnifRecordRowsIteratorItem::TailUnifVar { id, .. } => + // We don't need to perform any variable level checks when unifying a free + // unification variable with a ground type + // We close the tail because there is no guarantee that + // { a : Number, b : Number, _ : a?} <= { _ : Number} + { + state + .table + .assign_rrows(id, UnifRecordRows::concrete(RecordRowsF::Empty)) + } + GenericUnifRecordRowsIteratorItem::TailConstant(id) => { + let checked = UnifType::Concrete { + typ: TypeF::Dict { + type_fields: type_fields.clone(), + flavour, + }, + var_levels_data, + }; + Err(UnifError::WithConst { + var_kind: VarKindDiscriminant::RecordRows, + expected_const_id: id, + inferred: checked, + })? + } + _ => (), + } + } + Ok(()) + } + // Array T <: Array U if T <: U + ( + UnifType::Concrete { + typ: TypeF::Array(a), + .. + }, + UnifType::Concrete { + typ: TypeF::Array(b), + .. + }, + ) + // Dict T <: Dict U if T <: U + | ( + UnifType::Concrete { + typ: TypeF::Dict { type_fields: a, .. }, + .. + }, + UnifType::Concrete { + typ: TypeF::Dict { type_fields: b, .. }, + .. + }, + ) => a.subsumed_by(*b, state, ctxt), + // {a1 : T1,...,an : Tn} <: {b1 : U1,...,bn : Un} if for every n `Tn <: Un` + ( + UnifType::Concrete { + typ: TypeF::Record(rrows1), + .. + }, + UnifType::Concrete { + typ: TypeF::Record(rrows2), + .. + }, + ) => rrows1 + .clone() + .subsumed_by(rrows2.clone(), state, ctxt) + .map_err(|err| err.into_unif_err(mk_uty_record!(;rrows2), mk_uty_record!(;rrows1))), + // T <: U if T = U + (inferred, checked) => checked.unify(inferred, state, &ctxt), + } + } +} + +impl SubsumedBy for UnifRecordRows { + type Error = RowUnifError; + + fn subsumed_by(self, t2: Self, state: &mut State, ctxt: Context) -> Result<(), Self::Error> { + // This code is almost taken verbatim fro `unify`, but where some recursive calls are + // changed to be `subsumed_by` instead of `unify`. We can surely factorize both into a + // generic function, but this is left for future work. + let inferred = self.into_root(state.table); + let checked = t2.into_root(state.table); + + match (inferred, checked) { + ( + UnifRecordRows::Concrete { rrows: rrows1, .. }, + UnifRecordRows::Concrete { + rrows: rrows2, + var_levels_data: levels2, + }, + ) => match (rrows1, rrows2) { + (RecordRowsF::Extend { row, tail }, rrows2 @ RecordRowsF::Extend { .. }) => { + let urrows2 = UnifRecordRows::Concrete { + rrows: rrows2, + var_levels_data: levels2, + }; + let (ty_res, urrows_without_ty_res) = urrows2 + .remove_row(&row.id, &row.typ, state, ctxt.var_level) + .map_err(|err| match err { + RemoveRowError::Missing => RowUnifError::MissingRow(row.id), + RemoveRowError::Conflict => { + RowUnifError::RecordRowConflict(row.clone()) + } + })?; + if let RemoveRowResult::Extracted(ty) = ty_res { + row.typ + .subsumed_by(ty, state, ctxt.clone()) + .map_err(|err| RowUnifError::RecordRowMismatch { + id: row.id, + cause: Box::new(err), + })?; + } + tail.subsumed_by(urrows_without_ty_res, state, ctxt) + } + (RecordRowsF::TailVar(id), _) | (_, RecordRowsF::TailVar(id)) => { + Err(RowUnifError::UnboundTypeVariable(id)) + } + (RecordRowsF::Empty, RecordRowsF::Empty) + | (RecordRowsF::TailDyn, RecordRowsF::TailDyn) => Ok(()), + (RecordRowsF::Empty, RecordRowsF::TailDyn) + | (RecordRowsF::TailDyn, RecordRowsF::Empty) => Err(RowUnifError::ExtraDynTail), + ( + RecordRowsF::Empty, + RecordRowsF::Extend { + row: UnifRecordRow { id, .. }, + .. + }, + ) + | ( + RecordRowsF::TailDyn, + RecordRowsF::Extend { + row: UnifRecordRow { id, .. }, + .. + }, + ) => Err(RowUnifError::MissingRow(id)), + ( + RecordRowsF::Extend { + row: UnifRecordRow { id, .. }, + .. + }, + RecordRowsF::TailDyn, + ) + | ( + RecordRowsF::Extend { + row: UnifRecordRow { id, .. }, + .. + }, + RecordRowsF::Empty, + ) => Err(RowUnifError::ExtraRow(id)), + }, + (UnifRecordRows::UnifVar { id, .. }, urrows) + | (urrows, UnifRecordRows::UnifVar { id, .. }) => { + if let UnifRecordRows::Constant(cst_id) = urrows { + let constant_level = state.table.get_rrows_level(cst_id); + state.table.force_rrows_updates(constant_level); + if state.table.get_rrows_level(id) < constant_level { + return Err(RowUnifError::VarLevelMismatch { + constant_id: cst_id, + var_kind: VarKindDiscriminant::RecordRows, + }); + } + } + urrows.propagate_constrs(state.constr, id)?; + state.table.assign_rrows(id, urrows); + Ok(()) + } + (UnifRecordRows::Constant(i1), UnifRecordRows::Constant(i2)) if i1 == i2 => Ok(()), + (UnifRecordRows::Constant(i1), UnifRecordRows::Constant(i2)) => { + Err(RowUnifError::ConstMismatch { + var_kind: VarKindDiscriminant::RecordRows, + expected_const_id: i2, + inferred_const_id: i1, + }) + } + (urrows, UnifRecordRows::Constant(i)) | (UnifRecordRows::Constant(i), urrows) => { + Err(RowUnifError::WithConst { + var_kind: VarKindDiscriminant::RecordRows, + expected_const_id: i, + inferred: UnifType::concrete(TypeF::Record(urrows)), + }) + } + } + } +} diff --git a/core/src/typecheck/unif.rs b/core/src/typecheck/unif.rs index 2ec4cb6f4e..627ea06a5b 100644 --- a/core/src/typecheck/unif.rs +++ b/core/src/typecheck/unif.rs @@ -1043,7 +1043,7 @@ impl UnifTable { /// the map to be rather sparse, we use a `HashMap` instead of a `Vec`. pub type RowConstrs = HashMap>; -trait PropagateConstrs { +pub(super) trait PropagateConstrs { /// Check that unifying a variable with a type doesn't violate rows constraints, and update the /// row constraints of the unified type accordingly if needed. /// @@ -1599,7 +1599,7 @@ impl Unify for UnifRecordRows { } #[derive(Clone, Copy, Debug)] -enum RemoveRowError { +pub(super) enum RemoveRowError { // The row to add was missing and the row type was closed (no free unification variable in tail // position). Missing, @@ -1613,7 +1613,7 @@ pub enum RemoveRowResult { Extended, } -trait RemoveRow: Sized { +pub(super) trait RemoveRow: Sized { /// The row data minus the identifier. type RowContent: Clone; diff --git a/core/tests/integration/inputs/typecheck/record_subtyping.ncl b/core/tests/integration/inputs/typecheck/record_subtyping.ncl new file mode 100644 index 0000000000..8915cf85e1 --- /dev/null +++ b/core/tests/integration/inputs/typecheck/record_subtyping.ncl @@ -0,0 +1,6 @@ +# test.type = 'pass' +let test : _ = + let test_func : {a : {_ : Number}} -> {a : {_ : Number}} = fun a => a in + test_func {a = {foo = 5}} +in +true diff --git a/core/tests/integration/inputs/typecheck/record_subtyping_multiple_components.ncl b/core/tests/integration/inputs/typecheck/record_subtyping_multiple_components.ncl new file mode 100644 index 0000000000..3940f45da4 --- /dev/null +++ b/core/tests/integration/inputs/typecheck/record_subtyping_multiple_components.ncl @@ -0,0 +1,6 @@ +# test.type = 'pass' +let test : _ = + let test_func : {a : {_ : Number}, b : {_ : String}} -> {a : {_ : Number}, b : {_ : String}} = fun a => a in + test_func {a = {foo = 5}, b = {a = "test"}} +in +true diff --git a/core/tests/integration/inputs/typecheck/record_subtyping_with_tail.ncl b/core/tests/integration/inputs/typecheck/record_subtyping_with_tail.ncl new file mode 100644 index 0000000000..a24d95160c --- /dev/null +++ b/core/tests/integration/inputs/typecheck/record_subtyping_with_tail.ncl @@ -0,0 +1,6 @@ +# test.type = 'pass' +let test : _ = + let test_func : forall b. {a : {_ : Number}; b} -> {a : {_ : Number}; b} = fun a => a in + test_func {a = {foo = 5}, b = 5} +in +true diff --git a/doc/manual/typing.md b/doc/manual/typing.md index 68894b21ed..b964e2c40d 100644 --- a/doc/manual/typing.md +++ b/doc/manual/typing.md @@ -330,6 +330,24 @@ Subtyping extends to type constructors in the following way: Here, `{_ : {a : Number}}` is accepted where `{_ : {_ : Number}}` is expected, because `{a : Number} <: { _ : Number }`. +- **Record**: `{a1 : T1, ..., an : Tn} <: {a1 : U1, ..., an : Un}` if for each + `i`, `Ti <: Ui` + + Example: + + ```nickel + let block : _ = + let record_of_records : {a: {b : Number}} = {a = {b = 5}} in + let inject_c_in_a : {a : {_ : Number}} -> {a : {_ : Number}} + = fun x => {a = std.record.insert "c" 5 (std.record.get "a" x)} + in + + inject_c_in_a record_of_records in + block + ``` + + Here, `{a : {b : Number}}` is accepted where `{a : {_ : Number}}` is expected, + because `{b : Number} <: { _ : Number }`. **Remark**: if you've used languages with subtyping before, you might expect the presence of a rule for function types, namely that `T -> U <: S -> V` if `S <: diff --git a/lsp/nls/tests/snapshots/main__lsp__nls__tests__inputs__hover_field_typed_block_regression_1574.ncl.snap b/lsp/nls/tests/snapshots/main__lsp__nls__tests__inputs__hover_field_typed_block_regression_1574.ncl.snap index 61ad23654f..2e9e52340c 100644 --- a/lsp/nls/tests/snapshots/main__lsp__nls__tests__inputs__hover_field_typed_block_regression_1574.ncl.snap +++ b/lsp/nls/tests/snapshots/main__lsp__nls__tests__inputs__hover_field_typed_block_regression_1574.ncl.snap @@ -2,7 +2,9 @@ source: lsp/nls/tests/main.rs expression: output --- -<0:24-0:37>[Applies a function to every element in the given array. That is, +<0:24-0:37>[```nickel +(_a -> Number) -> Array _a -> Array Number +```, Applies a function to every element in the given array. That is, `map f [ x1, x2, ..., xn ]` is `[ f x1, f x2, ..., f xn ]`. # Examples @@ -13,4 +15,3 @@ std.array.map (fun x => x + 1) [ 1, 2, 3 ] => ```, ```nickel forall a b. (a -> b) -> Array a -> Array b ```] -