From 002f0fc92fea8f328314e58bea68825508161598 Mon Sep 17 00:00:00 2001 From: Yann Hamdaoui Date: Mon, 23 Dec 2024 12:04:57 +0100 Subject: [PATCH] [RFC007] Migrate the typechecker to the new AST - Part I (#2121) * Copy the code from typechecker, get rid of generic term env * Switch to bytecode::ast::typ::*, add a whole cargaison of lifetimes * Move Traverse in its own module * [WIP] Pass missing allocator to conv functions * Continue migration of the typechecker to the new AST * More typecheck conversion, typecheck::operation conversion * Introduce record-typechecking-related infrastructure * Some fixes related to term environment populating * End of first step for record typechecking * Implement Traverse for various new AST components * Fix various compiler errors * Reset unintentional changes to mainline typecheck module * Fix more compiler errors (remaining: tc::eq and tc::error) * Migrate type equality to the new AST * Fix compilation errors in bytecode::typecheck::error * Fix compilation errors in bytecode::typecheck::subtyping * Fix compilation errors in bytecode::typecheck::reporting * Fix more compiler errors * Update unif to use new TypeEq trait * Fix compilation errors and warnings * Fix typo in comment * Fix clippy errors * Fix cargo doc warnings * Fix more clippy warnings --- cli/src/doctest.rs | 6 +- core/src/bytecode/ast/mod.rs | 410 ++- core/src/bytecode/ast/pattern/mod.rs | 157 +- core/src/bytecode/ast/record.rs | 93 +- core/src/bytecode/ast/typ.rs | 172 +- core/src/bytecode/mod.rs | 1 + core/src/bytecode/typecheck/eq.rs | 674 +++++ core/src/bytecode/typecheck/error.rs | 555 ++++ core/src/bytecode/typecheck/mk_uniftype.rs | 167 ++ core/src/bytecode/typecheck/mod.rs | 2979 ++++++++++++++++++++ core/src/bytecode/typecheck/operation.rs | 633 +++++ core/src/bytecode/typecheck/pattern.rs | 655 +++++ core/src/bytecode/typecheck/record.rs | 543 ++++ core/src/bytecode/typecheck/reporting.rs | 269 ++ core/src/bytecode/typecheck/subtyping.rs | 264 ++ core/src/bytecode/typecheck/unif.rs | 1832 ++++++++++++ core/src/combine.rs | 2 +- core/src/identifier.rs | 20 +- core/src/label.rs | 1 + core/src/lib.rs | 1 + core/src/parser/uniterm.rs | 17 +- core/src/position.rs | 9 + core/src/repl/mod.rs | 48 +- core/src/term/mod.rs | 79 +- core/src/transform/import_resolution.rs | 11 +- core/src/transform/mod.rs | 3 +- core/src/transform/substitute_wildcards.rs | 3 +- core/src/traverse.rs | 139 + core/src/typ.rs | 5 +- core/src/typecheck/mod.rs | 3 +- lsp/nls/src/analysis.rs | 3 +- lsp/nls/src/cache.rs | 5 +- lsp/nls/src/position.rs | 3 +- lsp/nls/src/usage.rs | 3 +- 34 files changed, 9595 insertions(+), 170 deletions(-) create mode 100644 core/src/bytecode/typecheck/eq.rs create mode 100644 core/src/bytecode/typecheck/error.rs create mode 100644 core/src/bytecode/typecheck/mk_uniftype.rs create mode 100644 core/src/bytecode/typecheck/mod.rs create mode 100644 core/src/bytecode/typecheck/operation.rs create mode 100644 core/src/bytecode/typecheck/pattern.rs create mode 100644 core/src/bytecode/typecheck/record.rs create mode 100644 core/src/bytecode/typecheck/reporting.rs create mode 100644 core/src/bytecode/typecheck/subtyping.rs create mode 100644 core/src/bytecode/typecheck/unif.rs create mode 100644 core/src/traverse.rs diff --git a/cli/src/doctest.rs b/cli/src/doctest.rs index d6c2be37b2..9427b67c20 100644 --- a/cli/src/doctest.rs +++ b/cli/src/doctest.rs @@ -19,10 +19,8 @@ use nickel_lang_core::{ label::Label, match_sharedterm, mk_app, mk_fun, program::Program, - term::{ - make, record::RecordData, LabeledType, RichTerm, Term, Traverse as _, TraverseOrder, - TypeAnnotation, - }, + term::{make, record::RecordData, LabeledType, RichTerm, Term, TypeAnnotation}, + traverse::{Traverse as _, TraverseOrder}, typ::{Type, TypeF}, typecheck::TypecheckMode, }; diff --git a/core/src/bytecode/ast/mod.rs b/core/src/bytecode/ast/mod.rs index 092cb3fdec..3c584f4d87 100644 --- a/core/src/bytecode/ast/mod.rs +++ b/core/src/bytecode/ast/mod.rs @@ -23,6 +23,7 @@ use crate::{ error::ParseError, identifier::{Ident, LocIdent}, position::TermPos, + traverse::*, }; // For now, we reuse those types from the term module. @@ -52,7 +53,7 @@ use typ::*; /// Using an arena has another advantage: the data is allocated in the same order as the AST is /// built. This means that even if there are reference indirections, the children of a node are /// most likely close to the node itself in memory, which should be good for cache locality. -#[derive(Clone, Debug, PartialEq, Default)] +#[derive(Clone, Debug, PartialEq, Eq, Default)] pub enum Node<'ast> { /// The null value. #[default] @@ -154,7 +155,7 @@ pub enum Node<'ast> { } /// An individual binding in a let block. -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct LetBinding<'ast> { pub pattern: Pattern<'ast>, pub metadata: LetMetadata<'ast>, @@ -162,7 +163,7 @@ pub struct LetBinding<'ast> { } /// The metadata that can be attached to a let. It's a subset of [record::FieldMetadata]. -#[derive(Debug, Default, Clone, PartialEq)] +#[derive(Debug, Default, Clone, PartialEq, Eq)] pub struct LetMetadata<'ast> { pub doc: Option>, pub annotation: Annotation<'ast>, @@ -219,7 +220,7 @@ impl<'ast> Node<'ast> { /// //TODO: we don't expect to access the span much on the happy path. Should we add an indirection //through a reference? -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct Ast<'ast> { pub node: Node<'ast>, pub pos: TermPos, @@ -233,7 +234,7 @@ impl Ast<'_> { } /// A branch of a match expression. -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct MatchBranch<'ast> { /// The pattern on the left hand side of `=>`. pub pattern: Pattern<'ast>, @@ -245,7 +246,7 @@ pub struct MatchBranch<'ast> { } /// Content of a match expression. -#[derive(Debug, PartialEq, Clone, Copy)] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] pub struct Match<'ast> { /// Branches of the match expression, where the first component is the pattern on the left hand /// side of `=>` and the second component is the body of the branch. @@ -253,7 +254,7 @@ pub struct Match<'ast> { } /// A type and/or contract annotation. -#[derive(Debug, PartialEq, Clone, Default)] +#[derive(Debug, PartialEq, Eq, Clone, Default)] pub struct Annotation<'ast> { /// The type annotation (using `:`). pub typ: Option>, @@ -265,12 +266,12 @@ pub struct Annotation<'ast> { impl<'ast> Annotation<'ast> { /// Returns the main annotation, which is either the type annotation if any, or the first /// contract annotation. - pub fn first(&'ast self) -> Option<&'ast Type<'ast>> { + pub fn first<'a>(&'a self) -> Option<&'a Type<'ast>> { self.typ.as_ref().or(self.contracts.iter().next()) } /// Iterates over the annotations, starting by the type and followed by the contracts. - pub fn iter(&'ast self) -> impl Iterator> { + pub fn iter<'a>(&'a self) -> impl Iterator> { self.typ.iter().chain(self.contracts.iter()) } @@ -306,6 +307,380 @@ pub enum Import<'ast> { Package { id: Ident }, } +impl<'ast> TraverseAlloc<'ast, Ast<'ast>> for Ast<'ast> { + /// Traverse through all [Ast] in the tree. + /// + /// This also recurses into the terms that are contained in [typ::Type] subtrees. + fn traverse( + self, + alloc: &'ast AstAlloc, + f: &mut F, + order: TraverseOrder, + ) -> Result, E> + where + F: FnMut(Ast<'ast>) -> Result, E>, + { + let ast = match order { + TraverseOrder::TopDown => f(self)?, + TraverseOrder::BottomUp => self, + }; + let pos = ast.pos; + + let result = match &ast.node { + Node::Fun { args, body } => { + let args = traverse_alloc_many(alloc, args.iter().cloned(), f, order)?; + let body = alloc.alloc((*body).clone().traverse(alloc, f, order)?); + + Ast { + node: Node::Fun { args, body }, + pos, + } + } + Node::Let { + bindings, + body, + rec, + } => { + let bindings = traverse_alloc_many(alloc, bindings.iter().cloned(), f, order)?; + let body = alloc.alloc((*body).clone().traverse(alloc, f, order)?); + + Ast { + node: Node::Let { + bindings, + body, + rec: *rec, + }, + pos, + } + } + Node::App { head, args } => { + let head = alloc.alloc((*head).clone().traverse(alloc, f, order)?); + let args = traverse_alloc_many(alloc, args.iter().cloned(), f, order)?; + + Ast { + node: Node::App { head, args }, + pos, + } + } + Node::Match(data) => { + let branches = traverse_alloc_many(alloc, data.branches.iter().cloned(), f, order)?; + + Ast { + node: Node::Match(Match { branches }), + pos, + } + } + Node::PrimOpApp { op, args } => { + let args = traverse_alloc_many(alloc, args.iter().cloned(), f, order)?; + + Ast { + node: Node::PrimOpApp { op, args }, + pos, + } + } + Node::Record(record) => { + let field_defs = + traverse_alloc_many(alloc, record.field_defs.iter().cloned(), f, order)?; + + Ast { + node: Node::Record(alloc.alloc(record::Record { + field_defs, + open: record.open, + })), + pos, + } + } + Node::Array(elts) => { + let elts = traverse_alloc_many(alloc, elts.iter().cloned(), f, order)?; + + Ast { + node: Node::Array(elts), + pos, + } + } + Node::StringChunks(chunks) => { + let chunks_res: Result>>, E> = chunks + .iter() + .cloned() + .map(|chunk| match chunk { + chunk @ StringChunk::Literal(_) => Ok(chunk), + StringChunk::Expr(ast, indent) => { + Ok(StringChunk::Expr(ast.traverse(alloc, f, order)?, indent)) + } + }) + .collect(); + + Ast { + node: Node::StringChunks(alloc.alloc_many(chunks_res?)), + pos, + } + } + Node::Annotated { annot, inner } => { + let annot = alloc.alloc((*annot).clone().traverse(alloc, f, order)?); + let inner = alloc.alloc((*inner).clone().traverse(alloc, f, order)?); + + Ast { + node: Node::Annotated { annot, inner }, + pos, + } + } + Node::Type(typ) => { + let typ = alloc.alloc((*typ).clone().traverse(alloc, f, order)?); + + Ast { + node: Node::Type(typ), + pos, + } + } + _ => ast, + }; + + match order { + TraverseOrder::TopDown => Ok(result), + TraverseOrder::BottomUp => f(result), + } + } + + fn traverse_ref( + &self, + f: &mut dyn FnMut(&Ast<'ast>, &S) -> TraverseControl, + state: &S, + ) -> Option { + let child_state = match f(self, state) { + TraverseControl::Continue => None, + TraverseControl::ContinueWithScope(s) => Some(s), + TraverseControl::SkipBranch => { + return None; + } + TraverseControl::Return(ret) => { + return Some(ret); + } + }; + let state = child_state.as_ref().unwrap_or(state); + + match self.node { + Node::Null + | Node::Bool(_) + | Node::Number(_) + | Node::String(_) + | Node::Var(_) + | Node::Import(_) + | Node::ParseError(_) => None, + Node::IfThenElse { + cond, + then_branch, + else_branch, + } => cond + .traverse_ref(f, state) + .or_else(|| then_branch.traverse_ref(f, state)) + .or_else(|| else_branch.traverse_ref(f, state)), + Node::EnumVariant { tag: _, arg } => arg?.traverse_ref(f, state), + Node::StringChunks(chunks) => chunks.iter().find_map(|chk| { + if let StringChunk::Expr(term, _) = chk { + term.traverse_ref(f, state) + } else { + None + } + }), + Node::Fun { args, body } => args + .iter() + .find_map(|arg| arg.traverse_ref(f, state)) + .or_else(|| body.traverse_ref(f, state)), + Node::PrimOpApp { op: _, args } => { + args.iter().find_map(|arg| arg.traverse_ref(f, state)) + } + Node::Let { + bindings, + body, + rec: _, + } => bindings + .iter() + .find_map(|binding| binding.traverse_ref(f, state)) + .or_else(|| body.traverse_ref(f, state)), + Node::App { head, args } => head + .traverse_ref(f, state) + .or_else(|| args.iter().find_map(|arg| arg.traverse_ref(f, state))), + Node::Record(data) => data + .field_defs + .iter() + .find_map(|field_def| field_def.traverse_ref(f, state)), + Node::Match(data) => data.branches.iter().find_map( + |MatchBranch { + pattern, + guard, + body, + }| { + pattern + .traverse_ref(f, state) + .or_else(|| { + if let Some(cond) = guard.as_ref() { + cond.traverse_ref(f, state) + } else { + None + } + }) + .or_else(|| body.traverse_ref(f, state)) + }, + ), + Node::Array(elts) => elts.iter().find_map(|t| t.traverse_ref(f, state)), + Node::Annotated { annot, inner } => annot + .traverse_ref(f, state) + .or_else(|| inner.traverse_ref(f, state)), + Node::Type(typ) => typ.traverse_ref(f, state), + } + } +} + +impl<'ast> TraverseAlloc<'ast, Type<'ast>> for Ast<'ast> { + fn traverse( + self, + alloc: &'ast AstAlloc, + f: &mut F, + order: TraverseOrder, + ) -> Result, E> + where + F: FnMut(Type<'ast>) -> Result, E>, + { + self.traverse( + alloc, + &mut |ast: Ast<'ast>| match &ast.node { + Node::Type(typ) => { + let typ = alloc.alloc((*typ).clone().traverse(alloc, f, order)?); + Ok(Ast { + node: Node::Type(typ), + pos: ast.pos, + }) + } + _ => Ok(ast), + }, + order, + ) + } + + fn traverse_ref( + &self, + f: &mut dyn FnMut(&Type<'ast>, &S) -> TraverseControl, + state: &S, + ) -> Option { + self.traverse_ref( + &mut |ast: &Ast<'ast>, state: &S| match &ast.node { + Node::Type(typ) => typ.traverse_ref(f, state).into(), + _ => TraverseControl::Continue, + }, + state, + ) + } +} + +impl<'ast> TraverseAlloc<'ast, Ast<'ast>> for Annotation<'ast> { + fn traverse( + self, + alloc: &'ast AstAlloc, + f: &mut F, + order: TraverseOrder, + ) -> Result + where + F: FnMut(Ast<'ast>) -> Result, E>, + { + let typ = self + .typ + .map(|typ| typ.traverse(alloc, f, order)) + .transpose()?; + let contracts = traverse_alloc_many(alloc, self.contracts.iter().cloned(), f, order)?; + + Ok(Annotation { typ, contracts }) + } + + fn traverse_ref( + &self, + f: &mut dyn FnMut(&Ast<'ast>, &S) -> TraverseControl, + scope: &S, + ) -> Option { + self.typ + .iter() + .chain(self.contracts.iter()) + .find_map(|c| c.traverse_ref(f, scope)) + } +} + +impl<'ast> TraverseAlloc<'ast, Ast<'ast>> for LetBinding<'ast> { + fn traverse( + self, + alloc: &'ast AstAlloc, + f: &mut F, + order: TraverseOrder, + ) -> Result + where + F: FnMut(Ast<'ast>) -> Result, E>, + { + let pattern = self.pattern.traverse(alloc, f, order)?; + + let metadata = LetMetadata { + annotation: self.metadata.annotation.traverse(alloc, f, order)?, + doc: self.metadata.doc, + }; + + let value = self.value.traverse(alloc, f, order)?; + + Ok(LetBinding { + pattern, + metadata, + value, + }) + } + + fn traverse_ref( + &self, + f: &mut dyn FnMut(&Ast<'ast>, &S) -> TraverseControl, + scope: &S, + ) -> Option { + self.metadata + .annotation + .traverse_ref(f, scope) + .or_else(|| self.value.traverse_ref(f, scope)) + } +} + +impl<'ast> TraverseAlloc<'ast, Ast<'ast>> for MatchBranch<'ast> { + fn traverse( + self, + alloc: &'ast AstAlloc, + f: &mut F, + order: TraverseOrder, + ) -> Result + where + F: FnMut(Ast<'ast>) -> Result, E>, + { + let pattern = self.pattern.traverse(alloc, f, order)?; + let body = self.body.traverse(alloc, f, order)?; + let guard = self + .guard + .map(|guard| guard.traverse(alloc, f, order)) + .transpose()?; + + Ok(MatchBranch { + pattern, + guard, + body, + }) + } + + fn traverse_ref( + &self, + f: &mut dyn FnMut(&Ast<'ast>, &S) -> TraverseControl, + scope: &S, + ) -> Option { + self.pattern + .traverse_ref(f, scope) + .or_else(|| self.body.traverse_ref(f, scope)) + .or_else(|| { + self.guard + .as_ref() + .and_then(|guard| guard.traverse_ref(f, scope)) + }) + } +} + /// Marker trait for AST nodes that don't need to be dropped (in practice, it's often equivalent to /// not owning any heap allocated data) and can be used with [allocator][AstAlloc::alloc]. The /// current exceptions are [Number] and [crate::error::ParseError], which must be allocated through @@ -317,7 +692,9 @@ impl Allocable for StringChunk {} impl Allocable for LetBinding<'_> {} impl Allocable for PrimOp {} impl Allocable for Annotation<'_> {} +impl Allocable for MatchBranch<'_> {} +impl Allocable for Record<'_> {} impl Allocable for record::FieldPathElem<'_> {} impl Allocable for FieldDef<'_> {} @@ -636,3 +1013,18 @@ impl<'ast> From> for Ast<'ast> { } } } + +/// Similar to `TryFrom`, but takes an additional allocator for conversion from and to +/// [crate::bytecode::ast::Ast] that requires to thread an explicit allocator. +/// +/// We chose a different name than `try_from` for the method - although it has a different +/// signature from the standard `TryFrom` (two arguments vs one) - to avoid confusing the compiler +/// which would otherwise have difficulties disambiguating calls like `Ast::try_from`. +pub(crate) trait TryConvert<'ast, T> +where + Self: Sized, +{ + type Error; + + fn try_convert(alloc: &'ast AstAlloc, from: T) -> Result; +} diff --git a/core/src/bytecode/ast/pattern/mod.rs b/core/src/bytecode/ast/pattern/mod.rs index 51d0123af2..50907eaf65 100644 --- a/core/src/bytecode/ast/pattern/mod.rs +++ b/core/src/bytecode/ast/pattern/mod.rs @@ -3,11 +3,11 @@ use std::collections::{hash_map::Entry, HashMap}; use super::{Annotation, Ast, Number}; -use crate::{identifier::LocIdent, parser::error::ParseError, position::TermPos}; +use crate::{identifier::LocIdent, parser::error::ParseError, position::TermPos, traverse::*}; pub mod bindings; -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone)] pub enum PatternData<'ast> { /// A wildcard pattern, matching any value. As opposed to any, this pattern doesn't bind any /// variable. @@ -29,7 +29,7 @@ pub enum PatternData<'ast> { /// A generic pattern, that can appear in a match expression (not yet implemented) or in a /// destructuring let-binding. -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct Pattern<'ast> { /// The content of this pattern pub data: PatternData<'ast>, @@ -41,7 +41,7 @@ pub struct Pattern<'ast> { } /// An enum pattern, including both an enum tag and an enum variant. -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct EnumPattern<'ast> { pub tag: LocIdent, pub pattern: Option>, @@ -50,7 +50,7 @@ pub struct EnumPattern<'ast> { /// A field pattern inside a record pattern. Every field can be annotated with a type, contracts or /// with a default value. -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct FieldPattern<'ast> { /// The name of the matched field. For example, in `{..., foo = {bar, baz}, ...}`, the matched /// identifier is `foo`. @@ -67,7 +67,7 @@ pub struct FieldPattern<'ast> { } /// A record pattern. -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct RecordPattern<'ast> { /// The patterns for each field in the record. pub patterns: &'ast [FieldPattern<'ast>], @@ -78,7 +78,7 @@ pub struct RecordPattern<'ast> { } /// An array pattern. -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct ArrayPattern<'ast> { /// The patterns of the elements of the array. pub patterns: &'ast [Pattern<'ast>], @@ -97,13 +97,13 @@ impl ArrayPattern<'_> { } /// A constant pattern, matching a constant value. -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct ConstantPattern<'ast> { pub data: ConstantPatternData<'ast>, pub pos: TermPos, } -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone)] pub enum ConstantPatternData<'ast> { Bool(bool), Number(&'ast Number), @@ -111,7 +111,7 @@ pub enum ConstantPatternData<'ast> { Null, } -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct OrPattern<'ast> { pub patterns: &'ast [Pattern<'ast>], pub pos: TermPos, @@ -119,7 +119,7 @@ pub struct OrPattern<'ast> { /// The tail of a data structure pattern (record or array) which might capture the rest of said /// data structure. -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone)] pub enum TailPattern { /// The pattern is closed, i.e. it doesn't allow more fields. For example, `{foo, bar}`. Empty, @@ -131,6 +131,7 @@ pub enum TailPattern { } impl Pattern<'_> { + /// Creates an `Any` pattern with the corresponding capture name. pub fn any(id: LocIdent) -> Self { let pos = id.pos; @@ -140,6 +141,15 @@ impl Pattern<'_> { pos, } } + + /// Returns `Some(id)` if this pattern is an [PatternData::Any] pattern, `None` otherwise. + pub fn try_as_any(&self) -> Option { + if let PatternData::Any(id) = &self.data { + Some(*id) + } else { + None + } + } } impl TailPattern { @@ -200,6 +210,131 @@ impl RecordPattern<'_> { } } +impl<'ast> TraverseAlloc<'ast, Ast<'ast>> for Pattern<'ast> { + fn traverse( + self, + alloc: &'ast super::AstAlloc, + f: &mut F, + order: TraverseOrder, + ) -> Result + where + F: FnMut(Ast<'ast>) -> Result, E>, + { + match self.data { + data @ (PatternData::Wildcard | PatternData::Any(_) | PatternData::Constant(_)) => { + Ok(Pattern { data, ..self }) + } + PatternData::Record(record) => { + let record = record.clone(); + let patterns = + traverse_alloc_many(alloc, record.patterns.iter().cloned(), f, order)?; + + Ok(Pattern { + data: PatternData::Record(alloc.alloc(RecordPattern { patterns, ..record })), + ..self + }) + } + PatternData::Array(array) => { + let array = array.clone(); + let patterns = + traverse_alloc_many(alloc, array.patterns.iter().cloned(), f, order)?; + + Ok(Pattern { + data: PatternData::Array(alloc.alloc(ArrayPattern { patterns, ..array })), + ..self + }) + } + PatternData::Enum(enum_pat) => { + let enum_pat = enum_pat.clone(); + let pattern = enum_pat + .pattern + .map(|p| p.traverse(alloc, f, order)) + .transpose()?; + + Ok(Pattern { + data: PatternData::Enum(alloc.alloc(EnumPattern { + pattern, + ..enum_pat + })), + ..self + }) + } + PatternData::Or(or) => { + let or = or.clone(); + let patterns = traverse_alloc_many(alloc, or.patterns.iter().cloned(), f, order)?; + + Ok(Pattern { + data: PatternData::Or(alloc.alloc(OrPattern { patterns, ..or })), + ..self + }) + } + } + } + + fn traverse_ref( + &self, + f: &mut dyn FnMut(&Ast<'ast>, &S) -> TraverseControl, + scope: &S, + ) -> Option { + match &self.data { + PatternData::Wildcard | PatternData::Any(_) | PatternData::Constant(_) => None, + PatternData::Record(record) => record + .patterns + .iter() + .find_map(|field_pat| field_pat.traverse_ref(f, scope)), + PatternData::Array(array) => array + .patterns + .iter() + .find_map(|pat| pat.traverse_ref(f, scope)), + PatternData::Enum(enum_pat) => enum_pat + .pattern + .as_ref() + .and_then(|pat| pat.traverse_ref(f, scope)), + PatternData::Or(or) => or + .patterns + .iter() + .find_map(|pat| pat.traverse_ref(f, scope)), + } + } +} + +impl<'ast> TraverseAlloc<'ast, Ast<'ast>> for FieldPattern<'ast> { + fn traverse( + self, + alloc: &'ast super::AstAlloc, + f: &mut F, + order: TraverseOrder, + ) -> Result + where + F: FnMut(Ast<'ast>) -> Result, E>, + { + let annotation = self.annotation.traverse(alloc, f, order)?; + let default = self + .default + .map(|d| d.traverse(alloc, f, order)) + .transpose()?; + let pattern = self.pattern.traverse(alloc, f, order)?; + + Ok(FieldPattern { + annotation, + default, + pattern, + ..self + }) + } + + fn traverse_ref( + &self, + f: &mut dyn FnMut(&Ast<'ast>, &S) -> TraverseControl, + scope: &S, + ) -> Option { + self.annotation + .traverse_ref(f, scope) + .or_else(|| self.default.as_ref().and_then(|d| d.traverse_ref(f, scope))) + .or_else(|| self.pattern.traverse_ref(f, scope)) + } +} + //TODO: restore Pretty and Display. //impl_display_from_pretty!(PatternData); //impl_display_from_pretty!(Pattern); diff --git a/core/src/bytecode/ast/record.rs b/core/src/bytecode/ast/record.rs index e0e9aedb78..79f8324a02 100644 --- a/core/src/bytecode/ast/record.rs +++ b/core/src/bytecode/ast/record.rs @@ -1,4 +1,4 @@ -use super::{Annotation, Ast, AstAlloc}; +use super::{Annotation, Ast, AstAlloc, TraverseAlloc, TraverseControl, TraverseOrder}; use crate::{identifier::LocIdent, position::TermPos}; @@ -9,7 +9,7 @@ use std::rc::Rc; /// Element of a record field path in a record field definition. For example, in `{ a."%{"hello-" /// ++ "world"}".c = true }`, the path `a."%{b}".c` is composed of three elements: an identifier /// `a`, an expression `"hello" ++ "world"`, and another identifier `c`. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Eq)] pub enum FieldPathElem<'ast> { /// A statically known identifier. Ident(LocIdent), @@ -38,7 +38,7 @@ impl<'ast> FieldPathElem<'ast> { alloc.alloc_singleton(FieldPathElem::Ident(ident)) } - /// Crate a path composed of a single dynamic expression. + /// Create a path composed of a single dynamic expression. pub fn single_expr_path(alloc: &'ast AstAlloc, expr: Ast<'ast>) -> &'ast [FieldPathElem<'ast>] { alloc.alloc_singleton(FieldPathElem::Expr(expr)) } @@ -48,19 +48,24 @@ impl<'ast> FieldPathElem<'ast> { pub fn try_as_ident(&self) -> Option { match self { FieldPathElem::Ident(ident) => Some(*ident), - FieldPathElem::Expr(expr) => { - expr.node.try_str_chunk_as_static_str().map(LocIdent::from) - } + FieldPathElem::Expr(expr) => expr + .node + .try_str_chunk_as_static_str() + .map(|s| LocIdent::from(s).with_pos(expr.pos)), } } } /// A field definition. A field is defined by a dot-separated path of identifier or interpolated /// strings, a potential value, and associated metadata. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct FieldDef<'ast> { /// A sequence of field path elements, composing the left hand side (with respect to the `=`) /// of the field definition. + /// + /// # Invariants + /// + /// **Important**: The path must be non-empty, or some of `FieldDef` methods will panic. pub path: &'ast [FieldPathElem<'ast>], /// The metadata and the optional value bundled as a field. pub metadata: FieldMetadata<'ast>, @@ -79,10 +84,16 @@ impl FieldDef<'_> { None } } + + /// Try to get the declared field name, that is the last element of the path, as a static + /// identifier. + pub fn name_as_ident(&self) -> Option { + self.path.last().expect("empty field path").try_as_ident() + } } /// The metadata attached to record fields. -#[derive(Debug, PartialEq, Clone, Default)] +#[derive(Debug, PartialEq, Eq, Clone, Default)] pub struct FieldMetadata<'ast> { /// The documentation of the field. This is allocated once and for all and shared through a /// reference-counted pointer. @@ -121,7 +132,7 @@ impl<'ast> From> for FieldMetadata<'ast> { } /// A nickel record literal. -#[derive(Clone, Debug, Default, PartialEq)] +#[derive(Clone, Debug, Default, PartialEq, Eq)] pub struct Record<'ast> { /// Field definitions. pub field_defs: &'ast [FieldDef<'ast>], @@ -139,4 +150,68 @@ impl Record<'_> { pub fn open(self) -> Self { Record { open: true, ..self } } + + /// Returns `false` if at least one field in the first layer of the record (that is the first + /// element of each field path) is defined dynamically, and `true` otherwise. + pub fn has_static_structure(&self) -> bool { + self.field_defs + .iter() + .all(|field| field.path.iter().any(|elem| elem.try_as_ident().is_some())) + } +} + +impl<'ast> TraverseAlloc<'ast, Ast<'ast>> for FieldDef<'ast> { + fn traverse( + self, + alloc: &'ast AstAlloc, + f: &mut F, + order: TraverseOrder, + ) -> Result + where + F: FnMut(Ast<'ast>) -> Result, E>, + { + let path: Result, E> = self + .path + .iter() + .map(|elem| match elem { + FieldPathElem::Ident(ident) => Ok(FieldPathElem::Ident(*ident)), + FieldPathElem::Expr(expr) => expr + .clone() + .traverse(alloc, f, order) + .map(FieldPathElem::Expr), + }) + .collect(); + + let metadata = FieldMetadata { + annotation: self.metadata.annotation.traverse(alloc, f, order)?, + ..self.metadata + }; + + let value = self + .value + .map(|v| v.traverse(alloc, f, order)) + .transpose()?; + + Ok(FieldDef { + path: alloc.alloc_many(path?), + metadata, + value, + pos: self.pos, + }) + } + + fn traverse_ref( + &self, + f: &mut dyn FnMut(&Ast<'ast>, &S) -> TraverseControl, + scope: &S, + ) -> Option { + self.path + .iter() + .find_map(|elem| match elem { + FieldPathElem::Ident(_) => None, + FieldPathElem::Expr(expr) => expr.traverse_ref(f, scope), + }) + .or_else(|| self.metadata.annotation.traverse_ref(f, scope)) + .or_else(|| self.value.as_ref().and_then(|v| v.traverse_ref(f, scope))) + } } diff --git a/core/src/bytecode/ast/typ.rs b/core/src/bytecode/ast/typ.rs index 807064cdc5..76ff51d015 100644 --- a/core/src/bytecode/ast/typ.rs +++ b/core/src/bytecode/ast/typ.rs @@ -1,7 +1,7 @@ //! Representation of Nickel types in the AST. -use super::{Ast, TermPos}; -use crate::typ as mainline_typ; +use super::{Ast, AstAlloc, TermPos}; +use crate::{traverse::*, typ as mainline_typ}; pub use mainline_typ::{EnumRowF, EnumRowsF, RecordRowF, RecordRowsF, TypeF}; /// The recursive unrolling of a type, that is when we "peel off" the top-level layer to find the actual @@ -17,16 +17,16 @@ pub type RecordRowsUnr<'ast> = RecordRowsF<&'ast Type<'ast>, &'ast RecordRows<'a /// Concrete, recursive definition for an enum row. pub type EnumRow<'ast> = EnumRowF<&'ast Type<'ast>>; /// Concrete, recursive definition for enum rows. -#[derive(Clone, PartialEq, Debug)] +#[derive(Clone, PartialEq, Eq, Debug)] pub struct EnumRows<'ast>(pub EnumRowsUnr<'ast>); /// Concrete, recursive definition for a record row. pub type RecordRow<'ast> = RecordRowF<&'ast Type<'ast>>; -#[derive(Clone, PartialEq, Debug)] +#[derive(Clone, PartialEq, Eq, Debug)] /// Concrete, recursive definition for record rows. pub struct RecordRows<'ast>(pub RecordRowsUnr<'ast>); /// Concrete, recursive type for a Nickel type. -#[derive(Clone, PartialEq, Debug)] +#[derive(Clone, PartialEq, Eq, Debug)] pub struct Type<'ast> { pub typ: TypeUnr<'ast>, pub pos: TermPos, @@ -46,6 +46,14 @@ impl<'ast> Type<'ast> { pub fn with_pos(self, pos: TermPos) -> Type<'ast> { Type { pos, ..self } } + + /// Searches for a [crate::typ::TypeF]. If one is found, returns the term it contains. + pub fn find_contract(&self) -> Option<&'ast Ast<'ast>> { + self.find_map(|ty: &Type| match &ty.typ { + TypeF::Contract(f) => Some(*f), + _ => None, + }) + } } impl<'ast> TypeUnr<'ast> { @@ -53,3 +61,157 @@ impl<'ast> TypeUnr<'ast> { Type { typ: self, pos } } } + +impl<'ast> TraverseAlloc<'ast, Type<'ast>> for Type<'ast> { + fn traverse( + self, + alloc: &'ast AstAlloc, + f: &mut F, + order: TraverseOrder, + ) -> Result + where + F: FnMut(Type<'ast>) -> Result, E>, + { + let pre_map = match order { + TraverseOrder::TopDown => f(self)?, + TraverseOrder::BottomUp => self, + }; + + // traverse keeps track of state in the FnMut function. try_map_state + // keeps track of it in a separate state variable. we can pass the + // former into the latter by treating the function itself as the state + let typ = pre_map.typ.try_map_state( + |ty, f| Ok(alloc.alloc(ty.clone().traverse(alloc, f, order)?)), + |rrows, f| rrows.traverse(alloc, f, order), + |erows, _| Ok(erows), + |ctr, _| Ok(ctr), + f, + )?; + + let post_map = Type { typ, ..pre_map }; + + match order { + TraverseOrder::TopDown => Ok(post_map), + TraverseOrder::BottomUp => f(post_map), + } + } + + fn traverse_ref( + &self, + f: &mut dyn FnMut(&Type<'ast>, &S) -> TraverseControl, + state: &S, + ) -> Option { + let child_state = match f(self, state) { + TraverseControl::Continue => None, + TraverseControl::ContinueWithScope(s) => Some(s), + TraverseControl::SkipBranch => { + return None; + } + TraverseControl::Return(ret) => { + return Some(ret); + } + }; + let state = child_state.as_ref().unwrap_or(state); + + match &self.typ { + TypeF::Dyn + | TypeF::Number + | TypeF::Bool + | TypeF::String + | TypeF::ForeignId + | TypeF::Symbol + | TypeF::Var(_) + | TypeF::Enum(_) + | TypeF::Wildcard(_) => None, + TypeF::Contract(ast) => ast.traverse_ref(f, state), + TypeF::Arrow(t1, t2) => t1 + .traverse_ref(f, state) + .or_else(|| t2.traverse_ref(f, state)), + TypeF::Forall { body: t, .. } + | TypeF::Dict { type_fields: t, .. } + | TypeF::Array(t) => t.traverse_ref(f, state), + TypeF::Record(rrows) => rrows.traverse_ref(f, state), + } + } +} + +impl<'ast> TraverseAlloc<'ast, Ast<'ast>> for Type<'ast> { + fn traverse( + self, + alloc: &'ast AstAlloc, + f: &mut F, + order: TraverseOrder, + ) -> Result + where + F: FnMut(Ast<'ast>) -> Result, E>, + { + self.traverse( + alloc, + &mut |ty: Type| match ty.typ { + TypeF::Contract(t) => t + .clone() + .traverse(alloc, f, order) + .map(|t| Type::from(TypeF::Contract(alloc.alloc(t))).with_pos(ty.pos)), + _ => Ok(ty), + }, + order, + ) + } + + fn traverse_ref( + &self, + f: &mut dyn FnMut(&Ast<'ast>, &S) -> TraverseControl, + state: &S, + ) -> Option { + self.traverse_ref( + &mut |ty: &Type, s: &S| match &ty.typ { + TypeF::Contract(t) => { + if let Some(ret) = t.traverse_ref(f, s) { + TraverseControl::Return(ret) + } else { + TraverseControl::SkipBranch + } + } + _ => TraverseControl::Continue, + }, + state, + ) + } +} + +impl<'ast> TraverseAlloc<'ast, Type<'ast>> for RecordRows<'ast> { + fn traverse( + self, + alloc: &'ast AstAlloc, + f: &mut F, + order: TraverseOrder, + ) -> Result, E> + where + F: FnMut(Type<'ast>) -> Result, E>, + { + // traverse keeps track of state in the FnMut function. try_map_state + // keeps track of it in a separate state variable. we can pass the + // former into the latter by treating the function itself as the state + let rows = self.0.try_map_state( + |ty, f| Ok(alloc.alloc(ty.clone().traverse(alloc, f, order)?)), + |rrows, f| Ok(alloc.alloc(rrows.clone().traverse(alloc, f, order)?)), + f, + )?; + + Ok(RecordRows(rows)) + } + + fn traverse_ref( + &self, + f: &mut dyn FnMut(&Type<'ast>, &S) -> TraverseControl, + state: &S, + ) -> Option { + match &self.0 { + RecordRowsF::Extend { row, tail } => row + .typ + .traverse_ref(f, state) + .or_else(|| tail.traverse_ref(f, state)), + _ => None, + } + } +} diff --git a/core/src/bytecode/mod.rs b/core/src/bytecode/mod.rs index b65bb35b55..be69349a14 100644 --- a/core/src/bytecode/mod.rs +++ b/core/src/bytecode/mod.rs @@ -4,3 +4,4 @@ //! default in mainline Nickel. pub mod ast; +pub mod typecheck; diff --git a/core/src/bytecode/typecheck/eq.rs b/core/src/bytecode/typecheck/eq.rs new file mode 100644 index 0000000000..5f289ea8e0 --- /dev/null +++ b/core/src/bytecode/typecheck/eq.rs @@ -0,0 +1,674 @@ +//! Computation of type equality for contracts. +//! +//! Determine if two contracts are equal as opaque types. Used to decide if two contract should +//! unify. +//! +//! ## Aliases +//! +//! One basic case we want to handle is aliases, which come in handy for parametrized contracts. +//! For example, in the following: +//! +//! `let Alias = Foo "bar" "baz" in ...`. +//! +//! We want to equate `Alias` with `Foo "bar" "baz"`. +//! +//! We also want to equate different aliases with the same definition: +//! `let Alias' = Foo "bar" "baz" in ...`, or `let Alias' = Alias in ...`. +//! +//! We want that `Alias type_eq Alias'`. +//! +//! ## Recursion +//! +//! We must refrain from following all variables links blindly, as there could be cycles in the +//! graph leading to an infinite loop: +//! +//! ```nickel +//! { +//! Foo = Bar, +//! Bar = Foo, +//! } +//! ``` +//! +//! Because we just follows variables, and don't apply functions, we can detect cycles while +//! walking the graph. Still, as it is potentially performed many times during typechecking, type +//! equality ought to stay reasonably cheap. We choose to just set an arbitrary limit (the gas) on +//! the number of variable links that the type equality may follow. Doing so, we don't have to +//! worry about loops anymore. +//! +//! Note: we currently don't support recursive let or recursive record definitions, so this example +//! wouldn't work anyway, and there's no way to have an infinite cycle. Style, gas is a simple way +//! to bound the work done by the type equality computation. +//! +//! ## Equality on terms +//! +//! The terms inside a type may be arbitrarily complex. Primops applications, `match`, and the like +//! are quite unlikely to appear inside an annotation (they surely appear inside contract +//! definitions, both are usually put in a variable first - they unlinkely to appear _inline_ in a +//! type or contract annotation is what we mean) +//! +//! We don't want to compare functions syntactically either. The spirit of this implementation is +//! to structurally equate aliases and simple constructs that are likely to appear inlined inside +//! an annotation (applications, records, primitive constants and arrays, mostly). +//! +//! We first test for physical equality (both as an optimization and to detect two variables +//! pointing to the same contract definition in the AST). If the comparison fails, we perform +//! structural recursion, unfolding simple forms and following variables with a limited number of +//! times. For anything more complex, we bail out (returning `false`). + +use super::*; +use crate::{ + bytecode::ast::{primop::PrimOp, record}, + identifier::LocIdent, + term::IndexMap, + typ::VarKind, +}; + +/// The maximal number of variable links we want to unfold before abandoning the check. It should +/// stay low, but has been fixed arbitrarily: feel fee to increase reasonably if it turns out +/// legitimate type equalities between simple contracts are unduly rejected in practice. +pub const MAX_GAS: u8 = 12; + +/// State threaded through the type equality computation. +#[derive(Copy, Clone, Default)] +struct State { + /// Used to generate temporary rigid type variables for substituting type variables when + /// comparing foralls. Those ids never escape the type equality computations and are used + /// solely as rigid type variables: this is why they don't need proper allocation in the + /// unification table or to care about those ids clashing with the one generated by the + /// typechecker. Generated type constants simply needs to be unique for the duration of the + /// type equality computation. + var_uid: usize, + + /// The current gas remaining for variable substitutions. Once it reaches zero and we encounter + /// a variable, we abort the computation and return false. + gas: u8, +} + +impl State { + fn new() -> Self { + State { + var_uid: 0, + gas: MAX_GAS, + } + } + + /// Create a fresh unique id for a rigid type variable. + fn fresh_cst_id(&mut self) -> VarId { + let result = self.var_uid; + self.var_uid += 1; + result + } + + /// Try to consume one unit of gas for a variable substitution. Return true in case of success, + /// or false if the gas was already at zero. + fn use_gas(&mut self) -> bool { + if self.gas == 0 { + false + } else { + self.gas -= 1; + true + } + } +} + +pub trait TypeEq<'ast> { + /// Compute type equality. + /// + /// # Parameters + /// + /// - `env`: an environment mapping variables to their definition + fn type_eq(&self, other: &Self, env1: &TermEnv<'ast>, env2: &TermEnv<'ast>) -> bool; +} + +/// Values that can be statically compared for equality as Nickel types. This trait provides the +/// "internal" implementation that does the actual work but isn't public facing. [TypeEq] +trait TypeEqBounded<'ast> { + /// Compute type equality with a bounded number of variable links, stored in `state`. + fn type_eq_bounded( + &self, + other: &Self, + state: &mut State, + env1: &TermEnv<'ast>, + env2: &TermEnv<'ast>, + ) -> bool; +} + +impl<'ast, T> TypeEqBounded<'ast> for [T] +where + T: TypeEqBounded<'ast>, +{ + fn type_eq_bounded( + &self, + other: &Self, + state: &mut State, + env1: &TermEnv<'ast>, + env2: &TermEnv<'ast>, + ) -> bool { + self.len() == other.len() + && self + .iter() + .zip(other.iter()) + .all(|(x1, x2)| x1.type_eq_bounded(x2, state, env1, env2)) + } +} + +impl<'ast, T> TypeEqBounded<'ast> for Option +where + T: TypeEqBounded<'ast>, +{ + fn type_eq_bounded( + &self, + other: &Self, + state: &mut State, + env1: &TermEnv<'ast>, + env2: &TermEnv<'ast>, + ) -> bool { + match (self, other) { + (Some(x1), Some(x2)) => x1.type_eq_bounded(x2, state, env1, env2), + (None, None) => true, + _ => false, + } + } +} + +impl<'ast, T> TypeEqBounded<'ast> for &T +where + T: TypeEqBounded<'ast>, +{ + fn type_eq_bounded( + &self, + other: &Self, + state: &mut State, + env1: &TermEnv<'ast>, + env2: &TermEnv<'ast>, + ) -> bool { + (*self).type_eq_bounded(*other, state, env1, env2) + } +} + +impl<'ast> TypeEqBounded<'ast> for Ast<'ast> { + fn type_eq_bounded( + &self, + other: &Self, + state: &mut State, + env1: &TermEnv<'ast>, + env2: &TermEnv<'ast>, + ) -> bool { + // Test for physical equality as both an optimization and a way to cheaply equate complex + // contracts that happen to point to the same definition (while the purposely limited + // structural checks below may reject the equality) + if std::ptr::eq(self, other) && Environment::ptr_eq(&env1.0, &env2.0) { + return true; + } + + match (&self.node, &other.node) { + (Node::Null, Node::Null) => true, + (Node::Bool(b1), Node::Bool(b2)) => b1 == b2, + (Node::Number(n1), Node::Number(n2)) => n1 == n2, + (Node::String(s1), Node::String(s2)) => s1 == s2, + ( + Node::EnumVariant { + tag: tag1, + arg: arg1, + }, + Node::EnumVariant { + tag: tag2, + arg: arg2, + }, + ) => { + let arg_eq = match (arg1.as_ref(), arg2.as_ref()) { + (Some(arg1), Some(arg2)) => arg1.type_eq_bounded(arg2, state, env1, env2), + (None, None) => true, + _ => false, + }; + + tag1 == tag2 && arg_eq + } + // We only compare string chunks when they represent a plain string (they don't contain any + // interpolated expression), as static string may be currently parsed as such. We return + // false for anything more complex. + (Node::StringChunks(scs1), Node::StringChunks(scs2)) => { + scs1.len() == scs2.len() + && scs1 + .iter() + .zip(scs2.iter()) + .all(|(chunk1, chunk2)| match (chunk1, chunk2) { + (StringChunk::Literal(s1), StringChunk::Literal(s2)) => s1 == s2, + _ => false, + }) + } + ( + Node::App { + head: head1, + args: args1, + }, + Node::App { + head: head2, + args: args2, + }, + ) => { + head1.type_eq_bounded(head2, state, env1, env2) + && args1.type_eq_bounded(args2, state, env1, env2) + } + // All variables must be bound at this stage. This is checked by the typechecker when + // walking annotations. However, we may assume that `env` is a local environment (e.g. that + // it doesn't include the stdlib). In that case, free variables (unbound) may be deemed + // equal if they have the same identifier: whatever global environment the term will be put + // in, free variables are not redefined locally and will be bound to the same value in any + // case. + (Node::Var(id1), Node::Var(id2)) + if env1.0.get(&id1.ident()).is_none() && env2.0.get(&id2.ident()).is_none() => + { + id1 == id2 + } + // If both variables are equal and their environment are physically equal, then they point + // to the same thing. + // + // This case is supposed to handle co-recursive contracts such as `{Foo = Bar, Bar = Foo}`. + // Although we don't build recursive environment yet, we might in the future. + (Node::Var(id1), Node::Var(id2)) + if id1 == id2 && Environment::ptr_eq(&env1.0, &env2.0) => + { + true + } + (Node::Var(id), _) => { + state.use_gas() + && env1 + .0 + .get(&id.ident()) + .map(|(ast1, env1)| ast1.type_eq_bounded(other, state, env1, env2)) + .unwrap_or(false) + } + (_, Node::Var(id)) => { + state.use_gas() + && env2 + .0 + .get(&id.ident()) + .map(|(ast2, env2)| self.type_eq_bounded(ast2, state, env1, env2)) + .unwrap_or(false) + } + (Node::Record(r1), Node::Record(r2)) => r1.type_eq_bounded(r2, state, env1, env2), + (Node::Array(elts1), Node::Array(elts2)) => { + elts1.type_eq_bounded(elts2, state, env1, env2) + } + // We must compare the inner values as well as the corresponding contracts or type + // annotations. + ( + Node::Annotated { + annot: annot1, + inner: inner1, + }, + Node::Annotated { + annot: annot2, + inner: inner2, + }, + ) => { + // Questions: + // - does it really make sense to compare the annotations? + // - does it even happen to have contracts having themselves type annotations? + // - and in the latter case, should they be declared unequal because of that? + // + // The answer to the last question is probably yes, because contracts are fundamentally + // as powerful as function application, so they can change their argument. + + annot1.type_eq_bounded(annot2, state, env1, env2) + && inner1.type_eq_bounded(inner2, state, env1, env2) + } + ( + Node::PrimOpApp { + op: PrimOp::RecordStatAccess(id1), + args: args1, + }, + Node::PrimOpApp { + op: PrimOp::RecordStatAccess(id2), + args: args2, + }, + ) => id1 == id2 && args1.type_eq_bounded(args2, state, env1, env2), + (Node::Type(ty1), Node::Type(ty2)) => ty1.type_eq_bounded(ty2, state, env1, env2), + // We don't treat imports, parse errors, nor pairs of terms that don't have the same shape + _ => false, + } + } +} + +impl<'ast> TypeEqBounded<'ast> for Type<'ast> { + /// Perform the type equality comparison on types. Structurally recurse into type constructors and + /// test that subtypes or subterms (contracts) are equals. + /// + /// This function piggy backs on the type equality for [super::UnifType] implementation, + /// because we need to instantiate `foralls` with rigid type variables to properly compare + /// them. + /// + fn type_eq_bounded( + &self, + other: &Self, + state: &mut State, + env1: &TermEnv<'ast>, + env2: &TermEnv<'ast>, + ) -> bool { + let self_as_utype = UnifType::from_type(self.clone(), env1); + let other_as_utype = UnifType::from_type(other.clone(), env2); + self_as_utype.type_eq_bounded(&other_as_utype, state, env1, env2) + } +} + +impl<'ast, T> TypeEqBounded<'ast> for IndexMap +where + T: TypeEqBounded<'ast>, +{ + fn type_eq_bounded( + &self, + other: &Self, + state: &mut State, + env1: &TermEnv<'ast>, + env2: &TermEnv<'ast>, + ) -> bool { + self.len() == other.len() + && self.iter().all(|(id, v1)| { + other + .get(id) + .map(|v2| v1.type_eq_bounded(v2, state, env1, env2)) + .unwrap_or(false) + }) + } +} + +impl<'ast> TypeEqBounded<'ast> for UnifEnumRows<'ast> { + fn type_eq_bounded( + &self, + other: &Self, + state: &mut State, + env1: &TermEnv<'ast>, + env2: &TermEnv<'ast>, + ) -> bool { + let map_self: Option>> = self + .iter() + .map(|item| match item { + EnumRowsElt::Row(EnumRowF { id, typ: types }) => Some((id, types)), + _ => None, + }) + .collect(); + + let map_other: Option>> = other + .iter() + .map(|item| match item { + EnumRowsElt::Row(EnumRowF { id, typ: types }) => Some((id, types)), + _ => None, + }) + .collect(); + + let (Some(map_self), Some(map_other)) = (map_self, map_other) else { + return false; + }; + + map_self.type_eq_bounded(&map_other, state, env1, env2) + } +} + +impl<'ast> TypeEqBounded<'ast> for UnifRecordRows<'ast> { + fn type_eq_bounded( + &self, + other: &Self, + state: &mut State, + env1: &TermEnv<'ast>, + env2: &TermEnv<'ast>, + ) -> bool { + let map_self: Option> = self + .iter() + .map(|item| match item { + RecordRowsElt::Row(RecordRowF { id, typ: types }) => Some((id, types)), + _ => None, + }) + .collect(); + + let map_other: Option> = other + .iter() + .map(|item| match item { + RecordRowsElt::Row(RecordRowF { id, typ: types }) => Some((id, types)), + _ => None, + }) + .collect(); + + let (Some(map_self), Some(map_other)) = (map_self, map_other) else { + return false; + }; + + map_self.type_eq_bounded(&map_other, state, env1, env2) + } +} + +impl<'ast> TypeEqBounded<'ast> for Annotation<'ast> { + fn type_eq_bounded( + &self, + other: &Self, + state: &mut State, + env1: &TermEnv<'ast>, + env2: &TermEnv<'ast>, + ) -> bool { + self.typ.type_eq_bounded(&other.typ, state, env1, env2) + && self + .contracts + .type_eq_bounded(other.contracts, state, env1, env2) + } +} + +impl<'ast> TypeEqBounded<'ast> for record::FieldMetadata<'ast> { + fn type_eq_bounded( + &self, + other: &Self, + state: &mut State, + env1: &TermEnv<'ast>, + env2: &TermEnv<'ast>, + ) -> bool { + self.annotation + .type_eq_bounded(&other.annotation, state, env1, env2) + && self.opt == other.opt + && self.not_exported == other.not_exported + && self.priority == other.priority + } +} + +impl<'ast> TypeEqBounded<'ast> for record::FieldPathElem<'ast> { + fn type_eq_bounded( + &self, + other: &Self, + _state: &mut State, + _env1: &TermEnv<'ast>, + _env2: &TermEnv<'ast>, + ) -> bool { + // For now, we don't even try to compare interpolated expressions at all, and only compare + // static field definitions. + match (self.try_as_ident(), other.try_as_ident()) { + (Some(id1), Some(id2)) => id1 == id2, + _ => false, + } + } +} + +impl<'ast> TypeEqBounded<'ast> for record::Record<'ast> { + fn type_eq_bounded( + &self, + other: &Self, + state: &mut State, + env1: &TermEnv<'ast>, + env2: &TermEnv<'ast>, + ) -> bool { + use crate::identifier::FastOrdIdent; + + // We sort the field definitions based on their path. For dynamic fields, we don't have a + // good ordering (we could derive it, albeit it would be a bit artificial), so we just + // ignore this part - it might lead to equate a bit less than we could, in presence of + // dynamic fields on both side that are in different order, but this is at least sound. + fn sort_field_defs<'ast>(field_defs: &mut [&'ast FieldDef<'ast>]) { + field_defs.sort_by_cached_key(|field| -> Vec> { + field + .path + .iter() + .map(|path_elem| { + path_elem + .try_as_ident() + .as_ref() + .map(LocIdent::ident) + .map(FastOrdIdent) + }) + .collect() + }); + } + + let mut sorted_self: Vec<_> = self.field_defs.iter().collect(); + let mut sorted_other: Vec<_> = other.field_defs.iter().collect(); + sort_field_defs(&mut sorted_self); + sort_field_defs(&mut sorted_other); + + sorted_self + .as_slice() + .type_eq_bounded(sorted_other.as_slice(), state, env1, env2) + && self.open == other.open + } +} + +impl<'ast> TypeEqBounded<'ast> for FieldDef<'ast> { + /// Check for contract equality between record fields. Fields are equal if they are both without a + /// definition, or are both defined and their values are equal. + /// + /// The attached metadata must be equal as well: most record contracts are written as field with + /// metadata but without definition. For example, take `{ foo | {bar | Number}}` and `{foo | {bar | + /// String}}`. Those two record contracts are obviously not equal, but to know that, we have to + /// look at the contracts of each bar field. + fn type_eq_bounded( + &self, + other: &Self, + state: &mut State, + env1: &TermEnv<'ast>, + env2: &TermEnv<'ast>, + ) -> bool { + self.metadata + .type_eq_bounded(&other.metadata, state, env1, env2) + && self.path.type_eq_bounded(other.path, state, env1, env2) + && self.value.type_eq_bounded(&other.value, state, env1, env2) + } +} + +impl<'ast> TypeEqBounded<'ast> for UnifType<'ast> { + /// This function is used internally by the implementation of [TypeEq] for + /// [crate::bytecode::ast::typ::Type], but it makes a number of assumptions and isn't supposed + /// to be called from the outside. And indeed, computing type equality recursively on unifiable + /// type doesn't make a lot of sense: we can unify them instead. + /// + /// For example, we expect to never meet unification variables, and that all the rigid type + /// variables encountered have been introduced by `type_eq_bounded` itself. This is why we don't + /// need unique identifiers that are distinct from the one used during typechecking, and we can + /// just start from `0` when creating a new [State]. + fn type_eq_bounded( + &self, + other: &Self, + state: &mut State, + env1: &TermEnv<'ast>, + env2: &TermEnv<'ast>, + ) -> bool { + match (self, other) { + (UnifType::Concrete { typ: s1, .. }, UnifType::Concrete { typ: s2, .. }) => { + match (s1, s2) { + (TypeF::Wildcard(id1), TypeF::Wildcard(id2)) => id1 == id2, + (TypeF::Dyn, TypeF::Dyn) + | (TypeF::Number, TypeF::Number) + | (TypeF::Bool, TypeF::Bool) + | (TypeF::Symbol, TypeF::Symbol) + | (TypeF::String, TypeF::String) => true, + ( + TypeF::Dict { + type_fields: uty1, + flavour: attrs1, + }, + TypeF::Dict { + type_fields: uty2, + flavour: attrs2, + }, + ) if attrs1 == attrs2 => uty1.type_eq_bounded(uty2, state, env1, env2), + (TypeF::Array(uty1), TypeF::Array(uty2)) => { + uty1.type_eq_bounded(uty2, state, env1, env2) + } + (TypeF::Arrow(s1, t1), TypeF::Arrow(s2, t2)) => { + s1.type_eq_bounded(s2, state, env1, env2) + && t1.type_eq_bounded(t2, state, env1, env2) + } + (TypeF::Enum(uty1), TypeF::Enum(uty2)) => { + uty1.type_eq_bounded(uty2, state, env1, env2) + } + (TypeF::Record(uty1), TypeF::Record(uty2)) => { + uty1.type_eq_bounded(uty2, state, env1, env2) + } + (TypeF::Contract((t1, env1)), TypeF::Contract((t2, env2))) => { + t1.type_eq_bounded(t2, state, env1, env2) + } + ( + TypeF::Forall { + var: var1, + var_kind: var_kind1, + body: body1, + }, + TypeF::Forall { + var: var2, + var_kind: var_kind2, + body: body2, + }, + ) => { + let cst_id = state.fresh_cst_id(); + + if var_kind1 != var_kind2 { + return false; + } + + let body1 = body1.clone(); + let body2 = body2.clone(); + + let (uty1_subst, uty2_subst) = match var_kind1 { + VarKind::Type => ( + body1.subst(var1, &UnifType::Constant(cst_id)), + body2.subst(var2, &UnifType::Constant(cst_id)), + ), + VarKind::RecordRows { .. } => ( + body1.subst(var1, &UnifRecordRows::Constant(cst_id)), + body2.subst(var2, &UnifRecordRows::Constant(cst_id)), + ), + VarKind::EnumRows { .. } => ( + body1.subst(var1, &UnifEnumRows::Constant(cst_id)), + body2.subst(var2, &UnifEnumRows::Constant(cst_id)), + ), + }; + + uty1_subst.type_eq_bounded(&uty2_subst, state, env1, env2) + } + // We can't compare type variables without knowing what they are instantiated to, + // and all type variables should have been substituted at this point, so we bail + // out. + _ => false, + } + } + (UnifType::UnifVar { id: id1, .. }, UnifType::UnifVar { id: id2, .. }) => { + debug_assert!( + false, + "we shouldn't come across unification variables during type equality computation" + ); + id1 == id2 + } + (UnifType::Constant(i1), UnifType::Constant(i2)) => i1 == i2, + _ => false, + } + } +} + +/// Derive a [TypeEq] implementation from the internal [TypeEqBounded] implementation. We don't +/// necessarily want to do that for every type that implements [TypeEqBounded], for example the +/// implementation for [UnifType] makes some assumptions that make it unsuited for public usage. +macro_rules! derive_type_eq { + ($ty:ty) => { + impl<'ast> TypeEq<'ast> for $ty { + fn type_eq(&self, other: &Self, env1: &TermEnv<'ast>, env2: &TermEnv<'ast>) -> bool { + self.type_eq_bounded(other, &mut State::new(), env1, env2) + } + } + }; +} + +derive_type_eq!(Ast<'ast>); diff --git a/core/src/bytecode/typecheck/error.rs b/core/src/bytecode/typecheck/error.rs new file mode 100644 index 0000000000..94329fb349 --- /dev/null +++ b/core/src/bytecode/typecheck/error.rs @@ -0,0 +1,555 @@ +//! Internal error types for typechecking. +use super::{ + reporting::{self, ToType}, + State, UnifEnumRow, UnifRecordRow, UnifType, VarId, +}; + +use crate::{ + bytecode::ast::compat::ToMainline, + error::TypecheckError, + identifier::LocIdent, + label::ty_path, + position::TermPos, + typ::{TypeF, VarKindDiscriminant}, +}; + +/// Error during the unification of two row types. +#[derive(Debug, PartialEq)] +pub enum RowUnifError<'ast> { + /// The LHS had a binding that was missing in the RHS. + MissingRow(LocIdent), + /// The LHS had a `Dyn` tail that was missing in the RHS. + MissingDynTail, + /// The RHS had a binding that was not in the LHS. + ExtraRow(LocIdent), + /// The RHS had a additional `Dyn` tail. + ExtraDynTail, + /// There were two incompatible definitions for the same record row. + RecordRowMismatch { + id: LocIdent, + /// The underlying unification error that caused the mismatch. + cause: Box>, + }, + /// There were two incompatible definitions for the same enum row. + /// + /// Because enum rows have an optional argument, there might not be any underlying unification + /// error (e.g. one of the row has an argument, and the other does not). This is why the + /// underlying unification error is optional, as opposed to record rows. + EnumRowMismatch { + id: LocIdent, + /// The underlying unification error that caused the mismatch. + cause: Option>>, + }, + /// A [row constraint][super::RowConstrs] was violated. + RecordRowConflict(UnifRecordRow<'ast>), + /// A [row constraint][super::RowConstrs] was violated. + EnumRowConflict(UnifEnumRow<'ast>), + /// Tried to unify a type constant with another different type. + WithConst { + var_kind: VarKindDiscriminant, + expected_const_id: VarId, + inferred: UnifType<'ast>, + }, + /// Tried to unify two distinct type constants. + ConstMismatch { + var_kind: VarKindDiscriminant, + expected_const_id: usize, + inferred_const_id: usize, + }, + /// An unbound type variable was referenced. + UnboundTypeVariable(LocIdent), + /// Tried to unify a constant with a unification variable with a strictly lower level. + VarLevelMismatch { + constant_id: VarId, + var_kind: VarKindDiscriminant, + }, +} + +impl<'ast> RowUnifError<'ast> { + /// Convert a row unification error to a unification error. + /// + /// There is a hierarchy between error types, from the most local/specific to the most + /// high-level: + /// - [`RowUnifError<'ast>`] + /// - [`UnifError<'ast>`] + /// - [`crate::error::TypecheckError`] + /// + /// Each level usually adds information (such as types or positions) and group different + /// specific errors into most general ones. + pub fn into_unif_err( + self, + expected: UnifType<'ast>, + inferred: UnifType<'ast>, + ) -> UnifError<'ast> { + match self { + RowUnifError::MissingRow(id) => UnifError::MissingRow { + id, + expected, + inferred, + }, + RowUnifError::MissingDynTail => UnifError::MissingDynTail { expected, inferred }, + RowUnifError::ExtraRow(id) => UnifError::ExtraRow { + id, + expected, + inferred, + }, + RowUnifError::ExtraDynTail => UnifError::ExtraDynTail { expected, inferred }, + RowUnifError::RecordRowMismatch { id, cause } => UnifError::RecordRowMismatch { + id, + expected, + inferred, + cause, + }, + RowUnifError::EnumRowMismatch { id, cause } => UnifError::EnumRowMismatch { + id, + expected, + inferred, + cause, + }, + RowUnifError::RecordRowConflict(row) => UnifError::RecordRowConflict { + row, + expected, + inferred, + }, + RowUnifError::EnumRowConflict(row) => UnifError::EnumRowConflict { + row, + expected, + inferred, + }, + RowUnifError::WithConst { + var_kind, + expected_const_id, + inferred, + } => UnifError::WithConst { + var_kind, + expected_const_id, + inferred, + }, + RowUnifError::ConstMismatch { + var_kind, + expected_const_id, + inferred_const_id, + } => UnifError::ConstMismatch { + var_kind, + expected_const_id, + inferred_const_id, + }, + RowUnifError::UnboundTypeVariable(id) => UnifError::UnboundTypeVariable(id), + RowUnifError::VarLevelMismatch { + constant_id, + var_kind, + } => UnifError::VarLevelMismatch { + constant_id, + var_kind, + }, + } + } +} + +/// Error during the unification of two types. +/// +/// In each variant, `expected` and `inferred` refers to the two types that failed to unify. +#[derive(Debug, PartialEq)] +pub enum UnifError<'ast> { + /// Tried to unify two incompatible types. + TypeMismatch { + expected: UnifType<'ast>, + inferred: UnifType<'ast>, + }, + /// There are two incompatible definitions for the same row. + RecordRowMismatch { + id: LocIdent, + expected: UnifType<'ast>, + inferred: UnifType<'ast>, + /// The uderlying unification error (`expected` and `inferred` should be the record types + /// that failed to unify, while this error is the specific cause of the mismatch for the + /// `id` row) + cause: Box>, + }, + /// There are two incompatible definitions for the same row. + /// + /// Because enum rows have an optional argument, there might not be any underlying unification + /// error (e.g. one of the row has an argument, and the other does not). This is why the + /// underlying unification error is optional, as opposed to record rows. + EnumRowMismatch { + id: LocIdent, + expected: UnifType<'ast>, + inferred: UnifType<'ast>, + cause: Option>>, + }, + /// Tried to unify two distinct type constants. + ConstMismatch { + var_kind: VarKindDiscriminant, + expected_const_id: VarId, + inferred_const_id: VarId, + }, + /// Tried to unify two rows, but a row from the expected type was absent from the inferred type. + MissingRow { + id: LocIdent, + expected: UnifType<'ast>, + inferred: UnifType<'ast>, + }, + /// Tried to unify two rows, but a row from the inferred type was absent from the expected type. + ExtraRow { + id: LocIdent, + expected: UnifType<'ast>, + inferred: UnifType<'ast>, + }, + /// Tried to unify two rows, but the `Dyn` tail of the expected type was absent from the + /// inferred type. + MissingDynTail { + expected: UnifType<'ast>, + inferred: UnifType<'ast>, + }, + /// Tried to unify two rows, but the `Dyn` tail of the RHS was absent from the LHS. + ExtraDynTail { + expected: UnifType<'ast>, + inferred: UnifType<'ast>, + }, + /// Tried to unify a unification variable with a row type violating the [row + /// constraints][super::RowConstrs] of the variable. + RecordRowConflict { + /// The row that conflicts with an existing one. + row: UnifRecordRow<'ast>, + /// The original expected type that led to the row conflict (when unified with the inferred + /// type). + expected: UnifType<'ast>, + /// The original inferred type that led to the row conflict (when unified with the expected + /// type). + inferred: UnifType<'ast>, + }, + /// Tried to unify a unification variable with a row type violating the [row + /// constraints][super::RowConstrs] of the variable. + EnumRowConflict { + /// The row that conflicts with an existing one. + row: UnifEnumRow<'ast>, + /// The original expected type that led to the row conflict (when unified with the inferred + /// type). + expected: UnifType<'ast>, + /// The original inferred type that led to the row conflict (when unified with the expected + /// type). + inferred: UnifType<'ast>, + }, + /// Tried to unify a type constant with another different type. + WithConst { + var_kind: VarKindDiscriminant, + expected_const_id: VarId, + inferred: UnifType<'ast>, + }, + /// An unbound type variable was referenced. + UnboundTypeVariable(LocIdent), + /// An error occurred when unifying the domains of two arrows. + DomainMismatch { + expected: UnifType<'ast>, + inferred: UnifType<'ast>, + cause: Box>, + }, + /// An error occurred when unifying the codomains of two arrows. + CodomainMismatch { + expected: UnifType<'ast>, + inferred: UnifType<'ast>, + cause: Box>, + }, + /// Tried to unify a constant with a unification variable with a strictly lower level. + VarLevelMismatch { + constant_id: VarId, + var_kind: VarKindDiscriminant, + }, +} + +impl<'ast> UnifError<'ast> { + /// Convert a unification error to a typechecking error. There is a hierarchy between error + /// types, from the most local/specific to the most high-level: + /// - [`RowUnifError<'ast>`] + /// - [`UnifError<'ast>`] + /// - [`crate::error::TypecheckError`] + /// + /// Each level usually adds information (such as types or positions) and group different + /// specific errors into most general ones. + /// + /// # Parameters + /// + /// - `state`: the state of unification. Used to access the unification table, and the original + /// names of of unification variable or type constant. + /// - `pos_opt`: the position span of the expression that failed to typecheck. + pub fn into_typecheck_err(self, state: &State<'ast, '_>, pos_opt: TermPos) -> TypecheckError { + let mut names = reporting::NameReg::new(state.names.clone()); + self.into_typecheck_err_(state, &mut names, pos_opt) + } + + /// Convert a unification error to a typechecking error, given a populated [name + /// registry][reporting::NameReg]. Actual meat of the implementation of + /// [`Self::into_typecheck_err`]. + fn into_typecheck_err_( + self, + state: &State<'ast, '_>, + names_reg: &mut reporting::NameReg, + pos: TermPos, + ) -> TypecheckError { + match self { + UnifError::TypeMismatch { expected, inferred } => TypecheckError::TypeMismatch { + expected: expected + .to_type(state.ast_alloc, names_reg, state.table) + .to_mainline(), + inferred: inferred + .to_type(state.ast_alloc, names_reg, state.table) + .to_mainline(), + pos, + }, + UnifError::RecordRowMismatch { + id, + expected, + inferred, + cause, + } => TypecheckError::RecordRowMismatch { + id, + expected: expected + .to_type(state.ast_alloc, names_reg, state.table) + .to_mainline(), + inferred: inferred + .to_type(state.ast_alloc, names_reg, state.table) + .to_mainline(), + cause: Box::new((*cause).into_typecheck_err_(state, names_reg, TermPos::None)), + pos, + }, + UnifError::EnumRowMismatch { + id, + expected, + inferred, + cause, + } => TypecheckError::EnumRowMismatch { + id, + expected: expected + .to_type(state.ast_alloc, names_reg, state.table) + .to_mainline(), + inferred: inferred + .to_type(state.ast_alloc, names_reg, state.table) + .to_mainline(), + cause: cause.map(|err| { + Box::new((*err).into_typecheck_err_(state, names_reg, TermPos::None)) + }), + pos, + }, + // TODO: for now, failure to unify with a type constant causes the same error as a + // usual type mismatch. It could be nice to have a specific error message in the + // future. + UnifError::ConstMismatch { + var_kind, + expected_const_id, + inferred_const_id, + } => TypecheckError::TypeMismatch { + expected: UnifType::from_constant_of_kind(expected_const_id, var_kind) + .to_type(state.ast_alloc, names_reg, state.table) + .to_mainline(), + inferred: UnifType::from_constant_of_kind(inferred_const_id, var_kind) + .to_type(state.ast_alloc, names_reg, state.table) + .to_mainline(), + pos, + }, + UnifError::WithConst { + var_kind: VarKindDiscriminant::Type, + expected_const_id, + inferred, + } => TypecheckError::TypeMismatch { + expected: UnifType::Constant(expected_const_id) + .to_type(state.ast_alloc, names_reg, state.table) + .to_mainline(), + inferred: inferred + .to_type(state.ast_alloc, names_reg, state.table) + .to_mainline(), + pos, + }, + UnifError::WithConst { + var_kind, + expected_const_id, + inferred, + } => TypecheckError::ForallParametricityViolation { + kind: var_kind, + tail: UnifType::from_constant_of_kind(expected_const_id, var_kind) + .to_type(state.ast_alloc, names_reg, state.table) + .to_mainline(), + violating_type: inferred + .to_type(state.ast_alloc, names_reg, state.table) + .to_mainline(), + pos, + }, + UnifError::MissingRow { + id, + expected, + inferred, + } => TypecheckError::MissingRow { + id, + expected: expected + .to_type(state.ast_alloc, names_reg, state.table) + .to_mainline(), + inferred: inferred + .to_type(state.ast_alloc, names_reg, state.table) + .to_mainline(), + pos, + }, + UnifError::MissingDynTail { expected, inferred } => TypecheckError::MissingDynTail { + expected: expected + .to_type(state.ast_alloc, names_reg, state.table) + .to_mainline(), + inferred: inferred + .to_type(state.ast_alloc, names_reg, state.table) + .to_mainline(), + pos, + }, + UnifError::ExtraRow { + id, + expected, + inferred, + } => TypecheckError::ExtraRow { + id, + expected: expected + .to_type(state.ast_alloc, names_reg, state.table) + .to_mainline(), + inferred: inferred + .to_type(state.ast_alloc, names_reg, state.table) + .to_mainline(), + pos, + }, + UnifError::ExtraDynTail { expected, inferred } => TypecheckError::ExtraDynTail { + expected: expected + .to_type(state.ast_alloc, names_reg, state.table) + .to_mainline(), + inferred: inferred + .to_type(state.ast_alloc, names_reg, state.table) + .to_mainline(), + pos, + }, + #[allow(unused_variables)] + #[allow(unreachable_code)] + UnifError::RecordRowConflict { + row: _, + expected, + inferred, + } => TypecheckError::RecordRowConflict { + // We won't convert to mainline when we'll plug-in the migrated typechecker, so it doesn't make sense to try to fix this line now - the error will go away. + row: todo!(), //row.to_type(&state.ast_alloc, names_reg, state.table), + expected: expected + .to_type(state.ast_alloc, names_reg, state.table) + .to_mainline(), + inferred: inferred + .to_type(state.ast_alloc, names_reg, state.table) + .to_mainline(), + pos, + }, + #[allow(unused_variables)] + #[allow(unreachable_code)] + UnifError::EnumRowConflict { + row: _, + expected, + inferred, + } => TypecheckError::EnumRowConflict { + // We won't convert to mainline when we'll plug-in the migrated typechecker, so it doesn't make sense to try to fix this line now - the error will go away. + row: todo!(), + expected: expected + .to_type(state.ast_alloc, names_reg, state.table) + .to_mainline(), + inferred: inferred + .to_type(state.ast_alloc, names_reg, state.table) + .to_mainline(), + pos, + }, + UnifError::UnboundTypeVariable(ident) => TypecheckError::UnboundTypeVariable(ident), + err @ UnifError::CodomainMismatch { .. } | err @ UnifError::DomainMismatch { .. } => { + let (expected, inferred, type_path, err_final) = err.into_type_path().unwrap(); + TypecheckError::ArrowTypeMismatch { + expected: expected + .to_type(state.ast_alloc, names_reg, state.table) + .to_mainline(), + inferred: inferred + .to_type(state.ast_alloc, names_reg, state.table) + .to_mainline(), + type_path, + cause: Box::new(err_final.into_typecheck_err_(state, names_reg, TermPos::None)), + pos, + } + } + UnifError::VarLevelMismatch { + constant_id, + var_kind, + } => TypecheckError::VarLevelMismatch { + type_var: names_reg.gen_cst_name(constant_id, var_kind).into(), + pos, + }, + } + } + + /// Transform a `(Co)DomainMismatch` into a type path and other data. + /// + /// `(Co)DomainMismatch` can be nested: when unifying `Num -> Num -> Num` with `Num -> Bool -> + /// Num`, the resulting error is of the form `CodomainMismatch(.., DomainMismatch(.., + /// TypeMismatch(..)))`. The heading sequence of `(Co)DomainMismatch` is better represented as + /// a type path, here `[Codomain, Domain]`, while the last error of the chain -- which thus + /// cannot be a `(Co)DomainMismatch` -- is the actual cause of the unification failure. + /// + /// This function breaks down a `(Co)Domain` mismatch into a more convenient representation. + /// + /// # Return + /// + /// Return `None` if `self` is not a `DomainMismatch` nor a `CodomainMismatch`. + /// + /// Otherwise, return the following tuple: + /// - the original expected type. + /// - the original inferred type. + /// - a type path pointing at the subtypes which failed to be unified. + /// - the final error, which is the actual cause of that failure. + pub fn into_type_path(self) -> Option<(UnifType<'ast>, UnifType<'ast>, ty_path::Path, Self)> { + let mut curr: Self = self; + let mut path = ty_path::Path::new(); + // The original expected and inferred type. They are just updated once, in the first + // iteration of the loop below. + let mut utys: Option<(UnifType<'ast>, UnifType<'ast>)> = None; + + loop { + match curr { + UnifError::DomainMismatch { + expected: + expected @ UnifType::Concrete { + typ: TypeF::Arrow(_, _), + .. + }, + inferred: + inferred @ UnifType::Concrete { + typ: TypeF::Arrow(_, _), + .. + }, + cause: mismatch, + } => { + utys = utys.or(Some((expected, inferred))); + path.push(ty_path::Elem::Domain); + curr = *mismatch; + } + UnifError::DomainMismatch { .. } => panic!( + "typechecking::to_type_path(): domain mismatch error on a non arrow type" + ), + UnifError::CodomainMismatch { + expected: + expected @ UnifType::Concrete { + typ: TypeF::Arrow(_, _), + .. + }, + inferred: + inferred @ UnifType::Concrete { + typ: TypeF::Arrow(_, _), + .. + }, + cause: mismatch, + } => { + utys = utys.or(Some((expected, inferred))); + path.push(ty_path::Elem::Codomain); + curr = *mismatch; + } + UnifError::CodomainMismatch { .. } => panic!( + "typechecking::to_type_path(): codomain mismatch error on a non arrow type" + ), + // utys equals to `None` iff we did not even enter the case above once, i.e. if + // `self` was indeed neither a `DomainMismatch` nor a `CodomainMismatch` + _ => break utys.map(|(expected, inferred)| (expected, inferred, path, curr)), + } + } + } +} diff --git a/core/src/bytecode/typecheck/mk_uniftype.rs b/core/src/bytecode/typecheck/mk_uniftype.rs new file mode 100644 index 0000000000..a315f4cb33 --- /dev/null +++ b/core/src/bytecode/typecheck/mk_uniftype.rs @@ -0,0 +1,167 @@ +//! Helpers for building `TypeWrapper`s. +use super::{UnifType, VarLevelsData}; +use crate::typ::{DictTypeFlavour, TypeF}; + +/// Multi-ary arrow constructor for types implementing `Into`. +#[macro_export] +macro_rules! mk_buty_arrow { + ($left:expr, $right:expr) => { + $crate::bytecode::typecheck::UnifType::concrete( + $crate::typ::TypeF::Arrow( + Box::new($crate::bytecode::typecheck::UnifType::from($left)), + Box::new($crate::bytecode::typecheck::UnifType::from($right)) + ) + ) + }; + ( $fst:expr, $snd:expr , $( $types:expr ),+ ) => { + $crate::mk_buty_arrow!($fst, $crate::mk_buty_arrow!($snd, $( $types ),+)) + }; +} + +/// Multi-ary enum row constructor for types implementing `Into`. +/// `mk_buty_enum_row!(id1, .., idn; tail)` correspond to `[| 'id1, .., 'idn; tail |]. With the +/// addition of algebraic data types (enum variants), individual rows can also take an additional +/// type parameter, specified as a tuple: for example, `mk_buty_enum_row!(id1, (id2, ty2); tail)` +/// is `[| 'id1, 'id2 ty2; tail |]`. +#[macro_export] +macro_rules! mk_buty_enum_row { + () => { + $crate::bytecode::typecheck::UnifEnumRows::Concrete { + erows: $crate::typ::EnumRowsF::Empty, + var_levels_data: $crate::bytecode::typecheck::VarLevelsData::new_no_uvars(), + } + }; + (; $tail:expr) => { + $crate::bytecode::typecheck::UnifEnumRows::from($tail) + }; + ( ($id:expr, $ty:expr) $(, $rest:tt )* $(; $tail:expr)? ) => { + $crate::bytecode::typecheck::UnifEnumRows::concrete( + $crate::typ::EnumRowsF::Extend { + row: $crate::typ::EnumRowF { + id: $crate::identifier::LocIdent::from($id), + typ: Some(Box::new($ty.into())), + }, + tail: Box::new($crate::mk_buty_enum_row!($( $rest ),* $(; $tail)?)) + } + ) + }; + ( $id:expr $(, $rest:tt )* $(; $tail:expr)? ) => { + $crate::bytecode::typecheck::UnifEnumRows::concrete( + $crate::typ::EnumRowsF::Extend { + row: $crate::typ::EnumRowF { + id: $crate::identifier::LocIdent::from($id), + typ: None, + }, + tail: Box::new($crate::mk_buty_enum_row!($( $rest ),* $(; $tail)?)) + } + ) + }; +} + +/// Multi-ary record row constructor for types implementing `Into`. `mk_buty_row!((id1, +/// ty1), .., (idn, tyn); tail)` correspond to `{id1: ty1, .., idn: tyn; tail}`. The tail can be +/// omitted, in which case the empty row is uses as a tail instead. +#[macro_export] +macro_rules! mk_buty_record_row { + () => { + $crate::bytecode::typecheck::UnifRecordRows::Concrete { + rrows: $crate::typ::RecordRowsF::Empty, + var_levels_data: $crate::bytecode::typecheck::VarLevelsData::new_no_uvars() + } + }; + (; $tail:expr) => { + $crate::bytecode::typecheck::UnifRecordRows::from($tail) + }; + (($id:expr, $ty:expr) $(,($ids:expr, $tys:expr))* $(; $tail:expr)?) => { + $crate::bytecode::typecheck::UnifRecordRows::concrete( + $crate::typ::RecordRowsF::Extend { + row: $crate::typ::RecordRowF { + id: $crate::identifier::LocIdent::from($id), + typ: Box::new($ty.into()), + }, + tail: Box::new($crate::mk_buty_record_row!($(($ids, $tys)),* $(; $tail)?)), + } + ) + }; +} + +/// Wrapper around `mk_buty_enum_row!` to build an enum type from an enum row. +#[macro_export] +macro_rules! mk_buty_enum { + ($( $args:tt )*) => { + $crate::bytecode::typecheck::UnifType::concrete( + $crate::typ::TypeF::Enum( + $crate::mk_buty_enum_row!($( $args )*) + ) + ) + }; +} + +/// Wrapper around `mk_buty_record!` to build a record type from a record row. +#[macro_export] +macro_rules! mk_buty_record { + ($(($ids:expr, $tys:expr)),* $(; $tail:expr)?) => { + $crate::bytecode::typecheck::UnifType::concrete( + $crate::typ::TypeF::Record( + $crate::mk_buty_record_row!($(($ids, $tys)),* $(; $tail)?) + ) + ) + }; +} + +/// Generate an helper function to build a 0-ary type. +macro_rules! generate_builder { + ($fun:ident, $var:ident) => { + pub fn $fun<'ast>() -> UnifType<'ast> { + UnifType::Concrete { + typ: TypeF::$var, + var_levels_data: VarLevelsData::new_no_uvars(), + } + } + }; +} + +pub fn dict<'ast, T>(ty: T) -> UnifType<'ast> +where + T: Into>, +{ + UnifType::concrete(TypeF::Dict { + type_fields: Box::new(ty.into()), + flavour: DictTypeFlavour::Type, + }) +} + +pub fn array<'ast, T>(ty: T) -> UnifType<'ast> +where + T: Into>, +{ + UnifType::concrete(TypeF::Array(Box::new(ty.into()))) +} + +pub fn arrow<'ast>( + domain: impl Into>, + codomain: impl Into>, +) -> UnifType<'ast> { + UnifType::concrete(TypeF::Arrow( + Box::new(domain.into()), + Box::new(codomain.into()), + )) +} + +pub fn nary_arrow<'ast, I, U>(args: I, codomain: U) -> UnifType<'ast> +where + U: Into>, + I: IntoIterator>, IntoIter: std::iter::DoubleEndedIterator>, +{ + args.into_iter() + .rev() + .fold(codomain.into(), |acc, ty| mk_buty_arrow!(ty.into(), acc)) +} + +// dyn is a reserved keyword +generate_builder!(dynamic, Dyn); +generate_builder!(str, String); +generate_builder!(num, Number); +generate_builder!(bool, Bool); +generate_builder!(sym, Symbol); +generate_builder!(foreign_id, ForeignId); diff --git a/core/src/bytecode/typecheck/mod.rs b/core/src/bytecode/typecheck/mod.rs new file mode 100644 index 0000000000..2c99895532 --- /dev/null +++ b/core/src/bytecode/typecheck/mod.rs @@ -0,0 +1,2979 @@ +//! Typechecking and type inference. +//! +//! Nickel uses a mix of a bidirectional typechecking algorithm, together with standard +//! unification-based type inference. Nickel is gradually typed, and dynamic typing is the default. +//! Static typechecking is triggered by a type annotation. +//! +//! # Modes +//! +//! The typechecking algorithm runs in two separate modes, corresponding to static and dynamic +//! typing: +//! +//! - **enforce** corresponds to traditional typechecking in a statically typed language. This +//! happens inside a statically typed block. Such blocks are introduced by the type ascription +//! operator `:`, as in `1 + 1 : Number` or `let f : Number -> Number = fun x => x + 1 in ..`. +//! Enforce mode is implemented by [`type_check`] and variants. +//! - **walk** doesn't enforce any typing but traverses the AST looking for typed blocks to +//! typecheck. Walk mode also stores the annotations of bound identifiers in the environment. This +//! is implemented by the `walk` function. +//! +//! The algorithm usually starts in walk mode, although this can be configured. A typed block +//! (an expression annotated with a type) switches to enforce mode, and is switched back to walk +//! mode when entering an expression annotated with a contract. Type and contract annotations thus +//! serve as a switch for the typechecking mode. +//! +//! Note that the static typing part (enforce mode) is based on the bidirectional typing framework, +//! which defines two different modes. Thus, the enforce mode is itself divided again into +//! **checking** mode and **inference** mode. +//! +//! # Type inference +//! +//! Type inference is done via a form of bidirectional typechecking coupled with unification, in the +//! same spirit as GHC (Haskell), although the type system of Nickel is simpler. The type of +//! un-annotated let-bound expressions (the type of `bound_exp` in `let x = bound_exp in body`) is +//! inferred in enforce mode, but it is never implicitly generalized. For example, the following +//! program is rejected: +//! +//! ```nickel +//! # Rejected +//! (let id = fun x => x in std.seq (id "a") (id 5)) : Number +//! ``` +//! +//! Indeed, `id` is given the type `_a -> _a`, where `_a` is a unification variable, but is not +//! generalized to `forall a. a -> a`. At the first call site, `_a` is unified with `String`, and at +//! the second call site the typechecker complains that `5` is not of type `String`. +//! +//! This restriction is on purpose, as generalization is not trivial to implement efficiently and +//! more importantly can interact with other components of the type system and type inference. If +//! polymorphism is required, the user can simply add annotation: +//! +//! ```nickel +//! # Accepted +//! (let id : forall a. a -> a = fun x => x in std.seq (id "a") (id 5)) : Num +//! ``` +//! +//! In walk mode, the type of let-bound expressions is inferred in a shallow way (see +//! [`apparent_type`]). +use super::ast::{ + pattern::bindings::Bindings as _, record::FieldDef, typ::*, Annotation, Ast, AstAlloc, + MatchBranch, Node, StringChunk, TryConvert, +}; + +use crate::{ + cache::ImportResolver, + environment::Environment, + error::TypecheckError, + identifier::{Ident, LocIdent}, + mk_buty_arrow, mk_buty_enum, mk_buty_record, mk_buty_record_row, stdlib as nickel_stdlib, + traverse::TraverseAlloc, + typ::{EnumRowsIterator, RecordRowsIterator, VarKind, VarKindDiscriminant}, +}; + +use std::{ + cmp::max, + collections::{HashMap, HashSet}, + num::NonZeroU16, +}; + +pub mod error; +pub mod operation; +mod pattern; +pub mod reporting; +#[macro_use] +pub mod mk_uniftype; +pub mod eq; +pub mod record; +pub mod subtyping; +pub mod unif; + +use error::*; +use operation::PrimOpType; +use pattern::{PatternTypeData, PatternTypes}; +use record::Resolve; +use unif::*; + +use self::subtyping::SubsumedBy; + +/// The max depth parameter used to limit the work performed when inferring the type of the stdlib. +const INFER_RECORD_MAX_DEPTH: u8 = 4; + +/// The typechecker has two modes, one for statically typed code and one for dynamically type code. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TypecheckMode { + /// In `Walk` mode, the typechecker traverses the AST looking for typed blocks. + Walk, + /// In `Enforce` mode, the typechecker checks types. + Enforce, +} + +/// The typing environment. +pub type TypeEnv<'ast> = Environment>; + +/// A term environment defined as a mapping from identifiers to a tuple of a term and an +/// environment (i.e. a closure). Used to compute contract equality. +#[derive(PartialEq, Clone, Debug)] +pub struct TermEnv<'ast>(pub Environment, TermEnv<'ast>)>); + +impl TermEnv<'_> { + pub fn new() -> Self { + TermEnv(Environment::new()) + } +} + +impl Default for TermEnv<'_> { + fn default() -> Self { + Self::new() + } +} + +impl<'ast> std::iter::FromIterator<(Ident, (Ast<'ast>, TermEnv<'ast>))> for TermEnv<'ast> { + fn from_iter(iter: T) -> Self + where + T: IntoIterator, TermEnv<'ast>))>, + { + TermEnv(Environment::, TermEnv<'ast>)>::from_iter( + iter, + )) + } +} + +/// Mapping from wildcard IDs to inferred types +pub type Wildcards<'ast> = Vec>; + +/// A table mapping type variables and their kind to names. Used for reporting. +pub type NameTable = HashMap<(VarId, VarKindDiscriminant), Ident>; + +/// A unifiable record row. +pub type UnifRecordRow<'ast> = RecordRowF>>; +pub type UnifRecordRowsUnr<'ast> = RecordRowsF>, Box>>; + +/// Unifiable record rows. Same shape as [`crate::bytecode::ast::typ::RecordRows`], but where each +/// type is unifiable, and each tail may be a unification variable (or a constant). +#[derive(Clone, PartialEq, Debug)] +pub enum UnifRecordRows<'ast> { + Concrete { + rrows: UnifRecordRowsUnr<'ast>, + /// Additional metadata related to unification variable levels update. See [VarLevelsData]. + var_levels_data: VarLevelsData, + }, + Constant(VarId), + /// A unification variable. + UnifVar { + /// The unique identifier of this variable in the unification table. + id: VarId, + /// The initial variable level at which the variable was created. See + /// [UnifType::UnifVar]. + init_level: VarLevel, + }, +} + +pub type UnifEnumRow<'ast> = EnumRowF>>; +pub type UnifEnumRowsUnr<'ast> = EnumRowsF>, Box>>; + +/// Unifiable enum rows. Same shape as [`crate::typ::EnumRows`] but where each tail may be a +/// unification variable (or a constant). +#[derive(Clone, PartialEq, Debug)] +pub enum UnifEnumRows<'ast> { + Concrete { + erows: UnifEnumRowsUnr<'ast>, + /// Additional metadata related to unification variable levels update. See [VarLevelsData]. + var_levels_data: VarLevelsData, + }, + Constant(VarId), + UnifVar { + /// The unique identifier of this variable in the unification table. + id: VarId, + /// The initial variable level at which the variable was created. See + /// [UnifType::UnifVar]. + init_level: VarLevel, + }, +} + +/// Metadata attached to unification types, which are used to delay and optimize potentially costly +/// type traversals when updating the levels of the free unification variables of a type. Based on +/// Didier Remy's algorithm for the OCaml typechecker, see [Efficient and insightful +/// generalization](http://web.archive.org/web/20230525023637/https://okmij.org/ftp/ML/generalization.html). +/// +/// When unifying a variable with a composite type, we have to update the levels of all the free +/// unification variables contained in that type, which naively incurs a full traversal of the type. +/// The idea behind Didier Remy's algorithm is to delay such traversals, and use the values of +/// [VarLevelsData] to group traversals and avoid unneeded ones. This make variable unification run +/// in constant time again, as long as we don't unify with a rigid type variable. +/// +/// Variable levels data might correspond to different variable kinds (type, record rows and enum +/// rows) depending on where they appear (in a [UnifType<'ast>], [UnifRecordRows<'ast>] or [UnifEnumRows<'ast>]) +#[derive(Clone, Copy, Eq, PartialEq, Debug)] +pub struct VarLevelsData { + /// Upper bound on the variable levels of free unification variables contained in this type. + upper_bound: VarLevel, + /// Pending variable level update, which must satisfy `pending_level <= old_level`. + pending: Option, +} + +impl Default for VarLevelsData { + fn default() -> Self { + VarLevelsData::new_from_bound(VarLevel::MAX_LEVEL) + } +} + +impl VarLevelsData { + pub fn new() -> Self { + Self::default() + } + + /// Create new variable levels data with the given upper bound and no pending level update. + pub fn new_from_bound(upper_bound: VarLevel) -> Self { + VarLevelsData { + upper_bound, + pending: None, + } + } + + /// Create new variable levels data with an upper bound which indicates that there is no + /// unification variable in the attached type and no pending level update. + pub fn new_no_uvars() -> Self { + Self::new_from_bound(VarLevel::NO_VAR) + } +} + +/// Unification types and variants that store an upper bound on the level of the unification +/// variables they contain, or for which an upper bound can be computed quickly (in constant time). +trait VarLevelUpperBound { + // Return an upper bound on the level of the unification variables contained in `self`. + // Depending on the implementer, the level might refer to different kind of unification + // variables (type, record rows or enum rows). + fn var_level_upper_bound(&self) -> VarLevel; +} + +impl VarLevelUpperBound for UnifType<'_> { + fn var_level_upper_bound(&self) -> VarLevel { + match self { + UnifType::Concrete { + var_levels_data, .. + } => var_levels_data.upper_bound, + UnifType::UnifVar { init_level, .. } => *init_level, + UnifType::Constant(_) => VarLevel::NO_VAR, + } + } +} + +impl VarLevelUpperBound for UnifTypeUnr<'_> { + fn var_level_upper_bound(&self) -> VarLevel { + match self { + TypeF::Dyn + | TypeF::Bool + | TypeF::Number + | TypeF::String + | TypeF::ForeignId + | TypeF::Symbol => VarLevel::NO_VAR, + TypeF::Arrow(domain, codomain) => max( + domain.var_level_upper_bound(), + codomain.var_level_upper_bound(), + ), + TypeF::Forall { body, .. } => body.var_level_upper_bound(), + TypeF::Enum(erows) => erows.var_level_upper_bound(), + TypeF::Record(rrows) => rrows.var_level_upper_bound(), + TypeF::Dict { type_fields, .. } => type_fields.var_level_upper_bound(), + TypeF::Array(ty_elts) => ty_elts.var_level_upper_bound(), + TypeF::Wildcard(_) | TypeF::Var(_) | TypeF::Contract(_) => VarLevel::NO_VAR, + } + } +} + +impl VarLevelUpperBound for UnifEnumRows<'_> { + fn var_level_upper_bound(&self) -> VarLevel { + match self { + UnifEnumRows::Concrete { + var_levels_data, .. + } => var_levels_data.upper_bound, + UnifEnumRows::UnifVar { init_level, .. } => *init_level, + UnifEnumRows::Constant(_) => VarLevel::NO_VAR, + } + } +} + +impl VarLevelUpperBound for UnifEnumRowsUnr<'_> { + fn var_level_upper_bound(&self) -> VarLevel { + match self { + // A var that hasn't be instantiated yet isn't a unification variable + EnumRowsF::Empty | EnumRowsF::TailVar(_) => VarLevel::NO_VAR, + EnumRowsF::Extend { row: _, tail } => tail.var_level_upper_bound(), + } + } +} + +impl VarLevelUpperBound for UnifRecordRows<'_> { + fn var_level_upper_bound(&self) -> VarLevel { + match self { + UnifRecordRows::Concrete { + var_levels_data, .. + } => var_levels_data.upper_bound, + UnifRecordRows::UnifVar { init_level, .. } => *init_level, + UnifRecordRows::Constant(_) => VarLevel::NO_VAR, + } + } +} + +impl VarLevelUpperBound for UnifRecordRowsUnr<'_> { + fn var_level_upper_bound(&self) -> VarLevel { + match self { + // A var that hasn't be instantiated yet isn't a unification variable + RecordRowsF::Empty | RecordRowsF::TailVar(_) | RecordRowsF::TailDyn => VarLevel::NO_VAR, + RecordRowsF::Extend { + row: RecordRowF { id: _, typ }, + tail, + } => max(tail.var_level_upper_bound(), typ.var_level_upper_bound()), + } + } +} + +/// The types on which the unification algorithm operates, which may be either a concrete type, a +/// type constant or a unification variable. +/// +/// Contracts store an additional term environment for contract equality checking. +/// +/// # Invariants +/// +/// **Important**: the following invariant must always be satisfied: for any free unification +/// variable[^free-unif-var] part of a concrete unification type, the level of this variable must +/// be smaller or equal to `var_levels_data.upper_bound`. Otherwise, the typechecking algorithm +/// might not be correct. Be careful when creating new concrete [UnifType<'ast>] or [UnifType] +/// values manually. All `from` and `try_from` implementations, the `concrete` method as well as +/// builders from the [mk_uniftype] module all correctly compute the upper bound (given that the +/// upper bounds of the subcomponents are correct). +/// +/// The default value for `var_levels_data`, although it can incur more work, is at least always +/// correct (by setting `upper_bound = VarLevel::MAX`). +/// +/// [^free-unif-var]: A free unification variable is a unification variable that isn't assigned to +/// any type yet, i.e. verifying `uty.root_type(..) == uty` (adapt with the corresponding +/// `root_xxx` method for rows). +#[derive(Clone, PartialEq, Debug)] +pub enum UnifType<'ast> { + /// A concrete type (like `Number` or `String -> String`). Note that subcomponents of a + /// concrete type can still be free unification variables, such as the type `a -> a`, but the + /// top-level node is a concrete type constructor. + Concrete { + typ: UnifTypeUnr<'ast>, + /// Additional metadata related to unification variable levels update. See [VarLevelsData]. + var_levels_data: VarLevelsData, + }, + /// A rigid type constant which cannot be unified with anything but itself. + Constant(VarId), + /// A unification variable. + UnifVar { + /// The unique identifier of this variable in the unification table. + id: VarId, + /// An upper bound of this variable level, which usually correspond to the initial level at + /// which the variable was allocated, although this value might be bumped for some + /// variables by level updates. + /// + /// In a model where unification variables directly store a mutable level attribute, we + /// wouldn't need to duplicate this level information both here at the variable level and + /// in the unification table. `init_level` is used to compute upper bounds without having + /// to thread the unification table around (in the `from`/`try_from` implementation for + /// unification types, typically). + /// + /// Note that the actual level of this variable is stored in the unification table, which + /// is the source of truth. The actual level must satisfy `current_level <= init_level` + /// (the level of a variable can only decrease with time). + init_level: VarLevel, + }, +} + +type UnifTypeUnr<'ast> = TypeF< + Box>, + UnifRecordRows<'ast>, + UnifEnumRows<'ast>, + (&'ast Ast<'ast>, TermEnv<'ast>), +>; + +impl<'ast> UnifType<'ast> { + /// Create a concrete generic unification type. Compute the variable levels data from the + /// subcomponents. + pub fn concrete(typ: UnifTypeUnr<'ast>) -> Self { + let upper_bound = typ.var_level_upper_bound(); + + UnifType::Concrete { + typ, + var_levels_data: VarLevelsData::new_from_bound(upper_bound), + } + } + + /// Create a [`UnifType<'ast>`] from a [`Type`]. + pub fn from_type(ty: Type<'ast>, env: &TermEnv<'ast>) -> Self { + UnifType::concrete(ty.typ.map( + |ty| Box::new(UnifType::from_type(ty.clone(), env)), + |rrows| UnifRecordRows::from_record_rows(rrows, env), + |erows| UnifEnumRows::from_enum_rows(erows, env), + |term| (term, env.clone()), + )) + } + + /// Create a [`UnifType<'ast>`] from an [`ApparentType`]. As for [`UnifType::from_type`], this + /// function requires the current term environment. + pub fn from_apparent_type(at: ApparentType<'ast>, env: &TermEnv<'ast>) -> Self { + match at { + ApparentType::Annotated(ty) if has_wildcards(&ty) => UnifType::concrete(TypeF::Dyn), + ApparentType::Annotated(ty) + | ApparentType::Inferred(ty) + | ApparentType::Approximated(ty) => UnifType::from_type(ty, env), + ApparentType::FromEnv(uty) => uty, + } + } + + pub fn from_constant_of_kind(c: usize, k: VarKindDiscriminant) -> Self { + match k { + VarKindDiscriminant::Type => UnifType::Constant(c), + VarKindDiscriminant::EnumRows => UnifType::Concrete { + typ: TypeF::Enum(UnifEnumRows::Constant(c)), + var_levels_data: VarLevelsData::new_no_uvars(), + }, + VarKindDiscriminant::RecordRows => UnifType::Concrete { + typ: TypeF::Record(UnifRecordRows::Constant(c)), + var_levels_data: VarLevelsData::new_no_uvars(), + }, + } + } + + /// Extract the concrete type corresponding to a unifiable type. Free unification variables as + /// well as type constants are replaced with the type `Dyn`. + fn into_type(self, alloc: &'ast AstAlloc, table: &UnifTable<'ast>) -> Type<'ast> { + match self { + UnifType::UnifVar { id, init_level } => match table.root_type(id, init_level) { + t @ UnifType::Concrete { .. } => t.into_type(alloc, table), + _ => Type::from(TypeF::Dyn), + }, + UnifType::Constant(_) => Type::from(TypeF::Dyn), + UnifType::Concrete { typ, .. } => { + let mapped = typ.map( + |btyp| alloc.alloc(btyp.into_type(alloc, table)), + |urrows| urrows.into_rrows(alloc, table), + |uerows| uerows.into_erows(alloc, table), + |(term, _env)| term, + ); + Type::from(mapped) + } + } + } + + /// Return the unification root associated with this type. If the type is a unification + /// variable, return the result of `table.root_type`. Return `self` otherwise. + fn into_root(self, table: &UnifTable<'ast>) -> Self { + match self { + UnifType::UnifVar { id, init_level } => table.root_type(id, init_level), + uty => uty, + } + } +} + +impl<'ast> UnifRecordRows<'ast> { + /// Create concrete generic record rows. Compute the variable levels data from the + /// subcomponents. + pub fn concrete(typ: UnifRecordRowsUnr<'ast>) -> Self { + let upper_bound = typ.var_level_upper_bound(); + + UnifRecordRows::Concrete { + rrows: typ, + var_levels_data: VarLevelsData::new_from_bound(upper_bound), + } + } + + /// Extract the concrete [`RecordRows`] corresponding to a [`UnifRecordRows<'ast>`]. Free unification + /// variables as well as type constants are replaced with the empty row. + fn into_rrows(self, alloc: &'ast AstAlloc, table: &UnifTable<'ast>) -> RecordRows<'ast> { + match self { + UnifRecordRows::UnifVar { id, init_level } => match table.root_rrows(id, init_level) { + t @ UnifRecordRows::Concrete { .. } => t.into_rrows(alloc, table), + _ => RecordRows(RecordRowsF::Empty), + }, + UnifRecordRows::Constant(_) => RecordRows(RecordRowsF::Empty), + UnifRecordRows::Concrete { rrows, .. } => { + let mapped = rrows.map( + |ty| alloc.alloc(ty.into_type(alloc, table)), + |rrows| alloc.alloc(rrows.into_rrows(alloc, table)), + ); + RecordRows(mapped) + } + } + } + + /// Return the unification root associated with these record rows. If the rows are a unification + /// variable, return the result of `table.root_rrows`. Return `self` otherwise. + fn into_root(self, table: &UnifTable<'ast>) -> Self { + match self { + UnifRecordRows::UnifVar { id, init_level } => table.root_rrows(id, init_level), + urrows => urrows, + } + } +} + +impl<'ast> UnifEnumRows<'ast> { + /// Create concrete generic enum rows. Compute the variable levels data from the subcomponents. + pub fn concrete(typ: UnifEnumRowsUnr<'ast>) -> Self { + let upper_bound = typ.var_level_upper_bound(); + + UnifEnumRows::Concrete { + erows: typ, + var_levels_data: VarLevelsData::new_from_bound(upper_bound), + } + } + + /// Extract the concrete [`EnumRows`] corresponding to a [`UnifEnumRows<'ast>`]. Free unification + /// variables as well as type constants are replaced with the empty row. + fn into_erows(self, alloc: &'ast AstAlloc, table: &UnifTable<'ast>) -> EnumRows<'ast> { + match self { + UnifEnumRows::UnifVar { id, init_level } => match table.root_erows(id, init_level) { + t @ UnifEnumRows::Concrete { .. } => t.into_erows(alloc, table), + _ => EnumRows(EnumRowsF::Empty), + }, + UnifEnumRows::Constant(_) => EnumRows(EnumRowsF::Empty), + UnifEnumRows::Concrete { erows, .. } => { + let mapped = erows.map( + |ty| alloc.alloc(ty.into_type(alloc, table)), + |erows| alloc.alloc(erows.into_erows(alloc, table)), + ); + EnumRows(mapped) + } + } + } + + /// Return the unification root associated with these enum rows. If the rows are a unification + /// variable, return the result of `table.root_erows`. Return `self` otherwise. + fn into_root(self, table: &UnifTable<'ast>) -> Self { + match self { + UnifEnumRows::UnifVar { id, init_level } => table.root_erows(id, init_level), + uerows => uerows, + } + } +} + +impl<'ast> TryConvert<'ast, UnifRecordRows<'ast>> for RecordRows<'ast> { + type Error = (); + + fn try_convert( + alloc: &'ast AstAlloc, + urrows: UnifRecordRows<'ast>, + ) -> Result, ()> { + match urrows { + UnifRecordRows::Concrete { rrows, .. } => { + let converted: RecordRowsF<&'ast Type<'ast>, &'ast RecordRows<'ast>> = rrows + .try_map( + |uty| Ok(alloc.alloc(Type::try_convert(alloc, *uty)?)), + |urrows| Ok(alloc.alloc(RecordRows::try_convert(alloc, *urrows)?)), + )?; + Ok(RecordRows(converted)) + } + _ => Err(()), + } + } +} + +impl<'ast> TryConvert<'ast, UnifEnumRows<'ast>> for EnumRows<'ast> { + type Error = (); + + fn try_convert( + alloc: &'ast AstAlloc, + uerows: UnifEnumRows<'ast>, + ) -> Result, ()> { + match uerows { + UnifEnumRows::Concrete { erows, .. } => { + let converted: EnumRowsF<&'ast Type<'ast>, &'ast EnumRows<'ast>> = erows.try_map( + |uty| Ok(alloc.alloc(Type::try_convert(alloc, *uty)?)), + |uerows| Ok(alloc.alloc(EnumRows::try_convert(alloc, *uerows)?)), + )?; + Ok(EnumRows(converted)) + } + _ => Err(()), + } + } +} + +impl<'ast> TryConvert<'ast, UnifType<'ast>> for Type<'ast> { + type Error = (); + + fn try_convert(alloc: &'ast AstAlloc, utype: UnifType<'ast>) -> Result, ()> { + match utype { + UnifType::Concrete { typ, .. } => { + let converted: TypeF< + &'ast Type<'ast>, + RecordRows<'ast>, + EnumRows<'ast>, + &'ast Ast<'ast>, + > = typ.try_map( + |uty_boxed| { + let ty = Type::try_convert(alloc, *uty_boxed)?; + Ok(alloc.alloc(ty)) + }, + |urrows| RecordRows::try_convert(alloc, urrows), + |uerows| EnumRows::try_convert(alloc, uerows), + |(term, _env)| Ok(term), + )?; + Ok(Type::from(converted)) + } + _ => Err(()), + } + } +} + +impl<'ast> UnifEnumRows<'ast> { + pub fn from_enum_rows(erows: EnumRows<'ast>, env: &TermEnv<'ast>) -> Self { + let f_erow = |ty: &'ast Type<'ast>| Box::new(UnifType::from_type(ty.clone(), env)); + let f_erows = |erows: &'ast EnumRows<'ast>| { + Box::new(UnifEnumRows::from_enum_rows(erows.clone(), env)) + }; + + UnifEnumRows::concrete(erows.0.map(f_erow, f_erows)) + } +} + +impl<'ast> UnifEnumRows<'ast> { + /// Return an iterator producing immutable references to individual rows. + pub(super) fn iter(&self) -> EnumRowsIterator, UnifEnumRows<'ast>> { + EnumRowsIterator { + erows: Some(self), + ty: std::marker::PhantomData, + } + } +} + +impl<'ast> UnifRecordRows<'ast> { + /// Create [UnifRecordRows<'ast>] from [RecordRows]. + pub fn from_record_rows(rrows: RecordRows<'ast>, env: &TermEnv<'ast>) -> Self { + let f_rrow = |ty: &'ast Type<'ast>| Box::new(UnifType::from_type(ty.clone(), env)); + let f_rrows = |rrows: &'ast RecordRows<'ast>| { + Box::new(UnifRecordRows::from_record_rows(rrows.clone(), env)) + }; + + UnifRecordRows::concrete(rrows.0.map(f_rrow, f_rrows)) + } +} + +impl<'ast> UnifRecordRows<'ast> { + pub(super) fn iter(&self) -> RecordRowsIterator, UnifRecordRows<'ast>> { + RecordRowsIterator { + rrows: Some(self), + ty: std::marker::PhantomData, + } + } +} + +/// A type which contains variables that can be substituted with values of type `T`. +trait Subst: Sized { + /// Substitute all variables of identifier `id` with `to`. + fn subst(self, id: &LocIdent, to: &T) -> Self { + self.subst_levels(id, to).0 + } + + /// Must be filled by implementers of this trait. + /// In addition to performing substitution, this method threads variable levels upper bounds to + /// compute new upper bounds efficiently. + fn subst_levels(self, id: &LocIdent, to: &T) -> (Self, VarLevel); +} + +impl<'ast> Subst> for UnifType<'ast> { + fn subst_levels(self, id: &LocIdent, to: &UnifType<'ast>) -> (Self, VarLevel) { + match self { + UnifType::Concrete { + typ: TypeF::Var(var_id), + var_levels_data, + } if var_id == id.ident() => { + // A free type variable isn't (yet) a unification variable, so it shouldn't have a + // level set at this point. During instantiation, it might be substituted for a + // unification variable by this very function, and will then inherit this level. + debug_assert!(var_levels_data.upper_bound == VarLevel::NO_VAR); + + (to.clone(), to.var_level_upper_bound()) + } + UnifType::Concrete { + typ, + var_levels_data, + } => { + let mut upper_bound = VarLevel::NO_VAR; + + let new_ty = UnifType::Concrete { + typ: typ.map_state( + |ty, upper_bound| { + let (new_type, new_ub) = ty.subst_levels(id, to); + *upper_bound = max(*upper_bound, new_ub); + Box::new(new_type) + }, + |rrows, upper_bound| { + let (new_rrows, new_ub) = rrows.subst_levels(id, to); + *upper_bound = max(*upper_bound, new_ub); + new_rrows + }, + |erows, upper_bound| { + let (new_erows, new_ub) = erows.subst_levels(id, to); + *upper_bound = max(*upper_bound, new_ub); + new_erows + }, + // Substitution doesn't cross the contract boundaries + |ctr, _upper_bound| ctr, + &mut upper_bound, + ), + var_levels_data: VarLevelsData { + upper_bound, + ..var_levels_data + }, + }; + + (new_ty, upper_bound) + } + _ => { + let upper_bound = self.var_level_upper_bound(); + (self, upper_bound) + } + } + } +} + +impl<'ast> Subst> for UnifRecordRows<'ast> { + fn subst_levels(self, id: &LocIdent, to: &UnifType<'ast>) -> (Self, VarLevel) { + match self { + UnifRecordRows::Concrete { + rrows, + var_levels_data, + } => { + let mut upper_bound = VarLevel::NO_VAR; + + let new_rrows = rrows.map_state( + |ty, upper_bound| { + let (new_ty, new_ub) = ty.subst_levels(id, to); + *upper_bound = max(*upper_bound, new_ub); + Box::new(new_ty) + }, + |rrows, upper_bound| { + let (new_rrows, new_ub) = rrows.subst_levels(id, to); + *upper_bound = max(*upper_bound, new_ub); + Box::new(new_rrows) + }, + &mut upper_bound, + ); + + let new_urrows = UnifRecordRows::Concrete { + rrows: new_rrows, + var_levels_data: VarLevelsData { + upper_bound, + ..var_levels_data + }, + }; + + (new_urrows, upper_bound) + } + _ => { + let upper_bound = self.var_level_upper_bound(); + (self, upper_bound) + } + } + } +} + +impl<'ast> Subst> for UnifEnumRows<'ast> { + fn subst_levels(self, id: &LocIdent, to: &UnifType<'ast>) -> (Self, VarLevel) { + match self { + UnifEnumRows::Concrete { + erows, + var_levels_data, + } => { + let mut upper_bound = VarLevel::NO_VAR; + + let new_erows = erows.map_state( + |ty, upper_bound| { + let (new_ty, new_ub) = ty.subst_levels(id, to); + *upper_bound = max(*upper_bound, new_ub); + Box::new(new_ty) + }, + |erows, upper_bound| { + let (new_erows, new_ub) = erows.subst_levels(id, to); + *upper_bound = max(*upper_bound, new_ub); + Box::new(new_erows) + }, + &mut upper_bound, + ); + + let new_uerows = UnifEnumRows::Concrete { + erows: new_erows, + var_levels_data: VarLevelsData { + upper_bound, + ..var_levels_data + }, + }; + + (new_uerows, upper_bound) + } + _ => { + let upper_bound = self.var_level_upper_bound(); + (self, upper_bound) + } + } + } +} + +impl<'ast> Subst> for UnifType<'ast> { + fn subst_levels(self, id: &LocIdent, to: &UnifRecordRows<'ast>) -> (Self, VarLevel) { + match self { + UnifType::Concrete { + typ, + var_levels_data, + } => { + let mut upper_bound = VarLevel::NO_VAR; + + let new_ty = typ.map_state( + |ty, upper_bound| { + let (new_ty, new_ub) = ty.subst_levels(id, to); + *upper_bound = max(*upper_bound, new_ub); + Box::new(new_ty) + }, + |rrows, upper_bound| { + let (new_rrows, new_ub) = rrows.subst_levels(id, to); + *upper_bound = max(*upper_bound, new_ub); + new_rrows + }, + |erows, upper_bound| { + let (new_erows, new_ub) = erows.subst_levels(id, to); + *upper_bound = max(*upper_bound, new_ub); + new_erows + }, + |ctr, _upper_bound| ctr, + &mut upper_bound, + ); + + let new_uty = UnifType::Concrete { + typ: new_ty, + var_levels_data, + }; + + (new_uty, upper_bound) + } + _ => { + let upper_bound = self.var_level_upper_bound(); + (self, upper_bound) + } + } + } +} + +impl<'ast> Subst> for UnifRecordRows<'ast> { + fn subst_levels(self, id: &LocIdent, to: &UnifRecordRows<'ast>) -> (Self, VarLevel) { + match self { + UnifRecordRows::Concrete { + rrows: RecordRowsF::TailVar(var_id), + var_levels_data, + } if var_id == *id => { + debug_assert!(var_levels_data.upper_bound == VarLevel::NO_VAR); + (to.clone(), to.var_level_upper_bound()) + } + UnifRecordRows::Concrete { + rrows, + var_levels_data, + } => { + let mut upper_bound = VarLevel::NO_VAR; + + let new_rrows = rrows.map_state( + |ty, upper_bound| { + let (new_ty, new_ub) = ty.subst_levels(id, to); + *upper_bound = max(*upper_bound, new_ub); + Box::new(new_ty) + }, + |rrows, upper_bound| { + let (new_rrows, new_ub) = rrows.subst_levels(id, to); + *upper_bound = max(*upper_bound, new_ub); + Box::new(new_rrows) + }, + &mut upper_bound, + ); + + let new_urrows = UnifRecordRows::Concrete { + rrows: new_rrows, + var_levels_data: VarLevelsData { + upper_bound, + ..var_levels_data + }, + }; + + (new_urrows, upper_bound) + } + _ => { + let upper_bound = self.var_level_upper_bound(); + (self, upper_bound) + } + } + } +} + +impl<'ast> Subst> for UnifEnumRows<'ast> { + fn subst_levels(self, id: &LocIdent, to: &UnifRecordRows<'ast>) -> (Self, VarLevel) { + match self { + UnifEnumRows::Concrete { + erows, + var_levels_data, + } => { + let mut upper_bound = VarLevel::NO_VAR; + + let new_erows = erows.map_state( + |ty, upper_bound| { + let (new_ty, new_ub) = ty.subst_levels(id, to); + *upper_bound = max(*upper_bound, new_ub); + Box::new(new_ty) + }, + |erows, upper_bound| { + let (new_erows, new_ub) = erows.subst_levels(id, to); + *upper_bound = max(*upper_bound, new_ub); + Box::new(new_erows) + }, + &mut upper_bound, + ); + + let new_uerows = UnifEnumRows::Concrete { + erows: new_erows, + var_levels_data, + }; + + (new_uerows, upper_bound) + } + _ => { + let upper_bound = self.var_level_upper_bound(); + (self, upper_bound) + } + } + } +} + +impl<'ast> Subst> for UnifType<'ast> { + fn subst_levels(self, id: &LocIdent, to: &UnifEnumRows<'ast>) -> (Self, VarLevel) { + match self { + UnifType::Concrete { + typ, + var_levels_data, + } => { + let mut upper_bound = VarLevel::NO_VAR; + + let new_ty = typ.map_state( + |ty, upper_bound| { + let (new_ty, new_ub) = ty.subst_levels(id, to); + *upper_bound = max(*upper_bound, new_ub); + Box::new(new_ty) + }, + |rrows, upper_bound| { + let (new_rrows, new_ub) = rrows.subst_levels(id, to); + *upper_bound = max(*upper_bound, new_ub); + new_rrows + }, + |erows, upper_bound| { + let (new_erows, new_ub) = erows.subst_levels(id, to); + *upper_bound = max(*upper_bound, new_ub); + new_erows + }, + |ctr, _upper_bound| ctr, + &mut upper_bound, + ); + + let new_uty = UnifType::Concrete { + typ: new_ty, + var_levels_data: VarLevelsData { + upper_bound, + ..var_levels_data + }, + }; + + (new_uty, upper_bound) + } + _ => { + let upper_bound = self.var_level_upper_bound(); + (self, upper_bound) + } + } + } +} + +impl<'ast> Subst> for UnifRecordRows<'ast> { + fn subst_levels(self, id: &LocIdent, to: &UnifEnumRows<'ast>) -> (Self, VarLevel) { + match self { + UnifRecordRows::Concrete { + rrows, + var_levels_data, + } => { + let mut upper_bound = VarLevel::NO_VAR; + + let new_rrows = rrows.map_state( + |ty, upper_bound| { + let (new_ty, new_ub) = ty.subst_levels(id, to); + *upper_bound = max(*upper_bound, new_ub); + Box::new(new_ty) + }, + |rrows, upper_bound| { + let (new_rrows, new_ub) = rrows.subst_levels(id, to); + *upper_bound = max(*upper_bound, new_ub); + Box::new(new_rrows) + }, + &mut upper_bound, + ); + + let new_urrows = UnifRecordRows::Concrete { + rrows: new_rrows, + var_levels_data: VarLevelsData { + upper_bound, + ..var_levels_data + }, + }; + + (new_urrows, upper_bound) + } + + _ => { + let upper_bound = self.var_level_upper_bound(); + (self, upper_bound) + } + } + } +} + +impl<'ast> Subst> for UnifEnumRows<'ast> { + fn subst_levels(self, id: &LocIdent, to: &UnifEnumRows<'ast>) -> (Self, VarLevel) { + match self { + UnifEnumRows::Concrete { + erows: EnumRowsF::TailVar(var_id), + var_levels_data, + } if var_id == *id => { + debug_assert!(var_levels_data.upper_bound == VarLevel::NO_VAR); + + (to.clone(), to.var_level_upper_bound()) + } + UnifEnumRows::Concrete { + erows, + var_levels_data, + } => { + let mut upper_bound = VarLevel::NO_VAR; + + let new_erows = erows.map_state( + |ty, upper_bound| { + let (new_ty, new_ub) = ty.subst_levels(id, to); + *upper_bound = max(*upper_bound, new_ub); + Box::new(new_ty) + }, + |erows, upper_bound| { + let (new_erows, new_ub) = erows.subst_levels(id, to); + *upper_bound = max(*upper_bound, new_ub); + Box::new(new_erows) + }, + &mut upper_bound, + ); + + let new_uerows = UnifEnumRows::Concrete { + erows: new_erows, + var_levels_data: VarLevelsData { + upper_bound, + ..var_levels_data + }, + }; + + (new_uerows, upper_bound) + } + + _ => { + let upper_bound = self.var_level_upper_bound(); + (self, upper_bound) + } + } + } +} + +impl<'ast> From> for UnifType<'ast> { + fn from(typ: UnifTypeUnr<'ast>) -> Self { + let var_level_max = typ.var_level_upper_bound(); + + UnifType::Concrete { + typ, + var_levels_data: VarLevelsData::new_from_bound(var_level_max), + } + } +} + +impl<'ast> From>, Box>>> + for UnifRecordRows<'ast> +{ + fn from(rrows: RecordRowsF>, Box>>) -> Self { + let var_level_max = rrows.var_level_upper_bound(); + + UnifRecordRows::Concrete { + rrows, + var_levels_data: VarLevelsData::new_from_bound(var_level_max), + } + } +} + +impl<'ast> From>, Box>>> for UnifEnumRows<'ast> { + fn from(erows: EnumRowsF>, Box>>) -> Self { + UnifEnumRows::concrete(erows) + } +} + +/// Iterator items produced by [RecordRowsIterator] on [UnifRecordRows<'ast>]. +pub enum RecordRowsElt<'a, 'ast> { + TailDyn, + TailVar(&'a LocIdent), + TailUnifVar { id: VarId, init_level: VarLevel }, + TailConstant(VarId), + Row(RecordRowF<&'a UnifType<'ast>>), +} + +impl<'a, 'ast> Iterator for RecordRowsIterator<'a, UnifType<'ast>, UnifRecordRows<'ast>> { + type Item = RecordRowsElt<'a, 'ast>; + + fn next(&mut self) -> Option { + self.rrows.and_then(|next| match next { + UnifRecordRows::Concrete { rrows, .. } => match rrows { + RecordRowsF::Empty => { + self.rrows = None; + None + } + RecordRowsF::TailDyn => { + self.rrows = None; + Some(RecordRowsElt::TailDyn) + } + RecordRowsF::TailVar(id) => { + self.rrows = None; + Some(RecordRowsElt::TailVar(id)) + } + RecordRowsF::Extend { row, tail } => { + self.rrows = Some(tail); + Some(RecordRowsElt::Row(RecordRowF { + id: row.id, + typ: row.typ.as_ref(), + })) + } + }, + UnifRecordRows::UnifVar { id, init_level } => { + self.rrows = None; + Some(RecordRowsElt::TailUnifVar { + id: *id, + init_level: *init_level, + }) + } + UnifRecordRows::Constant(var_id) => { + self.rrows = None; + Some(RecordRowsElt::TailConstant(*var_id)) + } + }) + } +} + +/// Iterator items produced by [`EnumRowsIterator`]. +pub enum EnumRowsElt<'a, 'ast> { + TailVar(&'a LocIdent), + TailUnifVar { id: VarId, init_level: VarLevel }, + TailConstant(VarId), + Row(EnumRowF<&'a UnifType<'ast>>), +} + +impl<'a, 'ast> Iterator for EnumRowsIterator<'a, UnifType<'ast>, UnifEnumRows<'ast>> { + type Item = EnumRowsElt<'a, 'ast>; + + fn next(&mut self) -> Option { + self.erows.and_then(|next| match next { + UnifEnumRows::Concrete { erows, .. } => match erows { + EnumRowsF::Empty => { + self.erows = None; + None + } + EnumRowsF::TailVar(id) => { + self.erows = None; + Some(EnumRowsElt::TailVar(id)) + } + EnumRowsF::Extend { row, tail } => { + self.erows = Some(tail); + Some(EnumRowsElt::Row(EnumRowF { + id: row.id, + typ: row.typ.as_ref().map(|ty| ty.as_ref()), + })) + } + }, + UnifEnumRows::UnifVar { id, init_level } => { + self.erows = None; + Some(EnumRowsElt::TailUnifVar { + id: *id, + init_level: *init_level, + }) + } + UnifEnumRows::Constant(var_id) => { + self.erows = None; + Some(EnumRowsElt::TailConstant(*var_id)) + } + }) + } +} + +pub trait ReifyAsUnifType<'ast> { + fn unif_type() -> UnifType<'ast>; +} + +impl<'ast> ReifyAsUnifType<'ast> for crate::label::TypeVarData { + fn unif_type() -> UnifType<'ast> { + mk_buty_record!(("polarity", crate::label::Polarity::unif_type())) + } +} + +impl<'ast> ReifyAsUnifType<'ast> for crate::label::Polarity { + fn unif_type() -> UnifType<'ast> { + mk_buty_enum!("Positive", "Negative") + } +} + +/// The typing context is a structure holding the scoped, environment-like data structures required +/// to perform typechecking. +/// +#[derive(Debug, PartialEq, Clone)] +pub struct Context<'ast> { + /// The typing environment. + pub type_env: TypeEnv<'ast>, + /// The term environment, used to decide type equality over contracts. + pub term_env: TermEnv<'ast>, + /// The current variable level, incremented each time we instantiate a polymorphic type and + /// thus introduce a new block of variables (either unification variables or rigid type + /// variables). + pub var_level: VarLevel, +} + +impl Context<'_> { + pub fn new() -> Self { + Context { + type_env: TypeEnv::new(), + term_env: TermEnv::new(), + var_level: VarLevel::MIN_LEVEL, + } + } +} + +impl Default for Context<'_> { + fn default() -> Self { + Self::new() + } +} + +#[derive(Clone, Debug)] +pub enum EnvBuildError<'ast> { + NotARecord(Ast<'ast>), +} + +/// Populate the initial typing environment from a `Vec` of parsed files. +pub fn mk_initial_ctxt<'ast>( + ast_alloc: &'ast AstAlloc, + initial_env: &[(nickel_stdlib::StdlibModule, Ast<'ast>)], +) -> Result, EnvBuildError<'ast>> { + // Collect the bindings for each module, clone them and flatten the result to a single list. + let mut bindings = Vec::new(); + + for (module, ast) in initial_env { + match (module, &ast.node) { + // The internals module is special: it is required to be syntactically a record, + // and is added directly to the top-level environment. + (nickel_stdlib::StdlibModule::Internals, Node::Record(record)) => { + // We reject fields without a value (that would be a stdlib module without + // defintion). We also assume that the top-level modules of the stdlib aren't + // defined piecewise, so that all path have length exactly one, and that those are + // static. + bindings.extend(record.field_defs.iter().map(|field_def| { + // unwrap(s)(): see assumptions above about the structure of the stdlib. + debug_assert!( + field_def.path.len() == 1, + "unexpected piecewise definition in stdlib internals module" + ); + + let id = field_def.path.first().unwrap().try_as_ident().unwrap(); + + ( + id, + field_def + .value + .as_ref() + .unwrap_or_else(|| { + panic!("expected stdlib module {id} to have a definition") + }) + .clone(), + ) + })); + } + (nickel_stdlib::StdlibModule::Internals, _) => { + return Err(EnvBuildError::NotARecord(ast.clone())); + } + // Otherwise, we insert a value in the environment bound to the name of the module + (module, _) => bindings.push((module.name().into(), ast.clone())), + } + } + + let term_env = bindings + .iter() + .cloned() + .map(|(id, ast)| (id.ident(), (ast, TermEnv::new()))) + .collect(); + + let type_env = bindings + .into_iter() + .map(|(id, ast)| { + ( + id.ident(), + infer_record_type(ast_alloc, &ast, &term_env, INFER_RECORD_MAX_DEPTH), + ) + }) + .collect(); + + Ok(Context { + type_env, + term_env, + var_level: VarLevel::MIN_LEVEL, + }) +} + +/// Add the bindings of a record to a typing environment. Ignore fields whose name are defined +/// through interpolation. +//TODO: support the case of a record with a type annotation. +pub fn env_add_term<'ast>( + ast_alloc: &'ast AstAlloc, + env: &mut TypeEnv<'ast>, + ast: &Ast<'ast>, + term_env: &TermEnv<'ast>, + resolver: &dyn ImportResolver, +) -> Result<(), EnvBuildError<'ast>> { + match &ast.node { + Node::Record(record) => { + for field_def in record.field_defs.iter() { + if let Some(id) = field_def.path.first().unwrap().try_as_ident() { + let uty = UnifType::from_apparent_type( + field_def.apparent_type(ast_alloc, Some(env), Some(resolver)), + term_env, + ); + + env.insert(id.ident(), uty); + } + } + + Ok(()) + } + _ => Err(EnvBuildError::NotARecord(ast.clone())), + } +} + +/// Bind one term in a typing environment. +pub fn env_add<'ast>( + ast_alloc: &'ast AstAlloc, + env: &mut TypeEnv<'ast>, + id: LocIdent, + ast: &Ast<'ast>, + term_env: &TermEnv<'ast>, + resolver: &dyn ImportResolver, +) { + env.insert( + id.ident(), + UnifType::from_apparent_type( + apparent_type(ast_alloc, &ast.node, Some(env), Some(resolver)), + term_env, + ), + ); +} + +/// The shared state of unification. +pub struct State<'ast, 'tc> { + /// The import resolver, to retrieve and typecheck imports. + resolver: &'tc dyn ImportResolver, + /// The unification table. + table: &'tc mut UnifTable<'ast>, + /// Row constraints. + constr: &'tc mut RowConstrs, + /// A mapping from unification variables or constants together with their + /// kind to the name of the corresponding type variable which introduced it, + /// if any. + /// + /// Used for error reporting. + names: &'tc mut NameTable, + /// A mapping from wildcard ID to unification variable. + wildcard_vars: &'tc mut Vec>, + /// The AST allocator. + ast_alloc: &'ast AstAlloc, +} + +/// Immutable and owned data, required by the LSP to carry out specific analysis. +/// It is basically an owned-subset of the typechecking state. +pub struct TypeTables<'ast> { + pub table: UnifTable<'ast>, + pub names: NameTable, + pub wildcards: Vec>, +} + +/// Typecheck a term. +/// +/// Return the inferred type in case of success. This is just a wrapper that calls +/// `type_check_with_visitor` with a blanket implementation for the visitor. +/// +/// Note that this function doesn't recursively typecheck imports (anymore), but just the current +/// file. It however still needs the resolver to get the apparent type of imports. +/// +/// Return the type inferred for type wildcards. +pub fn type_check<'ast>( + alloc: &'ast AstAlloc, + ast: &Ast<'ast>, + initial_ctxt: Context<'ast>, + resolver: &impl ImportResolver, + initial_mode: TypecheckMode, +) -> Result, TypecheckError> { + type_check_with_visitor(alloc, ast, initial_ctxt, resolver, &mut (), initial_mode) + .map(|tables| tables.wildcards) +} + +/// Typecheck a term while providing the type information to a visitor. +pub fn type_check_with_visitor<'ast, V>( + ast_alloc: &'ast AstAlloc, + ast: &Ast<'ast>, + initial_ctxt: Context<'ast>, + resolver: &impl ImportResolver, + visitor: &mut V, + initial_mode: TypecheckMode, +) -> Result, TypecheckError> +where + V: TypecheckVisitor<'ast>, +{ + let (mut table, mut names) = (UnifTable::new(), HashMap::new()); + let mut wildcard_vars = Vec::new(); + + { + let mut state = State { + resolver, + table: &mut table, + constr: &mut RowConstrs::new(), + names: &mut names, + wildcard_vars: &mut wildcard_vars, + ast_alloc, + }; + + if initial_mode == TypecheckMode::Enforce { + let uty = state.table.fresh_type_uvar(initial_ctxt.var_level); + check(&mut state, initial_ctxt, visitor, ast, uty)?; + } else { + walk(&mut state, initial_ctxt, visitor, ast)?; + } + } + + let wildcards = wildcard_vars_to_type(ast_alloc, wildcard_vars.clone(), &table); + + Ok(TypeTables { + table, + names, + wildcards, + }) +} + +/// Walk the AST of a term looking for statically typed block to check. Fill the linearization +/// alongside and store the apparent type of variable inside the typing environment. +fn walk<'ast, V: TypecheckVisitor<'ast>>( + state: &mut State<'ast, '_>, + mut ctxt: Context<'ast>, + visitor: &mut V, + ast: &Ast<'ast>, +) -> Result<(), TypecheckError> { + let Ast { node, pos } = ast; + + visitor.visit_term( + ast, + UnifType::from_apparent_type( + apparent_type( + state.ast_alloc, + node, + Some(&ctxt.type_env), + Some(state.resolver), + ), + &ctxt.term_env, + ), + ); + + match node { + Node::ParseError(_) + | Node::Null + | Node::Bool(_) + | Node::Number(_) + | Node::String(_) + | Node::EnumVariant {arg: None, ..} + // This function doesn't recursively typecheck imports: this is the responsibility of the + // caller. + | Node::Import(_) => Ok(()), + Node::Var(x) => ctxt.type_env + .get(&x.ident()) + .ok_or(TypecheckError::UnboundIdentifier { id: *x, pos: *pos }) + .map(|_| ()), + Node::StringChunks(chunks) => { + for chunk in chunks.iter() { + if let StringChunk::Expr(t, _) = chunk { + walk(state, ctxt.clone(), visitor, t)?; + } + } + + Ok(()) + } + Node::Fun { args, body } => { + // The parameter of an unannotated function is always assigned type `Dyn`, unless the + // function is directly annotated with a function contract (see the special casing in + // `walk_with_annot`). + for arg in args.iter() { + let PatternTypeData { bindings: pat_bindings, ..} = arg.pattern_types(state, &ctxt, TypecheckMode::Walk)?; + ctxt.type_env.extend(pat_bindings.into_iter().map(|(id, typ)| (id.ident(), typ))); + } + + walk(state, ctxt, visitor, body) + } + Node::Array(array) => { + for elt in array.iter() { + walk(state, ctxt.clone(), visitor, elt)?; + } + + Ok(()) + } + Node::Let { bindings, body, rec } => { + // For a recursive let block, shadow all the names we're about to bind, so + // we aren't influenced by variables defined in an outer scope. + if *rec { + for binding in bindings.iter() { + for pat_binding in binding.pattern.bindings() { + ctxt.type_env + .insert(pat_binding.id.ident(), mk_uniftype::dynamic()); + } + } + } + + let start_ctxt = ctxt.clone(); + + for binding in bindings.iter() { + let ty_let = binding_type(state, &binding.value.node, &start_ctxt, false); + + // In the case of a let-binding, we want to guess a better type than `Dyn` when we can + // do so cheaply for the whole pattern. + if let Some(alias) = &binding.pattern.alias { + visitor.visit_ident(alias, ty_let.clone()); + ctxt.type_env.insert(alias.ident(), ty_let); + ctxt.term_env.0.insert(alias.ident(), (binding.value.clone(), start_ctxt.term_env.clone())); + } + + // [^separate-alias-treatment]: Note that we call `pattern_types` on the inner pattern + // data, which doesn't take into account the potential heading alias `x @ `. + // This is on purpose, as the alias has been treated separately, so we don't want to + // shadow it with a less precise type. + // + // The use of start_ctxt here looks like it might be wrong for let rec, but in fact + // the context is unused in mode `TypecheckMode::Walk` anyway. + let PatternTypeData {bindings: pat_bindings, ..} = binding.pattern.data.pattern_types(state, &start_ctxt, TypecheckMode::Walk)?; + + for (id, typ) in pat_bindings { + visitor.visit_ident(&id, typ.clone()); + ctxt.type_env.insert(id.ident(), typ); + // [^term-env-rec-bindings]: we don't support recursive binding when checking + // for contract equality. + // + // This would quickly lead to `Rc` cycles, which are hard to deal with without + // leaking memory. The best way out would be to allocate all the term + // environments inside an arena, local to each statically typed block, and use + // bare references to represent cycles. Everything would be cleaned at the end + // of the block. + ctxt.term_env.0.insert(id.ident(), (binding.value.clone(), start_ctxt.term_env.clone())); + } + } + + let value_ctxt = if *rec { ctxt.clone() } else { start_ctxt.clone() }; + + for binding in bindings.iter() { + walk(state, value_ctxt.clone(), visitor, &binding.value)?; + } + + walk(state, ctxt, visitor, body) + } + Node::App { head, args } => { + walk(state, ctxt.clone(), visitor, head)?; + + for arg in args.iter() { + walk(state, ctxt.clone(), visitor, arg)?; + } + + Ok(()) + } + Node::Match(match_data) => { + for MatchBranch { pattern, guard, body } in match_data.branches.iter() { + let mut local_ctxt = ctxt.clone(); + let PatternTypeData { bindings: pat_bindings, .. } = pattern.data.pattern_types(state, &ctxt, TypecheckMode::Walk)?; + + if let Some(alias) = &pattern.alias { + visitor.visit_ident(alias, mk_uniftype::dynamic()); + local_ctxt.type_env.insert(alias.ident(), mk_uniftype::dynamic()); + } + + for (id, typ) in pat_bindings { + visitor.visit_ident(&id, typ.clone()); + local_ctxt.type_env.insert(id.ident(), typ); + } + + if let Some(guard) = guard { + walk(state, local_ctxt.clone(), visitor, guard)?; + } + + walk(state, local_ctxt, visitor, body)?; + } + + Ok(()) + } + Node::IfThenElse { + cond, + then_branch, + else_branch, + } => { + walk(state, ctxt.clone(), visitor, cond)?; + walk(state, ctxt.clone(), visitor, then_branch)?; + walk(state, ctxt, visitor, else_branch) + } + Node::Record(record) => { + for field_def in record.field_defs.iter() { + let field_type = field_type( + state, + field_def, + &ctxt, + false, + ); + + if let Some(id) = field_def.name_as_ident() { + ctxt.type_env.insert(id.ident(), field_type.clone()); + visitor.visit_ident(&id, field_type); + } + } + + // Walk the type and contract annotations + + // We don't bind the fields in the term environment used to check for contract + // equality. See [^term-env-rec-bindings]. + record.field_defs + .iter() + .try_for_each(|field_def| -> Result<(), TypecheckError> { + walk_field(state, ctxt.clone(), visitor, field_def) + }) + } + Node::EnumVariant { arg: Some(arg), ..} => walk(state, ctxt, visitor, arg), + Node::PrimOpApp { args, .. } => { + args.iter().try_for_each(|arg| -> Result<(), TypecheckError> { + walk(state, ctxt.clone(), visitor, arg) + })?; + + Ok(()) + } + Node::Annotated { annot, inner } => { + walk_annotated(state, ctxt, visitor, annot, inner) + } + Node::Type(typ) => walk_type(state, ctxt, visitor, typ), + } +} + +/// Same as [`walk`] but operate on a type, which can contain terms as contracts +/// ([crate::typ::TypeF::Contract]), instead of a term. +fn walk_type<'ast, V: TypecheckVisitor<'ast>>( + state: &mut State<'ast, '_>, + ctxt: Context<'ast>, + visitor: &mut V, + ty: &Type<'ast>, +) -> Result<(), TypecheckError> { + match &ty.typ { + TypeF::Dyn + | TypeF::Number + | TypeF::Bool + | TypeF::String + | TypeF::ForeignId + | TypeF::Symbol + // Currently, the parser can't generate unbound type variables by construction. Thus we + // don't check here for unbound type variables again. + | TypeF::Var(_) + // An enum type can't contain a contract. + // TODO: the assertion above isn't true anymore (ADTs). Need fixing? + | TypeF::Enum(_) + | TypeF::Wildcard(_) => Ok(()), + TypeF::Arrow(ty1, ty2) => { + walk_type(state, ctxt.clone(), visitor, ty1)?; + walk_type(state, ctxt, visitor, ty2) + } + TypeF::Record(rrows) => walk_rrows(state, ctxt, visitor, rrows), + TypeF::Contract(t) => walk(state, ctxt, visitor, t), + TypeF::Dict { type_fields: ty2, .. } + | TypeF::Array(ty2) + | TypeF::Forall {body: ty2, ..} => walk_type(state, ctxt, visitor, ty2), + } +} + +/// Same as [`walk_type`] but operate on record rows. +fn walk_rrows<'ast, V: TypecheckVisitor<'ast>>( + state: &mut State<'ast, '_>, + ctxt: Context<'ast>, + visitor: &mut V, + rrows: &RecordRows<'ast>, +) -> Result<(), TypecheckError> { + match rrows.0 { + RecordRowsF::Empty + // Currently, the parser can't generate unbound type variables by construction. Thus we + // don't check here for unbound type variables again. + | RecordRowsF::TailVar(_) + | RecordRowsF::TailDyn => Ok(()), + RecordRowsF::Extend { ref row, tail } => { + walk_type(state, ctxt.clone(), visitor, row.typ)?; + walk_rrows(state, ctxt, visitor, tail) + } + } +} + +fn walk_field<'ast, V: TypecheckVisitor<'ast>>( + state: &mut State<'ast, '_>, + ctxt: Context<'ast>, + visitor: &mut V, + field_def: &FieldDef<'ast>, +) -> Result<(), TypecheckError> { + walk_with_annot( + state, + ctxt, + visitor, + &field_def.metadata.annotation, + field_def.value.as_ref(), + ) +} + +fn walk_annotated<'ast, V: TypecheckVisitor<'ast>>( + state: &mut State<'ast, '_>, + ctxt: Context<'ast>, + visitor: &mut V, + annot: &Annotation<'ast>, + ast: &Ast<'ast>, +) -> Result<(), TypecheckError> { + walk_with_annot(state, ctxt, visitor, annot, Some(ast)) +} + +/// Walk an annotated term, either via [crate::term::record::FieldMetadata], or via a standalone +/// type or contract annotation. A type annotation switches the typechecking mode to _enforce_. +fn walk_with_annot<'ast, V: TypecheckVisitor<'ast>>( + state: &mut State<'ast, '_>, + mut ctxt: Context<'ast>, + visitor: &mut V, + annot: &Annotation<'ast>, + value: Option<&Ast<'ast>>, +) -> Result<(), TypecheckError> { + annot + .iter() + .try_for_each(|ty| walk_type(state, ctxt.clone(), visitor, ty))?; + + match (annot, value) { + (Annotation { typ: Some(ty2), .. }, Some(value)) => { + let uty2 = UnifType::from_type(ty2.clone(), &ctxt.term_env); + check(state, ctxt, visitor, value, uty2) + } + ( + Annotation { + typ: None, + contracts, + }, + Some(value), + ) => { + // If we see a function annotated with a function contract, we can get the type of the + // arguments for free. We use this information both for typechecking (you could see it + // as an extension of the philosophy of apparent types, but for function arguments + // instead of let-bindings) and for the LSP, to provide better type information and + // completion. + if let Node::Fun { args, body } = value.node { + // We look for the first contract of the list that is a function contract. + let domains = contracts.iter().find_map(|c| { + if let TypeF::Arrow(mut domain, _) = &c.typ { + let mut domains = + vec![UnifType::from_type((*domain).clone(), &ctxt.term_env)]; + + while let TypeF::Arrow(next_domain, _) = &domain.typ { + domains.push(UnifType::from_type((*domain).clone(), &ctxt.term_env)); + domain = next_domain; + } + + Some(domains) + } else { + None + } + }); + + if let Some(domains) = domains { + for (arg, uty) in args.iter().zip(domains) { + // Because the normal code path in `walk` sets the function argument to `Dyn`, + // we need to short-circuit it. We manually visit the argument, augment the + // typing environment and walk the body of the function. + if let Some(id) = arg.try_as_any() { + visitor.visit_ident(&id, uty.clone()); + ctxt.type_env.insert(id.ident(), uty); + } + } + + return walk(state, ctxt, visitor, body); + } + } + + walk(state, ctxt, visitor, value) + } + _ => Ok(()), + } +} + +/// Check a term against a given type. Although this method mostly corresponds to checking mode in +/// the classical bidirectional framework, it combines both checking and inference modes in +/// practice, to avoid duplicating rules (that is, code) as detailed below. +/// +/// # Literals +/// +/// Checking a literal (a number, a string, a boolean, etc.) unifies the checked type with the +/// corresponding primitive type (`Number`, `String`, `Bool`, etc.). If the checked type is a +/// unification variable, `check` acts as an inference rule. If the type is concrete, unification +/// enforces equality, and `check` acts as a checking rule. +/// +/// # Introduction rules +/// +/// Following Pfenning's recipe (see [Bidirectional Typing][bidirectional-typing]), introduction +/// rules (e.g. typechecking a record) are checking. `check` follows the same logic here: it uses +/// unification to "match" on the expected type (for example in the case of records, a record type +/// or a dictionary type) and pushes typechecking down the record fields. +/// +/// # Elimination rules +/// +/// Elimination rules (such as function application or primitive operator application) only exist in +/// inference mode (still following Pfenning's recipe). `check` follows the inference mode here +/// (typically on function application, where we first call to `infer` on the function part, and +/// then check the argument). +/// +/// Still, `check` is supposed to be implementing checking mode from the outside. We thus also +/// apply the typing rule which switches from inference to checking mode. Currently, subtyping +/// isn't supported yet in Nickel but is planned as part of RFC004. When subtyping lands, as the +/// name suggests, [`subsumption`] will be the place where we apply subsumption, as customary in +/// bidirectional type systems with subtyping. +/// +/// To sum up, elimination rules inside `check` correspond to an inference rule composed with the +/// switching/subsumption rule, resulting in a composite checking rule. +/// +/// # Parameters +/// +/// - `state`: the unification state (see [`State`]). +/// - `env`: the typing environment, mapping free variable to types. +/// - `lin`: The current building linearization of building state `S` +/// - `visitor`: A visitor that can modify the linearization +/// - `t`: the term to check. +/// - `ty`: the type to check the term against. +/// +/// # Linearization (LSP) +/// +/// `check` is in charge of registering every term with the `visitor` and makes sure to scope +/// the visitor accordingly +/// +/// [bidirectional-typing]: (https://arxiv.org/abs/1908.05839) +fn check<'ast, V: TypecheckVisitor<'ast>>( + state: &mut State<'ast, '_>, + mut ctxt: Context<'ast>, + visitor: &mut V, + ast: &Ast<'ast>, + ty: UnifType<'ast>, +) -> Result<(), TypecheckError> { + let Ast { node, pos } = ast; + + visitor.visit_term(ast, ty.clone()); + + // When checking against a polymorphic type, we immediatly instantiate potential heading + // foralls. Otherwise, this polymorphic type wouldn't unify much with other types. If we infer + // a polymorphic type for `ast`, the subsumption rule will take care of instantiating this type + // with unification variables, such that terms like `(fun x => x : forall a. a -> a) : forall + // b. b -> b` typecheck correctly. + let ty = instantiate_foralls(state, &mut ctxt, ty, ForallInst::Constant); + + match node { + Node::ParseError(_) => Ok(()), + // null is inferred to be of type Dyn + Node::Null => ty + .unify(mk_uniftype::dynamic(), state, &ctxt) + .map_err(|err| err.into_typecheck_err(state, ast.pos)), + Node::Bool(_) => ty + .unify(mk_uniftype::bool(), state, &ctxt) + .map_err(|err| err.into_typecheck_err(state, ast.pos)), + Node::Number(_) => ty + .unify(mk_uniftype::num(), state, &ctxt) + .map_err(|err| err.into_typecheck_err(state, ast.pos)), + Node::String(_) => ty + .unify(mk_uniftype::str(), state, &ctxt) + .map_err(|err| err.into_typecheck_err(state, ast.pos)), + Node::StringChunks(chunks) => { + ty.unify(mk_uniftype::str(), state, &ctxt) + .map_err(|err| err.into_typecheck_err(state, ast.pos))?; + + chunks + .iter() + .try_for_each(|chunk| -> Result<(), TypecheckError> { + match chunk { + StringChunk::Literal(_) => Ok(()), + StringChunk::Expr(t, _) => { + check(state, ctxt.clone(), visitor, t, mk_uniftype::str()) + } + } + }) + } + Node::IfThenElse { + cond, + then_branch, + else_branch, + } => { + check(state, ctxt.clone(), visitor, cond, mk_uniftype::bool())?; + check(state, ctxt.clone(), visitor, then_branch, ty.clone())?; + check(state, ctxt, visitor, else_branch, ty) + } + // Fun is an introduction rule for the arrow type. The target type is thus expected to be + // of the form `T1 -> ... -> Tn -> U`, which we enforce by unification. We then check the + // body of the function against `U`, after adding the relevant argument types in the + // environment. + Node::Fun { args, body } => { + let codomain = state.table.fresh_type_uvar(ctxt.var_level); + let fun_type = args.iter().rev().try_fold( + codomain.clone(), + |fun_type, arg| -> Result<_, TypecheckError> { + // See [^separate-alias-treatment]. + let pat_types = arg + .data + .pattern_types(state, &ctxt, TypecheckMode::Enforce)?; + // In the destructuring case, there's no alternative pattern, and we must thus + // immediately close all the row types. + pattern::close_all_enums(pat_types.enum_open_tails, state); + let arg_type = pat_types.typ; + + if let Some(id) = arg.alias { + visitor.visit_ident(&id, arg_type.clone()); + ctxt.type_env.insert(id.ident(), arg_type.clone()); + } + + for (id, typ) in pat_types.bindings { + visitor.visit_ident(&id, typ.clone()); + ctxt.type_env.insert(id.ident(), typ); + } + + Ok(mk_buty_arrow!(arg_type, fun_type)) + }, + )?; + + ty.unify(fun_type, state, &ctxt) + .map_err(|err| err.into_typecheck_err(state, ast.pos))?; + + check(state, ctxt, visitor, body, codomain) + } + Node::Array(elts) => { + let ty_elts = state.table.fresh_type_uvar(ctxt.var_level); + + ty.unify(mk_uniftype::array(ty_elts.clone()), state, &ctxt) + .map_err(|err| err.into_typecheck_err(state, ast.pos))?; + + elts.iter() + .try_for_each(|elt| -> Result<(), TypecheckError> { + check(state, ctxt.clone(), visitor, elt, ty_elts.clone()) + }) + } + Node::Let { + bindings, + body, + rec, + } => { + // For a recursive let block, shadow all the names we're about to bind, so + // we aren't influenced by variables defined in an outer scope. + if *rec { + for binding in bindings.iter() { + for pat_binding in binding.pattern.bindings() { + ctxt.type_env.insert( + pat_binding.id.ident(), + state.table.fresh_type_uvar(ctxt.var_level), + ); + } + } + } + + let start_ctxt = ctxt.clone(); + + let typed_bindings: Result, _> = bindings + .iter() + .map(|binding| -> Result<_, TypecheckError> { + // See [^separate-alias-treatment]. + let pat_types = binding.pattern.pattern_types( + state, + &start_ctxt, + TypecheckMode::Enforce, + )?; + + // In the destructuring case, there's no alternative pattern, and we must thus + // immediatly close all the row types. + pattern::close_all_enums(pat_types.enum_open_tails, state); + + // The inferred type of the expr being bound + let ty_let = binding_type(state, &binding.value.node, &start_ctxt, true); + + pat_types + .typ + .unify(ty_let.clone(), state, &start_ctxt) + .map_err(|e| e.into_typecheck_err(state, binding.value.pos))?; + + if let Some(alias) = &binding.pattern.alias { + visitor.visit_ident(alias, ty_let.clone()); + ctxt.type_env.insert(alias.ident(), ty_let.clone()); + ctxt.term_env.0.insert( + alias.ident(), + (binding.value.clone(), start_ctxt.term_env.clone()), + ); + } + + for (id, typ) in pat_types.bindings { + visitor.visit_ident(&id, typ.clone()); + ctxt.type_env.insert(id.ident(), typ); + // See [^term-env-rec-bindings] for why we use `start_ctxt` independently + // from `rec`. + ctxt.term_env.0.insert( + id.ident(), + (binding.value.clone(), start_ctxt.term_env.clone()), + ); + } + + Ok((&binding.value, ty_let)) + }) + .collect(); + + let re_ctxt = if *rec { &ctxt } else { &start_ctxt }; + + for (value, ty_let) in typed_bindings? { + check(state, re_ctxt.clone(), visitor, value, ty_let)?; + } + + check(state, ctxt, visitor, body, ty) + } + Node::Match(data) => { + // [^typechecking-match-expression]: We can associate a type to each pattern of each + // case of the match expression. From there, the type of a valid argument for the match + // expression is ideally the union of each pattern type. + // + // For record types, we don't have a good way to express union: for example, what could + // be the type of something that is either `{x : a}` or `{y : a}`? In the case of + // record types, we thus just take the intersection of the types, which amounts to + // unify all pattern types together. While it might fail most of the time (including + // for the `{x}` and `{y}` example), it can still typecheck interesting expressions + // when the record pattern are similar enough: + // + // ```nickel + // x |> match { + // {foo, bar: 'Baz} => + // {foo, bar: 'Qux} => + // } + // ``` + // + // We can definitely find a type for `x`: `{foo: a, bar: [| 'Baz, 'Qux |]}`. + // + // For enum types, we can express union: for example, the union of `[|'Foo, 'Bar|]` and + // `[|'Bar, 'Baz|]` is `[|'Foo, 'Bar, 'Baz|]`. We can even turn this into a unification + // problem: "open" the initial row types as `[| 'Foo, 'Bar; ?a |]` and `[|'Bar, 'Baz; + // ?b |]`, unify them together, and close the result (unify the tail with an empty row + // tail). The advantage of this approach is that unification takes care of descending + // into record types and sub-patterns to perform this operation, and we're back to the + // same procedure (almost) than for record patterns: simply unify all pattern types. + // Although we have additional bookkeeping to perform (remember the tail variables + // introduced to open enum rows and close the corresponding rows at the end of the + // procedure). + + // We zip the pattern types with each branch + let with_pat_types = data + .branches + .iter() + .map(|branch| -> Result<_, TypecheckError> { + Ok(( + branch, + branch + .pattern + .pattern_types(state, &ctxt, TypecheckMode::Enforce)?, + )) + }) + .collect::)>, _>>()?; + + // A match expression is a special kind of function. Thus it's typed as `a -> b`, where + // `a` is a type determined by the patterns and `b` is the type of each match arm. + let arg_type = state.table.fresh_type_uvar(ctxt.var_level); + let return_type = state.table.fresh_type_uvar(ctxt.var_level); + + // Express the constraint that all the arms of the match expression should have a + // compatible type and that each guard must be a boolean. + for ( + MatchBranch { + pattern, + guard, + body, + }, + pat_types, + ) in with_pat_types.iter() + { + if let Some(alias) = &pattern.alias { + visitor.visit_ident(alias, return_type.clone()); + ctxt.type_env.insert(alias.ident(), return_type.clone()); + } + + for (id, typ) in pat_types.bindings.iter() { + visitor.visit_ident(id, typ.clone()); + ctxt.type_env.insert(id.ident(), typ.clone()); + } + + if let Some(guard) = guard { + check(state, ctxt.clone(), visitor, guard, mk_uniftype::bool())?; + } + + check(state, ctxt.clone(), visitor, body, return_type.clone())?; + } + + let pat_types = with_pat_types.into_iter().map(|(_, pat_types)| pat_types); + + // Unify all the pattern types with the argument's type, and build the list of all open + // tail vars + let mut enum_open_tails = Vec::with_capacity( + pat_types + .clone() + .map(|pat_type| pat_type.enum_open_tails.len()) + .sum(), + ); + + // Build the list of all wildcard pattern occurrences + let mut wildcard_occurrences = HashSet::with_capacity( + pat_types + .clone() + .map(|pat_type| pat_type.wildcard_occurrences.len()) + .sum(), + ); + + // We don't immediately return if an error occurs while unifying the patterns together. + // For error reporting purposes, it's best to first close the tail variables (if + // needed), to avoid cluttering the reported types with free unification variables + // which are mostly an artifact of our implementation of typechecking pattern matching. + let pat_unif_result: Result<(), UnifError> = + pat_types.into_iter().try_for_each(|pat_type| { + arg_type.clone().unify(pat_type.typ, state, &ctxt)?; + + for (id, typ) in pat_type.bindings { + visitor.visit_ident(&id, typ.clone()); + ctxt.type_env.insert(id.ident(), typ); + } + + enum_open_tails.extend(pat_type.enum_open_tails); + wildcard_occurrences.extend(pat_type.wildcard_occurrences); + + Ok(()) + }); + + // Once we have accumulated all the information about enum rows and wildcard + // occurrences, we can finally close the tails that need to be. + pattern::close_enums(enum_open_tails, &wildcard_occurrences, state); + + // And finally fail if there was an error. + pat_unif_result.map_err(|err| err.into_typecheck_err(state, ast.pos))?; + + // We unify the expected type of the match expression with `arg_type -> return_type`. + // + // This must happen last, or at least after having closed the tails: otherwise, the + // enum type inferred for the argument could be unduly generalized. For example, take: + // + // ``` + // let exp : forall r. [| 'Foo; r |] -> Dyn = match { 'Foo => null } + // ``` + // + // This must not typecheck, as the match expression doesn't have a default case, and + // its type is thus `[| 'Foo |] -> Dyn`. However, during the typechecking of the match + // expression, before tails are closed, the working type is `[| 'Foo; _erows_a |]`, + // which can definitely unify with `[| 'Foo; r |]` while the tail is still open. If we + // close the tail first, then the type becomes [| 'Foo |] and the generalization fails + // as desired. + // + // As a safety net, the tail closing code panics (in debug mode) if it finds a rigid + // type variable at the end of the tail of a pattern type, which would happen if we + // somehow generalized an enum row type variable before properly closing the tails + // before. + ty.unify( + mk_buty_arrow!(arg_type.clone(), return_type.clone()), + state, + &ctxt, + ) + .map_err(|err| err.into_typecheck_err(state, ast.pos))?; + + Ok(()) + } + // Elimination forms (variable, function application and primitive operator application) + // follow the inference discipline, following the Pfennig recipe and the current type + // system specification (as far as typechecking is concerned, primitive operator + // application is the same as function application). + Node::Var(_) | Node::App { .. } | Node::PrimOpApp { .. } | Node::Annotated { .. } => { + let inferred = infer(state, ctxt.clone(), visitor, ast)?; + + // We apply the subsumption rule when switching from infer mode to checking mode. + inferred + .subsumed_by(ty, state, ctxt) + .map_err(|err| err.into_typecheck_err(state, ast.pos)) + } + Node::EnumVariant { tag, arg: None } => { + let row = state.table.fresh_erows_uvar(ctxt.var_level); + ty.unify(mk_buty_enum!(*tag; row), state, &ctxt) + .map_err(|err| err.into_typecheck_err(state, ast.pos)) + } + Node::EnumVariant { + tag, + arg: Some(arg), + } => { + let tail = state.table.fresh_erows_uvar(ctxt.var_level); + let ty_arg = state.table.fresh_type_uvar(ctxt.var_level); + + // We match the expected type against `[| 'id ty_arg; row_tail |]`, where `row_tail` is + // a free unification variable, to ensure it has the right shape and extract the + // components. + ty.unify(mk_buty_enum!((*tag, ty_arg.clone()); tail), state, &ctxt) + .map_err(|err| err.into_typecheck_err(state, ast.pos))?; + + // Once we have a type for the argument, we check the variant's data against it. + check(state, ctxt, visitor, arg, ty_arg) + } + Node::Record(record) => record + .resolve() + .with_pos(*pos) + .check(state, ctxt, visitor, ty), + Node::Import(_) => todo!("need to figure out import resolution with the new AST first"), + // Node::Import(_) => ty + // .unify(mk_uniftype::dynamic(), state, &ctxt) + // .map_err(|err| err.into_typecheck_err(state, ast.pos)), + // We use the apparent type of the import for checking. This function doesn't recursively + // typecheck imports: this is the responsibility of the caller. + // Term::ResolvedImport(file_id) => { + // let t = state + // .resolver + // .get(*file_id) + // .expect("Internal error: resolved import not found during typechecking."); + // let ty_import: UnifType<'ast> = UnifType::from_apparent_type( + // apparent_type(t.as_ref(), Some(&ctxt.type_env), Some(state.resolver)), + // &ctxt.term_env, + // ); + // ty.unify(ty_import, state, &ctxt) + // .map_err(|err| err.into_typecheck_err(state, ast.pos)) + // } + Node::Type(typ) => { + if let Some(_contract) = typ.find_contract() { + todo!("needs to update `error::TypecheckError` first, but not ready to switch to the new typechecker yet") + // Err(TypecheckError::CtrTypeInTermPos { + // contract, + // pos: *pos, + // }) + } else { + Ok(()) + } + } + } +} + +fn check_field<'ast, V: TypecheckVisitor<'ast>>( + state: &mut State<'ast, '_>, + ctxt: Context<'ast>, + visitor: &mut V, + def: &FieldDef<'ast>, + ty: UnifType<'ast>, +) -> Result<(), TypecheckError> { + //unwrap(): a field path is always assumed to be non-empty + let pos_id = def.path.last().unwrap().pos(); + + // If there's no annotation, we simply check the underlying value, if any. + if def.metadata.annotation.is_empty() { + if let Some(value) = def.value.as_ref() { + check(state, ctxt, visitor, value, ty) + } else { + // It might make sense to accept any type for a value without definition (which would + // act a bit like a function parameter). But for now, we play safe and implement a more + // restrictive rule, which is that a value without a definition has type `Dyn` + ty.unify(mk_uniftype::dynamic(), state, &ctxt) + .map_err(|err| err.into_typecheck_err(state, pos_id)) + } + } else { + let pos = def.value.as_ref().map(|v| v.pos).unwrap_or(pos_id); + + let inferred = infer_with_annot( + state, + ctxt.clone(), + visitor, + &def.metadata.annotation, + def.value.as_ref(), + )?; + + inferred + .subsumed_by(ty, state, ctxt) + .map_err(|err| err.into_typecheck_err(state, pos)) + } +} + +fn infer_annotated<'ast, V: TypecheckVisitor<'ast>>( + state: &mut State<'ast, '_>, + ctxt: Context<'ast>, + visitor: &mut V, + annot: &Annotation<'ast>, + ast: &Ast<'ast>, +) -> Result, TypecheckError> { + infer_with_annot(state, ctxt, visitor, annot, Some(ast)) +} + +/// Function handling the common part of inferring the type of terms with type or contract +/// annotation, with or without definitions. This encompasses both standalone type annotation +/// (where `value` is always `Some(_)`) as well as field definitions (where `value` may or may not +/// be defined). +/// +/// As for [check_visited] and [infer_visited], the additional `item_id` is provided when the term +/// has been added to the visitor before but can still benefit from updating its information +/// with the inferred type. +fn infer_with_annot<'ast, V: TypecheckVisitor<'ast>>( + state: &mut State<'ast, '_>, + ctxt: Context<'ast>, + visitor: &mut V, + annot: &Annotation<'ast>, + value: Option<&Ast<'ast>>, +) -> Result, TypecheckError> { + for ty in annot.iter() { + walk_type(state, ctxt.clone(), visitor, ty)?; + } + + match (annot, value) { + (Annotation { typ: Some(ty2), .. }, Some(value)) => { + let uty2 = UnifType::from_type(ty2.clone(), &ctxt.term_env); + + visitor.visit_term(value, uty2.clone()); + + check(state, ctxt, visitor, value, uty2.clone())?; + Ok(uty2) + } + // An annotation without a type but with a contract switches the typechecker back to walk + // mode. If there are several contracts, we arbitrarily chose the first one as the apparent + // type (the most precise type would be the intersection of all contracts, but Nickel's + // type system doesn't feature intersection types). + ( + Annotation { + typ: None, + contracts, + }, + value_opt, + ) if !contracts.is_empty() => { + let ty2 = contracts.first().unwrap(); + let uty2 = UnifType::from_type(ty2.clone(), &ctxt.term_env); + + if let Some(value) = &value_opt { + visitor.visit_term(value, uty2.clone()); + } + + // If there's an inner value, we have to walk it, as it may contain statically typed + // blocks. + if let Some(value) = value_opt { + walk(state, ctxt, visitor, value)?; + } + + Ok(uty2) + } + // A non-empty value without a type or a contract annotation is typechecked in the same way + // as its inner value. This case should only happen for record fields, as the parser can't + // produce an annotated term without an actual annotation. Still, such terms could be + // produced programmatically, and aren't necessarily an issue. + (_, Some(value)) => infer(state, ctxt, visitor, value), + // An empty value is a record field without definition. We don't check anything, and infer + // its type to be either the first annotation defined if any, or `Dyn` otherwise. + // We can only hit this case for record fields. + _ => { + let inferred = annot + .first() + .map(|ty| UnifType::from_type(ty.clone(), &ctxt.term_env)) + .unwrap_or_else(mk_uniftype::dynamic); + Ok(inferred) + } + } +} + +/// Infer a type for an expression. +/// +/// `infer` corresponds to the inference mode of bidirectional typechecking. Nickel uses a mix of +/// bidirectional typechecking and traditional ML-like unification. +fn infer<'ast, V: TypecheckVisitor<'ast>>( + state: &mut State<'ast, '_>, + mut ctxt: Context<'ast>, + visitor: &mut V, + ast: &Ast<'ast>, +) -> Result, TypecheckError> { + let Ast { node, pos } = ast; + + match node { + Node::Var(x) => { + let x_ty = ctxt + .type_env + .get(&x.ident()) + .cloned() + .ok_or(TypecheckError::UnboundIdentifier { id: *x, pos: *pos })?; + + visitor.visit_term(ast, x_ty.clone()); + + Ok(x_ty) + } + // Theoretically, we need to instantiate the type of the head of the primop application, + // that is, the primop itself. In practice, + // [crate::bytecode::typecheck::operation::PrimOpType::primop_type] returns types that are + // already instantiated with free unification variables, to save building a polymorphic + // type that would be instantiated immediately. Thus, the type of a primop is currently + // always monomorphic. + Node::PrimOpApp { op, args } => { + let (tys_args, ty_res) = op.primop_type(state, ctxt.var_level)?; + + visitor.visit_term(ast, ty_res.clone()); + + for (ty_arg, arg) in tys_args.into_iter().zip(args.iter()) { + check(state, ctxt.clone(), visitor, arg, ty_arg)?; + } + + Ok(ty_res) + } + Node::App { head, args } => { + // If we go the full Quick Look route (cf [quick-look] and the Nickel type system + // specification), we will have a more advanced and specific rule to guess the + // instantiation of the potentially polymorphic type of the head of the application. + // Currently, we limit ourselves to predicative instantiation, and we can get away + // with eagerly instantiating heading `foralls` with fresh unification variables. + let head_poly = infer(state, ctxt.clone(), visitor, head)?; + let head_type = instantiate_foralls(state, &mut ctxt, head_poly, ForallInst::UnifVar); + + let arg_types: Vec<_> = + std::iter::repeat_with(|| state.table.fresh_type_uvar(ctxt.var_level)) + .take(args.len()) + .collect(); + let codomain = state.table.fresh_type_uvar(ctxt.var_level); + let fun_type = mk_uniftype::nary_arrow(arg_types.clone(), codomain.clone()); + + // "Match" the type of the head with `dom -> codom` + fun_type + .unify(head_type, state, &ctxt) + .map_err(|err| err.into_typecheck_err(state, head.pos))?; + + visitor.visit_term(ast, codomain.clone()); + + for (arg, arg_type) in args.iter().zip(arg_types.into_iter()) { + check(state, ctxt.clone(), visitor, arg, arg_type)?; + } + + Ok(codomain) + } + Node::Annotated { annot, inner } => infer_annotated(state, ctxt, visitor, annot, inner), + _ => { + // The remaining cases can't produce polymorphic types, and thus we can reuse the + // checking code. Inferring the type for those rules is equivalent to checking against + // a free unification variable. This saves use from duplicating all the remaining + // cases. + let inferred = state.table.fresh_type_uvar(ctxt.var_level); + + visitor.visit_term(ast, inferred.clone()); + + check(state, ctxt, visitor, ast, inferred.clone())?; + Ok(inferred.into_root(state.table)) + } + } +} + +/// Determine the type of a let-bound expression. +/// +/// Call [`apparent_type`] to see if the binding is annotated. If it is, return this type as a +/// [`UnifType<'ast>`]. Otherwise: +/// +/// - in walk mode, we won't (and possibly can't) infer the type of `bound_exp`: just return `Dyn`. +/// - in typecheck mode, we will typecheck `bound_exp`: return a new unification variable to be +/// associated to `bound_exp`. +/// +/// As this function is always called in a context where an `ImportResolver` is present, expect it +/// passed in arguments. +/// +/// If the annotated type contains any wildcard: +/// +/// - in non strict mode, wildcards are assigned `Dyn`. +/// - in strict mode, the wildcard is typechecked, and we return the unification variable +/// corresponding to it. +fn binding_type<'ast>( + state: &mut State<'ast, '_>, + ast: &Node<'ast>, + ctxt: &Context<'ast>, + strict: bool, +) -> UnifType<'ast> { + apparent_or_infer( + state, + apparent_type( + state.ast_alloc, + ast, + Some(&ctxt.type_env), + Some(state.resolver), + ), + ctxt, + strict, + ) +} + +/// Same as `binding_type` but for record field definition. +fn field_type<'ast>( + state: &mut State<'ast, '_>, + field_def: &FieldDef<'ast>, + ctxt: &Context<'ast>, + strict: bool, +) -> UnifType<'ast> { + apparent_or_infer( + state, + field_def.apparent_type(state.ast_alloc, Some(&ctxt.type_env), Some(state.resolver)), + ctxt, + strict, + ) +} + +/// Either returns the exact type annotation extracted as an apparent type, or return a fresh +/// unification variable, for the type to be inferred by the typechecker, in enforce mode. +/// +/// In walk mode, returns the type as approximated by [`apparent_type`]. +fn apparent_or_infer<'ast>( + state: &mut State<'ast, '_>, + aty: ApparentType<'ast>, + ctxt: &Context<'ast>, + strict: bool, +) -> UnifType<'ast> { + match aty { + ApparentType::Annotated(ty) if strict => { + replace_wildcards_with_var(state.table, ctxt, state.wildcard_vars, ty) + } + ApparentType::Approximated(_) if strict => state.table.fresh_type_uvar(ctxt.var_level), + ty_apt => UnifType::from_apparent_type(ty_apt, &ctxt.term_env), + } +} + +/// Substitute wildcards in a type for their unification variable. +fn replace_wildcards_with_var<'ast>( + table: &mut UnifTable<'ast>, + ctxt: &Context<'ast>, + wildcard_vars: &mut Vec>, + ty: Type<'ast>, +) -> UnifType<'ast> { + fn replace_rrows<'ast>( + table: &mut UnifTable<'ast>, + ctxt: &Context<'ast>, + wildcard_vars: &mut Vec>, + rrows: RecordRows<'ast>, + ) -> UnifRecordRows<'ast> { + UnifRecordRows::concrete(rrows.0.map_state( + |ty, (table, wildcard_vars)| { + Box::new(replace_wildcards_with_var( + table, + ctxt, + wildcard_vars, + ty.clone(), + )) + }, + |rrows, (table, wildcard_vars)| { + Box::new(replace_rrows(table, ctxt, wildcard_vars, rrows.clone())) + }, + &mut (table, wildcard_vars), + )) + } + + fn replace_erows<'ast>( + table: &mut UnifTable<'ast>, + ctxt: &Context<'ast>, + wildcard_vars: &mut Vec>, + erows: EnumRows<'ast>, + ) -> UnifEnumRows<'ast> { + UnifEnumRows::concrete(erows.0.map_state( + |ty, (table, wildcard_vars)| { + Box::new(replace_wildcards_with_var( + table, + ctxt, + wildcard_vars, + ty.clone(), + )) + }, + |erows, (table, wildcard_vars)| { + Box::new(replace_erows(table, ctxt, wildcard_vars, erows.clone())) + }, + &mut (table, wildcard_vars), + )) + } + + match ty.typ { + TypeF::Wildcard(i) => get_wildcard_var(table, ctxt.var_level, wildcard_vars, i), + _ => UnifType::concrete(ty.typ.map_state( + |ty, (table, wildcard_vars)| { + Box::new(replace_wildcards_with_var( + table, + ctxt, + wildcard_vars, + ty.clone(), + )) + }, + |rrows, (table, wildcard_vars)| replace_rrows(table, ctxt, wildcard_vars, rrows), + // Enum rows contain neither wildcards nor contracts + |erows, (table, wildcard_vars)| replace_erows(table, ctxt, wildcard_vars, erows), + |ctr, _| (ctr, ctxt.term_env.clone()), + &mut (table, wildcard_vars), + )), + } +} + +/// Different kinds of apparent types (see [`apparent_type`]). +/// +/// Indicate the nature of an apparent type. In particular, when in enforce mode, the typechecker +/// throws away approximations as it can do better and infer the actual type of an expression. In +/// walk mode, however, the approximation is the best we can do. This type allows the caller of +/// `apparent_type` to determine which situation it is. +#[derive(Debug)] +pub enum ApparentType<'ast> { + /// The apparent type is given by a user-provided annotation. + Annotated(Type<'ast>), + /// The apparent type has been inferred from a simple expression. + Inferred(Type<'ast>), + /// The term is a variable and its type was retrieved from the typing environment. + FromEnv(UnifType<'ast>), + /// The apparent type wasn't trivial to determine, and an approximation (most of the time, + /// `Dyn`) has been returned. + Approximated(Type<'ast>), +} + +impl<'ast> TryConvert<'ast, ApparentType<'ast>> for Type<'ast> { + type Error = std::convert::Infallible; + + fn try_convert(alloc: &'ast AstAlloc, at: ApparentType<'ast>) -> Result { + Ok(match at { + ApparentType::Annotated(ty) if has_wildcards(&ty) => Type::from(TypeF::Dyn), + ApparentType::Annotated(ty) + | ApparentType::Inferred(ty) + | ApparentType::Approximated(ty) => ty, + ApparentType::FromEnv(uty) => Type::try_convert(alloc, uty) + .ok() + .unwrap_or(Type::from(TypeF::Dyn)), + }) + } +} + +// Since there's already an enum named `ApparentType`, we can't use it as a trait name. +trait HasApparentType<'ast> { + fn apparent_type( + &self, + ast_alloc: &'ast AstAlloc, + env: Option<&TypeEnv<'ast>>, + resolver: Option<&dyn ImportResolver>, + ) -> ApparentType<'ast>; +} + +impl<'ast> HasApparentType<'ast> for FieldDef<'ast> { + // Return the apparent type of a field, by first looking at the type annotation, if any, then at + // the contracts annotation, and if there is none, fall back to the apparent type of the value. If + // there is no value, `Approximated(Dyn)` is returned. + fn apparent_type( + &self, + ast_alloc: &'ast AstAlloc, + env: Option<&TypeEnv<'ast>>, + resolver: Option<&dyn ImportResolver>, + ) -> ApparentType<'ast> { + self.metadata + .annotation + .first() + .cloned() + .map(ApparentType::Annotated) + .or_else(|| { + self.value + .as_ref() + .map(|v| apparent_type(ast_alloc, &v.node, env, resolver)) + }) + .unwrap_or(ApparentType::Approximated(Type::from(TypeF::Dyn))) + } +} + +/// Determine the apparent type of a let-bound expression. +/// +/// When a let-binding `let x = bound_exp in body` is processed, the type of `bound_exp` must be +/// determined in order to be bound to the variable `x` in the typing environment. +/// Then, future occurrences of `x` can be given this type when used in a statically typed block. +/// +/// The role of `apparent_type` is precisely to determine the type of `bound_exp`: +/// - if `bound_exp` is annotated by a type or contract annotation, return the user-provided type, +/// unless that type is a wildcard. +/// - if `bound_exp` is a constant (string, number, boolean or symbol) which type can be deduced +/// directly without unfolding the expression further, return the corresponding exact type. +/// - if `bound_exp` is an array, return `Array Dyn`. +/// - if `bound_exp` is a resolved import, return the apparent type of the imported term. Returns +/// `Dyn` if the resolver is not passed as a parameter to the function. +/// - Otherwise, return an approximation of the type (currently `Dyn`, but could be more precise in +/// the future, such as `Dyn -> Dyn` for functions, `{ | Dyn}` for records, and so on). +pub fn apparent_type<'ast>( + ast_alloc: &'ast AstAlloc, + node: &Node<'ast>, + env: Option<&TypeEnv<'ast>>, + resolver: Option<&dyn ImportResolver>, +) -> ApparentType<'ast> { + use crate::files::FileId; + + // Check the apparent type while avoiding cycling through direct imports loops. Indeed, + // `apparent_type` tries to see through imported terms. But doing so can lead to an infinite + // loop, for example with the trivial program which imports itself: + // + // ```nickel + // # foo.ncl + // import "foo.ncl" + // ``` + // + // The following function thus remembers what imports have been seen already, and simply + // returns `Dyn` if it detects a cycle. + fn apparent_type_check_cycle<'ast>( + ast_alloc: &'ast AstAlloc, + node: &Node<'ast>, + env: Option<&TypeEnv<'ast>>, + resolver: Option<&dyn ImportResolver>, + _imports_seen: HashSet, + ) -> ApparentType<'ast> { + match node { + Node::Annotated { annot, inner } => annot + .first() + .map(|typ| ApparentType::Annotated(typ.clone())) + .unwrap_or_else(|| apparent_type(ast_alloc, &inner.node, env, resolver)), + Node::Number(_) => ApparentType::Inferred(Type::from(TypeF::Number)), + Node::Bool(_) => ApparentType::Inferred(Type::from(TypeF::Bool)), + Node::String(_) | Node::StringChunks(_) => { + ApparentType::Inferred(Type::from(TypeF::String)) + } + Node::Array(_) => ApparentType::Approximated(Type::from(TypeF::Array( + ast_alloc.alloc(Type::from(TypeF::Dyn)), + ))), + Node::Var(id) => env + .and_then(|envs| envs.get(&id.ident()).cloned()) + .map(ApparentType::FromEnv) + .unwrap_or(ApparentType::Approximated(Type::from(TypeF::Dyn))), + //TODO: import + // Node::ResolvedImport(file_id) => match resolver { + // Some(r) if !imports_seen.contains(file_id) => { + // imports_seen.insert(*file_id); + // + // let t = r + // .get(*file_id) + // .expect("Internal error: resolved import not found during typechecking."); + // apparent_type_check_cycle(&t.term, env, Some(r), imports_seen) + // } + // _ => ApparentType::Approximated(Type::from(TypeF::Dyn)), + // }, + _ => ApparentType::Approximated(Type::from(TypeF::Dyn)), + } + } + + apparent_type_check_cycle(ast_alloc, node, env, resolver, HashSet::new()) +} + +/// Infer the type of a non-annotated record by recursing inside gathering the apparent type of the +/// fields. It's currently used essentially to type the stdlib. +/// +/// # Parameters +/// +/// - `ast`: the term to infer a type for +/// - `term_env`: the current term environment, used for contracts equality +/// - `max_depth`: the max recursion depth. `infer_record_type` descends into sub-records, as long +/// as it only encounters nested record literals. `max_depth` is used to control this behavior +/// and cap the work that `infer_record_type` might do. +/// +/// # Preconditions +/// +/// The recourd shouldn't have any dynamic fields. They are ignored anyway, so if the record has +/// some, the inferred type could be wrong. +pub fn infer_record_type<'ast>( + ast_alloc: &'ast AstAlloc, + ast: &Ast<'ast>, + term_env: &TermEnv<'ast>, + max_depth: u8, +) -> UnifType<'ast> { + match &ast.node { + Node::Record(record) if max_depth > 0 => UnifType::from(TypeF::Record( + UnifRecordRows::concrete(record.field_defs.iter().fold( + RecordRowsF::Empty, + |rtype, field_def| { + if let Some(id) = field_def.path.first().unwrap().try_as_ident() { + let uty = match field_def.apparent_type(ast_alloc, None, None) { + ApparentType::Annotated(ty) => UnifType::from_type(ty, term_env), + ApparentType::FromEnv(uty) => uty, + // Since we haven't reached max_depth yet, and the type is only + // approximated, we try to recursively infer a better type. + ApparentType::Inferred(ty) | ApparentType::Approximated(ty) => { + field_def + .value + .as_ref() + .map(|v| { + infer_record_type(ast_alloc, v, term_env, max_depth - 1) + }) + .unwrap_or(UnifType::from_type(ty, term_env)) + } + }; + + RecordRowsF::Extend { + row: UnifRecordRow { + id, + typ: Box::new(uty), + }, + tail: Box::new(rtype.into()), + } + } else { + rtype + } + }, + )), + )), + node => UnifType::from_apparent_type( + apparent_type(ast_alloc, node, None, None), + &TermEnv::new(), + ), + } +} + +/// Deeply check whether a type contains a wildcard. +fn has_wildcards(ty: &Type<'_>) -> bool { + ty.find_map(&mut |ty: &Type| ty.typ.is_wildcard().then_some(())) + .is_some() +} + +/// Type of the parameter controlling instantiation of foralls. +/// +/// See [`instantiate_foralls`]. +#[derive(Copy, Clone, Debug, PartialEq)] +enum ForallInst { + Constant, + UnifVar, +} + +/// Instantiate the type variables which are quantified in head position with either unification +/// variables or type constants. +/// +/// For example, if `inst` is `Constant`, `forall a. forall b. a -> (forall c. b -> c)` is +/// transformed to `cst1 -> (forall c. cst2 -> c)` where `cst1` and `cst2` are fresh type +/// constants. This is used when typechecking `forall`s: all quantified type variables in head +/// position are replaced by rigid type constants, and the term is then typechecked normally. As +/// these constants cannot be unified with anything, this forces all the occurrences of a type +/// variable to be the same type. +/// +/// # Parameters +/// +/// - `state`: the unification state +/// - `ty`: the polymorphic type to instantiate +/// - `inst`: the type of instantiation, either by a type constant or by a unification variable +fn instantiate_foralls<'ast>( + state: &mut State<'ast, '_>, + ctxt: &mut Context<'ast>, + mut ty: UnifType<'ast>, + inst: ForallInst, +) -> UnifType<'ast> { + ty = ty.into_root(state.table); + + // We are instantiating a polymorphic type: it's precisely the place where we have to increment + // the variable level, to prevent already existing unification variables to unify with the + // rigid type variables introduced here. + // + // As this function can be called on monomorphic types, we only increment the level when we + // really introduce a new block of rigid type variables. + if matches!( + ty, + UnifType::Concrete { + typ: TypeF::Forall { .. }, + .. + } + ) { + ctxt.var_level.incr(); + } + + while let UnifType::Concrete { + typ: TypeF::Forall { + var, + var_kind, + body, + }, + .. + } = ty + { + let kind: VarKindDiscriminant = (&var_kind).into(); + + match var_kind { + VarKind::Type => { + let fresh_uid = state.table.fresh_type_var_id(ctxt.var_level); + let uvar = match inst { + ForallInst::Constant => UnifType::Constant(fresh_uid), + ForallInst::UnifVar => UnifType::UnifVar { + id: fresh_uid, + init_level: ctxt.var_level, + }, + }; + state.names.insert((fresh_uid, kind), var.ident()); + ty = body.subst(&var, &uvar); + } + VarKind::RecordRows { excluded } => { + let fresh_uid = state.table.fresh_rrows_var_id(ctxt.var_level); + let uvar = match inst { + ForallInst::Constant => UnifRecordRows::Constant(fresh_uid), + ForallInst::UnifVar => UnifRecordRows::UnifVar { + id: fresh_uid, + init_level: ctxt.var_level, + }, + }; + state.names.insert((fresh_uid, kind), var.ident()); + ty = body.subst(&var, &uvar); + + if inst == ForallInst::UnifVar { + state.constr.insert(fresh_uid, excluded); + } + } + VarKind::EnumRows { excluded } => { + let fresh_uid = state.table.fresh_erows_var_id(ctxt.var_level); + let uvar = match inst { + ForallInst::Constant => UnifEnumRows::Constant(fresh_uid), + ForallInst::UnifVar => UnifEnumRows::UnifVar { + id: fresh_uid, + init_level: ctxt.var_level, + }, + }; + state.names.insert((fresh_uid, kind), var.ident()); + ty = body.subst(&var, &uvar); + + if inst == ForallInst::UnifVar { + state.constr.insert(fresh_uid, excluded); + } + } + }; + } + + ty +} + +/// Get the type unification variable associated with a given wildcard ID. +fn get_wildcard_var<'ast>( + table: &mut UnifTable<'ast>, + var_level: VarLevel, + wildcard_vars: &mut Vec>, + id: VarId, +) -> UnifType<'ast> { + // If `id` is not in `wildcard_vars`, populate it with fresh vars up to `id` + if id >= wildcard_vars.len() { + wildcard_vars.extend((wildcard_vars.len()..=id).map(|_| table.fresh_type_uvar(var_level))); + } + wildcard_vars[id].clone() +} + +/// Convert a mapping from wildcard ID to type var, into a mapping from wildcard ID to concrete +/// type. +fn wildcard_vars_to_type<'ast>( + alloc: &'ast AstAlloc, + wildcard_vars: Vec>, + table: &UnifTable<'ast>, +) -> Wildcards<'ast> { + wildcard_vars + .into_iter() + .map(|var| var.into_type(alloc, table)) + .collect() +} + +/// A visitor trait for receiving callbacks during typechecking. +pub trait TypecheckVisitor<'ast> { + /// Record the type of a term. + /// + /// It's possible for a single term to be visited multiple times, for example, if type + /// inference kicks in. + fn visit_term(&mut self, _ast: &Ast<'ast>, _ty: UnifType<'ast>) {} + + /// Record the type of a bound identifier. + fn visit_ident(&mut self, _ident: &LocIdent, _new_type: UnifType<'ast>) {} +} + +/// A do-nothing `TypeCheckVisitor` for when you don't want one. +impl TypecheckVisitor<'_> for () {} diff --git a/core/src/bytecode/typecheck/operation.rs b/core/src/bytecode/typecheck/operation.rs new file mode 100644 index 0000000000..d651a356de --- /dev/null +++ b/core/src/bytecode/typecheck/operation.rs @@ -0,0 +1,633 @@ +//! Typing of primitive operations. +use super::*; +use crate::{ + bytecode::ast::{builder, primop::PrimOp, AstAlloc}, + error::TypecheckError, + label::{Polarity, TypeVarData}, + typ::TypeF, +}; + +use crate::{mk_buty_arrow, mk_buty_enum, mk_buty_record}; + +pub trait PrimOpType { + fn primop_type<'ast>( + &self, + state: &mut State<'ast, '_>, + var_level: VarLevel, + ) -> Result<(Vec>, UnifType<'ast>), TypecheckError>; +} + +impl PrimOpType for PrimOp { + fn primop_type<'ast>( + &self, + state: &mut State<'ast, '_>, + var_level: VarLevel, + ) -> Result<(Vec>, UnifType<'ast>), TypecheckError> { + Ok(match self { + // Dyn -> [| 'Number, 'Bool, 'String, 'Enum, 'Function, 'Array, 'Record, 'Label, + // 'ForeignId, 'Type, 'Other |] + PrimOp::Typeof => ( + vec![mk_uniftype::dynamic()], + mk_buty_enum!( + "Number", + "Bool", + "String", + "Enum", + "Function", + "CustomContract", + "Array", + "Record", + "Label", + "ForeignId", + "Type", + "Other" + ), + ), + // Bool -> Bool -> Bool + PrimOp::BoolAnd | PrimOp::BoolOr => ( + vec![mk_uniftype::bool()], + mk_buty_arrow!(TypeF::Bool, TypeF::Bool), + ), + // Bool -> Bool + PrimOp::BoolNot => (vec![mk_uniftype::bool()], mk_uniftype::bool()), + // forall a. Dyn -> a + PrimOp::Blame => { + let res = state.table.fresh_type_uvar(var_level); + + (vec![mk_uniftype::dynamic()], res) + } + // Dyn -> Polarity + PrimOp::LabelPol => ( + vec![mk_uniftype::dynamic()], + mk_buty_enum!("Positive", "Negative"), + ), + // forall rows. [| ; rows |] -> [| id ; rows |] + PrimOp::EnumEmbed(id) => { + let row_var_id = state.table.fresh_erows_var_id(var_level); + let row = UnifEnumRows::UnifVar { + id: row_var_id, + init_level: var_level, + }; + + let domain = mk_buty_enum!(; row.clone()); + let codomain = mk_buty_enum!(*id; row); + + (vec![domain], codomain) + } + // Morally, Label -> Label + // Dyn -> Dyn + PrimOp::LabelFlipPol + | PrimOp::LabelGoDom + | PrimOp::LabelGoCodom + | PrimOp::LabelGoArray + | PrimOp::LabelGoDict => (vec![mk_uniftype::dynamic()], mk_uniftype::dynamic()), + // forall rows a. { id: a | rows} -> a + PrimOp::RecordStatAccess(id) => { + let rows = state.table.fresh_rrows_uvar(var_level); + let res = state.table.fresh_type_uvar(var_level); + + (vec![mk_buty_record!((*id, res.clone()); rows)], res) + } + // forall a b. Array a -> (a -> b) -> Array b + PrimOp::ArrayMap => { + let a = state.table.fresh_type_uvar(var_level); + let b = state.table.fresh_type_uvar(var_level); + + let f_type = mk_buty_arrow!(a.clone(), b.clone()); + ( + vec![mk_uniftype::array(a)], + mk_buty_arrow!(f_type, mk_uniftype::array(b)), + ) + } + // forall a. Num -> (Num -> a) -> Array a + PrimOp::ArrayGen => { + let a = state.table.fresh_type_uvar(var_level); + + let f_type = mk_buty_arrow!(TypeF::Number, a.clone()); + ( + vec![mk_uniftype::num()], + mk_buty_arrow!(f_type, mk_uniftype::array(a)), + ) + } + // forall a b. { _ : a} -> (Str -> a -> b) -> { _ : b } + PrimOp::RecordMap => { + // Assuming f has type Str -> a -> b, + // this has type Dict(a) -> Dict(b) + + let a = state.table.fresh_type_uvar(var_level); + let b = state.table.fresh_type_uvar(var_level); + + let f_type = mk_buty_arrow!(TypeF::String, a.clone(), b.clone()); + ( + vec![mk_uniftype::dict(a)], + mk_buty_arrow!(f_type, mk_uniftype::dict(b)), + ) + } + // forall a b. a -> b -> b + PrimOp::Seq | PrimOp::DeepSeq => { + let fst = state.table.fresh_type_uvar(var_level); + let snd = state.table.fresh_type_uvar(var_level); + + (vec![fst], mk_buty_arrow!(snd.clone(), snd)) + } + // forall a. Array a -> Num + PrimOp::ArrayLength => { + let ty_elt = state.table.fresh_type_uvar(var_level); + (vec![mk_uniftype::array(ty_elt)], mk_uniftype::num()) + } + // forall a. { _: a } -> Array Str + PrimOp::RecordFields(_) => { + let ty_a = state.table.fresh_type_uvar(var_level); + + ( + vec![mk_uniftype::dict(ty_a)], + mk_uniftype::array(mk_uniftype::str()), + ) + } + // forall a. { _: a } -> Array a + PrimOp::RecordValues => { + let ty_a = state.table.fresh_type_uvar(var_level); + + ( + vec![mk_uniftype::dict(ty_a.clone())], + mk_uniftype::array(ty_a), + ) + } + // Str -> Str + PrimOp::StringTrim => (vec![mk_uniftype::str()], mk_uniftype::str()), + // Str -> Array Str + PrimOp::StringChars => ( + vec![mk_uniftype::str()], + mk_uniftype::array(mk_uniftype::str()), + ), + // Str -> Str + PrimOp::StringUppercase => (vec![mk_uniftype::str()], mk_uniftype::str()), + // Str -> Str + PrimOp::StringLowercase => (vec![mk_uniftype::str()], mk_uniftype::str()), + // Str -> Num + PrimOp::StringLength => (vec![mk_uniftype::str()], mk_uniftype::num()), + // Dyn -> Str + PrimOp::ToString => (vec![mk_uniftype::dynamic()], mk_uniftype::str()), + // Str -> Num + PrimOp::NumberFromString => (vec![mk_uniftype::str()], mk_uniftype::num()), + // Str -> < | a> for a rigid type variable a + PrimOp::EnumFromString => ( + vec![mk_uniftype::str()], + mk_buty_enum!(; state.table.fresh_erows_const(var_level)), + ), + // Str -> Str -> Bool + PrimOp::StringIsMatch => ( + vec![mk_uniftype::str()], + mk_buty_arrow!(mk_uniftype::str(), mk_uniftype::bool()), + ), + // Str -> Str -> {matched: Str, index: Num, groups: Array Str} + PrimOp::StringFind => ( + vec![mk_uniftype::str()], + mk_buty_arrow!( + mk_uniftype::str(), + mk_buty_record!( + ("matched", TypeF::String), + ("index", TypeF::Number), + ("groups", mk_uniftype::array(TypeF::String)) + ) + ), + ), + // String -> String -> Array { matched: String, index: Number, groups: Array String } + PrimOp::StringFindAll => ( + vec![mk_uniftype::str()], + mk_buty_arrow!( + mk_uniftype::str(), + mk_uniftype::array(mk_buty_record!( + ("matched", TypeF::String), + ("index", TypeF::Number), + ("groups", mk_uniftype::array(TypeF::String)) + )) + ), + ), + // Dyn -> Dyn + PrimOp::Force { .. } => (vec![mk_uniftype::dynamic()], mk_uniftype::dynamic()), + PrimOp::RecordEmptyWithTail => (vec![mk_uniftype::dynamic()], mk_uniftype::dynamic()), + // forall a. { _ : a} -> { _ : a } + PrimOp::RecordFreeze => { + let dict = mk_uniftype::dict(state.table.fresh_type_uvar(var_level)); + (vec![dict.clone()], dict) + } + // forall a. Str -> a -> a + PrimOp::Trace => { + let ty = state.table.fresh_type_uvar(var_level); + (vec![mk_uniftype::str()], mk_buty_arrow!(ty.clone(), ty)) + } + // Morally: Lbl -> Lbl + // Actual: Dyn -> Dyn + PrimOp::LabelPushDiag => (vec![mk_uniftype::dynamic()], mk_uniftype::dynamic()), + // Str -> Dyn + #[cfg(feature = "nix-experimental")] + PrimOp::EvalNix => (vec![mk_uniftype::str()], mk_uniftype::dynamic()), + // Because the tag isn't fixed, we can't really provide a proper static type for this + // primop. + // This isn't a problem, as this operator is mostly internal and pattern matching should be + // used to destructure enum variants. + // Dyn -> Dyn + PrimOp::EnumGetArg => (vec![mk_uniftype::dynamic()], mk_uniftype::dynamic()), + // String -> (Dyn -> Dyn) + PrimOp::EnumMakeVariant => ( + vec![mk_uniftype::str()], + mk_uniftype::arrow(mk_uniftype::dynamic(), mk_uniftype::dynamic()), + ), + // Same as `EnumGetArg` just above. + // Dyn -> Dyn + PrimOp::EnumGetTag => (vec![mk_uniftype::dynamic()], mk_uniftype::dynamic()), + // Note that is_variant breaks parametricity, so it can't get a polymorphic type. + // Dyn -> Bool + PrimOp::EnumIsVariant => (vec![mk_uniftype::dynamic()], mk_uniftype::bool()), + // // [crate::term::PrimOp::PatternBranch] shouldn't appear anywhere in actual code, because its + // // second argument can't be properly typechecked: it has unbound variables. However, it's + // // not hard to come up with a vague working type for it, so we do. + // // forall a. {_ : a} -> Dyn -> Dyn + // PrimOp::PatternBranch => { + // let ty_elt = state.table.fresh_type_uvar(var_level); + // ( + // mk_uniftype::dict(ty_elt), + // mk_buty_arrow!(mk_uniftype::dynamic(), mk_uniftype::dynamic()), + // ) + // } + // -> Dyn + PrimOp::ContractCustom => ( + vec![custom_contract_type(state.ast_alloc)], + mk_uniftype::dynamic(), + ), + // Number -> Number + PrimOp::NumberCos + | PrimOp::NumberSin + | PrimOp::NumberTan + | PrimOp::NumberArcCos + | PrimOp::NumberArcSin + | PrimOp::NumberArcTan => (vec![mk_uniftype::num()], mk_uniftype::num()), + + // Binary ops + + // Number -> Number -> Number + PrimOp::Plus | PrimOp::Sub | PrimOp::Mult | PrimOp::Div | PrimOp::Modulo => ( + vec![mk_uniftype::num(), mk_uniftype::num()], + mk_uniftype::num(), + ), + // Sym -> Dyn -> Dyn -> Dyn + PrimOp::Seal => ( + vec![mk_uniftype::sym(), mk_uniftype::dynamic()], + mk_buty_arrow!(TypeF::Dyn, TypeF::Dyn), + ), + // String -> String -> String + PrimOp::StringConcat => ( + vec![mk_uniftype::str(), mk_uniftype::str()], + mk_uniftype::str(), + ), + // Ideally: Contract -> Label -> Dyn -> Dyn + // Currently: Dyn -> Dyn -> (Dyn -> Dyn) + PrimOp::ContractApply => ( + vec![mk_uniftype::dynamic(), mk_uniftype::dynamic()], + mk_buty_arrow!(mk_uniftype::dynamic(), mk_uniftype::dynamic()), + ), + // Ideally: Contract -> Label -> Dyn -> + // Currently: Dyn -> Dyn -> (Dyn -> ) + PrimOp::ContractCheck => ( + vec![mk_uniftype::dynamic(), mk_uniftype::dynamic()], + mk_buty_arrow!( + mk_uniftype::dynamic(), + custom_contract_ret_type(state.ast_alloc) + ), + ), + // Ideally: -> Label -> Dyn + // Currently: -> Dyn -> Dyn + PrimOp::LabelWithErrorData => ( + vec![error_data_type(state.ast_alloc), mk_uniftype::dynamic()], + mk_uniftype::dynamic(), + ), + // Sym -> Dyn -> Dyn -> Dyn + PrimOp::Unseal => ( + vec![mk_uniftype::sym(), mk_uniftype::dynamic()], + mk_buty_arrow!(TypeF::Dyn, TypeF::Dyn), + ), + // forall a b. a -> b -> Bool + PrimOp::Eq => ( + vec![ + state.table.fresh_type_uvar(var_level), + state.table.fresh_type_uvar(var_level), + ], + mk_uniftype::bool(), + ), + // Num -> Num -> Bool + PrimOp::LessThan | PrimOp::LessOrEq | PrimOp::GreaterThan | PrimOp::GreaterOrEq => ( + vec![mk_uniftype::num(), mk_uniftype::num()], + mk_uniftype::bool(), + ), + // Str -> Dyn -> Dyn + PrimOp::LabelGoField => ( + vec![mk_uniftype::str(), mk_uniftype::dynamic()], + mk_uniftype::dynamic(), + ), + // forall a. Str -> { _ : a} -> a + PrimOp::RecordGet => { + let res = state.table.fresh_type_uvar(var_level); + + ( + vec![mk_uniftype::str(), mk_uniftype::dict(res.clone())], + res, + ) + } + // forall a. Str -> {_ : a} -> a -> {_ : a} + PrimOp::RecordInsert(_) => { + let res = state.table.fresh_type_uvar(var_level); + ( + vec![mk_uniftype::str(), mk_uniftype::dict(res.clone())], + mk_buty_arrow!(res.clone(), mk_uniftype::dict(res)), + ) + } + // forall a. Str -> { _ : a } -> { _ : a} + PrimOp::RecordRemove(_) => { + let res = state.table.fresh_type_uvar(var_level); + ( + vec![mk_uniftype::str(), mk_uniftype::dict(res.clone())], + mk_uniftype::dict(res), + ) + } + // forall a. Str -> {_: a} -> Bool + PrimOp::RecordHasField(_) => { + let ty_elt = state.table.fresh_type_uvar(var_level); + ( + vec![mk_uniftype::str(), mk_uniftype::dict(ty_elt)], + mk_uniftype::bool(), + ) + } + // forall a. Str -> {_: a} -> Bool + PrimOp::RecordFieldIsDefined(_) => { + let ty_elt = state.table.fresh_type_uvar(var_level); + ( + vec![mk_uniftype::str(), mk_uniftype::dict(ty_elt)], + mk_uniftype::bool(), + ) + } + // forall a. Array a -> Array a -> Array a + PrimOp::ArrayConcat => { + let ty_elt = state.table.fresh_type_uvar(var_level); + let ty_array = mk_uniftype::array(ty_elt); + (vec![ty_array.clone(), ty_array.clone()], ty_array) + } + // forall a. Array a -> Num -> a + PrimOp::ArrayAt => { + let ty_elt = state.table.fresh_type_uvar(var_level); + ( + vec![mk_uniftype::array(ty_elt.clone()), mk_uniftype::num()], + ty_elt, + ) + } + // Dyn -> Dyn -> Dyn + PrimOp::Merge(_) => ( + vec![mk_uniftype::dynamic(), mk_uniftype::dynamic()], + mk_uniftype::dynamic(), + ), + // -> Str -> Str + PrimOp::Hash => ( + vec![ + mk_buty_enum!("Md5", "Sha1", "Sha256", "Sha512"), + mk_uniftype::str(), + ], + mk_uniftype::str(), + ), + // forall a. -> a -> Str + PrimOp::Serialize => { + let ty_input = state.table.fresh_type_uvar(var_level); + ( + vec![mk_buty_enum!("Json", "Yaml", "Toml"), ty_input], + mk_uniftype::str(), + ) + } + // -> Str -> Dyn + PrimOp::Deserialize => ( + vec![mk_buty_enum!("Json", "Yaml", "Toml"), mk_uniftype::str()], + mk_uniftype::dynamic(), + ), + // Num -> Num -> Num + PrimOp::NumberArcTan2 | PrimOp::NumberLog | PrimOp::Pow => ( + vec![mk_uniftype::num(), mk_uniftype::num()], + mk_uniftype::num(), + ), + // Str -> Str -> Bool + PrimOp::StringContains => ( + vec![mk_uniftype::str(), mk_uniftype::str()], + mk_uniftype::bool(), + ), + // Str -> Str -> + PrimOp::StringCompare => ( + vec![mk_uniftype::str(), mk_uniftype::str()], + mk_buty_enum!("Lesser", "Equal", "Greater"), + ), + // Str -> Str -> Array Str + PrimOp::StringSplit => ( + vec![mk_uniftype::str(), mk_uniftype::str()], + mk_uniftype::array(TypeF::String), + ), + // The first argument is a contract, the second is a label. + // forall a. Dyn -> Dyn -> Array a -> Array a + PrimOp::ContractArrayLazyApp => { + let ty_elt = state.table.fresh_type_uvar(var_level); + let ty_array = mk_uniftype::array(ty_elt); + ( + vec![mk_uniftype::dynamic(), mk_uniftype::dynamic()], + mk_buty_arrow!(ty_array.clone(), ty_array), + ) + } + // The first argument is a label, the third is a contract. + // forall a. Dyn -> {_: a} -> Dyn -> {_: a} + PrimOp::ContractRecordLazyApp => { + let ty_field = state.table.fresh_type_uvar(var_level); + let ty_dict = mk_uniftype::dict(ty_field); + ( + vec![mk_uniftype::dynamic(), ty_dict.clone()], + mk_buty_arrow!(mk_uniftype::dynamic(), ty_dict), + ) + } + // Morally: Str -> Lbl -> Lbl + // Actual: Str -> Dyn -> Dyn + PrimOp::LabelWithMessage => ( + vec![mk_uniftype::str(), mk_uniftype::dynamic()], + mk_uniftype::dynamic(), + ), + // Morally: Array Str -> Lbl -> Lbl + // Actual: Array Str -> Dyn -> Dyn + PrimOp::LabelWithNotes => ( + vec![mk_uniftype::array(TypeF::String), mk_uniftype::dynamic()], + mk_uniftype::dynamic(), + ), + // Morally: Str -> Lbl -> Lbl + // Actual: Str -> Dyn -> Dyn + PrimOp::LabelAppendNote => ( + vec![mk_uniftype::str(), mk_uniftype::dynamic()], + mk_uniftype::dynamic(), + ), + // Morally: Sym -> Lbl -> TypeVarData + // Actual: Sym -> Dyn -> TypeVarData + PrimOp::LabelLookupTypeVar => ( + vec![mk_uniftype::sym(), mk_uniftype::dynamic()], + TypeVarData::unif_type(), + ), + // {_ : a} -> {_ : a} + // -> { + // left_only: {_ : a}, + // right_only: {_ : a}, + // left_center: {_ : a}, + // right_center: {_ : a}, + // } + PrimOp::RecordSplitPair => { + let elt = state.table.fresh_type_uvar(var_level); + let dict = mk_uniftype::dict(elt.clone()); + + let split_result = mk_buty_record!( + ("left_only", dict.clone()), + ("right_only", dict.clone()), + ("left_center", dict.clone()), + ("right_center", dict.clone()) + ); + + (vec![dict.clone(), dict], split_result) + } + // {_ : a} -> {_ : a} -> {_ : a} + PrimOp::RecordDisjointMerge => { + let elt = state.table.fresh_type_uvar(var_level); + let dict = mk_uniftype::dict(elt.clone()); + + (vec![dict.clone(), dict.clone()], dict) + } + // Str -> Str -> Str -> Str + PrimOp::StringReplace | PrimOp::StringReplaceRegex => ( + vec![mk_uniftype::str(), mk_uniftype::str(), mk_uniftype::str()], + mk_uniftype::str(), + ), + // Str -> Num -> Num -> Str + PrimOp::StringSubstr => ( + vec![mk_uniftype::str(), mk_uniftype::num(), mk_uniftype::num()], + mk_uniftype::str(), + ), + // Dyn -> Dyn -> Dyn -> Dyn -> Dyn + PrimOp::RecordSealTail => ( + vec![ + mk_uniftype::dynamic(), + mk_uniftype::dynamic(), + mk_uniftype::dict(mk_uniftype::dynamic()), + mk_uniftype::dict(mk_uniftype::dynamic()), + ], + mk_uniftype::dynamic(), + ), + // Dyn -> Dyn -> Dyn -> Dyn + PrimOp::RecordUnsealTail => ( + vec![ + mk_uniftype::dynamic(), + mk_uniftype::dynamic(), + mk_uniftype::dict(mk_uniftype::dynamic()), + ], + mk_uniftype::dynamic(), + ), + // Num -> Num -> Array a -> Array a + PrimOp::ArraySlice => { + let element_type = state.table.fresh_type_uvar(var_level); + + ( + vec![ + mk_uniftype::num(), + mk_uniftype::num(), + mk_uniftype::array(element_type.clone()), + ], + mk_uniftype::array(element_type), + ) + } + // Morally: Label -> Record -> Record -> Record + // Actual: Dyn -> Dyn -> Dyn -> Dyn + PrimOp::MergeContract => ( + vec![ + mk_uniftype::dynamic(), + mk_uniftype::dynamic(), + mk_uniftype::dynamic(), + ], + mk_uniftype::dynamic(), + ), + // Morally: Sym -> Polarity -> Lbl -> Lbl + // Actual: Sym -> Polarity -> Dyn -> Dyn + PrimOp::LabelInsertTypeVar => ( + vec![ + mk_uniftype::sym(), + Polarity::unif_type(), + mk_uniftype::dynamic(), + ], + mk_uniftype::dynamic(), + ), + }) + } +} + +// pub fn get_nop_type( +// state: &mut State, +// var_level: VarLevel, +// op: &NAryOp, +// ) -> Result<(Vec, UnifType), TypecheckError> { +// Ok(match op { +// }) +// } + +/// The type of a custom contract. In nickel syntax, the returned type is: +/// +/// ```nickel +/// Dyn -> Dyn -> [| +/// 'Ok Dyn, +/// 'Error { message | String | optional, notes | Array String | optional } +/// |] +/// ``` +pub fn custom_contract_type(alloc: &AstAlloc) -> UnifType<'_> { + mk_buty_arrow!( + mk_uniftype::dynamic(), + mk_uniftype::dynamic(), + custom_contract_ret_type(alloc) + ) +} + +/// The return type of a custom contract. See [custom_contract_type]. +/// +/// ```nickel +/// [| +/// 'Ok Dyn, +/// 'Error { message | String | optional, notes | Array String | optional } +/// |] +/// ``` +pub fn custom_contract_ret_type(alloc: &AstAlloc) -> UnifType<'_> { + mk_buty_enum!( + ("Ok", mk_uniftype::dynamic()), + ("Error", error_data_type(alloc)) + ) +} + +/// The type of error data that can be returned by a custom contract: +/// +/// ```nickel +/// { +/// message +/// | String +/// | optional, +/// notes +/// | Array String +/// | optional +/// } +/// ``` +fn error_data_type(alloc: &AstAlloc) -> UnifType<'_> { + let error_data = builder::Record::new() + .field("message") + .optional(true) + .contract(TypeF::String) + .no_value(alloc) + .field("notes") + .contract(TypeF::Array(alloc.alloc(Type::from(TypeF::String)))) + .optional(true) + .no_value(alloc); + + UnifType::concrete(TypeF::Contract(( + alloc.alloc(error_data.build(alloc)), + TermEnv::new(), + ))) +} diff --git a/core/src/bytecode/typecheck/pattern.rs b/core/src/bytecode/typecheck/pattern.rs new file mode 100644 index 0000000000..87982c401b --- /dev/null +++ b/core/src/bytecode/typecheck/pattern.rs @@ -0,0 +1,655 @@ +use crate::{ + bytecode::ast::pattern::*, + error::TypecheckError, + identifier::{Ident, LocIdent}, + mk_buty_record_row, + typ::{EnumRowsF, RecordRowsF, TypeF}, +}; + +use super::*; + +/// A list of pattern variables and their associated type. +pub type TypeBindings<'ast> = Vec<(LocIdent, UnifType<'ast>)>; + +/// An element of a pattern path. A pattern path is a sequence of steps that can be used to +/// uniquely locate a sub-pattern within a pattern. +/// +/// For example, in the pattern `{foo={bar='Baz arg}}`: +/// +/// - The path of the full pattern within itself is the empty path. +/// - The path of the `arg` pattern is `[Field("foo"), Field("bar"), Variant]`. +#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)] +pub enum PatternPathElem { + Field(Ident), + Array(usize), + Variant, +} + +pub type PatternPath = Vec; + +/// The working state of [PatternType::pattern_types_inj]. +pub(super) struct PatTypeState<'ast, 'a> { + /// The list of pattern variables introduced so far and their inferred type. + bindings: &'a mut TypeBindings<'ast>, + /// The list of enum row tail variables that are left open when typechecking a match expression. + enum_open_tails: &'a mut Vec<(PatternPath, UnifEnumRows<'ast>)>, + /// Record, as a field path, the position of wildcard pattern encountered in a record. This + /// impact the final type of the pattern, as a wildcard pattern makes the corresponding row + /// open. + wildcard_pat_paths: &'a mut HashSet, +} + +/// Return value of [PatternTypes::pattern_types], which stores the overall type of a pattern, +/// together with the type of its bindings and additional information for the typechecking of match +/// expressions. +#[derive(Debug, Clone)] +pub struct PatternTypeData<'ast, T> { + /// The type of the pattern. + pub typ: T, + /// A list of pattern variables and their associated type. + pub bindings: Vec<(LocIdent, UnifType<'ast>)>, + /// A list of enum row tail variables that are left open when typechecking a match expression. + /// + /// Those variables (or their descendent in a row type) might need to be closed after the type + /// of all the patterns of a match expression have been unified, depending on the presence of a + /// wildcard pattern. The path of the corresponding sub-pattern is stored as well, since enum + /// patterns in different positions might need different treatment. For example: + /// + /// ```nickel + /// match { + /// 'Foo ('Bar x) => , + /// 'Foo ('Qux x) => , + /// _ => + /// } + /// ``` + /// + /// The presence of a default case means that the row variables of top-level enum patterns + /// might stay open. However, the type corresponding to the sub-patterns `'Bar x` and `'Qux x` + /// must be closed, because this match expression can't handle `'Foo ('Other 0)`. The type of + /// the match expression is thus `[| 'Foo [| 'Bar: a, 'Qux: b |]; c|] -> d`. + /// + /// Wildcard can occur anywhere, so the previous case can also happen within a record pattern: + /// + /// ```nickel + /// match { + /// {foo = 'Bar x} => , + /// {foo = 'Qux x} => , + /// {foo = _} => , + /// } + /// ``` + /// + /// Similarly, the type of the match expression is `{ foo: [| 'Bar: a, 'Qux: b; c |] } -> e`. + /// + /// See [^typechecking-match-expression] in [typecheck] for more details. + pub enum_open_tails: Vec<(PatternPath, UnifEnumRows<'ast>)>, + /// Paths of the occurrence of wildcard patterns encountered. This is used to determine which + /// tails in [Self::enum_open_tails] should be left open. + pub wildcard_occurrences: HashSet, +} +/// Close all the enum row types left open when typechecking a match expression. Special case of +/// `close_enums` for a single destructuring pattern (thus, where wildcard occurrences are not +/// relevant). +pub fn close_all_enums<'ast>( + enum_open_tails: Vec<(PatternPath, UnifEnumRows<'ast>)>, + state: &mut State<'ast, '_>, +) { + close_enums(enum_open_tails, &HashSet::new(), state); +} + +/// Close all the enum row types left open when typechecking a match expression, unless we recorded +/// a wildcard pattern somewhere in the same position. +pub fn close_enums<'ast>( + enum_open_tails: Vec<(PatternPath, UnifEnumRows<'ast>)>, + wildcard_occurrences: &HashSet, + state: &mut State<'ast, '_>, +) { + // Note: both for this function and for `close_enums`, for a given pattern path, all the tail + // variables should ultimately be part of the same enum type, and we just need to close it + // once. We might thus save a bit of work if we kept equivalence classes of tuples (path, tail) + // (equality being given by the equality of paths). Closing one arbitrary member per class + // should then be enough. It's not obvious that this would make any difference in practice, + // though. + for tail in enum_open_tails + .into_iter() + .filter_map(|(path, tail)| (!wildcard_occurrences.contains(&path)).then_some(tail)) + { + close_enum(tail, state); + } +} + +/// Take an enum row, find its final tail (in case of multiple indirection through unification +/// variables) and close it if it's a free unification variable. +fn close_enum<'ast>(tail: UnifEnumRows<'ast>, state: &mut State<'ast, '_>) { + let root = tail.into_root(state.table); + + if let UnifEnumRows::UnifVar { id, .. } = root { + // We don't need to perform any variable level checks when unifying a free + // unification variable with a ground type + state + .table + .assign_erows(id, UnifEnumRows::concrete(EnumRowsF::Empty)); + } else { + let tail = root.iter().find_map(|row_item| { + match row_item { + EnumRowsElt::TailUnifVar { id, init_level } => { + Some(UnifEnumRows::UnifVar { id, init_level }) + } + EnumRowsElt::TailVar(_) | EnumRowsElt::TailConstant(_) => { + // While unifying open enum rows coming from a pattern, we expect to always + // extend the enum row with other open rows such that the result should always + // stay open. So we expect to find a unification variable at the end of the + // enum row. + // + // But in fact, all the tails for a given pattern path will point to the same + // enum row, so it might have been closed already by a previous call to + // `close_enum`, and that's fine. On the other hand, we should never encounter + // a rigid type variable here (or a non-substituted type variable, although it + // has nothing to do with patterns), so if we reach this point, something is + // wrong with the typechecking of match expression. + debug_assert!(false); + + None + } + _ => None, + } + }); + + if let Some(tail) = tail { + close_enum(tail, state) + } + } +} + +pub(super) trait PatternTypes<'ast> { + /// The type produced by the pattern. Depending on the nature of the pattern, this type may + /// vary: for example, a record pattern will produce record rows, while a general pattern will + /// produce a general [super::UnifType<'ast>] + type PatType; + + /// Builds the type associated to the whole pattern, as well as the types associated to each + /// binding introduced by this pattern. When matching a value against a pattern in a statically + /// typed code, either by destructuring or by applying a match expression, the type of the + /// value will be checked against the type generated by `pattern_type` and the bindings will be + /// added to the type environment. + /// + /// The type of each "leaf" identifier will be assigned based on the `mode` argument. The + /// current possibilities are for each leaf to have type `Dyn`, to use an explicit type + /// annotation, or to be assigned a fresh unification variable. + fn pattern_types( + &self, + state: &mut State<'ast, '_>, + ctxt: &Context<'ast>, + mode: TypecheckMode, + ) -> Result, TypecheckError> { + let mut bindings = Vec::new(); + let mut enum_open_tails = Vec::new(); + let mut wildcard_pat_paths = HashSet::new(); + + let typ = self.pattern_types_inj( + &mut PatTypeState { + bindings: &mut bindings, + enum_open_tails: &mut enum_open_tails, + wildcard_pat_paths: &mut wildcard_pat_paths, + }, + Vec::new(), + state, + ctxt, + mode, + )?; + + Ok(PatternTypeData { + typ, + bindings, + enum_open_tails, + wildcard_occurrences: wildcard_pat_paths, + }) + } + + /// Same as `pattern_types`, but inject the bindings in a working vector instead of returning + /// them. Implementors should implement this method whose signature avoids creating and + /// combining many short-lived vectors when walking recursively through a pattern. + fn pattern_types_inj( + &self, + pt_state: &mut PatTypeState<'ast, '_>, + path: PatternPath, + state: &mut State<'ast, '_>, + ctxt: &Context<'ast>, + mode: TypecheckMode, + ) -> Result; +} + +impl<'ast> PatternTypes<'ast> for RecordPattern<'ast> { + type PatType = UnifRecordRows<'ast>; + + fn pattern_types_inj( + &self, + pt_state: &mut PatTypeState<'ast, '_>, + path: PatternPath, + state: &mut State<'ast, '_>, + ctxt: &Context<'ast>, + mode: TypecheckMode, + ) -> Result { + let tail = if self.is_open() { + match mode { + // We use a dynamic tail here since we're in walk mode, + // but if/when we remove dynamic record tails this could + // likely be made an empty tail with no impact. + TypecheckMode::Walk => mk_buty_record_row!(; RecordRowsF::TailDyn), + TypecheckMode::Enforce => state.table.fresh_rrows_uvar(ctxt.var_level), + } + } else { + UnifRecordRows::Concrete { + rrows: RecordRowsF::Empty, + var_levels_data: VarLevelsData::new_no_uvars(), + } + }; + + if let TailPattern::Capture(rest) = self.tail { + pt_state + .bindings + .push((rest, UnifType::concrete(TypeF::Record(tail.clone())))); + } + + self.patterns + .iter() + .map(|field_pat| field_pat.pattern_types_inj(pt_state, path.clone(), state, ctxt, mode)) + .try_fold(tail, |tail, row: Result| { + Ok(UnifRecordRows::concrete(RecordRowsF::Extend { + row: row?, + tail: Box::new(tail), + })) + }) + } +} + +impl<'ast> PatternTypes<'ast> for ArrayPattern<'ast> { + type PatType = UnifType<'ast>; + + fn pattern_types_inj( + &self, + pt_state: &mut PatTypeState<'ast, '_>, + path: PatternPath, + state: &mut State<'ast, '_>, + ctxt: &Context<'ast>, + mode: TypecheckMode, + ) -> Result { + // We allocate a fresh unification variable and unify it with the type of each element + // pattern in enforce mode. + // + // In walk mode, we still iterate through the sub patterns to populate the bindings, but we + // eschew unification, which might fail if the elements are heterogeneous (say two record + // patterns with different shapes). In this case, we just return `Dyn` as the element type. + let elem_type = match mode { + TypecheckMode::Enforce => state.table.fresh_type_uvar(ctxt.var_level), + TypecheckMode::Walk => mk_uniftype::dynamic(), + }; + + for (idx, subpat) in self.patterns.iter().enumerate() { + let mut path = path.clone(); + path.push(PatternPathElem::Array(idx)); + + let subpat_type = subpat.pattern_types_inj(pt_state, path, state, ctxt, mode)?; + + if let TypecheckMode::Enforce = mode { + elem_type + .clone() + .unify(subpat_type, state, ctxt) + .map_err(|e| e.into_typecheck_err(state, self.pos))?; + } + } + + if let TailPattern::Capture(rest) = &self.tail { + pt_state + .bindings + .push((*rest, mk_uniftype::array(elem_type.clone()))); + } + + Ok(elem_type) + } +} + +impl<'ast> PatternTypes<'ast> for Pattern<'ast> { + type PatType = UnifType<'ast>; + + fn pattern_types_inj( + &self, + pt_state: &mut PatTypeState<'ast, '_>, + path: PatternPath, + state: &mut State<'ast, '_>, + ctxt: &Context<'ast>, + mode: TypecheckMode, + ) -> Result { + let typ = self + .data + .pattern_types_inj(pt_state, path, state, ctxt, mode)?; + + if let Some(alias) = self.alias { + pt_state.bindings.push((alias, typ.clone())); + } + + Ok(typ) + } +} + +// Depending on the mode, returns the type affected to patterns that match any value (`Any` and +// `Wildcard`): `Dyn` in walk mode, a fresh unification variable in enforce mode. +fn any_type<'ast>( + mode: TypecheckMode, + state: &mut State<'ast, '_>, + ctxt: &Context<'ast>, +) -> UnifType<'ast> { + match mode { + TypecheckMode::Walk => mk_uniftype::dynamic(), + TypecheckMode::Enforce => state.table.fresh_type_uvar(ctxt.var_level), + } +} + +impl<'ast> PatternTypes<'ast> for PatternData<'ast> { + type PatType = UnifType<'ast>; + + fn pattern_types_inj( + &self, + pt_state: &mut PatTypeState<'ast, '_>, + path: PatternPath, + state: &mut State<'ast, '_>, + ctxt: &Context<'ast>, + mode: TypecheckMode, + ) -> Result { + match self { + PatternData::Wildcard => { + pt_state.wildcard_pat_paths.insert(path); + Ok(any_type(mode, state, ctxt)) + } + PatternData::Any(id) => { + let typ = any_type(mode, state, ctxt); + pt_state.bindings.push((*id, typ.clone())); + + Ok(typ) + } + PatternData::Record(record_pat) => Ok(UnifType::concrete(TypeF::Record( + record_pat.pattern_types_inj(pt_state, path, state, ctxt, mode)?, + ))), + PatternData::Array(array_pat) => Ok(mk_uniftype::array( + array_pat.pattern_types_inj(pt_state, path, state, ctxt, mode)?, + )), + PatternData::Enum(enum_pat) => { + let row = enum_pat.pattern_types_inj(pt_state, path.clone(), state, ctxt, mode)?; + // We elaborate the type `[| row; a |]` where `a` is a fresh enum rows unification + // variable registered in `enum_open_tails`. + let tail = state.table.fresh_erows_uvar(ctxt.var_level); + pt_state.enum_open_tails.push((path, tail.clone())); + + Ok(UnifType::concrete(TypeF::Enum(UnifEnumRows::concrete( + EnumRowsF::Extend { + row, + tail: Box::new(tail), + }, + )))) + } + PatternData::Constant(constant_pat) => { + constant_pat.pattern_types_inj(pt_state, path, state, ctxt, mode) + } + PatternData::Or(or_pat) => or_pat.pattern_types_inj(pt_state, path, state, ctxt, mode), + } + } +} + +impl<'ast> PatternTypes<'ast> for ConstantPattern<'ast> { + type PatType = UnifType<'ast>; + + fn pattern_types_inj( + &self, + pt_state: &mut PatTypeState<'ast, '_>, + path: PatternPath, + state: &mut State<'ast, '_>, + ctxt: &Context<'ast>, + mode: TypecheckMode, + ) -> Result { + self.data + .pattern_types_inj(pt_state, path, state, ctxt, mode) + } +} + +impl<'ast> PatternTypes<'ast> for ConstantPatternData<'ast> { + type PatType = UnifType<'ast>; + + fn pattern_types_inj( + &self, + _pt_state: &mut PatTypeState<'ast, '_>, + _path: PatternPath, + _state: &mut State, + _ctxt: &Context, + _mode: TypecheckMode, + ) -> Result { + Ok(match self { + ConstantPatternData::Bool(_) => UnifType::concrete(TypeF::Bool), + ConstantPatternData::Number(_) => UnifType::concrete(TypeF::Number), + ConstantPatternData::String(_) => UnifType::concrete(TypeF::String), + ConstantPatternData::Null => UnifType::concrete(TypeF::Dyn), + }) + } +} + +impl<'ast> PatternTypes<'ast> for FieldPattern<'ast> { + type PatType = UnifRecordRow<'ast>; + + fn pattern_types_inj( + &self, + pt_state: &mut PatTypeState<'ast, '_>, + mut path: PatternPath, + state: &mut State<'ast, '_>, + ctxt: &Context<'ast>, + mode: TypecheckMode, + ) -> Result { + path.push(PatternPathElem::Field(self.matched_id.ident())); + + // If there is a static type annotation in a nested record patterns then we need to unify + // them with the pattern type we've built to ensure (1) that they're mutually compatible + // and (2) that we assign the annotated types to the right unification variables. + let ty_row = match (&self.annotation.typ, &self.pattern.data, mode) { + // However, in walk mode, we only do that when the nested pattern isn't a leaf (i.e. + // `Any` or `Wildcard`) for backward-compatibility reasons. + // + // Before this function was refactored, Nickel has been allowing things like `let {foo + // : Number} = {foo = 1} in foo` in walk mode, which would fail to typecheck with the + // generic approach: the pattern is parsed as `{foo : Number = foo}`, the second + // occurrence of `foo` gets type `Dyn` in walk mode, but `Dyn` fails to unify with + // `Number`. In this case, we don't recursively call `pattern_types_inj` in the first + // place and just declare that the type of `foo` is `Number`. + // + // This special case should probably be ruled out, requiring the users to use `let {foo + // | Number}` instead, at least outside of a statically typed code block. But before + // this happens, we special case the old behavior and eschew unification. + (Some(annot_ty), PatternData::Any(id), TypecheckMode::Walk) => { + let ty_row = UnifType::from_type(annot_ty.clone(), &ctxt.term_env); + pt_state.bindings.push((*id, ty_row.clone())); + ty_row + } + (Some(annot_ty), PatternData::Wildcard, TypecheckMode::Walk) => { + UnifType::from_type(annot_ty.clone(), &ctxt.term_env) + } + (Some(annot_ty), _, _) => { + let pos = annot_ty.pos; + let annot_uty = UnifType::from_type(annot_ty.clone(), &ctxt.term_env); + + let ty_row = self + .pattern + .pattern_types_inj(pt_state, path, state, ctxt, mode)?; + + ty_row + .clone() + .unify(annot_uty, state, ctxt) + .map_err(|e| e.into_typecheck_err(state, pos))?; + + ty_row + } + _ => self + .pattern + .pattern_types_inj(pt_state, path, state, ctxt, mode)?, + }; + + Ok(UnifRecordRow { + id: self.matched_id, + typ: Box::new(ty_row), + }) + } +} + +impl<'ast> PatternTypes<'ast> for EnumPattern<'ast> { + type PatType = UnifEnumRow<'ast>; + + fn pattern_types_inj( + &self, + pt_state: &mut PatTypeState<'ast, '_>, + mut path: PatternPath, + state: &mut State<'ast, '_>, + ctxt: &Context<'ast>, + mode: TypecheckMode, + ) -> Result { + let typ_arg = self + .pattern + .as_ref() + .map(|pat| { + path.push(PatternPathElem::Variant); + pat.pattern_types_inj(pt_state, path, state, ctxt, mode) + }) + .transpose()? + .map(Box::new); + + Ok(UnifEnumRow { + id: self.tag, + typ: typ_arg, + }) + } +} + +impl<'ast> PatternTypes<'ast> for OrPattern<'ast> { + type PatType = UnifType<'ast>; + + fn pattern_types_inj( + &self, + pt_state: &mut PatTypeState<'ast, '_>, + path: PatternPath, + state: &mut State<'ast, '_>, + ctxt: &Context<'ast>, + mode: TypecheckMode, + ) -> Result { + // When checking a sequence of or-patterns, we must combine their open tails and wildcard + // pattern positions - in fact, when typechecking a whole match expression, this is exactly + // what the typechecker is doing: it merges all those data. And a match expression is, + // similarly to an or-pattern, a disjunction of patterns. + // + // However, the treatment of bindings is different. If any of the branch in an or-pattern + // matches, the same code path (the match branch) will be run, and thus they must agree on + // pattern variables. Which means: + // + // 1. All pattern branches must have the same set of variables + // 2. Each variable must have a compatible type across all or-pattern branches + // + // To do so, we call to `pattern_types_inj` with a fresh vector of bindings, so that we can + // post-process them afterward (enforcing 1. and 2. above) before actually adding them to + // the original overall bindings. + // + // `bindings` stores, for each or-pattern branch, the inferred type of the whole branch, + // the generated bindings and the position (the latter for error reporting). + let bindings: Result, _> = self + .patterns + .iter() + .map(|pat| -> Result<_, TypecheckError> { + let mut fresh_bindings = Vec::new(); + + let mut local_state = PatTypeState { + bindings: &mut fresh_bindings, + enum_open_tails: pt_state.enum_open_tails, + wildcard_pat_paths: pt_state.wildcard_pat_paths, + }; + + let typ = + pat.pattern_types_inj(&mut local_state, path.clone(), state, ctxt, mode)?; + + // We sort the bindings to check later that they are the same in all branches + fresh_bindings.sort_by_key(|(id, _typ)| *id); + + Ok((typ, fresh_bindings, pat.pos)) + }) + .collect(); + + let mut it = bindings?.into_iter(); + + // We need a reference set of variables (and their types for unification). We just pick the + // first bindings of the list. + let Some((model_typ, model, _pos)) = it.next() else { + // We should never generate empty `or` sequences (it's not possible to write them in + // the source language, at least). However, it doesn't cost much to support them: such + // a pattern never matches anything. Thus, we return the bottom type encoded as `forall + // a. a`. + let free_var = Ident::from("a"); + + return Ok(UnifType::concrete(TypeF::Forall { + var: free_var.into(), + var_kind: VarKind::Type, + body: Box::new(UnifType::concrete(TypeF::Var(free_var))), + })); + }; + + for (typ, pat_bindings, pos) in it { + if model.len() != pat_bindings.len() { + // We need to arbitrary choose a variable to report. We take the last one of the + // longest list, which is guaranteed to not be present in all branches + let witness = if model.len() > pat_bindings.len() { + // unwrap(): model.len() > pat_bindings.len() >= 0 + model.last().unwrap().0 + } else { + // unwrap(): model.len() <= pat_bindings.len() and (by the outer-if) + // pat_bindings.len() != mode.len(), so: + // 0 <= model.len() < pat_bindings.len() + pat_bindings.last().unwrap().0 + }; + + return Err(TypecheckError::OrPatternVarsMismatch { + var: witness, + pos: self.pos, + }); + } + + // We unify the type of the first or-branch with the current or-branch, to make sure + // all the subpatterns are matching values of the same type + if let TypecheckMode::Enforce = mode { + model_typ + .clone() + .unify(typ, state, ctxt) + .map_err(|e| e.into_typecheck_err(state, pos))?; + } + + // Finally, we unify the type of the bindings + for (idx, (id, typ)) in pat_bindings.into_iter().enumerate() { + let (model_id, model_ty) = &model[idx]; + + if *model_id != id { + // Once again, we must arbitrarily pick a variable to report. We take the + // smaller one, which is guaranteed to be missing (indeed, the greater one + // could still appear later in the other list, but the smaller is necessarily + // missing in the list with the greater one) + return Err(TypecheckError::OrPatternVarsMismatch { + var: std::cmp::min(*model_id, id), + pos: self.pos, + }); + } + + if let TypecheckMode::Enforce = mode { + model_ty + .clone() + .unify(typ, state, ctxt) + .map_err(|e| e.into_typecheck_err(state, id.pos))?; + } + } + } + + // Once we have checked that all the bound variables are the same and we have unified their + // types, we can add them to the overall bindings (since they are unified, it doesn't + // matter which type we use - so we just reuse the model, which is still around) + pt_state.bindings.extend(model); + + Ok(model_typ) + } +} diff --git a/core/src/bytecode/typecheck/record.rs b/core/src/bytecode/typecheck/record.rs new file mode 100644 index 0000000000..e5e77ab477 --- /dev/null +++ b/core/src/bytecode/typecheck/record.rs @@ -0,0 +1,543 @@ +//! Typechecking records. +//! +//! Because record literal definitions are flexible in Nickel (piecewise definitions), they need +//! a bit of preprocessing before they can be typechecked. Preprocessing and typechecking of +//! records is handled in this module. +use super::*; +use crate::{ + bytecode::ast::record::{FieldDef, FieldPathElem, Record}, + combine::Combine, + position::TermPos, +}; + +use std::iter; + +use indexmap::{map::Entry, IndexMap}; + +pub(super) trait Resolve<'ast> { + type Resolved; + + fn resolve(&'ast self) -> Self::Resolved; +} + +/// A resolved record literal, without field paths or piecewise definitions. Piecewise definitions +/// of fields have been grouped together, path have been broken into proper levels and top-level +/// fields are partitioned between static and dynamic. +#[derive(Default)] +pub(super) struct ResolvedRecord<'ast> { + /// The static fields of the record. + pub stat_fields: IndexMap>, + /// The dynamic fields of the record. + pub dyn_fields: Vec<(&'ast Ast<'ast>, ResolvedField<'ast>)>, + /// The position of the resolved record. + pub pos: TermPos, +} + +impl<'ast> ResolvedRecord<'ast> { + pub fn empty() -> Self { + Self::default() + } + + pub fn is_empty(&self) -> bool { + self.stat_fields.is_empty() && self.dyn_fields.is_empty() + } + + pub fn check>( + &self, + state: &mut State<'ast, '_>, + ctxt: Context<'ast>, + visitor: &mut V, + ty: UnifType<'ast>, + ) -> Result<(), TypecheckError> { + // If we have no dynamic fields, we can check the record against a record type or a + // dictionary type, depending on `ty`. + if self.dyn_fields.is_empty() { + self.check_stat(state, ctxt, visitor, ty) + } + // If some fields are defined dynamically, the only potential type that works is `{_ : a}` + // for some `a`. + else { + self.check_dyn(state, ctxt, visitor, ty) + } + } + + /// Checks a record with dynamic fields (and potentially static fields as well) against a type. + /// + /// # Preconditions + /// + /// This method assumes that `self.dyn_fields` is non-empty. Currently, violating this invariant + /// shouldn't cause panic or unsoundness, but will unduly enforce that `ty` is a dictionary + /// type. + fn check_dyn>( + &self, + state: &mut State<'ast, '_>, + mut ctxt: Context<'ast>, + visitor: &mut V, + ty: UnifType<'ast>, + ) -> Result<(), TypecheckError> { + let ty_elts = state.table.fresh_type_uvar(ctxt.var_level); + + ty.unify(mk_uniftype::dict(ty_elts.clone()), state, &ctxt) + .map_err(|err| err.into_typecheck_err(state, self.pos))?; + + for id in self.stat_fields.keys() { + ctxt.type_env.insert(id.ident(), ty_elts.clone()); + visitor.visit_ident(id, ty_elts.clone()) + } + + for (expr, field) in &self.dyn_fields { + check(state, ctxt.clone(), visitor, expr, mk_uniftype::str())?; + field.check(state, ctxt.clone(), visitor, ty_elts.clone())?; + } + + // We don't bind recursive fields in the term environment used to check for contract. See + // [^term-env-rec-bindings] in `./mod.rs`. + for (_, field) in self.stat_fields.iter() { + field.check(state, ctxt.clone(), visitor, ty_elts.clone())?; + } + + Ok(()) + } + + /// Checks a record with only static fields against a type. + /// + /// # Preconditions + /// + /// This method assumes that `self.dyn_fields` is empty. Currently, violating this invariant + /// shouldn't cause panic or unsoundness, but will unduly enforce that `ty` is a dictionary + /// type. + fn check_stat>( + &self, + state: &mut State<'ast, '_>, + mut ctxt: Context<'ast>, + visitor: &mut V, + ty: UnifType<'ast>, + ) -> Result<(), TypecheckError> { + let root_ty = ty.clone().into_root(state.table); + + if let UnifType::Concrete { + typ: TypeF::Dict { + type_fields: rec_ty, + .. + }, + .. + } = root_ty + { + // Checking mode for a dictionary + for (_, field) in self.stat_fields.iter() { + field.check(state, ctxt.clone(), visitor, (*rec_ty).clone())?; + } + + Ok(()) + } else { + // As records are recursive, we look at the apparent type of each field and bind it in ctxt + // before actually typechecking the content of fields. + // + // Fields defined by interpolation are ignored, because they can't be referred to + // recursively. + + // When we build the recursive environment, there are two different possibilities for each + // field: + // + // 1. The field is annotated. In this case, we use this type to build the type environment. + // We don't need to do any additional check that the field respects this annotation: + // this will be handled by `check_field` when processing the field. + // 2. The field isn't annotated. We are going to infer a concrete type later, but for now, + // we allocate a fresh unification variable in the type environment. In this case, once + // we have inferred an actual type for this field, we need to unify what's inside the + // environment with the actual type to ensure that they agree. + // + // `need_unif_step` stores the list of fields corresponding to the case 2, which require + // this additional unification step. Note that performing the additional unification in + // case 1. should be harmless, but it's wasteful, and is also not entirely trivial because + // of polymorphism (we need to make sure to instantiate polymorphic type annotations). At + // the end of the day, it's simpler to skip unneeded unifications. + let mut need_unif_step = HashSet::new(); + + for (id, field) in &self.stat_fields { + let uty_apprt = field.apparent_type( + state.ast_alloc, + Some(&ctxt.type_env), + Some(state.resolver), + ); + + // `Approximated` corresponds to the case where the type isn't obvious (annotation + // or constant), and thus to case 2. above + if matches!(uty_apprt, ApparentType::Approximated(_)) { + need_unif_step.insert(*id); + } + + let uty = apparent_or_infer(state, uty_apprt, &ctxt, true); + ctxt.type_env.insert(id.ident(), uty.clone()); + visitor.visit_ident(id, uty); + } + + // We build a vector of unification variables representing the type of the fields of + // the record. + // + // Since `IndexMap` guarantees a stable order of iteration, we use a vector instead of + // hashmap here. To find the type associated to the field `foo`, retrieve the index of + // `foo` in `self.stat_fields.keys()` and index into `field_types`. + let mut field_types: Vec> = + iter::repeat_with(|| state.table.fresh_type_uvar(ctxt.var_level)) + .take(self.stat_fields.len()) + .collect(); + + // Build the type {id1 : ?a1, id2: ?a2, .., idn: ?an}, which is the type of the whole + // record. + let rows = self.stat_fields.keys().zip(field_types.iter()).fold( + mk_buty_record_row!(), + |acc, (id, row_ty)| mk_buty_record_row!((*id, row_ty.clone()); acc), + ); + + ty.unify(mk_buty_record!(; rows), state, &ctxt) + .map_err(|err| err.into_typecheck_err(state, self.pos))?; + + // We reverse the order of `field_types`. The idea is that we can then pop each + // field type as we iterate a last time over the fields, taking ownership, instead of + // having to clone elements if we indexed instead. + field_types.reverse(); + + for (id, field) in self.stat_fields.iter() { + // unwrap(): `field_types` has exactly the same length as `self.stat_fields`, as it + // was constructed with `.take(self.stat_fields.len()).collect()`. + let field_type = field_types.pop().unwrap(); + + // For a recursive record and a field which requires the additional unification + // step (whose type wasn't known when building the recursive environment), we + // unify the actual type with the type affected in the typing environment + // (which started as a fresh unification variable, but might have been unified + // with a more concrete type if the current field has been used recursively + // from other fields). + if need_unif_step.contains(id) { + // unwrap(): if the field is in `need_unif_step`, it must be in the context. + let affected_type = ctxt.type_env.get(&id.ident()).cloned().unwrap(); + + field_type + .clone() + .unify(affected_type, state, &ctxt) + .map_err(|err| { + err.into_typecheck_err( + state, + field.pos(), + // field.value.as_ref().map(|v| v.pos).unwrap_or_default(), + ) + })?; + } + + field.check(state, ctxt.clone(), visitor, field_type)?; + } + + Ok(()) + } + } +} + +impl<'ast> Combine for ResolvedRecord<'ast> { + fn combine(this: ResolvedRecord<'ast>, other: ResolvedRecord<'ast>) -> Self { + use crate::eval::merge::split; + + let split::SplitResult { + left, + center, + right, + } = split::split(this.stat_fields, other.stat_fields); + + let mut stat_fields = IndexMap::with_capacity(left.len() + center.len() + right.len()); + + stat_fields.extend(left); + stat_fields.extend(right); + + for (id, (field1, field2)) in center.into_iter() { + stat_fields.insert(id, Combine::combine(field1, field2)); + } + + let dyn_fields = this + .dyn_fields + .into_iter() + .chain(other.dyn_fields) + .collect(); + + let pos = match (this.pos, other.pos) { + // If only one of the two position is defined, we use it + (pos, TermPos::None) | (TermPos::None, pos) => pos, + // Otherwise, we don't know how to combine two disjoint positions of a piecewise + // definition, so we just return `TermPos::None`. + _ => TermPos::None, + }; + + ResolvedRecord { + stat_fields, + dyn_fields, + pos, + } + } +} + +/// A wrapper type around a record that has been resolved but hasn't yet got a position. This is +/// done to force the caller of [Record::resolve] to provide a position before doing anything else. +pub(super) struct PoslessResolvedRecord<'ast>(ResolvedRecord<'ast>); + +impl<'ast> PoslessResolvedRecord<'ast> { + pub(super) fn new( + stat_fields: IndexMap>, + dyn_fields: Vec<(&'ast Ast<'ast>, ResolvedField<'ast>)>, + ) -> Self { + PoslessResolvedRecord(ResolvedRecord { + stat_fields, + dyn_fields, + pos: TermPos::None, + }) + } + + pub(super) fn with_pos(self, pos: TermPos) -> ResolvedRecord<'ast> { + let PoslessResolvedRecord(record) = self; + + ResolvedRecord { pos, ..record } + } +} + +/// The field of a resolved record. +/// +/// A resolved field can be either: +/// +/// - another resolved record, for the fields coming from elaboration, as +/// `mid` in `{ outer.mid.inner = true }`. +/// - A final value, for the last field of path, as `inner` in `{ outer.mid.inner = true }` or in +/// `fun param => { outer.mid.inner = param}`. +/// - A combination of the previous cases, for a field defined piecewise with multiple +/// definitions, such as `mid` in `fun param => { outer.mid.inner = true, outer.mid = param}`. +/// +/// In the combined, the resolved field `mid` will have a resolved part `{inner = true}` and a +/// value part `param`. Values can't be combined statically in all generality (imagine adding +/// another piecewise definition `outer.mid = other_variable` in the previous example), hence we +/// keep accumulating them. However, resolved parts can be merged statically, so we only need one +/// that we update as we collect the pieces of the definition. +/// +/// Rather than having an ad-hoc enum with all those cases (that would just take up more memory), +/// we consider the general combined case directly. Others are special cases with an empty +/// `resolved`, or an empty or one-element `values`. +#[derive(Default)] +pub(super) struct ResolvedField<'ast> { + /// The resolved part of the field, coming from piecewise definitions where this field appears + /// in the middle of the path. + resolved: ResolvedRecord<'ast>, + /// The accumulated values of the field, coming from piecewise definitions where this field + /// appears last in the path. + /// + /// We store the whole [crate::bytecode::ast::record::FieldDef] here, although we don't need + /// the path anymore, because it's easier and less costly that create an ad-hoc structure to + /// store only the value and the metadata. + defs: Vec<&'ast FieldDef<'ast>>, +} + +impl<'ast> ResolvedField<'ast> { + /// Return the first type or contract annotation available in the definitions, if any. + /// + /// [ResolvedField::first_annot] first looks for a type annotation in all definitions. If we + /// can't find any, [ResolvedField::first_annot] will look for the first contract annotation. + /// If there is no annotation at all, `None` is returned. + /// + /// [ResolvedField::first_annot] is equivalent to calling + /// [crate::bytecode::ast::Annotation::first] on the combined metadata of all definitions. + pub fn first_annot(&self) -> Option> { + self.defs + .iter() + .find_map(|def| def.metadata.annotation.typ.as_ref().cloned()) + .or(self + .defs + .iter() + .find_map(|def| def.metadata.annotation.contracts.first().cloned())) + } + + pub fn check>( + &self, + state: &mut State<'ast, '_>, + ctxt: Context<'ast>, + visitor: &mut V, + ty: UnifType<'ast>, + ) -> Result<(), TypecheckError> { + match (self.resolved.is_empty(), self.defs.as_slice()) { + // This shouldn't happen (fields present in the record should either have a definition + // or comes from record resolution). + (true, []) => { + unreachable!("typechecker internal error: checking a vacant field") + } + (true, [def]) if def.metadata.is_empty() => check_field(state, ctxt, visitor, def, ty), + (false, []) => self.resolved.check(state, ctxt, visitor, ty), + // In all other cases, we have either several definitions or at least one definition + // and a resolved part. Those cases will result in a runtime merge, so we type + // everything as `Dyn`. + (_, defs) => { + for def in defs.iter() { + check_field(state, ctxt.clone(), visitor, def, mk_uniftype::dynamic())?; + } + + if !self.resolved.is_empty() { + // This will always raise an error, since the resolved part is equivalent to a + // record literal which doens't type against `Dyn` (at least currently). We + // could raise the error directly, but it's simpler to call `check` on + // `self.resolved`, which will handle that for us. + // + // Another reason is that the error situation might change in the future, if we + // have proper subtyping for `Dyn`. + self.resolved + .check(state, ctxt, visitor, mk_uniftype::dynamic())?; + } + + Ok(()) + } + } + } + + /// Returns the position of this resolved field if and only if there is a single defined + /// position (among both the resolved part and the definitions). Otherwise, returns + /// [crate::position::TermPos::None]. + pub fn pos(&self) -> TermPos { + self.defs + .iter() + .fold(self.resolved.pos, |acc, def| acc.xor(def.pos)) + } +} + +impl Combine for ResolvedField<'_> { + fn combine(this: Self, other: Self) -> Self { + let mut defs = this.defs; + defs.extend(other.defs); + + ResolvedField { + resolved: Combine::combine(this.resolved, other.resolved), + defs, + } + } +} + +impl<'ast> From<&'ast FieldDef<'ast>> for ResolvedField<'ast> { + fn from(def: &'ast FieldDef<'ast>) -> Self { + ResolvedField { + resolved: ResolvedRecord::empty(), + defs: vec![def], + } + } +} + +impl<'ast> From> for ResolvedField<'ast> { + fn from(resolved: ResolvedRecord<'ast>) -> Self { + ResolvedField { + resolved, + defs: Vec::new(), + } + } +} + +impl<'ast> Resolve<'ast> for Record<'ast> { + type Resolved = PoslessResolvedRecord<'ast>; + + fn resolve(&self) -> PoslessResolvedRecord<'ast> { + fn insert_static_field<'ast>( + static_fields: &mut IndexMap>, + id: LocIdent, + field: ResolvedField<'ast>, + ) { + match static_fields.entry(id) { + Entry::Occupied(mut occpd) => { + // temporarily putting an empty field in the entry to take the previous value. + let prev = occpd.insert(ResolvedField::default()); + + // unwrap(): the field's identifier must have a position during parsing. + occpd.insert(Combine::combine(prev, field)); + } + Entry::Vacant(vac) => { + vac.insert(field); + } + } + } + + let mut stat_fields = IndexMap::new(); + let mut dyn_fields = Vec::new(); + + for def in self.field_defs.iter() { + // expect(): the field path must have at least one element, it's an invariant. + let toplvl_field = def.path.first().expect("empty field path"); + let rfield = def.resolve(); + + if let Some(id) = toplvl_field.try_as_ident() { + insert_static_field(&mut stat_fields, id, rfield); + continue; + } else { + // unreachable!(): `try_as_ident` returns `None` only if the path element is a + // `Expr` + let FieldPathElem::Expr(expr) = toplvl_field else { + unreachable!() + }; + dyn_fields.push((expr, rfield)); + } + } + + PoslessResolvedRecord::new(stat_fields, dyn_fields) + } +} + +// This turns a field definition into potentially nested resolved fields. Note that the top-level +// field is left out, as it's already been processed by the caller: resolving `foo.bar.baz.qux = +// 42` will return nested resolved records of the form `{bar = {baz = {qux = 42}}}`. +impl<'ast> Resolve<'ast> for FieldDef<'ast> { + type Resolved = ResolvedField<'ast>; + + fn resolve(&'ast self) -> ResolvedField<'ast> { + self.path[1..] + .iter() + .rev() + .fold(self.into(), |acc, path_elem| { + if let Some(id) = path_elem.try_as_ident() { + let pos_acc = acc.pos(); + + ResolvedField::from(ResolvedRecord { + stat_fields: iter::once((id, acc)).collect(), + dyn_fields: Vec::new(), + pos: id.pos.fuse(pos_acc), + }) + } else { + // unreachable!(): `try_as_ident` returns `None` only if the path element is a + // `Expr` + let FieldPathElem::Expr(expr) = path_elem else { + unreachable!() + }; + + let pos_acc = acc.pos(); + + ResolvedField::from(ResolvedRecord { + stat_fields: IndexMap::new(), + dyn_fields: vec![(expr, acc)], + pos: expr.pos.fuse(pos_acc), + }) + } + }) + } +} + +impl<'ast> HasApparentType<'ast> for ResolvedField<'ast> { + // Return the apparent type of a field, by first looking at the type annotation, if any, then at + // the contracts annotation, and if there is none, fall back to the apparent type of the value. If + // there is no value, `Approximated(Dyn)` is returned. + fn apparent_type( + &self, + ast_alloc: &'ast AstAlloc, + env: Option<&TypeEnv<'ast>>, + resolver: Option<&dyn ImportResolver>, + ) -> ApparentType<'ast> { + match self.defs.as_slice() { + // If there is a resolved part, the apparent type is `Dyn`: a resolved part itself is a + // record literal without annotation, whose apparent type is indeed `Dyn`. If there are + // definitions as well, the result will be merged at runtime, and the apparent type of a + // merge expression is also `Dyn`. + _ if !self.resolved.is_empty() => ApparentType::Approximated(Type::from(TypeF::Dyn)), + [] => ApparentType::Approximated(Type::from(TypeF::Dyn)), + [def] => def.apparent_type(ast_alloc, env, resolver), + _ => self + .first_annot() + .map(ApparentType::Annotated) + .unwrap_or(ApparentType::Approximated(Type::from(TypeF::Dyn))), + } + } +} diff --git a/core/src/bytecode/typecheck/reporting.rs b/core/src/bytecode/typecheck/reporting.rs new file mode 100644 index 0000000000..4d31497802 --- /dev/null +++ b/core/src/bytecode/typecheck/reporting.rs @@ -0,0 +1,269 @@ +//! Helpers to convert a `TypeWrapper` to a human-readable `Type` representation for error +//! reporting. +use super::*; + +/// A name registry used to replace unification variables and type constants with human-readable +/// and distinct names. +pub struct NameReg { + /// Currently allocated names, including both variables written by the user and generated + /// names. + names: NameTable, + /// A reverse name table, always kept in sync with `names`, in order to efficiently check if a + /// name is already taken. + taken: HashSet, + /// Counter used to generate fresh letters for unification variables. + var_count: usize, + /// Counter used to generate fresh letters for type constants. + cst_count: usize, +} + +impl NameReg { + /// Create a new registry from an initial table corresponding to user-written type constants. + pub fn new(names: NameTable) -> Self { + let taken = names.values().copied().collect(); + + NameReg { + names, + taken, + var_count: 0, + cst_count: 0, + } + } + + pub fn taken(&self, name: &str) -> bool { + self.taken.contains(&name.into()) + } + + fn insert(&mut self, var_id: VarId, discriminant: VarKindDiscriminant, name: Ident) { + self.names.insert((var_id, discriminant), name); + self.taken.insert(name); + } + + /// Create a fresh name candidate for a type variable or a type constant. + /// + /// Used to convert a unification type to a human-readable representation. + /// + /// To select a candidate, first check in `names` if the variable or the constant corresponds + /// to a type variable written by the user. If it is, return the name of the variable. + /// Otherwise, use the given counter to generate a new single letter. + /// + /// A generated name is clearly not necessarily unique. [`select_uniq`] must then be applied. + fn gen_candidate_name( + names: &NameTable, + counter: &mut usize, + id: VarId, + kind: VarKindDiscriminant, + ) -> String { + match names.get(&(id, kind)) { + // First check if that constant or variable was introduced by a forall. If it was, try + // to use the same name. + Some(orig) => format!("{orig}"), + None => { + //Otherwise, generate a new character + let next = *counter; + *counter += 1; + + let prefix = match kind { + VarKindDiscriminant::Type => "", + VarKindDiscriminant::EnumRows => "erows_", + VarKindDiscriminant::RecordRows => "rrows_", + }; + let character = std::char::from_u32(('a' as u32) + ((next % 26) as u32)).unwrap(); + format!("{prefix}{character}") + } + } + } + + /// Select a name distinct from all the others, starting from a candidate name for a type + /// variable or a type constant. Insert the corresponding name in the name table. + /// + /// If the name is already taken, it just iterates by adding a numeric suffix `1`, `2`, .., and + /// so on until a free name is found. See `var_to_type` and `cst_to_type`. + fn select_uniq(&mut self, mut name: String, id: VarId, kind: VarKindDiscriminant) -> Ident { + // To avoid clashing with already picked names, we add a numeric suffix to the picked + // letter. + if self.taken(&name) { + let mut suffix = 1; + + name = format!("{name}{suffix}"); + while self.taken(&name) { + suffix += 1; + } + } + + let sym = Ident::from(name); + self.insert(id, kind, sym); + sym + } + + /// Either retrieve or generate a new fresh name for a unification variable for error reporting, + /// and wrap it as an identifier. Unification variables are named `_a`, `_b`, .., `_a1`, `_b1`, + /// .. and so on. + pub fn gen_var_name(&mut self, id: VarId, kind: VarKindDiscriminant) -> Ident { + self.names.get(&(id, kind)).cloned().unwrap_or_else(|| { + // Select a candidate name and add a "_" prefix + let candidate = format!( + "_{}", + Self::gen_candidate_name(&self.names, &mut self.var_count, id, kind) + ); + // Add a suffix to make it unique if it has already been picked + self.select_uniq(candidate, id, kind) + }) + } + + /// Either retrieve or generate a new fresh name for a constant for error reporting, and wrap it + /// as type variable. Constant are named `a`, `b`, .., `a1`, `b1`, .. and so on. + pub fn gen_cst_name(&mut self, id: VarId, kind: VarKindDiscriminant) -> Ident { + self.names.get(&(id, kind)).cloned().unwrap_or_else(|| { + // Select a candidate name + let candidate = Self::gen_candidate_name(&self.names, &mut self.cst_count, id, kind); + // Add a suffix to make it unique if it has already been picked + self.select_uniq(candidate, id, kind) + }) + } +} + +pub trait ToType<'ast> { + /// The target type to convert to. If `Self` is `UnifXXX`, then `Target` is `XXX`. + type Target; + + /// Extract a concrete type corresponding to a unification type for error reporting purpose, + /// given a registry of currently allocated names. + /// + /// As opposed to [`crate::typ::Type::from`], free unification variables and type constants are + /// replaced by type variables which names are determined by this name registry. + /// + /// When reporting error, we want to distinguish occurrences of unification variables and type + /// constants in a human-readable way. + fn to_type( + self, + alloc: &'ast AstAlloc, + reg: &mut NameReg, + table: &UnifTable<'ast>, + ) -> Self::Target; +} + +impl<'ast> ToType<'ast> for UnifType<'ast> { + type Target = Type<'ast>; + + fn to_type( + self, + alloc: &'ast AstAlloc, + reg: &mut NameReg, + table: &UnifTable<'ast>, + ) -> Self::Target { + let ty = self.into_root(table); + + match ty { + UnifType::UnifVar { id, .. } => { + TypeF::Var(reg.gen_var_name(id, VarKindDiscriminant::Type)).into() + } + UnifType::Constant(id) => { + TypeF::Var(reg.gen_cst_name(id, VarKindDiscriminant::Type)).into() + } + UnifType::Concrete { typ, .. } => typ + .map_state( + |btyp, reg| alloc.alloc(btyp.to_type(alloc, reg, table)), + |rrows, reg| rrows.to_type(alloc, reg, table), + |erows, reg| erows.to_type(alloc, reg, table), + |(ctr, _env), _reg| ctr, + reg, + ) + .into(), + } + } +} + +impl<'ast> ToType<'ast> for UnifRecordRows<'ast> { + type Target = RecordRows<'ast>; + + fn to_type( + self, + alloc: &'ast AstAlloc, + reg: &mut NameReg, + table: &UnifTable<'ast>, + ) -> Self::Target { + let rrows = self.into_root(table); + + match rrows { + UnifRecordRows::UnifVar { id, .. } => RecordRows(RecordRowsF::TailVar( + reg.gen_var_name(id, VarKindDiscriminant::RecordRows).into(), + )), + UnifRecordRows::Constant(id) => RecordRows(RecordRowsF::TailVar( + reg.gen_cst_name(id, VarKindDiscriminant::RecordRows).into(), + )), + UnifRecordRows::Concrete { rrows, .. } => { + let mapped = rrows.map_state( + |btyp, reg| alloc.alloc(btyp.to_type(alloc, reg, table)), + |rrows, reg| alloc.alloc(rrows.to_type(alloc, reg, table)), + reg, + ); + RecordRows(mapped) + } + } + } +} + +impl<'ast> ToType<'ast> for UnifEnumRows<'ast> { + type Target = EnumRows<'ast>; + + fn to_type( + self, + alloc: &'ast AstAlloc, + reg: &mut NameReg, + table: &UnifTable<'ast>, + ) -> Self::Target { + let erows = self.into_root(table); + + match erows { + UnifEnumRows::UnifVar { id, .. } => EnumRows(EnumRowsF::TailVar( + reg.gen_var_name(id, VarKindDiscriminant::EnumRows).into(), + )), + UnifEnumRows::Constant(id) => EnumRows(EnumRowsF::TailVar( + reg.gen_cst_name(id, VarKindDiscriminant::EnumRows).into(), + )), + UnifEnumRows::Concrete { erows, .. } => { + let mapped = erows.map_state( + |btyp, reg| alloc.alloc(btyp.to_type(alloc, reg, table)), + |erows, reg| alloc.alloc(erows.to_type(alloc, reg, table)), + reg, + ); + EnumRows(mapped) + } + } + } +} + +impl<'ast> ToType<'ast> for UnifEnumRow<'ast> { + type Target = EnumRow<'ast>; + + fn to_type( + self, + alloc: &'ast AstAlloc, + reg: &mut NameReg, + table: &UnifTable<'ast>, + ) -> Self::Target { + EnumRow { + id: self.id, + typ: self + .typ + .map(|typ| alloc.alloc(typ.to_type(alloc, reg, table))), + } + } +} + +impl<'ast> ToType<'ast> for UnifRecordRow<'ast> { + type Target = RecordRow<'ast>; + + fn to_type( + self, + alloc: &'ast AstAlloc, + reg: &mut NameReg, + table: &UnifTable<'ast>, + ) -> Self::Target { + RecordRow { + id: self.id, + typ: alloc.alloc(self.typ.to_type(alloc, reg, table)), + } + } +} diff --git a/core/src/bytecode/typecheck/subtyping.rs b/core/src/bytecode/typecheck/subtyping.rs new file mode 100644 index 0000000000..67f93b03cc --- /dev/null +++ b/core/src/bytecode/typecheck/subtyping.rs @@ -0,0 +1,264 @@ +//! Type subsumption (subtyping) +//! +//! Subtyping is a relation between types that allows a value of one type to be used at a place +//! where another type is expected, because the value's actual type is subsumed by the expected +//! type. +//! +//! The subsumption rule is applied when switching from inference mode to checking mode, as +//! customary in bidirectional type checking. +//! +//! Currently, there is one core subtyping axiom: +//! +//! - Record / Dictionary : `{a1 : T1,...,an : Tn} <: {_ : U}` if for every n `Tn <: U` +//! +//! The subtyping relation is extended to a congruence on other type constructors in the obvious +//! way: +//! +//! - `Array T <: Array U` if `T <: U` +//! - `{_ : T} <: {_ : U}` if `T <: U` +//! - `{a1 : T1,...,an : Tn} <: {b1 : U1,...,bn : Un}` if for every n `Tn <: Un` +//! +//! In all other cases, we fallback to unification (although we instantiate polymorphic types as +//! needed before). That is, we try to apply reflexivity: `T <: U` if `T = U`. +//! +//! The type instantiation corresponds to the zero-ary case of application in the current +//! specification (which is loosely based on [A Quick Look at Impredicativity][quick-look], +//! although we currently don't support impredicative polymorphism). +//! +//! [quick-look]: https://www.microsoft.com/en-us/research/uploads/prod/2020/01/quick-look-icfp20-fixed.pdf +use super::*; + +pub(super) trait SubsumedBy<'ast> { + type Error; + + /// Checks if `self` is subsumed by `t2`, that is if `self <: t2`. Returns an error otherwise. + fn subsumed_by( + self, + t2: Self, + state: &mut State<'ast, '_>, + ctxt: Context<'ast>, + ) -> Result<(), Self::Error>; +} + +impl<'ast> SubsumedBy<'ast> for UnifType<'ast> { + type Error = UnifError<'ast>; + + fn subsumed_by( + self, + t2: Self, + state: &mut State<'ast, '_>, + mut ctxt: Context<'ast>, + ) -> Result<(), Self::Error> { + let inferred = instantiate_foralls(state, &mut ctxt, self, ForallInst::UnifVar); + let checked = t2.into_root(state.table); + + match (inferred, checked) { + // {a1 : T1,...,an : Tn} <: {_ : U} if for every n `Tn <: U` + ( + UnifType::Concrete { + typ: TypeF::Record(rrows), + .. + }, + UnifType::Concrete { + typ: + TypeF::Dict { + type_fields, + flavour, + }, + var_levels_data, + }, + ) => { + for row in rrows.iter() { + match row { + RecordRowsElt::Row(a) => { + a.typ + .clone() + .subsumed_by(*type_fields.clone(), state, ctxt.clone())? + } + RecordRowsElt::TailUnifVar { id, .. } => + // We don't need to perform any variable level checks when unifying a free + // unification variable with a ground type + // We close the tail because there is no guarantee that + // { a : Number, b : Number, _ : a?} <= { _ : Number} + { + state + .table + .assign_rrows(id, UnifRecordRows::concrete(RecordRowsF::Empty)) + } + RecordRowsElt::TailConstant(id) => { + let checked = UnifType::Concrete { + typ: TypeF::Dict { + type_fields: type_fields.clone(), + flavour, + }, + var_levels_data, + }; + Err(UnifError::WithConst { + var_kind: VarKindDiscriminant::RecordRows, + expected_const_id: id, + inferred: checked, + })? + } + _ => (), + } + } + Ok(()) + } + // Array T <: Array U if T <: U + ( + UnifType::Concrete { + typ: TypeF::Array(a), + .. + }, + UnifType::Concrete { + typ: TypeF::Array(b), + .. + }, + ) + // Dict T <: Dict U if T <: U + | ( + UnifType::Concrete { + typ: TypeF::Dict { type_fields: a, .. }, + .. + }, + UnifType::Concrete { + typ: TypeF::Dict { type_fields: b, .. }, + .. + }, + ) => a.subsumed_by(*b, state, ctxt), + // {a1 : T1,...,an : Tn} <: {b1 : U1,...,bn : Un} if for every n `Tn <: Un` + ( + UnifType::Concrete { + typ: TypeF::Record(rrows1), + .. + }, + UnifType::Concrete { + typ: TypeF::Record(rrows2), + .. + }, + ) => rrows1 + .clone() + .subsumed_by(rrows2.clone(), state, ctxt) + .map_err(|err| err.into_unif_err(mk_buty_record!(;rrows2), mk_buty_record!(;rrows1))), + // T <: U if T = U + (inferred, checked) => checked.unify(inferred, state, &ctxt), + } + } +} + +impl<'ast> SubsumedBy<'ast> for UnifRecordRows<'ast> { + type Error = RowUnifError<'ast>; + + fn subsumed_by( + self, + t2: Self, + state: &mut State<'ast, '_>, + ctxt: Context<'ast>, + ) -> Result<(), Self::Error> { + // This code is almost taken verbatim fro `unify`, but where some recursive calls are + // changed to be `subsumed_by` instead of `unify`. We can surely factorize both into a + // generic function, but this is left for future work. + let inferred = self.into_root(state.table); + let checked = t2.into_root(state.table); + + match (inferred, checked) { + ( + UnifRecordRows::Concrete { rrows: rrows1, .. }, + UnifRecordRows::Concrete { + rrows: rrows2, + var_levels_data: levels2, + }, + ) => match (rrows1, rrows2) { + (RecordRowsF::Extend { row, tail }, rrows2 @ RecordRowsF::Extend { .. }) => { + let urrows2 = UnifRecordRows::Concrete { + rrows: rrows2, + var_levels_data: levels2, + }; + let (ty_res, urrows_without_ty_res) = urrows2 + .remove_row(&row.id, &row.typ, state, ctxt.var_level) + .map_err(|err| match err { + RemoveRowError::Missing => RowUnifError::MissingRow(row.id), + RemoveRowError::Conflict => { + RowUnifError::RecordRowConflict(row.clone()) + } + })?; + if let RemoveRowResult::Extracted(ty) = ty_res { + row.typ + .subsumed_by(ty, state, ctxt.clone()) + .map_err(|err| RowUnifError::RecordRowMismatch { + id: row.id, + cause: Box::new(err), + })?; + } + tail.subsumed_by(urrows_without_ty_res, state, ctxt) + } + (RecordRowsF::TailVar(id), _) | (_, RecordRowsF::TailVar(id)) => { + Err(RowUnifError::UnboundTypeVariable(id)) + } + (RecordRowsF::Empty, RecordRowsF::Empty) + | (RecordRowsF::TailDyn, RecordRowsF::TailDyn) => Ok(()), + (RecordRowsF::Empty, RecordRowsF::TailDyn) + | (RecordRowsF::TailDyn, RecordRowsF::Empty) => Err(RowUnifError::ExtraDynTail), + ( + RecordRowsF::Empty, + RecordRowsF::Extend { + row: UnifRecordRow { id, .. }, + .. + }, + ) + | ( + RecordRowsF::TailDyn, + RecordRowsF::Extend { + row: UnifRecordRow { id, .. }, + .. + }, + ) => Err(RowUnifError::MissingRow(id)), + ( + RecordRowsF::Extend { + row: UnifRecordRow { id, .. }, + .. + }, + RecordRowsF::TailDyn, + ) + | ( + RecordRowsF::Extend { + row: UnifRecordRow { id, .. }, + .. + }, + RecordRowsF::Empty, + ) => Err(RowUnifError::ExtraRow(id)), + }, + (UnifRecordRows::UnifVar { id, .. }, urrows) + | (urrows, UnifRecordRows::UnifVar { id, .. }) => { + if let UnifRecordRows::Constant(cst_id) = urrows { + let constant_level = state.table.get_rrows_level(cst_id); + state.table.force_rrows_updates(constant_level); + if state.table.get_rrows_level(id) < constant_level { + return Err(RowUnifError::VarLevelMismatch { + constant_id: cst_id, + var_kind: VarKindDiscriminant::RecordRows, + }); + } + } + urrows.propagate_constrs(state.constr, id)?; + state.table.assign_rrows(id, urrows); + Ok(()) + } + (UnifRecordRows::Constant(i1), UnifRecordRows::Constant(i2)) if i1 == i2 => Ok(()), + (UnifRecordRows::Constant(i1), UnifRecordRows::Constant(i2)) => { + Err(RowUnifError::ConstMismatch { + var_kind: VarKindDiscriminant::RecordRows, + expected_const_id: i2, + inferred_const_id: i1, + }) + } + (urrows, UnifRecordRows::Constant(i)) | (UnifRecordRows::Constant(i), urrows) => { + Err(RowUnifError::WithConst { + var_kind: VarKindDiscriminant::RecordRows, + expected_const_id: i, + inferred: UnifType::concrete(TypeF::Record(urrows)), + }) + } + } + } +} diff --git a/core/src/bytecode/typecheck/unif.rs b/core/src/bytecode/typecheck/unif.rs new file mode 100644 index 0000000000..68df32d0f7 --- /dev/null +++ b/core/src/bytecode/typecheck/unif.rs @@ -0,0 +1,1832 @@ +//! Types unification. + +use super::{eq::TypeEq, *}; + +/// Unification variable or type constants unique identifier. +pub type VarId = usize; + +/// Variable levels. Levels are used in order to implement polymorphism in a sound way: we need to +/// associate to each unification variable and rigid type variable a level, which depends on when +/// those variables were introduced, and to forbid some unifications if a condition on levels is +/// not met. +#[derive(Clone, Copy, Ord, Eq, PartialEq, PartialOrd, Debug)] +pub struct VarLevel(NonZeroU16); + +impl VarLevel { + /// Special constant used for level upper bound to indicate that a type doesn't contain any + /// unification variable. It's equal to `1` and strictly smaller than [VarLevel::MIN_LEVEL], so + /// it's strictly smaller than any concrete variable level. + pub const NO_VAR: Self = VarLevel(NonZeroU16::MIN); + /// The first available variable level, `2`. + // unsafe is required because `unwrap()` is not usable in `const fn` code as of today in stable + // Rust. + // unsafe(): we must enforce the invariant that the argument `n` of `new_unchecked(n)` verifies + // `0 < n`. Indeed `0 < 2`. + pub const MIN_LEVEL: Self = unsafe { VarLevel(NonZeroU16::new_unchecked(2)) }; + /// The maximum level. Used as an upper bound to indicate that nothing can be said about the + /// levels of the unification variables contained in a type. + pub const MAX_LEVEL: Self = VarLevel(NonZeroU16::MAX); + + /// Increment the variable level by one. Panic if the maximum capacity of the underlying + /// numeric type is reached (currently, `u16::MAX`). + pub fn incr(&mut self) { + let new_value = self + .0 + .checked_add(1) + .expect("reached the maxium unification variable level"); + self.0 = new_value; + } +} + +/// An element of the unification table. Contains the potential type this variable points to (or +/// `None` if the variable hasn't been unified with something yet), and the variable's level. +pub struct UnifSlot { + value: Option, + level: VarLevel, +} + +impl UnifSlot { + pub fn new(level: VarLevel) -> Self { + UnifSlot { value: None, level } + } +} + +/// The unification table. +/// +/// Map each unification variable to either another type variable or a concrete type it has been +/// unified with. Each binding `(ty, var)` in this map should be thought of an edge in a +/// unification graph. +/// +/// The unification table is really three separate tables, corresponding to the different kinds of +/// types: standard types, record rows, and enum rows. +/// +/// The unification table is a relatively low-level data structure, whose consumer has to ensure +/// specific invariants. It is used by the `unify` function and its variants, but you should avoid +/// using it directly, unless you know what you're doing. +#[derive(Default)] +pub struct UnifTable<'ast> { + types: Vec>>, + rrows: Vec>>, + erows: Vec>>, + pending_type_updates: Vec, + pending_rrows_updates: Vec, + pending_erows_updates: Vec, +} + +impl<'ast> UnifTable<'ast> { + pub fn new() -> Self { + UnifTable::default() + } + + /// Assign a type to a type unification variable. + /// + /// This method updates variables level, at least lazily, by pushing them to a stack of pending + /// traversals. + /// + /// # Preconditions + /// + /// - This method doesn't check for the variable level conditions. This is the responsibility + /// of the caller. + /// - If the target type is a unification variable as well, it must not be assigned to another + /// unification type. That is, `assign` should always be passed a root type. Otherwise, the + /// handling of variable levels will be messed up. + /// - This method doesn't force pending level updates when needed (calling to + /// `force_type_updates`), i.e. when `uty` is a rigid type variable. Having pending variable + /// level updates and using `assign_type` might make typechecking incorrect in some situation + /// by unduely allowing unsound generalization. This is the responsibility of the caller. + pub fn assign_type(&mut self, var: VarId, uty: UnifType<'ast>) { + // Unifying a free variable with itself is a no-op. + if matches!(uty, UnifType::UnifVar { id, ..} if id == var) { + return; + } + + debug_assert!({ + if let UnifType::UnifVar { id, init_level: _ } = &uty { + self.types[*id].value.is_none() + } else { + true + } + }); + debug_assert!(self.types[var].value.is_none()); + + let uty_lvl_updated = self.update_type_level(var, uty, self.types[var].level); + self.types[var].value = Some(uty_lvl_updated); + } + + // Lazily propagate a variable level to the unification variables contained in `uty`. Either do + // a direct update in constant time when possible, or push a stack of delayed updates for + // composite types. + fn update_type_level( + &mut self, + var: VarId, + uty: UnifType<'ast>, + new_level: VarLevel, + ) -> UnifType<'ast> { + match uty { + // We can do the update right away + UnifType::UnifVar { id, init_level } => { + if new_level < self.types[id].level { + self.types[id].level = new_level; + } + + UnifType::UnifVar { id, init_level } + } + // If a concrete type is a candidate for update, we push the pending update on the + // stack + UnifType::Concrete { + typ, + var_levels_data, + } if var_levels_data.upper_bound >= new_level => { + self.pending_type_updates.push(var); + + UnifType::Concrete { + typ, + var_levels_data: VarLevelsData { + pending: Some(new_level), + ..var_levels_data + }, + } + } + // The remaining types either don't contain unification variables or have all their + // level greater than the updated level + _ => uty, + } + } + + /// Assign record rows to a record rows unification variable. + /// + /// This method updates variables level, at least lazily, by pushing them to a stack of pending + /// traversals. + /// + /// # Preconditions + /// + /// - This method doesn't check for the variable level conditions. This is the responsibility + /// of the caller. + /// - If the target type is a unification variable as well, it must not be assigned to another + /// unification type. That is, `assign` should always be passed a root type. Otherwise, the + /// handling of variable levels will be messed up. + /// - This method doesn't force pending level updates when needed (calling to + /// `force_rrows_updates`), i.e. when `uty` is a rigid type variable. Having pending variable + /// level updates and using `assign_type` might make typechecking incorrect in some situation + /// by unduly allowing unsound generalization. This is the responsibility of the caller. + pub fn assign_rrows(&mut self, var: VarId, rrows: UnifRecordRows<'ast>) { + // Unifying a free variable with itself is a no-op. + if matches!(rrows, UnifRecordRows::UnifVar { id, ..} if id == var) { + return; + } + + self.update_rrows_level(var, &rrows, self.rrows[var].level); + debug_assert!(self.rrows[var].value.is_none()); + self.rrows[var].value = Some(rrows); + } + + // cf `update_type_level()` + fn update_rrows_level(&mut self, var: VarId, uty: &UnifRecordRows<'ast>, new_level: VarLevel) { + match uty { + // We can do the update right away + UnifRecordRows::UnifVar { + id: var_id, + init_level: _, + } => { + if new_level < self.rrows[*var_id].level { + self.rrows[*var_id].level = new_level; + } + } + // If concrete rows are a candidate for update, we push the pending update on the stack + UnifRecordRows::Concrete { + var_levels_data, .. + } if var_levels_data.upper_bound >= new_level => self.pending_rrows_updates.push(var), + // The remaining rows either don't contain unification variables or have all their + // level greater than the updated level + _ => (), + } + } + + /// Assign enum rows to an enum rows unification variable. + /// + /// This method updates variables level, at least lazily, by pushing them to a stack of pending + /// traversals. + /// + /// # Preconditions + /// + /// - This method doesn't check for the variable level conditions. This is the responsibility + /// of the caller. + /// - If the target type is a unification variable as well, it must not be assigned to another + /// unification type. That is, `assign` should always be passed a root type. Otherwise, the + /// handling of variable levels will be messed up. + /// - This method doesn't force pending level updates when needed (calling to + /// `force_erows_updates`), i.e. when `uty` is a rigid type variable. Having pending variable + /// level updates and using `assign_type` might make typechecking incorrect in some situation + /// by unduly allowing unsound generalization. This is the responsibility of the caller. + pub fn assign_erows(&mut self, var: VarId, erows: UnifEnumRows<'ast>) { + // Unifying a free variable with itself is a no-op. + if matches!(erows, UnifEnumRows::UnifVar { id, .. } if id == var) { + return; + } + + self.update_erows_level(var, &erows, self.erows[var].level); + debug_assert!(self.erows[var].value.is_none()); + self.erows[var].value = Some(erows); + } + + // cf `update_type_level()` + fn update_erows_level(&mut self, var: VarId, uty: &UnifEnumRows<'ast>, new_level: VarLevel) { + match uty { + // We can do the update right away + UnifEnumRows::UnifVar { + id: var_id, + init_level: _, + } => { + if new_level < self.erows[*var_id].level { + self.erows[*var_id].level = new_level; + } + } + // If concrete rows are a candidate for update, we push the pending update on the stack + UnifEnumRows::Concrete { + var_levels_data, .. + } if var_levels_data.upper_bound >= new_level => self.pending_erows_updates.push(var), + // The remaining rows either don't contain unification variables or have all their + // level greater than the updated level + _ => (), + } + } + + /// Retrieve the current assignment of a type unification variable. + pub fn get_type(&self, var: VarId) -> Option<&UnifType<'ast>> { + self.types[var].value.as_ref() + } + + /// Retrieve the current level of a unification variable or a rigid type variable. + pub fn get_level(&self, var: VarId) -> VarLevel { + self.types[var].level + } + + /// Retrieve the current assignment of a record rows unification variable. + pub fn get_rrows(&self, var: VarId) -> Option<&UnifRecordRows<'ast>> { + self.rrows[var].value.as_ref() + } + + /// Retrieve the current level of a record rows unification variable or a record rows rigid + /// type variable. + pub fn get_rrows_level(&self, var: VarId) -> VarLevel { + self.rrows[var].level + } + + /// Retrieve the current assignment of an enum rows unification variable. + pub fn get_erows(&self, var: VarId) -> Option<&UnifEnumRows<'ast>> { + self.erows[var].value.as_ref() + } + + /// Retrieve the current level of an enu rows unification variable or a record rows rigid type + /// variable. + pub fn get_erows_level(&self, var: VarId) -> VarLevel { + self.erows[var].level + } + + /// Create a fresh type unification variable (or constant) identifier and allocate a + /// corresponding slot in the table. + pub fn fresh_type_var_id(&mut self, current_level: VarLevel) -> VarId { + let next = self.types.len(); + self.types.push(UnifSlot::new(current_level)); + next + } + + /// Create a fresh record rows variable (or constant) identifier and allocate a corresponding + /// slot in the table. + pub fn fresh_rrows_var_id(&mut self, current_level: VarLevel) -> VarId { + let next = self.rrows.len(); + self.rrows.push(UnifSlot::new(current_level)); + next + } + + /// Create a fresh enum rows variable (or constant) identifier and allocate a corresponding + /// slot in the table. + pub fn fresh_erows_var_id(&mut self, current_level: VarLevel) -> VarId { + let next = self.erows.len(); + self.erows.push(UnifSlot::new(current_level)); + next + } + + /// Create a fresh type unification variable and allocate a corresponding slot in the table. + pub fn fresh_type_uvar(&mut self, current_level: VarLevel) -> UnifType<'ast> { + UnifType::UnifVar { + id: self.fresh_type_var_id(current_level), + init_level: current_level, + } + } + + /// Create a fresh record rows unification variable and allocate a corresponding slot in the + /// table. + pub fn fresh_rrows_uvar(&mut self, current_level: VarLevel) -> UnifRecordRows<'ast> { + UnifRecordRows::UnifVar { + id: self.fresh_rrows_var_id(current_level), + init_level: current_level, + } + } + + /// Create a fresh enum rows unification variable and allocate a corresponding slot in the + /// table. + pub fn fresh_erows_uvar(&mut self, current_level: VarLevel) -> UnifEnumRows<'ast> { + UnifEnumRows::UnifVar { + id: self.fresh_erows_var_id(current_level), + init_level: current_level, + } + } + + /// Create a fresh type constant and allocate a corresponding slot in the table. + pub fn fresh_type_const(&mut self, current_level: VarLevel) -> UnifType<'ast> { + UnifType::Constant(self.fresh_type_var_id(current_level)) + } + + /// Create a fresh record rows constant and allocate a corresponding slot in the table. + pub fn fresh_rrows_const(&mut self, current_level: VarLevel) -> UnifRecordRows<'ast> { + UnifRecordRows::Constant(self.fresh_rrows_var_id(current_level)) + } + + /// Create a fresh enum rows constant and allocate a corresponding slot in the table. + pub fn fresh_erows_const(&mut self, current_level: VarLevel) -> UnifEnumRows<'ast> { + UnifEnumRows::Constant(self.fresh_erows_var_id(current_level)) + } + + /// Follow the links in the unification table to find the representative of the equivalence + /// class of the type unification variable `x`. + /// + /// This corresponds to the find in union-find. + // TODO This should be a union find like algorithm + pub fn root_type(&self, var_id: VarId, init_level: VarLevel) -> UnifType<'ast> { + // All queried variable must have been introduced by `new_var` and thus a corresponding + // entry must always exist in `state`. If not, the typechecking algorithm is not correct, + // and we panic. + match self.types[var_id].value.as_ref() { + None => UnifType::UnifVar { + id: var_id, + init_level, + }, + Some(UnifType::UnifVar { id, init_level }) => self.root_type(*id, *init_level), + Some(ty) => ty.clone(), + } + } + + /// Follow the links in the unification table to find the representative of the equivalence + /// class of the record rows unification variable `x`. + /// + /// This corresponds to the find in union-find. + // TODO This should be a union find like algorithm + pub fn root_rrows(&self, var_id: VarId, init_level: VarLevel) -> UnifRecordRows<'ast> { + // All queried variable must have been introduced by `new_var` and thus a corresponding + // entry must always exist in `state`. If not, the typechecking algorithm is not correct, + // and we panic. + match self.rrows[var_id].value.as_ref() { + None => UnifRecordRows::UnifVar { + id: var_id, + init_level, + }, + Some(UnifRecordRows::UnifVar { id, init_level }) => self.root_rrows(*id, *init_level), + Some(ty) => ty.clone(), + } + } + + /// Follow the links in the unification table to find the representative of the equivalence + /// class of the enum rows unification variable `x`. + /// + /// This corresponds to the find in union-find. + // TODO This should be a union find like algorithm + pub fn root_erows(&self, var_id: VarId, init_level: VarLevel) -> UnifEnumRows<'ast> { + // All queried variable must have been introduced by `new_var` and thus a corresponding + // entry must always exist in `state`. If not, the typechecking algorithm is not correct, + // and we panic. + match self.erows[var_id].value.as_ref() { + None => UnifEnumRows::UnifVar { + id: var_id, + init_level, + }, + Some(UnifEnumRows::UnifVar { id, init_level }) => self.root_erows(*id, *init_level), + Some(ty) => ty.clone(), + } + } + + /// Return a `VarId` greater than all of the variables currently allocated (unification and + /// rigid type variables, of all kinds, rows or types). The returned UID is guaranteed to be + /// different from all the currently live variables. This is currently simply the max of the + /// length of the various unification tables. + /// + /// Used inside [self::eq] to generate temporary rigid type variables that are guaranteed to + /// not conflict with existing variables. + pub fn max_uvars_count(&self) -> VarId { + max(self.types.len(), max(self.rrows.len(), self.erows.len())) + } + + /// This function forces pending type updates prior to unifying a variable with a rigid type + /// variable of level `constant_level`. Updates that wouldn't change the outcome of such a + /// unification are delayed further. + /// + /// The whole point of variable levels is to forbid some unsound unifications of a unification + /// variable with a rigid type variable. For performance reasons, those levels aren't + /// propagated immediatly when unifying a variable with a concrete type, but lazily stored at + /// the level of types (see [VarLevel]). + /// + /// However, unifying with a rigid type variable is an instance that requires levels to be up + /// to date. In this case, this function must be called before checking variable levels. + /// + /// # Parameters + /// + /// - `constant_level`: the level of the rigid type variable we're unifying with. While not + /// strictly required to propagate levels, it is used to eschew variable level updates that + /// wouldn't change the outcome of the unfication, which we can keep for later forced + /// updates. + fn force_type_updates(&mut self, constant_level: VarLevel) { + fn update_unr_with_lvl<'ast>( + table: &mut UnifTable<'ast>, + uty: UnifTypeUnr<'ast>, + level: VarLevel, + ) -> UnifTypeUnr<'ast> { + uty.map_state( + |uty, table| Box::new(update_utype_with_lvl(table, *uty, level)), + |rrows, table| update_rrows_with_lvl(table, rrows, level), + |erows, table| update_erows_with_lvl(table, erows, level), + |ctr, _| ctr, + table, + ) + } + + fn update_rrows_with_lvl<'ast>( + table: &mut UnifTable<'ast>, + rrows: UnifRecordRows<'ast>, + level: VarLevel, + ) -> UnifRecordRows<'ast> { + let rrows = rrows.into_root(table); + + match rrows { + UnifRecordRows::Concrete { + rrows, + var_levels_data, + } => { + let rrows = rrows.map_state( + |uty, table| Box::new(update_utype_with_lvl(table, *uty, level)), + |rrows, table| Box::new(update_rrows_with_lvl(table, *rrows, level)), + table, + ); + + // [^var-level-kinds]: Note that for `UnifRecordRows<'ast>` (and for enum rows as + // well), the variable levels data are concerned with record rows unification + // variables, not type unification variable. We thus let them untouched, as + // updating record rows variable levels is an orthogonal concern. + UnifRecordRows::Concrete { + rrows, + var_levels_data, + } + } + UnifRecordRows::UnifVar { .. } | UnifRecordRows::Constant(_) => rrows, + } + } + + fn update_erows_with_lvl<'ast>( + table: &mut UnifTable<'ast>, + erows: UnifEnumRows<'ast>, + level: VarLevel, + ) -> UnifEnumRows<'ast> { + let erows = erows.into_root(table); + + match erows { + UnifEnumRows::Concrete { + erows, + var_levels_data, + } => { + let erows = erows.map_state( + |uty, table| Box::new(update_utype_with_lvl(table, *uty, level)), + |erows, table| Box::new(update_erows_with_lvl(table, *erows, level)), + table, + ); + + // see [^var-level-kinds] + UnifEnumRows::Concrete { + erows, + var_levels_data, + } + } + UnifEnumRows::UnifVar { .. } | UnifEnumRows::Constant(_) => erows, + } + } + + fn update_utype_with_lvl<'ast>( + table: &mut UnifTable<'ast>, + uty: UnifType<'ast>, + level: VarLevel, + ) -> UnifType<'ast> { + let uty = uty.into_root(table); + + match uty { + UnifType::UnifVar { id, init_level } => { + if table.types[id].level > level { + table.types[id].level = level; + } + + UnifType::UnifVar { id, init_level } + } + UnifType::Concrete { + typ, + var_levels_data, + } if var_levels_data.upper_bound > level => { + let level = var_levels_data + .pending + .map(|pending_level| max(pending_level, level)) + .unwrap_or(level); + let typ = update_unr_with_lvl(table, typ, level); + + UnifType::Concrete { + typ, + var_levels_data: VarLevelsData { + upper_bound: level, + pending: None, + }, + } + } + UnifType::Constant(_) | UnifType::Concrete { .. } => uty, + } + } + + fn update_utype<'ast>( + table: &mut UnifTable<'ast>, + uty: UnifType<'ast>, + constant_level: VarLevel, + ) -> (UnifType<'ast>, bool) { + match uty { + UnifType::UnifVar { .. } => { + // We should never end up updating the level of a type variable, as this update + // is done on the spot. + debug_assert!(false); + + (uty, false) + } + UnifType::Concrete { + typ, + var_levels_data: + VarLevelsData { + pending: Some(pending_level), + upper_bound, + }, + } => { + // [^irrelevant-level-update]: A level update where the if-condition below is + // true wouldn't change the outcome of unifying a variable with a constant of + // level `constant_level`. + // + // Impactful updates are updates that might change the level of a variable from + // a value greater than or equals to `constant_level` to a new level strictly + // smaller, but: + // + // 1. If `upper_bound` < `constant_level`, then all unification variable levels + // are already strictly smaller than `constant_level`. An update won't change + // this inequality (level update can only decrease levels) + // 2. If `pending_level` >= `constant_level`, then the update might only + // decrease a level that was greater than `constant_level` to a + // `pending_level` which is still greater than `constant_level`. Once again, + // the update doesn't change the inequality with respect to constant_level. + // + // Thus, such updates might be delayed even more. + if upper_bound < constant_level || pending_level >= constant_level { + return ( + UnifType::Concrete { + typ, + var_levels_data: VarLevelsData { + upper_bound: pending_level, + pending: Some(pending_level), + }, + }, + true, + ); + } + + let typ = if upper_bound > pending_level { + update_unr_with_lvl(table, typ, pending_level) + } else { + typ + }; + + ( + UnifType::Concrete { + typ, + var_levels_data: VarLevelsData { + upper_bound: pending_level, + pending: None, + }, + }, + false, + ) + } + // [^ignore-no-pending-level] If there is no pending level, then this update has + // already been handled (possibly by a forced update on an enclosing type), and + // there's nothing to do. + // + // Note that this type might still contain other pending updates deeper inside, but + // those are registered as pending updates and will be processed in any case. + UnifType::Constant(_) | UnifType::Concrete { .. } => (uty, false), + } + } + + let rest = std::mem::take(&mut self.pending_type_updates) + .into_iter() + .filter(|id| { + // unwrap(): if a unification variable has been push on the update stack, it + // has been been by `assign_type`, and thus MUST have been assigned to + // something. + let typ = self.types[*id].value.take().unwrap(); + let (new_type, delayed) = update_utype(self, typ, constant_level); + self.types[*id].value = Some(new_type); + + delayed + }) + .collect(); + + self.pending_type_updates = rest; + } + + /// See `force_type_updates`. Same as `force_type_updates`, but when unifying a record row + /// unification variable. + pub fn force_rrows_updates(&mut self, constant_level: VarLevel) { + fn update_rrows_unr_with_lvl<'ast>( + table: &mut UnifTable<'ast>, + rrows: UnifRecordRowsUnr<'ast>, + level: VarLevel, + ) -> UnifRecordRowsUnr<'ast> { + rrows.map_state( + |uty, table| Box::new(update_utype_with_lvl(table, *uty, level)), + |rrows, table| Box::new(update_rrows_with_lvl(table, *rrows, level)), + table, + ) + } + + fn update_erows_unr_with_lvl<'ast>( + table: &mut UnifTable<'ast>, + erows: UnifEnumRowsUnr<'ast>, + level: VarLevel, + ) -> UnifEnumRowsUnr<'ast> { + erows.map_state( + |uty, table| Box::new(update_utype_with_lvl(table, *uty, level)), + |erows, table| Box::new(update_erows_with_lvl(table, *erows, level)), + table, + ) + } + + fn update_utype_with_lvl<'ast>( + table: &mut UnifTable<'ast>, + utype: UnifType<'ast>, + level: VarLevel, + ) -> UnifType<'ast> { + let utype = utype.into_root(table); + + match utype { + UnifType::Concrete { + typ, + var_levels_data, + } => { + let typ = typ.map_state( + |uty, table| Box::new(update_utype_with_lvl(table, *uty, level)), + |rrows, table| update_rrows_with_lvl(table, rrows, level), + |erows, table| update_erows_with_lvl(table, erows, level), + |ctr, _| ctr, + table, + ); + + // See [^var-level-kinds] + UnifType::Concrete { + typ, + var_levels_data, + } + } + UnifType::UnifVar { .. } | UnifType::Constant(_) => utype, + } + } + + fn update_rrows_with_lvl<'ast>( + table: &mut UnifTable<'ast>, + rrows: UnifRecordRows<'ast>, + level: VarLevel, + ) -> UnifRecordRows<'ast> { + let rrows = rrows.into_root(table); + + match rrows { + UnifRecordRows::UnifVar { id, init_level } => { + if table.rrows[id].level > level { + table.rrows[id].level = level; + } + + UnifRecordRows::UnifVar { id, init_level } + } + UnifRecordRows::Concrete { + rrows, + var_levels_data, + } if var_levels_data.upper_bound > level => { + let level = var_levels_data + .pending + .map(|pending_level| max(pending_level, level)) + .unwrap_or(level); + let rrows = update_rrows_unr_with_lvl(table, rrows, level); + + UnifRecordRows::Concrete { + rrows, + var_levels_data: VarLevelsData { + upper_bound: level, + pending: None, + }, + } + } + UnifRecordRows::Constant(_) | UnifRecordRows::Concrete { .. } => rrows, + } + } + + fn update_erows_with_lvl<'ast>( + table: &mut UnifTable<'ast>, + erows: UnifEnumRows<'ast>, + level: VarLevel, + ) -> UnifEnumRows<'ast> { + let erows = erows.into_root(table); + + match erows { + UnifEnumRows::Concrete { + erows, + var_levels_data, + } => { + let erows = update_erows_unr_with_lvl(table, erows, level); + + // See [^var-level-kinds] + UnifEnumRows::Concrete { + erows, + var_levels_data, + } + } + UnifEnumRows::UnifVar { .. } | UnifEnumRows::Constant(_) => erows, + } + } + + fn update_rrows<'ast>( + table: &mut UnifTable<'ast>, + rrows: UnifRecordRows<'ast>, + constant_level: VarLevel, + ) -> (UnifRecordRows<'ast>, bool) { + match rrows { + UnifRecordRows::UnifVar { .. } => { + // We should never end up updating the level of a unification variable, as this + // update is done on the spot. + debug_assert!(false); + + (rrows, false) + } + UnifRecordRows::Concrete { + rrows, + var_levels_data: + VarLevelsData { + pending: Some(pending_level), + upper_bound, + }, + } => { + // See [^irrelevant-level-update] + if upper_bound < constant_level || pending_level >= constant_level { + return ( + UnifRecordRows::Concrete { + rrows, + var_levels_data: VarLevelsData { + upper_bound: pending_level, + pending: Some(pending_level), + }, + }, + true, + ); + } + + let rrows = if upper_bound > pending_level { + update_rrows_unr_with_lvl(table, rrows, pending_level) + } else { + rrows + }; + + ( + UnifRecordRows::Concrete { + rrows, + var_levels_data: VarLevelsData { + upper_bound: pending_level, + pending: None, + }, + }, + false, + ) + } + // See [^ignore-no-pending-level] + UnifRecordRows::Constant(_) | UnifRecordRows::Concrete { .. } => (rrows, false), + } + } + + let rest = std::mem::take(&mut self.pending_rrows_updates) + .into_iter() + .filter(|id| { + // unwrap(): if a unification variable has been push on the update stack, it + // has been been by `assign_rrows`, and thus MUST have been assigned to + // something. + let rrows = self.rrows[*id].value.take().unwrap(); + let (new_rrows, delay) = update_rrows(self, rrows, constant_level); + self.rrows[*id].value = Some(new_rrows); + + delay + }) + .collect(); + + self.pending_rrows_updates = rest; + } + + /// See `force_type_updates`. Same as `force_type_updates`, but when unifying an enum row + /// unification variable. + pub fn force_erows_updates(&mut self, constant_level: VarLevel) { + fn update_rrows_unr_with_lvl<'ast>( + table: &mut UnifTable<'ast>, + rrows: UnifRecordRowsUnr<'ast>, + level: VarLevel, + ) -> UnifRecordRowsUnr<'ast> { + rrows.map_state( + |uty, table| Box::new(update_utype_with_lvl(table, *uty, level)), + |rrows, table| Box::new(update_rrows_with_lvl(table, *rrows, level)), + table, + ) + } + + fn update_erows_unr_with_lvl<'ast>( + table: &mut UnifTable<'ast>, + erows: UnifEnumRowsUnr<'ast>, + level: VarLevel, + ) -> UnifEnumRowsUnr<'ast> { + erows.map_state( + |uty, table| Box::new(update_utype_with_lvl(table, *uty, level)), + |erows, table| Box::new(update_erows_with_lvl(table, *erows, level)), + table, + ) + } + + fn update_utype_with_lvl<'ast>( + table: &mut UnifTable<'ast>, + utype: UnifType<'ast>, + level: VarLevel, + ) -> UnifType<'ast> { + let utype = utype.into_root(table); + + match utype { + UnifType::Concrete { + typ, + var_levels_data, + } => { + let typ = typ.map_state( + |uty, table| Box::new(update_utype_with_lvl(table, *uty, level)), + |rrows, table| update_rrows_with_lvl(table, rrows, level), + |erows, table| update_erows_with_lvl(table, erows, level), + |ctr, _| ctr, + table, + ); + + // See [^var-level-kinds] + UnifType::Concrete { + typ, + var_levels_data, + } + } + UnifType::UnifVar { .. } | UnifType::Constant(_) => utype, + } + } + + fn update_rrows_with_lvl<'ast>( + table: &mut UnifTable<'ast>, + rrows: UnifRecordRows<'ast>, + level: VarLevel, + ) -> UnifRecordRows<'ast> { + let rrows = rrows.into_root(table); + + match rrows { + UnifRecordRows::Concrete { + rrows, + var_levels_data, + } => { + let rrows = update_rrows_unr_with_lvl(table, rrows, level); + + // See [^var-level-kinds] + UnifRecordRows::Concrete { + rrows, + var_levels_data, + } + } + UnifRecordRows::UnifVar { .. } | UnifRecordRows::Constant(_) => rrows, + } + } + + fn update_erows_with_lvl<'ast>( + table: &mut UnifTable<'ast>, + erows: UnifEnumRows<'ast>, + level: VarLevel, + ) -> UnifEnumRows<'ast> { + let erows = erows.into_root(table); + + match erows { + UnifEnumRows::UnifVar { id, init_level } => { + if table.erows[id].level > level { + table.erows[id].level = level; + } + + UnifEnumRows::UnifVar { id, init_level } + } + UnifEnumRows::Concrete { + erows, + var_levels_data, + } if var_levels_data.upper_bound > level => { + let level = var_levels_data + .pending + .map(|pending_level| max(pending_level, level)) + .unwrap_or(level); + let erows = update_erows_unr_with_lvl(table, erows, level); + + UnifEnumRows::Concrete { + erows, + var_levels_data: VarLevelsData { + upper_bound: level, + pending: None, + }, + } + } + UnifEnumRows::Constant(_) | UnifEnumRows::Concrete { .. } => erows, + } + } + + fn update_erows<'ast>( + table: &mut UnifTable<'ast>, + erows: UnifEnumRows<'ast>, + constant_level: VarLevel, + ) -> (UnifEnumRows<'ast>, bool) { + match erows { + UnifEnumRows::UnifVar { .. } => { + // We should never end up updating the level of a unification variable, as this + // update is done on the spot. + debug_assert!(false); + + (erows, false) + } + UnifEnumRows::Concrete { + erows, + var_levels_data: + VarLevelsData { + pending: Some(pending_level), + upper_bound, + }, + } => { + // See [^irrelevant-level-update] + if upper_bound < constant_level || pending_level >= constant_level { + return ( + UnifEnumRows::Concrete { + erows, + var_levels_data: VarLevelsData { + upper_bound: pending_level, + pending: Some(pending_level), + }, + }, + true, + ); + } + + let erows = if upper_bound > pending_level { + update_erows_unr_with_lvl(table, erows, pending_level) + } else { + erows + }; + + ( + UnifEnumRows::Concrete { + erows, + var_levels_data: VarLevelsData { + upper_bound: pending_level, + pending: None, + }, + }, + false, + ) + } + // See [^ignore-no-pending-level] + UnifEnumRows::Constant(_) | UnifEnumRows::Concrete { .. } => (erows, false), + } + } + + let rest = std::mem::take(&mut self.pending_erows_updates) + .into_iter() + .filter(|id| { + // unwrap(): if a unification variable has been pushed on the update stack, it must + // have been done by `assign_erows`, and thus MUST have been assigned to something. + let erows = self.erows[*id].value.take().unwrap(); + let (new_erows, delay) = update_erows(self, erows, constant_level); + self.erows[*id].value = Some(new_erows); + + delay + }) + .collect(); + + self.pending_erows_updates = rest; + } +} + +/// Row constraints. +/// +/// A row constraint applies to a unification variable appearing inside a row type (such as `r` in +/// `{ someId: SomeType ; r }` or `[| 'Foo Number, 'Baz; r |]`). A row constraint is a set of +/// identifiers that said row must NOT contain, to forbid ill-formed types with multiple +/// declaration of the same id, for example `{ a: Number, a: String}` or `[| 'Foo String, 'Foo +/// Number |]`. +/// +/// Note that because the syntax (and pattern matching likewise) distinguishes between `'Foo` and +/// `'Foo some_arg`, the type `[| 'Foo, 'Foo SomeType |]` is unproblematic for typechecking. In +/// some sense, enum tags and enum variants live in a different dimension. It looks like we should +/// use separate sets of constraints for enum tag constraints and enum variants constraints. But a +/// set just for enum tag constraints is useless, because enum tags can never conflict, as they +/// don't have any argument: `'Foo` always "agrees with" another `'Foo` definition. In consequence, +/// we simply record enum variants constraints and ignore enum tags. +/// +/// Note that a `VarId` always refer to either a type unification variable, a record row +/// unification variable or an enum row unification variable. Thus, we can use a single constraint +/// set per variable id (which isn't used at all for type unification variables). Because we expect +/// the map to be rather sparse, we use a `HashMap` instead of a `Vec`. +pub type RowConstrs = HashMap>; + +pub(super) trait PropagateConstrs<'ast> { + /// Check that unifying a variable with a type doesn't violate rows constraints, and update the + /// row constraints of the unified type accordingly if needed. + /// + /// When a unification variable `UnifVar(p)` is unified with a type `uty` which is either a row type + /// or another unification variable which could be later unified with a row type itself, the + /// following operations are required: + /// + /// 1. If `uty` is a concrete row, check that it doesn't contain an identifier which is forbidden by + /// a row constraint on `p`. + /// 2. If `uty` is either a unification variable `u` or a row type ending with a unification + /// variable `u`, we must add the constraints of `p` to the constraints of `u`. Indeed, take the + /// following situation: `p` appears in a row type `{a: Number ; p}`, hence has a constraint that + /// it must not contain a field `a`. Then `p` is unified with a fresh type variable `u`. If we + /// don't constrain `u`, `u` could be unified later with a row type `{a : String}` which violates + /// the original constraint on `p`. Thus, when unifying `p` with `u` or a row ending with `u`, + /// `u` must inherit all the constraints of `p`. + fn propagate_constrs( + &self, + constr: &mut RowConstrs, + var_id: VarId, + ) -> Result<(), RowUnifError<'ast>>; +} + +impl<'ast> PropagateConstrs<'ast> for UnifRecordRows<'ast> { + fn propagate_constrs( + &self, + constr: &mut RowConstrs, + var_id: VarId, + ) -> Result<(), RowUnifError<'ast>> { + fn propagate<'ast>( + constr: &mut RowConstrs, + var_id: VarId, + var_constr: HashSet, + rrows: &UnifRecordRows<'ast>, + ) -> Result<(), RowUnifError<'ast>> { + match rrows { + UnifRecordRows::Concrete { + rrows: RecordRowsF::Extend { row, .. }, + .. + } if var_constr.contains(&row.id.ident()) => { + Err(RowUnifError::RecordRowConflict(row.clone())) + } + UnifRecordRows::Concrete { + rrows: RecordRowsF::Extend { tail, .. }, + .. + } => propagate(constr, var_id, var_constr, tail), + UnifRecordRows::UnifVar { id, .. } if *id != var_id => { + if let Some(tail_constr) = constr.get_mut(id) { + tail_constr.extend(var_constr); + } else { + constr.insert(*id, var_constr); + } + + Ok(()) + } + _ => Ok(()), + } + } + + if let Some(var_constr) = constr.remove(&var_id) { + propagate(constr, var_id, var_constr, self) + } else { + Ok(()) + } + } +} + +impl<'ast> PropagateConstrs<'ast> for UnifEnumRows<'ast> { + fn propagate_constrs( + &self, + constr: &mut RowConstrs, + var_id: VarId, + ) -> Result<(), RowUnifError<'ast>> { + fn propagate<'ast>( + constr: &mut RowConstrs, + var_id: VarId, + var_constr: HashSet, + erows: &UnifEnumRows<'ast>, + ) -> Result<(), RowUnifError<'ast>> { + match erows { + UnifEnumRows::Concrete { + // If the row is an enum tag (ie `typ` is `None`), it can't cause any conflict. + // See [RowConstrs] for more details. + erows: + EnumRowsF::Extend { + row: + row @ UnifEnumRow { + id: _, + typ: Some(_), + }, + .. + }, + .. + } if var_constr.contains(&row.id.ident()) => { + Err(RowUnifError::EnumRowConflict(row.clone())) + } + UnifEnumRows::Concrete { + erows: EnumRowsF::Extend { tail, .. }, + .. + } => propagate(constr, var_id, var_constr, tail), + UnifEnumRows::UnifVar { id, .. } if *id != var_id => { + if let Some(tail_constr) = constr.get_mut(id) { + tail_constr.extend(var_constr); + } else { + constr.insert(*id, var_constr); + } + + Ok(()) + } + _ => Ok(()), + } + } + + if let Some(var_constr) = constr.remove(&var_id) { + propagate(constr, var_id, var_constr, self) + } else { + Ok(()) + } + } +} + +/// Types which can be unified. +pub(super) trait Unify<'ast> { + type Error; + + /// Try to unify two types. Unification corresponds to imposing an equality constraints on + /// those types. This can fail if the types can't be matched. + fn unify( + self, + t2: Self, + state: &mut State<'ast, '_>, + ctxt: &Context<'ast>, + ) -> Result<(), Self::Error>; +} + +impl<'ast> Unify<'ast> for UnifType<'ast> { + type Error = UnifError<'ast>; + + fn unify( + self, + t2: UnifType<'ast>, + state: &mut State<'ast, '_>, + ctxt: &Context<'ast>, + ) -> Result<(), UnifError<'ast>> { + let t1 = self.into_root(state.table); + let t2 = t2.into_root(state.table); + + // t1 and t2 are roots of the type + match (t1, t2) { + // If either type is a wildcard, unify with the associated type var + ( + UnifType::Concrete { + typ: TypeF::Wildcard(id), + .. + }, + ty2, + ) + | ( + ty2, + UnifType::Concrete { + typ: TypeF::Wildcard(id), + .. + }, + ) => { + let ty1 = get_wildcard_var(state.table, ctxt.var_level, state.wildcard_vars, id); + ty1.unify(ty2, state, ctxt) + } + ( + UnifType::Concrete { + typ: s1, + var_levels_data: _, + }, + UnifType::Concrete { + typ: s2, + var_levels_data: _, + }, + ) => match (s1, s2) { + (TypeF::Dyn, TypeF::Dyn) + | (TypeF::Number, TypeF::Number) + | (TypeF::Bool, TypeF::Bool) + | (TypeF::String, TypeF::String) + | (TypeF::Symbol, TypeF::Symbol) => Ok(()), + (TypeF::Array(uty1), TypeF::Array(uty2)) => uty1.unify(*uty2, state, ctxt), + (TypeF::Arrow(s1s, s1t), TypeF::Arrow(s2s, s2t)) => { + s1s.clone() + .unify((*s2s).clone(), state, ctxt) + .map_err(|err| UnifError::DomainMismatch { + expected: UnifType::concrete(TypeF::Arrow(s1s.clone(), s1t.clone())), + inferred: UnifType::concrete(TypeF::Arrow(s2s.clone(), s2t.clone())), + cause: Box::new(err), + })?; + s1t.clone() + .unify((*s2t).clone(), state, ctxt) + .map_err(|err| UnifError::CodomainMismatch { + expected: UnifType::concrete(TypeF::Arrow(s1s, s1t)), + inferred: UnifType::concrete(TypeF::Arrow(s2s, s2t)), + cause: Box::new(err), + }) + } + (TypeF::Contract((t1, env1)), TypeF::Contract((t2, env2))) + if t1.type_eq(t2, &env1, &env2) => + { + Ok(()) + } + (TypeF::Enum(erows1), TypeF::Enum(erows2)) => erows1 + .clone() + .unify(erows2.clone(), state, ctxt) + .map_err(|err| { + err.into_unif_err(mk_buty_enum!(; erows1), mk_buty_enum!(; erows2)) + }), + (TypeF::Record(rrows1), TypeF::Record(rrows2)) => rrows1 + .clone() + .unify(rrows2.clone(), state, ctxt) + .map_err(|err| { + err.into_unif_err(mk_buty_record!(; rrows1), mk_buty_record!(; rrows2)) + }), + ( + TypeF::Dict { + type_fields: uty1, .. + }, + TypeF::Dict { + type_fields: uty2, .. + }, + ) => uty1.unify(*uty2, state, ctxt), + ( + TypeF::Forall { + var: var1, + var_kind: var_kind1, + body: body1, + }, + TypeF::Forall { + var: var2, + var_kind: var_kind2, + body: body2, + }, + ) if var_kind1 == var_kind2 => { + // Very stupid (slow) implementation + let (substd1, substd2) = match var_kind1 { + VarKind::Type => { + let constant_type = state.table.fresh_type_const(ctxt.var_level); + ( + body1.subst(&var1, &constant_type), + body2.subst(&var2, &constant_type), + ) + } + VarKind::RecordRows { .. } => { + let constant_type = state.table.fresh_rrows_const(ctxt.var_level); + ( + body1.subst(&var1, &constant_type), + body2.subst(&var2, &constant_type), + ) + } + VarKind::EnumRows { .. } => { + let constant_type = state.table.fresh_erows_const(ctxt.var_level); + ( + body1.subst(&var1, &constant_type), + body2.subst(&var2, &constant_type), + ) + } + }; + + substd1.unify(substd2, state, ctxt) + } + (TypeF::Var(ident), _) | (_, TypeF::Var(ident)) => { + Err(UnifError::UnboundTypeVariable(ident.into())) + } + (ty1, ty2) => Err(UnifError::TypeMismatch { + expected: UnifType::concrete(ty1), + inferred: UnifType::concrete(ty2), + }), + }, + (UnifType::UnifVar { id, .. }, uty) | (uty, UnifType::UnifVar { id, .. }) => { + // [^check-unif-var-level]: If we are unifying a variable with a rigid type + // variable, force potential unification variable level updates and check that the + // level of the unification variable is greater or equals to the constant: that is, + // that the variable doesn't "escape its scope". This is required to handle + // polymorphism soundly, and is the whole point of all the machinery around variable + // levels. + if let UnifType::Constant(cst_id) = uty { + let constant_level = state.table.get_level(cst_id); + state.table.force_type_updates(constant_level); + + if state.table.get_level(id) < constant_level { + return Err(UnifError::VarLevelMismatch { + constant_id: cst_id, + var_kind: VarKindDiscriminant::Type, + }); + } + } + + state.table.assign_type(id, uty); + Ok(()) + } + (UnifType::Constant(i1), UnifType::Constant(i2)) if i1 == i2 => Ok(()), + (UnifType::Constant(i1), UnifType::Constant(i2)) => Err(UnifError::ConstMismatch { + var_kind: VarKindDiscriminant::Type, + expected_const_id: i1, + inferred_const_id: i2, + }), + (ty, UnifType::Constant(i)) | (UnifType::Constant(i), ty) => { + Err(UnifError::WithConst { + var_kind: VarKindDiscriminant::Type, + expected_const_id: i, + inferred: ty, + }) + } + } + } +} + +impl<'ast> Unify<'ast> for UnifEnumRows<'ast> { + type Error = RowUnifError<'ast>; + + fn unify( + self, + uerows2: UnifEnumRows<'ast>, + state: &mut State<'ast, '_>, + ctxt: &Context<'ast>, + ) -> Result<(), RowUnifError<'ast>> { + let uerows1 = self.into_root(state.table); + let uerows2 = uerows2.into_root(state.table); + + match (uerows1, uerows2) { + ( + UnifEnumRows::Concrete { + erows: erows1, + var_levels_data: _, + }, + UnifEnumRows::Concrete { + erows: erows2, + var_levels_data: var_levels2, + }, + ) => match (erows1, erows2) { + (EnumRowsF::TailVar(id), _) | (_, EnumRowsF::TailVar(id)) => { + Err(RowUnifError::UnboundTypeVariable(id)) + } + (EnumRowsF::Empty, EnumRowsF::Empty) => Ok(()), + ( + EnumRowsF::Empty, + EnumRowsF::Extend { + row: UnifEnumRow { id, .. }, + .. + }, + ) => Err(RowUnifError::ExtraRow(id)), + ( + EnumRowsF::Extend { + row: UnifEnumRow { id, .. }, + .. + }, + EnumRowsF::Empty, + ) => Err(RowUnifError::MissingRow(id)), + (EnumRowsF::Extend { row, tail }, erows2 @ EnumRowsF::Extend { .. }) => { + let uerows2 = UnifEnumRows::Concrete { + erows: erows2, + var_levels_data: var_levels2, + }; + + let (ty2_result, t2_without_row) = + //TODO[adts]: it's ugly to create a temporary Option just to please the + //Box/Nobox types, we should find a better signature for remove_row + uerows2.remove_row(&row.id, &row.typ.clone().map(|typ| *typ), state, ctxt.var_level).map_err(|err| match err { + RemoveRowError::Missing => RowUnifError::MissingRow(row.id), + RemoveRowError::Conflict => RowUnifError::EnumRowConflict(row.clone()), + })?; + + // The alternative to this if-condition is `RemoveRowResult::Extended`, which + // means that `t2` could be successfully extended with the row `id typ`, in + // which case we don't have to perform additional unification for this specific + // row + if let RemoveRowResult::Extracted(ty2) = ty2_result { + match (row.typ, ty2) { + (Some(typ), Some(ty2)) => { + typ.unify(ty2, state, ctxt).map_err(|err| { + RowUnifError::EnumRowMismatch { + id: row.id, + cause: Some(Box::new(err)), + } + })?; + } + (Some(_), None) | (None, Some(_)) => { + return Err(RowUnifError::EnumRowMismatch { + id: row.id, + cause: None, + }); + } + (None, None) => (), + } + } + + tail.unify(t2_without_row, state, ctxt) + } + }, + (UnifEnumRows::UnifVar { id, init_level: _ }, uerows) + | (uerows, UnifEnumRows::UnifVar { id, init_level: _ }) => { + // see [^check-unif-var-level] + if let UnifEnumRows::Constant(cst_id) = uerows { + let constant_level = state.table.get_erows_level(cst_id); + state.table.force_erows_updates(constant_level); + + if state.table.get_erows_level(id) < constant_level { + return Err(RowUnifError::VarLevelMismatch { + constant_id: cst_id, + var_kind: VarKindDiscriminant::EnumRows, + }); + } + } + + uerows.propagate_constrs(state.constr, id)?; + state.table.assign_erows(id, uerows); + Ok(()) + } + (UnifEnumRows::Constant(i1), UnifEnumRows::Constant(i2)) if i1 == i2 => Ok(()), + (UnifEnumRows::Constant(i1), UnifEnumRows::Constant(i2)) => { + Err(RowUnifError::ConstMismatch { + var_kind: VarKindDiscriminant::EnumRows, + expected_const_id: i1, + inferred_const_id: i2, + }) + } + (uerows, UnifEnumRows::Constant(i)) | (UnifEnumRows::Constant(i), uerows) => { + //TODO ROWS: should we refactor RowUnifError as well? + Err(RowUnifError::WithConst { + var_kind: VarKindDiscriminant::EnumRows, + expected_const_id: i, + inferred: UnifType::concrete(TypeF::Enum(uerows)), + }) + } + } + } +} + +impl<'ast> Unify<'ast> for UnifRecordRows<'ast> { + type Error = RowUnifError<'ast>; + + fn unify( + self, + urrows2: UnifRecordRows<'ast>, + state: &mut State<'ast, '_>, + ctxt: &Context<'ast>, + ) -> Result<(), RowUnifError<'ast>> { + let urrows1 = self.into_root(state.table); + let urrows2 = urrows2.into_root(state.table); + + match (urrows1, urrows2) { + ( + UnifRecordRows::Concrete { + rrows: rrows1, + var_levels_data: _, + }, + UnifRecordRows::Concrete { + rrows: rrows2, + var_levels_data: var_levels2, + }, + ) => match (rrows1, rrows2) { + (RecordRowsF::TailVar(id), _) | (_, RecordRowsF::TailVar(id)) => { + Err(RowUnifError::UnboundTypeVariable(id)) + } + (RecordRowsF::Empty, RecordRowsF::Empty) + | (RecordRowsF::TailDyn, RecordRowsF::TailDyn) => Ok(()), + (RecordRowsF::Empty, RecordRowsF::TailDyn) => Err(RowUnifError::ExtraDynTail), + (RecordRowsF::TailDyn, RecordRowsF::Empty) => Err(RowUnifError::MissingDynTail), + ( + RecordRowsF::Empty, + RecordRowsF::Extend { + row: UnifRecordRow { id, .. }, + .. + }, + ) + | ( + RecordRowsF::TailDyn, + RecordRowsF::Extend { + row: UnifRecordRow { id, .. }, + .. + }, + ) => Err(RowUnifError::ExtraRow(id)), + ( + RecordRowsF::Extend { + row: UnifRecordRow { id, .. }, + .. + }, + RecordRowsF::TailDyn, + ) + | ( + RecordRowsF::Extend { + row: UnifRecordRow { id, .. }, + .. + }, + RecordRowsF::Empty, + ) => Err(RowUnifError::MissingRow(id)), + (RecordRowsF::Extend { row, tail }, rrows2 @ RecordRowsF::Extend { .. }) => { + let urrows2 = UnifRecordRows::Concrete { + rrows: rrows2, + var_levels_data: var_levels2, + }; + + let (ty2_result, urrows2_without_ty2) = urrows2 + .remove_row(&row.id, &row.typ, state, ctxt.var_level) + .map_err(|err| match err { + RemoveRowError::Missing => RowUnifError::MissingRow(row.id), + RemoveRowError::Conflict => { + RowUnifError::RecordRowConflict(row.clone()) + } + })?; + + // The alternative to this if-condition is `RemoveRowResult::Extended`, which + // means that `t2` could be successfully extended with the row `id typ`, in + // which case we don't have to perform additional unification for this specific + // row + if let RemoveRowResult::Extracted(ty2) = ty2_result { + row.typ.unify(ty2, state, ctxt).map_err(|err| { + RowUnifError::RecordRowMismatch { + id: row.id, + cause: Box::new(err), + } + })?; + } + + tail.unify(urrows2_without_ty2, state, ctxt) + } + }, + (UnifRecordRows::UnifVar { id, init_level: _ }, urrows) + | (urrows, UnifRecordRows::UnifVar { id, init_level: _ }) => { + // see [^check-unif-var-level] + if let UnifRecordRows::Constant(cst_id) = urrows { + let constant_level = state.table.get_rrows_level(cst_id); + state.table.force_rrows_updates(constant_level); + + if state.table.get_rrows_level(id) < constant_level { + return Err(RowUnifError::VarLevelMismatch { + constant_id: cst_id, + var_kind: VarKindDiscriminant::RecordRows, + }); + } + } + + urrows.propagate_constrs(state.constr, id)?; + state.table.assign_rrows(id, urrows); + Ok(()) + } + (UnifRecordRows::Constant(i1), UnifRecordRows::Constant(i2)) if i1 == i2 => Ok(()), + (UnifRecordRows::Constant(i1), UnifRecordRows::Constant(i2)) => { + Err(RowUnifError::ConstMismatch { + var_kind: VarKindDiscriminant::RecordRows, + expected_const_id: i1, + inferred_const_id: i2, + }) + } + (urrows, UnifRecordRows::Constant(i)) | (UnifRecordRows::Constant(i), urrows) => { + Err(RowUnifError::WithConst { + var_kind: VarKindDiscriminant::RecordRows, + expected_const_id: i, + inferred: UnifType::concrete(TypeF::Record(urrows)), + }) + } + } + } +} + +#[derive(Clone, Copy, Debug)] +pub(super) enum RemoveRowError { + // The row to add was missing and the row type was closed (no free unification variable in tail + // position). + Missing, + // The row to add was missing and the row type couldn't be extended because of row constraints. + Conflict, +} + +#[derive(Clone, Debug)] +pub enum RemoveRowResult { + Extracted(RowContent), + Extended, +} + +pub(super) trait RemoveRow<'ast>: Sized { + /// The row data minus the identifier. + type RowContent: Clone; + + /// Fetch a specific `row_id` from a row type, and return the content of the row together with + /// the original row type without the found row. + /// + /// If the searched row isn't found: + /// + /// - If the row type is extensible, i.e. it ends with a free unification variable in tail + /// position, this function adds the missing row (with `row.types` as a type for record rows, + /// if allowed by row constraints) and then acts as if `remove_row` was called again on + /// this extended row type. That is, `remove_row` returns the new row and the extended type + /// without the added row). + /// - Otherwise, raise a missing row error. + /// + /// # Motivation + /// + /// This method is used as part of row unification: let's say we want to unify `{ r1, ..tail1 + /// }` with `{ ..tail2 }` where `r1` is a row (the head of the left hand side rows), and + /// `tail1` and `tail2` are sequences of rows. + /// + /// For those to unify, we must have either: + /// + /// - `r1` is somewhere in `tail2`, and `tail1` unifies with `{..tail2'}` where `tail2'` is + /// `tail2` without `r1`. + /// - `tail2` is extensible, in which case we can extend `tail2` with `r1`, assuming that + /// `tail1` unifies with `{..tail2'}`, where `tail2'` is `tail2` after extending with `r1` + /// and then removing it. Modulo fresh unification variable shuffling, `tail2'` is in fact + /// isomorphic to `tail2` before it was extended. + /// + /// When we unify two row types, we destructure the left hand side to extract the head `r1` and + /// the tail `tail1`. Then, we try to find and extract `r1` from `tail2`. If `r1` was found, we + /// additionally unify the extracted type found in `tail2` (returned as part of + /// [RemoveRowResult::Extracted]) with `r1.typ` to make sure they agree. In case of extension, + /// we were free to chose the type of the new added row, which we set to be `r1.typ` (the + /// `row_content` parameter of `remove_row`), and there's no additional check to perform (and + /// indeed [RemoveRowResult::Extended] doesn't carry any information). + /// + /// Finally, since `remove_row` returns the initial row type minus the extracted row, we can go + /// on recursively and unify `tail1` with this rest. + /// + /// # Parameters + /// + /// - `row_id`: the identifier of the row to extract + /// - `row_content`: as explained above, `remove_row` is used in the context of unifying two row + /// types. If `self` doesn't contain `row_id` but is extensible, we must add a corresponding + /// new row: we fill it with `row_content`. In the context of unification, the is the content of + /// the row coming from the other row type. + /// - `state`: the unification state + /// - `var_level`: the ambient variable level + fn remove_row( + self, + row_id: &LocIdent, + row_content: &Self::RowContent, + state: &mut State<'ast, '_>, + var_level: VarLevel, + ) -> Result<(RemoveRowResult, Self), RemoveRowError>; +} + +impl<'ast> RemoveRow<'ast> for UnifRecordRows<'ast> { + type RowContent = UnifType<'ast>; + + fn remove_row( + self, + target: &LocIdent, + target_content: &Self::RowContent, + state: &mut State<'ast, '_>, + var_level: VarLevel, + ) -> Result<(RemoveRowResult, UnifRecordRows<'ast>), RemoveRowError> { + let rrows = self.into_root(state.table); + + match rrows { + UnifRecordRows::Concrete { rrows, .. } => match rrows { + RecordRowsF::Empty | RecordRowsF::TailDyn | RecordRowsF::TailVar(_) => { + Err(RemoveRowError::Missing) + } + RecordRowsF::Extend { + row: next_row, + tail, + } => { + if target.ident() == next_row.id.ident() { + Ok((RemoveRowResult::Extracted(*next_row.typ), *tail)) + } else { + let (extracted_row, rest) = + tail.remove_row(target, target_content, state, var_level)?; + Ok(( + extracted_row, + UnifRecordRows::concrete(RecordRowsF::Extend { + row: next_row, + tail: Box::new(rest), + }), + )) + } + } + }, + UnifRecordRows::UnifVar { id: var_id, .. } => { + let tail_var_id = state.table.fresh_rrows_var_id(var_level); + // We have to manually insert the constraint that `tail_var_id` can't contain a row + // `target`, to avoid producing ill-formed record rows later + state + .constr + .insert(tail_var_id, HashSet::from([target.ident()])); + + let row_to_insert = UnifRecordRow { + id: *target, + typ: Box::new(target_content.clone()), + }; + + let tail_var = UnifRecordRows::UnifVar { + id: tail_var_id, + init_level: var_level, + }; + + let tail_extended = UnifRecordRows::concrete(RecordRowsF::Extend { + row: row_to_insert, + tail: Box::new(tail_var.clone()), + }); + + tail_extended + .propagate_constrs(state.constr, var_id) + .map_err(|_| RemoveRowError::Conflict)?; + state.table.assign_rrows(var_id, tail_extended); + + Ok((RemoveRowResult::Extended, tail_var)) + } + UnifRecordRows::Constant(_) => Err(RemoveRowError::Missing), + } + } +} + +impl<'ast> RemoveRow<'ast> for UnifEnumRows<'ast> { + type RowContent = Option>; + + fn remove_row( + self, + target: &LocIdent, + target_content: &Self::RowContent, + state: &mut State<'ast, '_>, + var_level: VarLevel, + ) -> Result<(RemoveRowResult, UnifEnumRows<'ast>), RemoveRowError> { + let uerows = self.into_root(state.table); + + match uerows { + UnifEnumRows::Concrete { erows, .. } => match erows { + EnumRowsF::Empty | EnumRowsF::TailVar(_) => Err(RemoveRowError::Missing), + EnumRowsF::Extend { + row: next_row, + tail, + } => { + // Enum variants and enum tags don't conflict, and can thus coexist in the same + // row type (for example, [| 'Foo Number, 'Foo |]). In some sense, they live + // inside different dimensions. Thus, when matching rows, we don't only compare + // the tag but also the nature of the enum row (tag vs variant) + if target.ident() == next_row.id.ident() + && target_content.is_some() == next_row.typ.is_some() + { + Ok(( + RemoveRowResult::Extracted(next_row.typ.map(|typ| *typ)), + *tail, + )) + } else { + let (extracted_row, rest) = + tail.remove_row(target, target_content, state, var_level)?; + Ok(( + extracted_row, + UnifEnumRows::concrete(EnumRowsF::Extend { + row: next_row, + tail: Box::new(rest), + }), + )) + } + } + }, + UnifEnumRows::UnifVar { id: var_id, .. } => { + let tail_var_id = state.table.fresh_erows_var_id(var_level); + + // Enum tag are ignored for row conflict. See [RowConstrs] + if target_content.is_some() { + state + .constr + .insert(tail_var_id, HashSet::from([target.ident()])); + } + + let row_to_insert = UnifEnumRow { + id: *target, + typ: target_content.clone().map(Box::new), + }; + + let tail_var = UnifEnumRows::UnifVar { + id: tail_var_id, + init_level: var_level, + }; + + let tail_extended = UnifEnumRows::concrete(EnumRowsF::Extend { + row: row_to_insert, + tail: Box::new(tail_var.clone()), + }); + + tail_extended + .propagate_constrs(state.constr, var_id) + .map_err(|_| RemoveRowError::Conflict)?; + state.table.assign_erows(var_id, tail_extended); + + Ok((RemoveRowResult::Extended, tail_var)) + } + UnifEnumRows::Constant(_) => Err(RemoveRowError::Missing), + } + } +} diff --git a/core/src/combine.rs b/core/src/combine.rs index f508089b79..67339fe15a 100644 --- a/core/src/combine.rs +++ b/core/src/combine.rs @@ -6,7 +6,7 @@ use crate::bytecode::ast::AstAlloc; /// Trait for structures representing a series of annotation that can be combined (flattened). /// Pedantically, `Combine` is just a monoid. -pub trait Combine: Default { +pub trait Combine { /// Combine two elements. fn combine(left: Self, right: Self) -> Self; } diff --git a/core/src/identifier.rs b/core/src/identifier.rs index 705bdecfb4..6225cbc57b 100644 --- a/core/src/identifier.rs +++ b/core/src/identifier.rs @@ -206,6 +206,24 @@ where } } +/// Wrapper around [Ident] with a fast ordering function that only compares the underlying symbols. +/// Useful when a bunch of idents need to be sorted for algorithmic reasons, but one doesn't need +/// the actual natural order on strings nor care about the specific order. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct FastOrdIdent(pub Ident); + +impl PartialOrd for FastOrdIdent { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for FastOrdIdent { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.0 .0.cmp(&other.0 .0) + } +} + // False-positive Clippy error: if we apply this suggestion, // we end up with an implementation of `From for String`. // Then setting `F = Ident` in the implementation above gives @@ -244,7 +262,7 @@ mod interner { /// A symbol is a correspondence between an [Ident](super::Ident) and its string representation /// stored in the [Interner]. - #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct Symbol(u32); /// The interner, which serves a double purpose: it pre-allocates space diff --git a/core/src/label.rs b/core/src/label.rs index 5314bf2f46..2956e89d8b 100644 --- a/core/src/label.rs +++ b/core/src/label.rs @@ -340,6 +340,7 @@ impl ReifyAsUnifType for Polarity { mk_uty_enum!("Positive", "Negative") } } + /// A polarity. See [`Label`] #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Polarity { diff --git a/core/src/lib.rs b/core/src/lib.rs index 59031df64a..327dd932b1 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -21,6 +21,7 @@ pub mod serialize; pub mod stdlib; pub mod term; pub mod transform; +pub mod traverse; pub mod typ; pub mod typecheck; diff --git a/core/src/parser/uniterm.rs b/core/src/parser/uniterm.rs index c011a57448..ed3fffffb8 100644 --- a/core/src/parser/uniterm.rs +++ b/core/src/parser/uniterm.rs @@ -8,7 +8,7 @@ use crate::{ self, record::{FieldDef, FieldMetadata, FieldPathElem}, typ::{EnumRow, EnumRows, RecordRow, RecordRows, Type}, - Annotation, Ast, AstAlloc, MergePriority, Node, + *, }, environment::Environment, identifier::Ident, @@ -78,21 +78,6 @@ impl UniTerm<'_> { } } -/// Similar to `TryFrom`, but takes an additional allocator for conversion from and to -/// [crate::bytecode::ast::Ast] that requires to thread an explicit allocator. -/// -/// We chose a different name than `try_from` for the method - although it has a different -/// signature from the standard `TryFrom` (two arguments vs one) - to avoid confusing the compiler -/// which would otherwise have difficulties disambiguating calls like `Ast::try_from`. -pub(crate) trait TryConvert<'ast, T> -where - Self: Sized, -{ - type Error; - - fn try_convert(alloc: &'ast AstAlloc, from: T) -> Result; -} - // For nodes such as `Type` or `Record`, the following implementation has to choose between two // positions to use: the one of the wrapping `UniTerm`, and the one stored inside the `RichTerm` or // the `Type`. This implementation assumes that the latest set is the one of `UniTerm`, which is diff --git a/core/src/position.rs b/core/src/position.rs index 46645c7154..884ef2b785 100644 --- a/core/src/position.rs +++ b/core/src/position.rs @@ -127,6 +127,15 @@ impl TermPos { } } + /// Return either `self` or `other` if and only if exactly one of them is defined. If both are + /// `None` or both are defined, `None` is returned. + pub fn xor(self, other: Self) -> Self { + match (self, other) { + (defn, TermPos::None) | (TermPos::None, defn) => defn, + _ => TermPos::None, + } + } + /// Determine is the position is defined. Return `false` if it is `None`, and `true` otherwise. pub fn is_def(&self) -> bool { matches!(self, TermPos::Original(_) | TermPos::Inherited(_)) diff --git a/core/src/repl/mod.rs b/core/src/repl/mod.rs index 03998d024b..c6feb1b471 100644 --- a/core/src/repl/mod.rs +++ b/core/src/repl/mod.rs @@ -6,30 +6,34 @@ //! Dually, the frontend is the user-facing part, which may be a CLI, a web application, a //! jupyter-kernel (which is not exactly user-facing, but still manages input/output and //! formatting), etc. -use crate::cache::{Cache, Envs, ErrorTolerance, InputFormat, SourcePath}; -use crate::error::NullReporter; -use crate::error::{ - report::{self, ColorOpt, ErrorFormat}, - Error, EvalError, IOError, IntoDiagnostics, ParseError, ParseErrors, ReplError, +use crate::{ + cache::{Cache, Envs, ErrorTolerance, InputFormat, SourcePath}, + error::{ + report::{self, ColorOpt, ErrorFormat}, + Error, EvalError, IOError, IntoDiagnostics, NullReporter, ParseError, ParseErrors, + ReplError, + }, + eval::{self, cache::Cache as EvalCache, Closure, VirtualMachine}, + files::FileId, + identifier::LocIdent, + parser::{grammar, lexer, ErrorTolerantParserCompat, ExtendedTerm}, + program::FieldPath, + term::{record::Field, RichTerm, Term}, + transform::{self, import_resolution}, + traverse::{Traverse, TraverseOrder}, + typ::Type, + typecheck::{self, TypecheckMode}, }; -use crate::eval::cache::Cache as EvalCache; -use crate::eval::{Closure, VirtualMachine}; -use crate::files::FileId; -use crate::identifier::LocIdent; -use crate::parser::{grammar, lexer, ErrorTolerantParserCompat, ExtendedTerm}; -use crate::program::FieldPath; -use crate::term::TraverseOrder; -use crate::term::{record::Field, RichTerm, Term, Traverse}; -use crate::transform::import_resolution; -use crate::typ::Type; -use crate::typecheck::TypecheckMode; -use crate::{eval, transform, typecheck}; + use simple_counter::*; -use std::convert::Infallible; -use std::ffi::{OsStr, OsString}; -use std::io::Write; -use std::result::Result; -use std::str::FromStr; + +use std::{ + convert::Infallible, + ffi::{OsStr, OsString}, + io::Write, + result::Result, + str::FromStr, +}; #[cfg(feature = "repl")] use ansi_term::{Colour, Style}; diff --git a/core/src/term/mod.rs b/core/src/term/mod.rs index 364b756fdb..a92b673ed6 100644 --- a/core/src/term/mod.rs +++ b/core/src/term/mod.rs @@ -32,6 +32,7 @@ use crate::{ match_sharedterm, position::{RawSpan, TermPos}, pretty::PrettyPrintCap, + traverse::*, typ::{Type, UnboundTypeVariableError}, typecheck::eq::{contract_eq, type_eq_noenv}, }; @@ -2110,12 +2111,6 @@ impl fmt::Display for NAryOp { } } -#[derive(Copy, Clone)] -pub enum TraverseOrder { - TopDown, - BottomUp, -} - /// Wrap [Term] with positional information. #[derive(Debug, PartialEq, Clone)] pub struct RichTerm { @@ -2173,78 +2168,6 @@ impl RichTerm { impl PrettyPrintCap for RichTerm {} -/// Flow control for tree traverals. -pub enum TraverseControl { - /// Normal control flow: continue recursing into the children. - /// - /// Pass the state &S to all children. - ContinueWithScope(S), - /// Normal control flow: continue recursing into the children. - /// - /// The state that was passed to the parent will be re-used for the children. - Continue, - - /// Skip this branch of the tree. - SkipBranch, - - /// Finish traversing immediately (and return a value). - Return(U), -} - -impl From> for TraverseControl { - fn from(value: Option) -> Self { - match value { - Some(u) => TraverseControl::Return(u), - None => TraverseControl::Continue, - } - } -} - -pub trait Traverse: Sized { - /// Apply a transformation on a object containing syntactic elements of type `T` (terms, types, - /// etc.) by mapping a faillible function `f` on each such node as prescribed by the order. - /// - /// `f` may return a generic error `E` and use the state `S` which is passed around. - fn traverse(self, f: &mut F, order: TraverseOrder) -> Result - where - F: FnMut(T) -> Result; - - /// Recurse through the tree of objects top-down (a.k.a. pre-order), applying `f` to - /// each object. - /// - /// Through its return value, `f` can short-circuit one branch of the traversal or - /// the entire traversal. - /// - /// This traversal can make use of "scoped" state. The `scope` argument is passed to - /// each callback, and the callback can optionally override that scope just for its - /// own subtree in the traversal. For example, when traversing a tree of terms you can - /// maintain an environment. Most of the time the environment should get passed around - /// unchanged, but a `Term::Let` should override the environment of its subtree. It - /// does this by returning a `TraverseControl::ContinueWithScope` that contains the - /// new environment. - fn traverse_ref( - &self, - f: &mut dyn FnMut(&T, &S) -> TraverseControl, - scope: &S, - ) -> Option; - - fn find_map(&self, mut pred: impl FnMut(&T) -> Option) -> Option - where - T: Clone, - { - self.traverse_ref( - &mut |t, _state: &()| { - if let Some(s) = pred(t) { - TraverseControl::Return(s) - } else { - TraverseControl::Continue - } - }, - &(), - ) - } -} - impl Traverse for RichTerm { /// Traverse through all `RichTerm`s in the tree. /// diff --git a/core/src/transform/import_resolution.rs b/core/src/transform/import_resolution.rs index 6741fbb919..5a5bb51e44 100644 --- a/core/src/transform/import_resolution.rs +++ b/core/src/transform/import_resolution.rs @@ -56,7 +56,7 @@ pub mod strict { /// Resolve the import if the term is an unresolved import, or return the term unchanged. This /// function is not recursive, and is to be used in conjunction with e.g. - /// [crate::term::Traverse]. + /// [crate::traverse::Traverse]. pub fn transform_one( rt: RichTerm, resolver: &mut R, @@ -77,9 +77,12 @@ pub mod strict { /// together with a (partially) resolved term. pub mod tolerant { use super::ImportResolver; - use crate::error::ImportError; - use crate::files::FileId; - use crate::term::{RichTerm, Term, Traverse, TraverseOrder}; + use crate::{ + error::ImportError, + files::FileId, + term::{RichTerm, Term}, + traverse::{Traverse, TraverseOrder}, + }; /// The result of an error tolerant import resolution. #[derive(Debug)] diff --git a/core/src/transform/mod.rs b/core/src/transform/mod.rs index 426d10f019..872a526c79 100644 --- a/core/src/transform/mod.rs +++ b/core/src/transform/mod.rs @@ -1,7 +1,8 @@ //! Program transformations. use crate::{ cache::ImportResolver, - term::{RichTerm, Traverse, TraverseOrder}, + term::RichTerm, + traverse::{Traverse, TraverseOrder}, typ::UnboundTypeVariableError, typecheck::Wildcards, }; diff --git a/core/src/transform/substitute_wildcards.rs b/core/src/transform/substitute_wildcards.rs index 36dc273e0f..8e4352467a 100644 --- a/core/src/transform/substitute_wildcards.rs +++ b/core/src/transform/substitute_wildcards.rs @@ -11,8 +11,9 @@ use crate::{ match_sharedterm, term::{ record::{Field, FieldMetadata, RecordData}, - LabeledType, RichTerm, Term, Traverse, TraverseOrder, TypeAnnotation, + LabeledType, RichTerm, Term, TypeAnnotation, }, + traverse::{Traverse, TraverseOrder}, typ::{Type, TypeF}, typecheck::Wildcards, }; diff --git a/core/src/traverse.rs b/core/src/traverse.rs new file mode 100644 index 0000000000..4ef6efb515 --- /dev/null +++ b/core/src/traverse.rs @@ -0,0 +1,139 @@ +//! Traversal of trees of objects. + +use crate::bytecode::ast::{Allocable, AstAlloc}; + +#[derive(Copy, Clone)] +pub enum TraverseOrder { + TopDown, + BottomUp, +} + +/// Flow control for tree traverals. +pub enum TraverseControl { + /// Normal control flow: continue recursing into the children. + /// + /// Pass the state &S to all children. + ContinueWithScope(S), + /// Normal control flow: continue recursing into the children. + /// + /// The state that was passed to the parent will be re-used for the children. + Continue, + + /// Skip this branch of the tree. + SkipBranch, + + /// Finish traversing immediately (and return a value). + Return(U), +} + +impl From> for TraverseControl { + fn from(value: Option) -> Self { + match value { + Some(u) => TraverseControl::Return(u), + None => TraverseControl::Continue, + } + } +} + +pub trait Traverse: Sized { + /// Apply a transformation on a object containing syntactic elements of type `T` (terms, types, + /// etc.) by mapping a faillible function `f` on each such node as prescribed by the order. + /// + /// `f` may return a generic error `E` and use the state `S` which is passed around. + fn traverse(self, f: &mut F, order: TraverseOrder) -> Result + where + F: FnMut(T) -> Result; + + /// Recurse through the tree of objects top-down (a.k.a. pre-order), applying `f` to + /// each object. + /// + /// Through its return value, `f` can short-circuit one branch of the traversal or + /// the entire traversal. + /// + /// This traversal can make use of "scoped" state. The `scope` argument is passed to + /// each callback, and the callback can optionally override that scope just for its + /// own subtree in the traversal. For example, when traversing a tree of terms you can + /// maintain an environment. Most of the time the environment should get passed around + /// unchanged, but a let binder should override the environment of its subtree. It + /// does this by returning a `TraverseControl::ContinueWithScope` that contains the + /// new environment. + fn traverse_ref( + &self, + f: &mut dyn FnMut(&T, &S) -> TraverseControl, + scope: &S, + ) -> Option; + + fn find_map(&self, mut pred: impl FnMut(&T) -> Option) -> Option + where + T: Clone, + { + self.traverse_ref( + &mut |t, _state: &()| { + if let Some(s) = pred(t) { + TraverseControl::Return(s) + } else { + TraverseControl::Continue + } + }, + &(), + ) + } +} + +/// Similar to [Traverse], but takes an additional AST allocator for AST components that require +/// such an allocator in order to build the result. +pub trait TraverseAlloc<'ast, T>: Sized { + /// Same as [Traverse::traverse], but takes an additional AST allocator. + fn traverse( + self, + alloc: &'ast AstAlloc, + f: &mut F, + order: TraverseOrder, + ) -> Result + where + F: FnMut(T) -> Result; + + /// Same as [Traverse::traverse_ref], but takes an additional AST allocator. + fn traverse_ref( + &self, + f: &mut dyn FnMut(&T, &S) -> TraverseControl, + scope: &S, + ) -> Option; + + fn find_map(&self, mut pred: impl FnMut(&T) -> Option) -> Option + where + T: Clone, + { + self.traverse_ref( + &mut |t, _state: &()| { + if let Some(s) = pred(t) { + TraverseControl::Return(s) + } else { + TraverseControl::Continue + } + }, + &(), + ) + } +} + +/// Takes an iterator whose item type implements [TraverseAlloc], traverse each element, and +/// collect the result as a slice allocated via `alloc`. +pub fn traverse_alloc_many<'ast, T, U, I, F, E>( + alloc: &'ast AstAlloc, + it: I, + f: &mut F, + order: TraverseOrder, +) -> Result<&'ast [U], E> +where + U: TraverseAlloc<'ast, T> + Sized + Allocable, + I: IntoIterator, + F: FnMut(T) -> Result, +{ + let collected: Result, E> = it + .into_iter() + .map(|elt| elt.traverse(alloc, f, order)) + .collect(); + + Ok(alloc.alloc_many(collected?)) +} diff --git a/core/src/typ.rs b/core/src/typ.rs index 5b51175c62..5b58f9ef10 100644 --- a/core/src/typ.rs +++ b/core/src/typ.rs @@ -54,8 +54,9 @@ use crate::{ term::pattern::compile::Compile, term::{ array::Array, make as mk_term, record::RecordData, string::NickelString, IndexMap, - MatchBranch, MatchData, RichTerm, Term, Traverse, TraverseControl, TraverseOrder, + MatchBranch, MatchData, RichTerm, Term, }, + traverse::*, }; use std::{collections::HashSet, convert::Infallible}; @@ -258,7 +259,7 @@ pub enum DictTypeFlavour { /// - `RRows`: the recursive unfolding of record rows /// - `ERows`: the recursive unfolding of enum rows /// - `Te`: the type of a term (used to store contracts) -#[derive(Clone, PartialEq, Debug)] +#[derive(Clone, PartialEq, Eq, Debug)] pub enum TypeF { /// The dynamic type, or unitype. Assigned to values whose actual type is not statically known /// or checked. diff --git a/core/src/typecheck/mod.rs b/core/src/typecheck/mod.rs index c3a9e98736..df9bd80955 100644 --- a/core/src/typecheck/mod.rs +++ b/core/src/typecheck/mod.rs @@ -62,8 +62,9 @@ use crate::{ mk_uty_arrow, mk_uty_enum, mk_uty_record, mk_uty_record_row, stdlib as nickel_stdlib, term::{ pattern::bindings::Bindings as _, record::Field, LabeledType, MatchBranch, RichTerm, - StrChunk, Term, Traverse, TraverseOrder, TypeAnnotation, + StrChunk, Term, TypeAnnotation, }, + traverse::{Traverse, TraverseOrder}, typ::*, }; diff --git a/lsp/nls/src/analysis.rs b/lsp/nls/src/analysis.rs index 248f0bb420..14ab6b741e 100644 --- a/lsp/nls/src/analysis.rs +++ b/lsp/nls/src/analysis.rs @@ -4,7 +4,8 @@ use nickel_lang_core::{ files::FileId, identifier::Ident, position::RawSpan, - term::{BinaryOp, RichTerm, Term, Traverse, TraverseControl, UnaryOp}, + term::{BinaryOp, RichTerm, Term, UnaryOp}, + traverse::{Traverse, TraverseControl}, typ::{Type, TypeF}, typecheck::{ reporting::{NameReg, ToType}, diff --git a/lsp/nls/src/cache.rs b/lsp/nls/src/cache.rs index 995a2f98ac..e006ba98c8 100644 --- a/lsp/nls/src/cache.rs +++ b/lsp/nls/src/cache.rs @@ -1,12 +1,13 @@ use codespan::ByteIndex; use lsp_types::{TextDocumentPositionParams, Url}; -use nickel_lang_core::cache::InputFormat; -use nickel_lang_core::term::{RichTerm, Term, Traverse}; use nickel_lang_core::{ + cache::InputFormat, cache::{Cache, CacheError, CacheOp, EntryState, SourcePath, TermEntry}, error::{Error, ImportError}, files::FileId, position::RawPos, + term::{RichTerm, Term}, + traverse::Traverse, typecheck::{self}, }; diff --git a/lsp/nls/src/position.rs b/lsp/nls/src/position.rs index 477144476a..66524e6fa3 100644 --- a/lsp/nls/src/position.rs +++ b/lsp/nls/src/position.rs @@ -3,7 +3,8 @@ use std::ops::Range; use codespan::ByteIndex; use nickel_lang_core::{ position::TermPos, - term::{pattern::bindings::Bindings, RichTerm, Term, Traverse, TraverseControl}, + term::{pattern::bindings::Bindings, RichTerm, Term}, + traverse::{Traverse, TraverseControl}, }; use crate::{identifier::LocIdent, term::RichTermPtr}; diff --git a/lsp/nls/src/usage.rs b/lsp/nls/src/usage.rs index 588baccf50..6187f7e671 100644 --- a/lsp/nls/src/usage.rs +++ b/lsp/nls/src/usage.rs @@ -4,7 +4,8 @@ use nickel_lang_core::{ environment::Environment as GenericEnvironment, identifier::Ident, position::RawSpan, - term::{pattern::bindings::Bindings, MatchData, RichTerm, Term, Traverse, TraverseControl}, + term::{pattern::bindings::Bindings, MatchData, RichTerm, Term}, + traverse::{Traverse, TraverseControl}, }; use crate::{field_walker::Def, identifier::LocIdent};