From 9455fe72962e6ea3ee71fabc0aab4621a2d4e254 Mon Sep 17 00:00:00 2001 From: sezna Date: Tue, 30 Jul 2024 18:49:38 -0400 Subject: [PATCH 01/10] wip: TDD --- petr-playground/index.js | 2 ++ petr-typecheck/src/lib.rs | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/petr-playground/index.js b/petr-playground/index.js index 5c5b206..af5ac79 100644 --- a/petr-playground/index.js +++ b/petr-playground/index.js @@ -18,6 +18,7 @@ monaco.languages.setMonarchTokensProvider("petr", { [/\@[a-zA-Z_]+/, "intrinsic"], [/[0-9]+/, "integer-literal"], [/\".*\"/, "string-literal"], + [/\{-.*-\}/, "comment"], ], }, }); @@ -52,6 +53,7 @@ monaco.editor.defineTheme("petr-theme", { { token: "string-literal", foreground: literalColor }, { token: "integer-literal", foreground: literalColor }, { token: "keyword", foreground: literalColor }, + { token: "comment", foreground: "C4A484", fontStyle: "italic"}, ], colors: { "editor.foreground": "#ffffff", diff --git a/petr-typecheck/src/lib.rs b/petr-typecheck/src/lib.rs index a674708..a1c474d 100644 --- a/petr-typecheck/src/lib.rs +++ b/petr-typecheck/src/lib.rs @@ -2115,4 +2115,25 @@ fn main() returns 'int ~hi(1, 2)"#, "#]], ) } + + #[test] + fn disallow_wrong_sum_type_in_add() { + check( + r#" + type IntBelowFive = 1 | 2 | 3 | 4 | 5 + {- reject an `add` which may return an int above five -} + fn add(a in 'IntBelowFive, b in 'IntBelowFive) returns 'IntBelowFive @add(a, b) +"#, +expect![[r#""#]]) + } + #[test] + fn allow_wrong_sum_type_in_add() { + check( + r#" + type IntBelowFive = 1 | 2 | 3 | 4 | 5 + {- reject an `add` which may return an int above five -} + fn add(a in 'IntBelowFive, b in 'IntBelowFive) returns 'int @add(a, b) +"#, +expect![[r#""#]]) + } } From 8df47d006c50991b00389d331bf98a9880083b08 Mon Sep 17 00:00:00 2001 From: sezna Date: Tue, 30 Jul 2024 20:19:18 -0700 Subject: [PATCH 02/10] fix unification and satisfy rules --- petr-typecheck/src/lib.rs | 169 +++++++++++++++++++++++++++++++------- 1 file changed, 140 insertions(+), 29 deletions(-) diff --git a/petr-typecheck/src/lib.rs b/petr-typecheck/src/lib.rs index a1c474d..4fc4b1e 100644 --- a/petr-typecheck/src/lib.rs +++ b/petr-typecheck/src/lib.rs @@ -206,6 +206,21 @@ pub enum GeneralType { Sum(BTreeSet), } +impl GeneralType { + /// Because [`GeneralType`]'s type info is less detailed (specific) than [`SpecificType`], + /// we can losslessly cast any [`GeneralType`] into an instance of [`SpecificType`]. + pub fn safely_upcast(&self) -> SpecificType { + match self { + GeneralType::Unit => SpecificType::Unit, + GeneralType::Integer => SpecificType::Integer, + GeneralType::Boolean => SpecificType::Boolean, + GeneralType::String => SpecificType::String, + GeneralType::ErrorRecovery => SpecificType::ErrorRecovery, + _ => todo!(), + } + } +} + /// This is an information-rich type -- it tracks effects and data types. It is used for /// the type-checking stage to provide rich information to the user. /// Types are generalized into instances of [`GeneralType`] for monomorphization and @@ -297,23 +312,32 @@ impl SpecificType { } } - /// If `self` is a generalized form of `sum_tys`, return true + /// If `self` is a generalized form of `b`, return true /// A generalized form is a type that is a superset of the sum types. /// For example, `String` is a generalized form of `Sum(Literal("a") | Literal("B"))` - fn is_subset_of( + fn is_superset_of( &self, - sum_tys: &BTreeSet, + b: &SpecificType, ctx: &TypeContext, ) -> bool { - use petr_resolve::Literal; use SpecificType::*; - match self { - String => sum_tys.iter().all(|ty| matches!(ty, Literal(Literal::String(_)))), - Integer => sum_tys.iter().all(|ty| matches!(ty, Literal(Literal::Integer(_)))), - Boolean => sum_tys.iter().all(|ty| matches!(ty, Literal(Literal::Boolean(_)))), - Ref(ty) => ctx.types.get(*ty).is_subset_of(sum_tys, ctx), - Sum(tys) => tys.iter().all(|ty| sum_tys.contains(ty)), - _ => false, + let generalized_b = b.generalize(ctx).safely_upcast(); + match (self, b) { + (a, b) if a == b || *a == generalized_b => true, + (Sum(a_tys), b) if a_tys.contains(b) || a_tys.contains(&generalized_b) => true, + (Sum(a_tys), Sum(b_tys)) => { + // if a_tys is a superset of b_tys, + // every element OR its generalized version is contained in a_tys + for b_ty in b_tys { + let b_ty_generalized = b_ty.generalize(ctx).safely_upcast(); + if !(a_tys.contains(b_ty) || a_tys.contains(&b_ty_generalized)) { + return false; + } + } + + true + }, + _ => todo!(), } } } @@ -551,18 +575,31 @@ impl TypeChecker { // to the first self.ctx.update_type(t2, Ref(t1)); }, - (Sum(a_tys), Sum(b_tys)) => { - // the unification of two sum types is the union of the two types - let union = a_tys.iter().chain(b_tys.iter()).cloned().collect(); - self.ctx.update_type(t1, Sum(union)); - self.ctx.update_type(t2, Ref(t1)); + (a @ Sum(_), b @ Sum(_)) => { + // the unification of two sum types is the union of the two types if and only if + // `t2` is a total subset of `t1` + // `t1` remains unchanged, as we are trying to coerce `t2` into something that + // represents `t1` + // TODO remove clone + if a.is_superset_of(&b, &self.ctx) { + } else { + self.push_error(span.with_item(self.unify_err(a, b))); + } }, + // If `t2` is a non-sum type, and `t1` is a sum type, then `t1` must contain either + // exactly the same specific type OR the generalization of that type + // If the latter, then the specific type must be updated to its generalization (Sum(sum_tys), other) => { - // the unfication of a sum type and another type is the sum type plus the other - // type - let union: BTreeSet<_> = sum_tys.iter().cloned().chain(std::iter::once(other)).collect(); - self.ctx.update_type(t1, Sum(union)); - self.ctx.update_type(t2, Ref(t1)); + if sum_tys.contains(&other) { + self.ctx.update_type(t2, Ref(t1)); + } else { + let generalization = other.generalize(&self.ctx).safely_upcast(); + if sum_tys.contains(&generalization) { + self.ctx.update_type(t2, generalization); + } else { + self.push_error(span.with_item(self.unify_err(Sum(sum_tys.clone()), generalization))); + } + } }, // literals can unify to each other if they're equal (Literal(l1), Literal(l2)) if l1 == l2 => (), @@ -647,15 +684,13 @@ impl TypeChecker { let intersection = a_tys.iter().filter(|a_ty| b_tys.contains(a_ty)).cloned().collect(); self.ctx.update_type(t2, Sum(intersection)); }, - (Sum(sum_tys), other) | (other, Sum(sum_tys)) => { + (ty1 @ Sum(sum_tys), other) => { if // if `other` is a generalized version of the sum type, // then it satisfies the sum type - other.is_subset_of(sum_tys, &self.ctx) || - // `other` must be a member of the Sum type - sum_tys.contains(other) { + ty1.is_superset_of(other, &self.ctx) { } else { - self.push_error(span.with_item(self.satisfy_err(other.clone(), SpecificType::Sum(sum_tys.iter().cloned().collect())))); + self.push_error(span.with_item(self.satisfy_err(SpecificType::Sum(sum_tys.iter().cloned().collect()), other.clone()))); } }, (Literal(l1), Literal(l2)) if l1 == l2 => (), @@ -1519,7 +1554,7 @@ mod pretty_printing { } if !type_checker.errors.is_empty() { - s.push_str("\n__ERRORS__\n"); + s.push_str("\n\n__ERRORS__\n"); for error in type_checker.errors { s.push_str(&format!("{:?}\n", error)); } @@ -2089,6 +2124,8 @@ fn main() returns 'int ~hi(1, 2)"#, ) } + // TODO remove ignore before merging + #[ignore] #[test] fn disallow_incorrect_constant_bool() { check( @@ -2116,6 +2153,8 @@ fn main() returns 'int ~hi(1, 2)"#, ) } + // TODO remove ignore before merging + #[ignore] #[test] fn disallow_wrong_sum_type_in_add() { check( @@ -2124,8 +2163,11 @@ fn main() returns 'int ~hi(1, 2)"#, {- reject an `add` which may return an int above five -} fn add(a in 'IntBelowFive, b in 'IntBelowFive) returns 'IntBelowFive @add(a, b) "#, -expect![[r#""#]]) + expect![[r#""#]], + ) } + + #[ignore] #[test] fn allow_wrong_sum_type_in_add() { check( @@ -2134,6 +2176,75 @@ expect![[r#""#]]) {- reject an `add` which may return an int above five -} fn add(a in 'IntBelowFive, b in 'IntBelowFive) returns 'int @add(a, b) "#, -expect![[r#""#]]) + expect![[r#""#]], + ) + } + + #[test] + fn sum_type_unifies_to_superset() { + check( + r"fn test(a in 'sum 1 | 2 | 3) returns 'sum 1 | 2 | 3 a + fn test_(a in 'sum 1 | 2) returns 'sum 1 | 2 a + fn main() returns 'int + {- should be of specific type lit 2 -} + let x = 2; + {- should be of specific type 'sum 1 | 2 -} + y = ~test_(x); + {- should be of specific type 'sum 1 | 2 | 3 -} + z = ~test(y); + {- should also be of specific type 'sum 1 | 2 | 3 -} + zz = ~test(x) + + {- and should generalize to 'int with no problems -} + zz + ", + expect![[r#""#]], + ) + } + + #[test] + fn specific_type_generalizes() { + check( + r#"fn test(a in 'sum 'int | 'string) returns 'sum 'int | 'string a + fn test_(a in 'int) returns 'sum 'int | 'string a + fn main() returns 'int + let x = ~test_(5); + y = ~test("a string"); + 42 + "#, + expect![[r#""#]], + ) + } + + #[test] + fn disallow_bad_generalization() { + check( + r#"fn test(a in 'sum 'int | 'string) returns 'sum 'int | 'string a + fn test_(a in 'bool) returns 'sum 'int | 'string a + fn main() returns 'int + {- we are passing 'bool into 'int | 'string so this should fail to satisfy constraints -} + let y = ~test(~test_(true)); + 42 + "#, + expect![[r#""#]], + ) + } + + #[test] + fn order_of_sum_type_doesnt_matter() { + check( + r#"fn test(a in 'sum 'int | 'string) returns 'sum 'string | 'int a + "#, + expect![[r#""#]], + ) + } + + #[test] + fn can_return_superset() { + check( + r#"fn test(a in 'sum 'int | 'string) returns 'sum 'string | 'int | 'bool a + "#, + expect![[r#""#]], + ) } } From 2c96d0d0f64b5664be66202716a1cdac2650557d Mon Sep 17 00:00:00 2001 From: sezna Date: Wed, 31 Jul 2024 06:29:33 -0700 Subject: [PATCH 03/10] fix up satisfies constraint --- petr-typecheck/src/lib.rs | 135 +++++++++++++++++++++++++++++++------- 1 file changed, 110 insertions(+), 25 deletions(-) diff --git a/petr-typecheck/src/lib.rs b/petr-typecheck/src/lib.rs index 4fc4b1e..d1fa01c 100644 --- a/petr-typecheck/src/lib.rs +++ b/petr-typecheck/src/lib.rs @@ -323,8 +323,15 @@ impl SpecificType { use SpecificType::*; let generalized_b = b.generalize(ctx).safely_upcast(); match (self, b) { + // If `a` is the generalized form of `b`, then `b` satisfies the constraint. (a, b) if a == b || *a == generalized_b => true, + // If `a` is a sum type which contains `b` OR the generalized form of `b`, then `b` + // satisfies the constraint. (Sum(a_tys), b) if a_tys.contains(b) || a_tys.contains(&generalized_b) => true, + // if both `a` and `b` are sum types, then `a` must be a superset of `b`: + // - every element in `b` must either: + // - be a member of `a` + // - generalize to a member of `a` (Sum(a_tys), Sum(b_tys)) => { // if a_tys is a superset of b_tys, // every element OR its generalized version is contained in a_tys @@ -337,7 +344,17 @@ impl SpecificType { true }, - _ => todo!(), + _otherwise => false, + } + } + + /// Use this to construct `[SpecificType::Sum]` types -- + /// it will attempt to collapse the sum into a single type if possible + fn sum(tys: BTreeSet) -> SpecificType { + if tys.len() == 1 { + tys.into_iter().next().expect("invariant") + } else { + SpecificType::Sum(tys) } } } @@ -638,7 +655,7 @@ impl TypeChecker { (other, Sum(sum_tys)) => { // `other` must be a member of the Sum type if !sum_tys.contains(&other) { - self.push_error(span.with_item(self.unify_err(other.clone(), SpecificType::Sum(sum_tys.iter().cloned().collect())))); + self.push_error(span.with_item(self.unify_err(other.clone(), SpecificType::sum(sum_tys.clone())))); } // unify both types to the other type self.ctx.update_type(t2, other); @@ -682,17 +699,11 @@ impl TypeChecker { (Sum(a_tys), Sum(b_tys)) => { // calculate the intersection of these types, update t2 to the intersection let intersection = a_tys.iter().filter(|a_ty| b_tys.contains(a_ty)).cloned().collect(); - self.ctx.update_type(t2, Sum(intersection)); - }, - (ty1 @ Sum(sum_tys), other) => { - if - // if `other` is a generalized version of the sum type, - // then it satisfies the sum type - ty1.is_superset_of(other, &self.ctx) { - } else { - self.push_error(span.with_item(self.satisfy_err(SpecificType::Sum(sum_tys.iter().cloned().collect()), other.clone()))); - } + self.ctx.update_type(t2, SpecificType::sum(intersection)); }, + // if `ty1` is a generalized version of the sum type, + // then it satisfies the sum type + (ty1, other) if ty1.is_superset_of(other, &self.ctx) => (), (Literal(l1), Literal(l2)) if l1 == l2 => (), // Literals can satisfy broader parent types (ty, Literal(lit)) => match (lit, ty) { @@ -1553,8 +1564,12 @@ mod pretty_printing { )); } + if !type_checker.monomorphized_functions.is_empty() { + s.push('\n'); + } + if !type_checker.errors.is_empty() { - s.push_str("\n\n__ERRORS__\n"); + s.push_str("\n__ERRORS__\n"); for error in type_checker.errors { s.push_str(&format!("{:?}\n", error)); } @@ -1778,7 +1793,8 @@ mod tests { function call to functionid2 with args: someField: MyType, returns MyComposedType __MONOMORPHIZED FUNCTIONS__ - fn firstVariant(["MyType"]) -> MyComposedType"#]], + fn firstVariant(["MyType"]) -> MyComposedType + "#]], ); } @@ -1835,7 +1851,8 @@ mod tests { intrinsic: @puts(function call to functionid0 with args: ) __MONOMORPHIZED FUNCTIONS__ - fn string_literal([]) -> string"#]], + fn string_literal([]) -> string + "#]], ); } @@ -1902,6 +1919,7 @@ mod tests { __MONOMORPHIZED FUNCTIONS__ fn bool_literal([]) -> bool + __ERRORS__ SpannedItem UnificationFailure("string", "bool") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(110), length: 14 } }] "#]], @@ -1934,7 +1952,8 @@ mod tests { __MONOMORPHIZED FUNCTIONS__ fn bool_literal(["int", "int"]) -> bool - fn bool_literal(["bool", "bool"]) -> bool"#]], + fn bool_literal(["bool", "bool"]) -> bool + "#]], ); } #[test] @@ -1999,7 +2018,8 @@ fn main() returns 'int ~hi(1, 2)"#, __MONOMORPHIZED FUNCTIONS__ fn hi(["int", "int"]) -> int - fn main([]) -> int"#]], + fn main([]) -> int + "#]], ) } @@ -2020,6 +2040,7 @@ fn main() returns 'int ~hi(1, 2)"#, __MONOMORPHIZED FUNCTIONS__ fn hi(["int"]) -> int fn main([]) -> int + __ERRORS__ SpannedItem UnificationFailure("int", "bool") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(61), length: 2 } }] "#]], @@ -2043,6 +2064,7 @@ fn main() returns 'int ~hi(1, 2)"#, __MONOMORPHIZED FUNCTIONS__ fn hi([]) -> int fn main([]) -> int + __ERRORS__ SpannedItem UnificationFailure("unit", "1") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(33), length: 46 } }] "#]], @@ -2066,7 +2088,8 @@ fn main() returns 'int ~hi(1, 2)"#, __MONOMORPHIZED FUNCTIONS__ fn hi([]) -> unit - fn main([]) -> unit"#]], + fn main([]) -> unit + "#]], ) } @@ -2091,8 +2114,9 @@ fn main() returns 'int ~hi(1, 2)"#, __MONOMORPHIZED FUNCTIONS__ fn OneOrTwo(["int"]) -> OneOrTwo fn main([]) -> OneOrTwo + __ERRORS__ - SpannedItem FailedToSatisfy("10", "(1 | 2)") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(104), length: 0 } }] + SpannedItem NotSubtype(["1", "2"], "10") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(104), length: 0 } }] "#]], ) } @@ -2118,8 +2142,9 @@ fn main() returns 'int ~hi(1, 2)"#, __MONOMORPHIZED FUNCTIONS__ fn AOrB(["string"]) -> AOrB fn main([]) -> AOrB + __ERRORS__ - SpannedItem FailedToSatisfy("\"c\"", "(\"A\" | \"B\")") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(97), length: 0 } }] + SpannedItem NotSubtype(["\"A\"", "\"B\""], "\"c\"") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(97), length: 0 } }] "#]], ) } @@ -2198,7 +2223,25 @@ fn main() returns 'int ~hi(1, 2)"#, {- and should generalize to 'int with no problems -} zz ", - expect![[r#""#]], + expect![[r#" + fn test: ((1 | 2 | 3) → (1 | 2 | 3)) + variable a: (1 | 2 | 3) + + fn test_: ((1 | 2) → (1 | 2)) + variable a: (1 | 2) + + fn main: int + x: literal: 2 (2), + y: function call to functionid1 with args: symbolid1: variable: symbolid5, ((1 | 2)), + z: function call to functionid0 with args: symbolid1: variable: symbolid6, ((1 | 2 | 3)), + zz: function call to functionid0 with args: symbolid1: variable: symbolid5, ((1 | 2 | 3)), + "variable zz: (1 | 2 | 3)" ((1 | 2 | 3)) + + __MONOMORPHIZED FUNCTIONS__ + fn test(["int"]) -> (1 | 2 | 3) + fn test_(["int"]) -> (1 | 2) + fn main([]) -> int + "#]], ) } @@ -2212,7 +2255,23 @@ fn main() returns 'int ~hi(1, 2)"#, y = ~test("a string"); 42 "#, - expect![[r#""#]], + expect![[r#" + fn test: ((int | string) → (int | string)) + variable a: (int | string) + + fn test_: (int → (int | string)) + variable a: int + + fn main: int + x: function call to functionid1 with args: symbolid1: literal: 5, ((int | string)), + y: function call to functionid0 with args: symbolid1: literal: "a string", ((int | string)), + "literal: 42" (42) + + __MONOMORPHIZED FUNCTIONS__ + fn test(["string"]) -> (int | string) + fn test_(["int"]) -> (int | string) + fn main([]) -> int + "#]], ) } @@ -2226,7 +2285,25 @@ fn main() returns 'int ~hi(1, 2)"#, let y = ~test(~test_(true)); 42 "#, - expect![[r#""#]], + expect![[r#" + fn test: ((int | string) → (int | string)) + variable a: (int | string) + + fn test_: (bool → (int | string)) + variable a: bool + + fn main: int + y: function call to functionid0 with args: symbolid1: function call to functionid1 with args: symbolid1: literal: true, , ((int | string)), + "literal: 42" (42) + + __MONOMORPHIZED FUNCTIONS__ + fn test(["(int | string)"]) -> (int | string) + fn test_(["bool"]) -> (int | string) + fn main([]) -> int + + __ERRORS__ + SpannedItem NotSubtype(["int", "string"], "bool") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(129), length: 0 } }] + "#]], ) } @@ -2235,7 +2312,11 @@ fn main() returns 'int ~hi(1, 2)"#, check( r#"fn test(a in 'sum 'int | 'string) returns 'sum 'string | 'int a "#, - expect![[r#""#]], + expect![[r#" + fn test: ((int | string) → (int | string)) + variable a: (int | string) + + "#]], ) } @@ -2244,7 +2325,11 @@ fn main() returns 'int ~hi(1, 2)"#, check( r#"fn test(a in 'sum 'int | 'string) returns 'sum 'string | 'int | 'bool a "#, - expect![[r#""#]], + expect![[r#" + fn test: ((int | string) → (int | bool | string)) + variable a: (int | string) + + "#]], ) } } From 008a217c81e7d22a3cc3d99b1f29aaa26dd52942 Mon Sep 17 00:00:00 2001 From: sezna Date: Wed, 31 Jul 2024 06:30:43 -0700 Subject: [PATCH 04/10] fmt --- petr-fmt/src/tests.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/petr-fmt/src/tests.rs b/petr-fmt/src/tests.rs index 8ed6f8a..aa591a0 100644 --- a/petr-fmt/src/tests.rs +++ b/petr-fmt/src/tests.rs @@ -636,10 +636,14 @@ fn let_bindings_no_trailing_comma() { #[test] fn sum_ty_formatting() { - check(Default::default(), "fn myFunc(x in 'sum 1 | 2 | 3) returns 'int 5", expect![[r#" + check( + Default::default(), + "fn myFunc(x in 'sum 1 | 2 | 3) returns 'int 5", + expect![[r#" fn myFunc( x ∈ 'Σ 1 | 2 | 3, ) → 'int 5 - "#]]) + "#]], + ) } From 5487f5908dfd78069cf5a7ff1da17a2b79766104 Mon Sep 17 00:00:00 2001 From: sezna Date: Wed, 31 Jul 2024 07:09:08 -0700 Subject: [PATCH 05/10] something is terribly wrong --- petr-typecheck/src/error.rs | 3 ++ petr-typecheck/src/lib.rs | 98 ++++++++++++++++++++++++++++++++++++- 2 files changed, 100 insertions(+), 1 deletion(-) diff --git a/petr-typecheck/src/error.rs b/petr-typecheck/src/error.rs index eb3742f..0103423 100644 --- a/petr-typecheck/src/error.rs +++ b/petr-typecheck/src/error.rs @@ -15,4 +15,7 @@ pub enum TypeConstraintError { UnknownInference, #[error("internal compiler error: {0}")] Internal(String), + // TODO better errors here + #[error("This type references itself in a circular way")] + CircularType, } diff --git a/petr-typecheck/src/lib.rs b/petr-typecheck/src/lib.rs index d1fa01c..676370d 100644 --- a/petr-typecheck/src/lib.rs +++ b/petr-typecheck/src/lib.rs @@ -99,6 +99,12 @@ pub enum TypeConstraintKind { Satisfies(TypeVariable, TypeVariable), } +#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq)] +enum TypeConstraintKindValue { + Unify, + Satisfies, +} + pub struct TypeContext { types: IndexMap, constraints: Vec, @@ -542,6 +548,11 @@ impl TypeChecker { call.type_check(self); } + // before applying existing constraints, it is likely that many duplicate constraints + // exist. We can safely remove any duplicate constraints to avoid excessive error + // reporting. + self.deduplicate_constraints(); + // we have now collected our constraints and can solve for them self.apply_constraints(); } @@ -984,6 +995,90 @@ impl TypeChecker { pub fn ctx(&self) -> &TypeContext { &self.ctx } + + /// terms: + /// ### resolved type variable + /// + /// a type variable that is not a `Ref`. To get the resolved type of + /// a type variable, you must follow the chain of `Ref`s until you reach a non-Ref type. + /// + /// ### constraint kind strength: + /// The following is the hierarchy of constraints in terms of strength, from strongest (1) to + /// weakest: + /// 1. Unify(t1, t2) (t2 _must_ be coerceable to exactly equal t1) + /// 2. Satisfies (t2 must be a subset of t1. For all cases where t2 can unify to t1, t2 + /// satisfies t1 as a constraint) + /// + /// ### constraint strength + /// A constraint `a` is _stronger than_ a constraint `b` iff: + /// - `a` is higher than `b` in terms of constraint kind strength `a` is a more specific constraint than `b` + /// - e.g. Unify(Literal(5), x) is stronger than Unify(Int, x) because the former is more specific + /// - e.g. Unify(a, b) is stronger than Satisfies(a, b) + /// + /// + /// ### duplicated constraint: + /// A constraint `a` is _duplicated by_ constraint `b` iff: + /// - `a` and `b` are the same constraint kind, and the resolved type variables are the same + /// - `a` is a stronger constraint than `b` + /// + fn deduplicate_constraints(&mut self) { + use TypeConstraintKindValue as Kind; + let mut constraints = ConstraintDeduplicator::default(); + let mut errs = vec![]; + for constraint in &self.ctx.constraints { + println!("on constraint: {:?}", constraint); + let (mut tys, kind) = match &constraint.kind { + TypeConstraintKind::Unify(t1, t2) => (vec![*t1, *t2], Kind::Unify), + TypeConstraintKind::Satisfies(t1, t2) => (vec![*t1, *t2], Kind::Satisfies), + }; + + // resolve all `Ref` types to get a resolved type variable + 'outer: for ty_var in tys.iter_mut() { + // track what we have seen, in case a circular reference is present + let mut seen_vars = BTreeSet::new(); + seen_vars.insert(*ty_var); + let ty = self.ctx.types.get(*ty_var); + while let SpecificType::Ref(t) = ty { + if seen_vars.insert(*t) { + *ty_var = *t; + } else { + // circular reference + errs.push(constraint.span.with_item(TypeConstraintError::CircularType)); + continue 'outer; + } + *ty_var = *t; + } + } + + constraints.insert((kind, tys), *constraint); + } + + for err in errs { + self.push_error(err); + } + + self.ctx.constraints = constraints.into_values(); + } +} + +/// the `key` type is what we use to deduplicate constraints +#[derive(Default)] +struct ConstraintDeduplicator { + constraints: BTreeMap<(TypeConstraintKindValue, Vec), TypeConstraint>, +} + +impl ConstraintDeduplicator { + fn insert( + &mut self, + key: (TypeConstraintKindValue, Vec), + constraint: TypeConstraint, + ) { + self.constraints.insert(key, constraint); + } + + fn into_values(self) -> Vec { + self.constraints.into_values().collect() + } } #[derive(Clone)] @@ -1309,7 +1404,7 @@ impl TypeCheck for SpannedItem { let arg = self.item().args[0].type_check(ctx); let arg_ty = ctx.expr_ty(&arg); let int_ty = ctx.int(); - ctx.unify(arg_ty, int_ty, arg.span()); + ctx.unify(int_ty, arg_ty, arg.span()); TypedExprKind::Intrinsic { intrinsic: Intrinsic::Malloc(Box::new(arg)), ty: int_ty, @@ -1583,6 +1678,7 @@ mod pretty_printing { ) -> String { let mut ty = type_checker.look_up_variable(*ty); while let SpecificType::Ref(t) = ty { + println!("looping"); ty = type_checker.look_up_variable(*t); } pretty_print_petr_type(ty, type_checker) From ee2761313f14368c9d8a992eb95f02c8a7f5c373 Mon Sep 17 00:00:00 2001 From: sezna Date: Fri, 2 Aug 2024 09:07:50 -0700 Subject: [PATCH 06/10] wip: begin refactor into separate type solution stage --- petr-typecheck/src/error.rs | 2 + petr-typecheck/src/lib.rs | 595 ++++++++++++++++++++++-------------- 2 files changed, 364 insertions(+), 233 deletions(-) diff --git a/petr-typecheck/src/error.rs b/petr-typecheck/src/error.rs index 0103423..f4f31d0 100644 --- a/petr-typecheck/src/error.rs +++ b/petr-typecheck/src/error.rs @@ -18,4 +18,6 @@ pub enum TypeConstraintError { // TODO better errors here #[error("This type references itself in a circular way")] CircularType, + #[error("Type {1} is not castable to type {0}")] + InvalidTypeUpdate(String, String), } diff --git a/petr-typecheck/src/lib.rs b/petr-typecheck/src/lib.rs index 676370d..40ca5dc 100644 --- a/petr-typecheck/src/lib.rs +++ b/petr-typecheck/src/lib.rs @@ -17,7 +17,7 @@ use error::TypeConstraintError; pub use petr_bind::FunctionId; use petr_resolve::{Expr, ExprKind, QueryableResolvedItems}; pub use petr_resolve::{Intrinsic as ResolvedIntrinsic, IntrinsicName, Literal}; -use petr_utils::{idx_map_key, Identifier, IndexMap, Span, SpannedItem, SymbolId, TypeId}; +use petr_utils::{idx_map_key, Identifier, IndexMap, Span, SpannedItem, SymbolId, SymbolInterner, TypeId}; pub type TypeError = SpannedItem; pub type TResult = Result; @@ -25,9 +25,12 @@ pub type TResult = Result; // TODO return QueryableTypeChecked instead of type checker // Clean up API so this is the only function exposed pub fn type_check(resolved: QueryableResolvedItems) -> (Vec, TypeChecker) { - let mut type_checker = TypeChecker::new(resolved); + todo!("design new api") + /* + let solution = TypeChecker::new(resolved); type_checker.fully_type_check(); (type_checker.errors.clone(), type_checker) + */ } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] @@ -97,12 +100,16 @@ pub enum TypeConstraintKind { Unify(TypeVariable, TypeVariable), // constraint that lhs is a "subtype" or satisfies the typeclass constraints of "rhs" Satisfies(TypeVariable, TypeVariable), + // If a type variable is constrained to be an axiom, it means that the type variable + // cannot be updated by the inference engine. It effectively fixes the type, or pins the type. + Axiom(TypeVariable), } #[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq)] enum TypeConstraintKindValue { Unify, Satisfies, + Axiom, } pub struct TypeContext { @@ -167,7 +174,7 @@ impl TypeContext { self.types.insert(SpecificType::Infer(infer_id, span)) } - /// Update a type variable with a new PetrType + /// Update a type variable with a new SpecificType fn update_type( &mut self, t1: TypeVariable, @@ -227,6 +234,301 @@ impl GeneralType { } } +/// Represents the result of the type-checking stage for an individual type variable. +pub struct TypeSolutionEntry { + axiomatic: bool, + ty: SpecificType, +} + +impl TypeSolutionEntry { + pub fn new_axiomatic(ty: SpecificType) -> Self { + Self { axiomatic: true, ty } + } +} + +pub struct TypeSolution { + solution: BTreeMap, + unsolved_types: IndexMap, + errors: Vec, + interner: SymbolInterner, +} + +impl TypeSolution { + pub fn new( + unsolved_types: IndexMap, + interner: SymbolInterner, + ) -> Self { + Self { + solution: Default::default(), + unsolved_types, + errors: Default::default(), + interner, + } + } + + fn push_error( + &mut self, + e: TypeError, + ) { + self.errors.push(e); + } + + pub fn insert_solution( + &mut self, + ty: TypeVariable, + entry: TypeSolutionEntry, + span: Span, + ) { + if self.solution.contains_key(&ty) { + self.update_type(ty, entry, span); + return; + } + self.solution.insert(ty, entry); + } + + fn pretty_print_type( + &self, + ty: &SpecificType, + ) -> String { + pretty_printing::pretty_print_petr_type(&ty, &self.unsolved_types, &self.interner) + } + + fn unify_err( + &self, + clone_1: SpecificType, + clone_2: SpecificType, + ) -> TypeConstraintError { + let pretty_printed_b = self.pretty_print_type(&clone_2); + match clone_1 { + SpecificType::Sum(tys) => { + let tys = tys.iter().map(|ty| self.pretty_print_type(ty)).collect::>(); + TypeConstraintError::NotSubtype(tys, pretty_printed_b) + }, + _ => { + let pretty_printed_a = self.pretty_print_type(&clone_1); + TypeConstraintError::UnificationFailure(pretty_printed_a, pretty_printed_b) + }, + } + } + + fn satisfy_err( + &self, + clone_1: SpecificType, + clone_2: SpecificType, + ) -> TypeConstraintError { + let pretty_printed_b = self.pretty_print_type(&clone_2); + match clone_1 { + SpecificType::Sum(tys) => { + let tys = tys.iter().map(|ty| self.pretty_print_type(&ty)).collect::>(); + TypeConstraintError::NotSubtype(tys, pretty_printed_b) + }, + _ => { + let pretty_printed_a = self.pretty_print_type(&clone_1); + TypeConstraintError::FailedToSatisfy(pretty_printed_a, pretty_printed_b) + }, + } + } + + pub fn update_type( + &mut self, + ty: TypeVariable, + entry: TypeSolutionEntry, + span: Span, + ) { + match self.solution.get_mut(&ty) { + Some(e) => { + if e.axiomatic { + let pretty_printed_preexisting = pretty_printing::pretty_print_petr_type(&entry.ty, &self.unsolved_types, &self.interner); + let pretty_printed_ty = pretty_printing::pretty_print_petr_type(&entry.ty, &self.unsolved_types, &self.interner); + self.errors + .push(span.with_item(TypeConstraintError::InvalidTypeUpdate(pretty_printed_preexisting, pretty_printed_ty))); + return; + } + *e = entry; + }, + None => { + self.errors.push(span.with_item(TypeConstraintError::Internal( + "attempted to update type that did not exist in solution".into(), + ))); + }, + } + } + + fn into_result(self) -> Result>> { + if self.errors.is_empty() { + Ok(self) + } else { + Err(self.errors) + } + } + + /// Attempt to unify two types, returning an error if they cannot be unified + /// The more specific of the two types will instantiate the more general of the two types. + /// + /// TODO: The unify constraint should attempt to upcast `t2` as `t1` if possible, but will never + /// downcast `t1` as `t2`. This is not currently how it works and needs investigation. + fn apply_unify_constraint( + &mut self, + t1: TypeVariable, + t2: TypeVariable, + span: Span, + ) { + let ty1 = self.get_latest_type(t1).clone(); + let ty2 = self.get_latest_type(t2).clone(); + use SpecificType::*; + match (ty1, ty2) { + (a, b) if a == b => (), + (ErrorRecovery, _) | (_, ErrorRecovery) => (), + (Ref(a), _) => self.apply_unify_constraint(a, t2, span), + (_, Ref(b)) => self.apply_unify_constraint(t1, b, span), + (Infer(id, _), Infer(id2, _)) if id != id2 => { + // if two different inferred types are unified, replace the second with a reference + // to the first + self.update_type(t2, Ref(t1)); + }, + (a @ Sum(_), b @ Sum(_)) => { + // the unification of two sum types is the union of the two types if and only if + // `t2` is a total subset of `t1` + // `t1` remains unchanged, as we are trying to coerce `t2` into something that + // represents `t1` + // TODO remove clone + if a.is_superset_of(&b, &self.ctx) { + } else { + self.push_error(span.with_item(self.unify_err(a, b))); + } + }, + // If `t2` is a non-sum type, and `t1` is a sum type, then `t1` must contain either + // exactly the same specific type OR the generalization of that type + // If the latter, then the specific type must be updated to its generalization + (Sum(sum_tys), other) => { + if sum_tys.contains(&other) { + self.ctx.update_type(t2, Ref(t1)); + } else { + let generalization = other.generalize(&self.ctx).safely_upcast(); + if sum_tys.contains(&generalization) { + self.ctx.update_type(t2, generalization); + } else { + self.push_error(span.with_item(self.unify_err(Sum(sum_tys.clone()), generalization))); + } + } + }, + // literals can unify to each other if they're equal + (Literal(l1), Literal(l2)) if l1 == l2 => (), + (Literal(l1), Literal(l2)) if l1 != l2 => { + // update t1 to a sum type of both, + // and update t2 to reference t1 + let sum = Sum([Literal(l1), Literal(l2)].into()); + self.ctx.update_type(t1, sum); + self.ctx.update_type(t2, Ref(t1)); + }, + (Literal(l1), Sum(tys)) => { + // update t1 to a sum type of both, + // and update t2 to reference t1 + let sum = Sum([Literal(l1)].iter().chain(tys.iter()).cloned().collect()); + self.ctx.update_type(t1, sum); + self.ctx.update_type(t2, Ref(t1)); + }, + // literals can unify broader parent types + // but the broader parent type gets instantiated with the literal type + (ty, Literal(lit)) => match (&lit, ty) { + (petr_resolve::Literal::Integer(_), Integer) + | (petr_resolve::Literal::Boolean(_), Boolean) + | (petr_resolve::Literal::String(_), String) => self.ctx.update_type(t1, SpecificType::Literal(lit)), + (lit, ty) => self.push_error(span.with_item(self.unify_err(ty.clone(), SpecificType::Literal(lit.clone())))), + }, + // literals can unify broader parent types + // but the broader parent type gets instantiated with the literal type + (Literal(lit), ty) => match (&lit, ty) { + (petr_resolve::Literal::Integer(_), Integer) + | (petr_resolve::Literal::Boolean(_), Boolean) + | (petr_resolve::Literal::String(_), String) => self.ctx.update_type(t2, SpecificType::Literal(lit)), + (lit, ty) => { + self.push_error(span.with_item(self.unify_err(ty.clone(), SpecificType::Literal(lit.clone())))); + }, + }, + (other, Sum(sum_tys)) => { + // `other` must be a member of the Sum type + if !sum_tys.contains(&other) { + self.push_error(span.with_item(self.unify_err(other.clone(), SpecificType::sum(sum_tys.clone())))); + } + // unify both types to the other type + self.ctx.update_type(t2, other); + }, + // instantiate the infer type with the known type + (Infer(_, _), known) => { + self.ctx.update_type(t1, known); + }, + (known, Infer(_, _)) => { + self.ctx.update_type(t2, known); + }, + // lastly, if no unification rule exists for these two types, it is a mismatch + (a, b) => { + self.push_error(span.with_item(self.unify_err(a, b))); + }, + } + } + + // This function will need to be rewritten when type constraints and bounded polymorphism are + // implemented. + fn apply_satisfies_constraint( + &mut self, + t1: TypeVariable, + t2: TypeVariable, + span: Span, + ) { + let ty1 = self.ctx.types.get(t1); + let ty2 = self.ctx.types.get(t2); + use SpecificType::*; + match (ty1, ty2) { + (a, b) if a == b => (), + (ErrorRecovery, _) | (_, ErrorRecovery) => (), + (Ref(a), _) => self.apply_satisfies_constraint(*a, t2, span), + (_, Ref(b)) => self.apply_satisfies_constraint(t1, *b, span), + // if t1 is a fully instantiated type, then t2 can be updated to be a reference to t1 + (Unit | Integer | Boolean | UserDefined { .. } | String | Arrow(..) | List(..) | Literal(_) | Sum(_), Infer(_, _)) => { + self.ctx.update_type(t2, Ref(t1)); + }, + // the "parent" infer type will not instantiate to the "child" type + (Infer(_, _), Unit | Integer | Boolean | UserDefined { .. } | String | Arrow(..) | List(..) | Literal(_) | Sum(_)) => (), + (Sum(a_tys), Sum(b_tys)) => { + // calculate the intersection of these types, update t2 to the intersection + let intersection = a_tys.iter().filter(|a_ty| b_tys.contains(a_ty)).cloned().collect(); + self.ctx.update_type(t2, SpecificType::sum(intersection)); + }, + // if `ty1` is a generalized version of the sum type, + // then it satisfies the sum type + (ty1, other) if ty1.is_superset_of(other, &self.ctx) => (), + (Literal(l1), Literal(l2)) if l1 == l2 => (), + // Literals can satisfy broader parent types + (ty, Literal(lit)) => match (lit, ty) { + (petr_resolve::Literal::Integer(_), Integer) => (), + (petr_resolve::Literal::Boolean(_), Boolean) => (), + (petr_resolve::Literal::String(_), String) => (), + (lit, ty) => { + self.push_error(span.with_item(self.satisfy_err(ty.clone(), SpecificType::Literal(lit.clone())))); + }, + }, + // if we are trying to satisfy an inferred type with no bounds, this is ok + (Infer(..), _) => (), + (a, b) => { + self.push_error(span.with_item(self.satisfy_err(a.clone(), b.clone()))); + }, + } + } + + /// Gets the latest version of a type available. First checks solved types, + /// and if it doesn't exist, gets it from the unsolved types. + fn get_latest_type( + &self, + t1: TypeVariable, + ) -> SpecificType { + self.solution + .get(&t1) + .map(|entry| entry.ty.clone()) + .unwrap_or_else(|| self.unsolved_types.get(t1).clone()) + } +} + /// This is an information-rich type -- it tracks effects and data types. It is used for /// the type-checking stage to provide rich information to the user. /// Types are generalized into instances of [`GeneralType`] for monomorphization and @@ -327,7 +629,10 @@ impl SpecificType { ctx: &TypeContext, ) -> bool { use SpecificType::*; + dbg!(&self); + dbg!(&b); let generalized_b = b.generalize(ctx).safely_upcast(); + dbg!(&generalized_b); match (self, b) { // If `a` is the generalized form of `b`, then `b` satisfies the constraint. (a, b) if a == b || *a == generalized_b => true, @@ -503,7 +808,7 @@ impl TypeChecker { None } - fn fully_type_check(&mut self) { + fn fully_type_check(mut self) -> Result> { for (id, decl) in self.resolved.types() { let ty = self.fresh_ty_var(decl.name.span); let variants = decl @@ -530,7 +835,7 @@ impl TypeChecker { } for (id, func) in self.resolved.functions() { - let typed_function = func.type_check(self); + let typed_function = func.type_check(&mut self); let ty = self.arrow_type([typed_function.params.iter().map(|(_, b)| *b).collect(), vec![typed_function.return_ty]].concat()); self.type_map.insert(id.into(), ty); @@ -545,7 +850,7 @@ impl TypeChecker { args: vec![], span: func.name.span, }; - call.type_check(self); + call.type_check(&mut self); } // before applying existing constraints, it is likely that many duplicate constraints @@ -554,7 +859,7 @@ impl TypeChecker { self.deduplicate_constraints(); // we have now collected our constraints and can solve for them - self.apply_constraints(); + self.into_solution() } pub fn get_main_function(&self) -> Option<(FunctionId, Function)> { @@ -565,9 +870,24 @@ impl TypeChecker { /// - unification tries to collapse two types into one /// - satisfaction tries to make one type satisfy the constraints of another, although type /// constraints don't exist in the language yet - fn apply_constraints(&mut self) { + fn into_solution(self) -> Result> { let constraints = self.ctx.constraints.clone(); - for constraint in constraints { + let mut solution = TypeSolution::new(self.ctx.types.clone(), self.resolved.interner); + for TypeConstraint { kind, span } in constraints + .iter() + .filter(|c| if let TypeConstraintKind::Axiom(_) = c.kind { true } else { false }) + { + let TypeConstraintKind::Axiom(axiomatic_variable) = kind else { + unreachable!("above filter ensures that all constraints are axioms here") + }; + // first, pin all axiomatic type variables in the solution + let ty = self.ctx.types.get(*axiomatic_variable).clone(); + solution.insert_solution(*axiomatic_variable, TypeSolutionEntry::new_axiomatic(ty), *span); + } + + /* + // now apply the constraints + for constraint in constraints.iter().filter(|c| !matches!(c.kind, TypeConstraintKind::Axiom(_))) { match &constraint.kind { TypeConstraintKind::Unify(t1, t2) => { self.apply_unify_constraint(*t1, *t2, constraint.span); @@ -575,167 +895,17 @@ impl TypeChecker { TypeConstraintKind::Satisfies(t1, t2) => { self.apply_satisfies_constraint(*t1, *t2, constraint.span); }, + TypeConstraintKind::Axiom(_) => unreachable!(), } } - } - - /// Attempt to unify two types, returning an error if they cannot be unified - /// The more specific of the two types will instantiate the more general of the two types. - /// - /// TODO: The unify constraint should attempt to upcast `t2` as `t1` if possible, but will never - /// downcast `t1` as `t2`. This is not currently how it works and needs investigation. - fn apply_unify_constraint( - &mut self, - t1: TypeVariable, - t2: TypeVariable, - span: Span, - ) { - let ty1 = self.ctx.types.get(t1).clone(); - let ty2 = self.ctx.types.get(t2).clone(); - use SpecificType::*; - match (ty1, ty2) { - (a, b) if a == b => (), - (ErrorRecovery, _) | (_, ErrorRecovery) => (), - (Ref(a), _) => self.apply_unify_constraint(a, t2, span), - (_, Ref(b)) => self.apply_unify_constraint(t1, b, span), - (Infer(id, _), Infer(id2, _)) if id != id2 => { - // if two different inferred types are unified, replace the second with a reference - // to the first - self.ctx.update_type(t2, Ref(t1)); - }, - (a @ Sum(_), b @ Sum(_)) => { - // the unification of two sum types is the union of the two types if and only if - // `t2` is a total subset of `t1` - // `t1` remains unchanged, as we are trying to coerce `t2` into something that - // represents `t1` - // TODO remove clone - if a.is_superset_of(&b, &self.ctx) { - } else { - self.push_error(span.with_item(self.unify_err(a, b))); - } - }, - // If `t2` is a non-sum type, and `t1` is a sum type, then `t1` must contain either - // exactly the same specific type OR the generalization of that type - // If the latter, then the specific type must be updated to its generalization - (Sum(sum_tys), other) => { - if sum_tys.contains(&other) { - self.ctx.update_type(t2, Ref(t1)); - } else { - let generalization = other.generalize(&self.ctx).safely_upcast(); - if sum_tys.contains(&generalization) { - self.ctx.update_type(t2, generalization); - } else { - self.push_error(span.with_item(self.unify_err(Sum(sum_tys.clone()), generalization))); - } - } - }, - // literals can unify to each other if they're equal - (Literal(l1), Literal(l2)) if l1 == l2 => (), - (Literal(l1), Literal(l2)) if l1 != l2 => { - // update t1 to a sum type of both, - // and update t2 to reference t1 - let sum = Sum([Literal(l1), Literal(l2)].into()); - self.ctx.update_type(t1, sum); - self.ctx.update_type(t2, Ref(t1)); - }, - (Literal(l1), Sum(tys)) => { - // update t1 to a sum type of both, - // and update t2 to reference t1 - let sum = Sum([Literal(l1)].iter().chain(tys.iter()).cloned().collect()); - self.ctx.update_type(t1, sum); - self.ctx.update_type(t2, Ref(t1)); - }, - // literals can unify broader parent types - // but the broader parent type gets instantiated with the literal type - (ty, Literal(lit)) => match (&lit, ty) { - (petr_resolve::Literal::Integer(_), Integer) - | (petr_resolve::Literal::Boolean(_), Boolean) - | (petr_resolve::Literal::String(_), String) => self.ctx.update_type(t1, SpecificType::Literal(lit)), - (lit, ty) => self.push_error(span.with_item(self.unify_err(ty.clone(), SpecificType::Literal(lit.clone())))), - }, - // literals can unify broader parent types - // but the broader parent type gets instantiated with the literal type - (Literal(lit), ty) => match (&lit, ty) { - (petr_resolve::Literal::Integer(_), Integer) - | (petr_resolve::Literal::Boolean(_), Boolean) - | (petr_resolve::Literal::String(_), String) => self.ctx.update_type(t2, SpecificType::Literal(lit)), - (lit, ty) => { - self.push_error(span.with_item(self.unify_err(ty.clone(), SpecificType::Literal(lit.clone())))); - }, - }, - (other, Sum(sum_tys)) => { - // `other` must be a member of the Sum type - if !sum_tys.contains(&other) { - self.push_error(span.with_item(self.unify_err(other.clone(), SpecificType::sum(sum_tys.clone())))); - } - // unify both types to the other type - self.ctx.update_type(t2, other); - }, - // instantiate the infer type with the known type - (Infer(_, _), known) => { - self.ctx.update_type(t1, known); - }, - (known, Infer(_, _)) => { - self.ctx.update_type(t2, known); - }, - // lastly, if no unification rule exists for these two types, it is a mismatch - (a, b) => { - self.push_error(span.with_item(self.unify_err(a, b))); - }, - } - } + */ - // This function will need to be rewritten when type constraints and bounded polymorphism are - // implemented. - fn apply_satisfies_constraint( - &mut self, - t1: TypeVariable, - t2: TypeVariable, - span: Span, - ) { - let ty1 = self.ctx.types.get(t1); - let ty2 = self.ctx.types.get(t2); - use SpecificType::*; - match (ty1, ty2) { - (a, b) if a == b => (), - (ErrorRecovery, _) | (_, ErrorRecovery) => (), - (Ref(a), _) => self.apply_satisfies_constraint(*a, t2, span), - (_, Ref(b)) => self.apply_satisfies_constraint(t1, *b, span), - // if t1 is a fully instantiated type, then t2 can be updated to be a reference to t1 - (Unit | Integer | Boolean | UserDefined { .. } | String | Arrow(..) | List(..) | Literal(_) | Sum(_), Infer(_, _)) => { - self.ctx.update_type(t2, Ref(t1)); - }, - // the "parent" infer type will not instantiate to the "child" type - (Infer(_, _), Unit | Integer | Boolean | UserDefined { .. } | String | Arrow(..) | List(..) | Literal(_) | Sum(_)) => (), - (Sum(a_tys), Sum(b_tys)) => { - // calculate the intersection of these types, update t2 to the intersection - let intersection = a_tys.iter().filter(|a_ty| b_tys.contains(a_ty)).cloned().collect(); - self.ctx.update_type(t2, SpecificType::sum(intersection)); - }, - // if `ty1` is a generalized version of the sum type, - // then it satisfies the sum type - (ty1, other) if ty1.is_superset_of(other, &self.ctx) => (), - (Literal(l1), Literal(l2)) if l1 == l2 => (), - // Literals can satisfy broader parent types - (ty, Literal(lit)) => match (lit, ty) { - (petr_resolve::Literal::Integer(_), Integer) => (), - (petr_resolve::Literal::Boolean(_), Boolean) => (), - (petr_resolve::Literal::String(_), String) => (), - (lit, ty) => { - self.push_error(span.with_item(self.satisfy_err(ty.clone(), SpecificType::Literal(lit.clone())))); - }, - }, - // if we are trying to satisfy an inferred type with no bounds, this is ok - (Infer(..), _) => (), - (a, b) => { - self.push_error(span.with_item(self.satisfy_err(a.clone(), b.clone()))); - }, - } + solution.into_result() } - pub fn new(resolved: QueryableResolvedItems) -> Self { + pub fn new(resolved: QueryableResolvedItems) -> Result> { let ctx = TypeContext::default(); - let mut type_checker = TypeChecker { + let type_checker = TypeChecker { ctx, type_map: Default::default(), errors: Default::default(), @@ -745,8 +915,7 @@ impl TypeChecker { monomorphized_functions: Default::default(), }; - type_checker.fully_type_check(); - type_checker + type_checker.fully_type_check() } pub fn insert_variable( @@ -828,13 +997,6 @@ impl TypeChecker { self.ctx.types.insert(ty) } - fn push_error( - &mut self, - e: TypeError, - ) { - self.errors.push(e); - } - pub fn unify( &mut self, ty1: TypeVariable, @@ -906,7 +1068,7 @@ impl TypeChecker { } } - /// Given a concrete [`PetrType`], unify it with the return type of the given expression. + /// Given a concrete [`SpecificType`], unify it with the return type of the given expression. pub fn unify_expr_return( &mut self, ty: TypeVariable, @@ -947,42 +1109,6 @@ impl TypeChecker { &self.errors } - fn unify_err( - &self, - clone_1: SpecificType, - clone_2: SpecificType, - ) -> TypeConstraintError { - let pretty_printed_b = pretty_printing::pretty_print_petr_type(&clone_2, self); - match clone_1 { - SpecificType::Sum(tys) => { - let tys = tys.iter().map(|ty| pretty_printing::pretty_print_petr_type(ty, self)).collect::>(); - TypeConstraintError::NotSubtype(tys, pretty_printed_b) - }, - _ => { - let pretty_printed_a = pretty_printing::pretty_print_petr_type(&clone_1, self); - TypeConstraintError::UnificationFailure(pretty_printed_a, pretty_printed_b) - }, - } - } - - fn satisfy_err( - &self, - clone_1: SpecificType, - clone_2: SpecificType, - ) -> TypeConstraintError { - let pretty_printed_b = pretty_printing::pretty_print_petr_type(&clone_2, self); - match clone_1 { - SpecificType::Sum(tys) => { - let tys = tys.iter().map(|ty| pretty_printing::pretty_print_petr_type(ty, self)).collect::>(); - TypeConstraintError::NotSubtype(tys, pretty_printed_b) - }, - _ => { - let pretty_printed_a = pretty_printing::pretty_print_petr_type(&clone_1, self); - TypeConstraintError::FailedToSatisfy(pretty_printed_a, pretty_printed_b) - }, - } - } - fn satisfy_expr_return( &mut self, ty: TypeVariable, @@ -1026,7 +1152,7 @@ impl TypeChecker { let mut constraints = ConstraintDeduplicator::default(); let mut errs = vec![]; for constraint in &self.ctx.constraints { - println!("on constraint: {:?}", constraint); + //println!("on constraint: {:?}", constraint); let (mut tys, kind) = match &constraint.kind { TypeConstraintKind::Unify(t1, t2) => (vec![*t1, *t2], Kind::Unify), TypeConstraintKind::Satisfies(t1, t2) => (vec![*t1, *t2], Kind::Satisfies), @@ -1037,16 +1163,15 @@ impl TypeChecker { // track what we have seen, in case a circular reference is present let mut seen_vars = BTreeSet::new(); seen_vars.insert(*ty_var); - let ty = self.ctx.types.get(*ty_var); + let mut ty = self.ctx.types.get(*ty_var); while let SpecificType::Ref(t) = ty { - if seen_vars.insert(*t) { - *ty_var = *t; - } else { + if seen_vars.contains(t) { // circular reference errs.push(constraint.span.with_item(TypeConstraintError::CircularType)); continue 'outer; } *ty_var = *t; + ty = self.ctx.types.get(*t); } } @@ -1606,6 +1731,8 @@ fn replace_var_reference_types( } mod pretty_printing { + use petr_utils::SymbolInterner; + use crate::*; #[cfg(test)] @@ -1674,28 +1801,29 @@ mod pretty_printing { pub fn pretty_print_ty( ty: &TypeVariable, - type_checker: &TypeChecker, + types: &IndexMap, + interner: &SymbolInterner, ) -> String { - let mut ty = type_checker.look_up_variable(*ty); + let mut ty = types.get(*ty); while let SpecificType::Ref(t) = ty { - println!("looping"); - ty = type_checker.look_up_variable(*t); + ty = types.get(*t); } - pretty_print_petr_type(ty, type_checker) + pretty_print_petr_type(ty, types, interner) } pub fn pretty_print_petr_type( ty: &SpecificType, - type_checker: &TypeChecker, + types: &IndexMap, + interner: &SymbolInterner, ) -> String { match ty { SpecificType::Unit => "unit".to_string(), SpecificType::Integer => "int".to_string(), SpecificType::Boolean => "bool".to_string(), SpecificType::String => "string".to_string(), - SpecificType::Ref(ty) => pretty_print_ty(ty, type_checker), + SpecificType::Ref(ty) => pretty_print_ty(ty, types, interner), SpecificType::UserDefined { name, .. } => { - let name = type_checker.resolved.interner.get(name.id); + let name = interner.get(name.id); name.to_string() }, SpecificType::Arrow(tys) => { @@ -1704,7 +1832,7 @@ mod pretty_printing { for (ix, ty) in tys.iter().enumerate() { let is_last = ix == tys.len() - 1; - s.push_str(&pretty_print_ty(ty, type_checker)); + s.push_str(&pretty_print_ty(ty, types, interner)); if !is_last { s.push_str(" → "); } @@ -1713,7 +1841,7 @@ mod pretty_printing { s }, SpecificType::ErrorRecovery => "error recovery".to_string(), - SpecificType::List(ty) => format!("[{}]", pretty_print_petr_type(ty, type_checker)), + SpecificType::List(ty) => format!("[{}]", pretty_print_petr_type(ty, types, interner)), SpecificType::Infer(id, _) => format!("infer t{id}"), SpecificType::Sum(tys) => { let mut s = String::new(); @@ -1721,7 +1849,7 @@ mod pretty_printing { for (ix, ty) in tys.iter().enumerate() { let is_last = ix == tys.len() - 1; // print the petr ty - s.push_str(&pretty_print_petr_type(ty, type_checker)); + s.push_str(&pretty_print_petr_type(ty, types, interner)); if !is_last { s.push_str(" | "); } @@ -1739,23 +1867,24 @@ mod pretty_printing { type_checker: &TypeChecker, ) -> String { let interner = &type_checker.resolved.interner; + let types = &type_checker.ctx.types; match &typed_expr.kind { TypedExprKind::ExprWithBindings { bindings, expression } => { let mut s = String::new(); for (name, expr) in bindings { let ident = interner.get(name.id); let ty = type_checker.expr_ty(expr); - let ty = pretty_print_ty(&ty, type_checker); + let ty = pretty_print_ty(&ty, types, interner); s.push_str(&format!("{ident}: {:?} ({}),\n", expr, ty)); } let expr_ty = type_checker.expr_ty(expression); - let expr_ty = pretty_print_ty(&expr_ty, type_checker); - s.push_str(&format!("{:?} ({})", pretty_print_typed_expr(expression, type_checker), expr_ty)); + let expr_ty = pretty_print_ty(&expr_ty, types, interner); + s.push_str(&format!("{:?} ({})", pretty_print_typed_expr(expression, &type_checker), expr_ty)); s }, TypedExprKind::Variable { name, ty } => { let name = interner.get(name.id); - let ty = pretty_print_ty(ty, type_checker); + let ty = pretty_print_ty(ty, types, interner); format!("variable {name}: {ty}") }, @@ -1765,14 +1894,14 @@ mod pretty_printing { for (name, arg) in args { let name = interner.get(name.id); let arg_ty = type_checker.expr_ty(arg); - let arg_ty = pretty_print_ty(&arg_ty, type_checker); + let arg_ty = pretty_print_ty(&arg_ty, types, interner); s.push_str(&format!("{name}: {}, ", arg_ty)); } - let ty = pretty_print_ty(ty, type_checker); + let ty = pretty_print_ty(ty, types, interner); s.push_str(&format!("returns {ty}")); s }, - TypedExprKind::TypeConstructor { ty, .. } => format!("type constructor: {}", pretty_print_ty(ty, type_checker)), + TypedExprKind::TypeConstructor { ty, .. } => format!("type constructor: {}", pretty_print_ty(ty, types, interner)), _otherwise => format!("{:?}", typed_expr), } } From b04b5e2917f113a00988d165c7a2cd97433442b0 Mon Sep 17 00:00:00 2001 From: sezna Date: Fri, 2 Aug 2024 18:43:19 -0700 Subject: [PATCH 07/10] progress on refactor --- petr-cli/src/main.rs | 15 +- petr-ir/src/lib.rs | 56 +++-- petr-playground/src/lib.rs | 9 +- petr-typecheck/src/lib.rs | 495 +++++++++++++++++++++++++------------ petr-vm/src/tests.rs | Bin 6132 -> 6154 bytes 5 files changed, 385 insertions(+), 190 deletions(-) diff --git a/petr-cli/src/main.rs b/petr-cli/src/main.rs index 65f6c3d..51cfaa7 100644 --- a/petr-cli/src/main.rs +++ b/petr-cli/src/main.rs @@ -22,6 +22,8 @@ pub mod error { Pkg(#[from] petr_pkg::error::PkgError), #[error("Failed to lower code")] FailedToLower, + #[error("Program contained type errors")] + FailedToTypeCheck, } } @@ -220,12 +222,20 @@ pub fn compile( timings.start("type check"); // type check - let (type_errs, type_checker) = petr_typecheck::type_check(resolved); + let res = petr_typecheck::type_check(resolved); timings.end("type check"); + let type_solution = match res { + Ok(o) => o, + Err(e) => { + render_errors(parse_errs, &source_map); + render_errors(e, &source_map); + return Err(PeteError::FailedToTypeCheck); + }, + }; timings.start("lowering"); - let lowerer: Lowerer = match Lowerer::new(type_checker) { + let lowerer: Lowerer = match Lowerer::new(type_solution) { Ok(l) => l, Err(e) => { eprintln!("Failed to lower: {:?}", e); @@ -235,7 +245,6 @@ pub fn compile( timings.end("lowering"); render_errors(parse_errs, &source_map); - render_errors(type_errs, &source_map); render_errors(resolution_errs, &source_map); Ok(lowerer) } diff --git a/petr-ir/src/lib.rs b/petr-ir/src/lib.rs index a5ec046..c643b38 100644 --- a/petr-ir/src/lib.rs +++ b/petr-ir/src/lib.rs @@ -9,7 +9,7 @@ use std::{collections::BTreeMap, rc::Rc}; -use petr_typecheck::{FunctionSignature, SpecificType, Type, TypeChecker, TypeVariable, TypedExpr, TypedExprKind}; +use petr_typecheck::{FunctionSignature, SpecificType, TypeSolution, TypeVariable, TypedExpr, TypedExprKind}; use petr_utils::{idx_map_key, Identifier, IndexMap, SpannedItem, SymbolId}; mod error; @@ -19,8 +19,8 @@ pub use error::LoweringError; use opcodes::*; pub use opcodes::{DataLabel, Intrinsic, IrOpcode, LabelId, Reg, ReservedRegister}; -pub fn lower(checker: TypeChecker) -> Result<(DataSection, Vec)> { - let lowerer = Lowerer::new(checker)?; +pub fn lower(solution: TypeSolution) -> Result<(DataSection, Vec)> { + let lowerer = Lowerer::new(solution)?; Ok(lowerer.finalize()) } @@ -38,7 +38,7 @@ pub struct Lowerer { data_section: DataSection, entry_point: Option, reg_assigner: usize, - type_checker: TypeChecker, + type_solution: TypeSolution, variables_in_scope: Vec>, monomorphized_functions: IndexMap, errors: Vec>, @@ -53,16 +53,19 @@ pub enum DataSectionEntry { } impl Lowerer { - pub fn new(type_checker: TypeChecker) -> Result { + pub fn new(type_solution: petr_typecheck::TypeSolution) -> Result { // if there is an entry point, set that // set entry point to func named main - let entry_point = type_checker.get_main_function(); + let entry_point = match type_solution.get_main_function() { + Some((a, b)) => Some((a.clone(), b.clone())), + None => None, + }; let mut lowerer = Self { data_section: IndexMap::default(), entry_point: None, reg_assigner: 0, - type_checker, + type_solution, variables_in_scope: Default::default(), monomorphized_functions: Default::default(), label_assigner: 0, @@ -110,7 +113,7 @@ impl Lowerer { return Ok(previously_monomorphized_definition.0); } - let func_def = self.type_checker.get_monomorphized_function(&func).clone(); + let func_def = self.type_solution.get_monomorphized_function(&func).clone(); let mut buf = vec![]; self.with_variable_context(|ctx| -> Result<_> { @@ -181,8 +184,8 @@ impl Lowerer { for (arg_name, arg_expr) in args { let reg = self.fresh_reg(); let mut expr = self.lower_expr(arg_expr, ReturnDestination::Reg(reg))?; - let arg_ty = self.type_checker.expr_ty(arg_expr); - let petr_ty = self.type_checker.look_up_variable(arg_ty); + let arg_ty = self.type_solution.expr_ty(arg_expr); + let petr_ty = self.type_solution.get_latest_type(arg_ty); arg_types.push((*arg_name, petr_ty.clone())); let ir_ty = self.lower_type(petr_ty.clone()); expr.push(IrOpcode::StackPush(TypedReg { ty: ir_ty, reg })); @@ -192,7 +195,7 @@ impl Lowerer { // push current PC onto the stack buf.push(IrOpcode::PushPc()); - let arg_petr_types = arg_types.iter().map(|(_name, ty)| ty.generalize(self.type_checker.ctx())).collect(); + let arg_petr_types = arg_types.iter().map(|(_name, ty)| self.type_solution.generalize(ty)).collect(); let monomorphized_func_id = self.monomorphize_function((*func, arg_petr_types))?; @@ -215,7 +218,7 @@ impl Lowerer { List { elements, .. } => { let size_of_each_elements = elements .iter() - .map(|el| self.to_ir_type(self.type_checker.expr_ty(el)).size().num_bytes() as u64) + .map(|el| self.to_ir_type(self.type_solution.expr_ty(el)).size().num_bytes() as u64) .sum::(); let size_of_list = size_of_each_elements * elements.len() as u64; let size_of_list_reg = self.fresh_reg(); @@ -235,7 +238,7 @@ impl Lowerer { buf.push(IrOpcode::LoadImmediate(current_offset_reg, current_offset)); buf.push(IrOpcode::Add(current_offset_reg, current_offset_reg, return_reg)); buf.push(IrOpcode::WriteRegisterToMemory(reg, current_offset_reg)); - current_offset += self.to_ir_type(self.type_checker.expr_ty(el)).size().num_bytes() as u64; + current_offset += self.to_ir_type(self.type_solution.expr_ty(el)).size().num_bytes() as u64; } Ok(buf) }, @@ -291,7 +294,7 @@ impl Lowerer { buf.push(IrOpcode::Add(current_size_offset_reg, current_size_offset_reg, return_destination)); buf.push(IrOpcode::WriteRegisterToMemory(reg, current_size_offset_reg)); - let arg_ty = self.type_checker.expr_ty(arg); + let arg_ty = self.type_solution.expr_ty(arg); current_size_offset += self.to_ir_type(arg_ty).size().num_bytes() as u64; } @@ -345,7 +348,7 @@ impl Lowerer { // Generalize the type. These general types are much more useful for codegen. // Specific types include extra information about constant literal value types, // data flow analysis, effects tracking, etc., that codegen does not care about. - let ty = ty.generalize(self.type_checker.ctx()); + let ty = self.type_solution.generalize(&ty.as_specific_ty()); use petr_typecheck::GeneralType::*; match ty { Unit => IrTy::Unit, @@ -406,7 +409,7 @@ impl Lowerer { &mut self, param_ty: TypeVariable, ) -> IrTy { - let ty = self.type_checker.look_up_variable(param_ty).clone(); + let ty = self.type_solution.get_latest_type(param_ty).clone(); self.lower_type(ty) } @@ -450,7 +453,7 @@ impl Lowerer { Ok(buf) }, SizeOf(expr) => { - let ty = self.type_checker.expr_ty(expr); + let ty = self.type_solution.expr_ty(expr); let size = self.to_ir_type(ty).size(); match return_destination { ReturnDestination::Reg(reg) => { @@ -582,6 +585,7 @@ mod tests { use expect_test::{expect, Expect}; use petr_resolve::resolve_symbols; + use petr_typecheck::TypeChecker; use petr_utils::render_error; use super::*; @@ -606,15 +610,19 @@ mod tests { } panic!("resolving names failed"); } - let type_checker = TypeChecker::new(resolved); + let mut type_checker = TypeChecker::new(resolved); - let typecheck_errors = type_checker.errors(); - if !typecheck_errors.is_empty() { - typecheck_errors.iter().for_each(|err| eprintln!("{:?}", err)); - panic!("ir gen failed: code didn't typecheck"); - } + type_checker.fully_type_check(); + + let solution = match type_checker.into_solution() { + Ok(o) => o, + Err(errs) => { + errs.iter().for_each(|err| eprintln!("{:?}", err)); + panic!("ir gen failed: code didn't typecheck"); + }, + }; - let lowerer = match Lowerer::new(type_checker) { + let lowerer = match Lowerer::new(solution) { Ok(lowerer) => lowerer, Err(err) => { eprintln!("{:?}", err); diff --git a/petr-playground/src/lib.rs b/petr-playground/src/lib.rs index c6dc2cd..3c88684 100644 --- a/petr-playground/src/lib.rs +++ b/petr-playground/src/lib.rs @@ -74,13 +74,18 @@ fn compile_snippet(input: String) -> Result> { return Err(errs.into_iter().map(|err| format!("{:?}", render_error(&source_map, err))).collect()); } - let (errs, type_checker) = type_check(resolved); + let solution = match type_check(resolved) { + Ok(o) => o, + Err(e) => { + return Err(e.into_iter().map(|err| format!("{:?}", render_error(&source_map, err))).collect()); + }, + }; if !errs.is_empty() { return Err(errs.into_iter().map(|err| format!("{:?}", render_error(&source_map, err))).collect()); } - let lowerer = match Lowerer::new(type_checker) { + let lowerer = match Lowerer::new(solution) { Ok(l) => l, Err(err) => panic!("lowering failed: {:?}", err), }; diff --git a/petr-typecheck/src/lib.rs b/petr-typecheck/src/lib.rs index 40ca5dc..7fa8283 100644 --- a/petr-typecheck/src/lib.rs +++ b/petr-typecheck/src/lib.rs @@ -5,6 +5,7 @@ //! - Satisfies //! - UnifyEffects //! - SatisfiesEffects +#![allow(warnings)] mod error; @@ -22,15 +23,11 @@ use petr_utils::{idx_map_key, Identifier, IndexMap, Span, SpannedItem, SymbolId, pub type TypeError = SpannedItem; pub type TResult = Result; -// TODO return QueryableTypeChecked instead of type checker -// Clean up API so this is the only function exposed -pub fn type_check(resolved: QueryableResolvedItems) -> (Vec, TypeChecker) { - todo!("design new api") - /* - let solution = TypeChecker::new(resolved); +pub fn type_check(resolved: QueryableResolvedItems) -> Result>> { + let mut type_checker = TypeChecker::new(resolved); type_checker.fully_type_check(); - (type_checker.errors.clone(), type_checker) - */ + + type_checker.into_solution() } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] @@ -182,6 +179,10 @@ impl TypeContext { ) { *self.types.get_mut(t1) = known; } + + pub fn types(&self) -> &IndexMap { + &self.types + } } pub type FunctionSignature = (FunctionId, Box<[GeneralType]>); @@ -236,33 +237,80 @@ impl GeneralType { /// Represents the result of the type-checking stage for an individual type variable. pub struct TypeSolutionEntry { - axiomatic: bool, - ty: SpecificType, + source: TypeSolutionSource, + ty: SpecificType, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum TypeSolutionSource { + Inherent, + Axiomatic, + Inferred, } impl TypeSolutionEntry { pub fn new_axiomatic(ty: SpecificType) -> Self { - Self { axiomatic: true, ty } + Self { + source: TypeSolutionSource::Axiomatic, + ty, + } + } + + pub fn new_inherent(ty: SpecificType) -> Self { + Self { + source: TypeSolutionSource::Inherent, + ty, + } + } + + pub fn new_inferred(ty: SpecificType) -> Self { + Self { + source: TypeSolutionSource::Inferred, + ty, + } + } + + pub fn is_axiomatic(&self) -> bool { + self.source == TypeSolutionSource::Axiomatic } } pub struct TypeSolution { - solution: BTreeMap, + solution: BTreeMap, unsolved_types: IndexMap, - errors: Vec, - interner: SymbolInterner, + errors: Vec, + interner: SymbolInterner, + error_recovery: TypeVariable, + unit: TypeVariable, + functions: BTreeMap, + monomorphized_functions: BTreeMap, } impl TypeSolution { pub fn new( unsolved_types: IndexMap, + error_recovery: TypeVariable, + unit: TypeVariable, + functions: BTreeMap, + monomorphized_functions: BTreeMap, interner: SymbolInterner, + preexisting_errors: Vec, ) -> Self { + let solution = vec![ + (unit, TypeSolutionEntry::new_inherent(SpecificType::Unit)), + (error_recovery, TypeSolutionEntry::new_inherent(SpecificType::ErrorRecovery)), + ] + .into_iter() + .collect(); Self { - solution: Default::default(), + solution, unsolved_types, - errors: Default::default(), + errors: preexisting_errors, interner, + functions, + monomorphized_functions, + unit, + error_recovery, } } @@ -337,7 +385,7 @@ impl TypeSolution { ) { match self.solution.get_mut(&ty) { Some(e) => { - if e.axiomatic { + if e.is_axiomatic() { let pretty_printed_preexisting = pretty_printing::pretty_print_petr_type(&entry.ty, &self.unsolved_types, &self.interner); let pretty_printed_ty = pretty_printing::pretty_print_petr_type(&entry.ty, &self.unsolved_types, &self.interner); self.errors @@ -347,9 +395,7 @@ impl TypeSolution { *e = entry; }, None => { - self.errors.push(span.with_item(TypeConstraintError::Internal( - "attempted to update type that did not exist in solution".into(), - ))); + self.solution.insert(ty, entry); }, } } @@ -384,7 +430,8 @@ impl TypeSolution { (Infer(id, _), Infer(id2, _)) if id != id2 => { // if two different inferred types are unified, replace the second with a reference // to the first - self.update_type(t2, Ref(t1)); + let entry = TypeSolutionEntry::new_inferred(Ref(t1)); + self.update_type(t2, entry, span); }, (a @ Sum(_), b @ Sum(_)) => { // the unification of two sum types is the union of the two types if and only if @@ -392,7 +439,7 @@ impl TypeSolution { // `t1` remains unchanged, as we are trying to coerce `t2` into something that // represents `t1` // TODO remove clone - if a.is_superset_of(&b, &self.ctx) { + if self.a_superset_of_b(&a, &b) { } else { self.push_error(span.with_item(self.unify_err(a, b))); } @@ -400,40 +447,48 @@ impl TypeSolution { // If `t2` is a non-sum type, and `t1` is a sum type, then `t1` must contain either // exactly the same specific type OR the generalization of that type // If the latter, then the specific type must be updated to its generalization - (Sum(sum_tys), other) => { - if sum_tys.contains(&other) { - self.ctx.update_type(t2, Ref(t1)); + (ref t1_ty @ Sum(_), other) => { + if self.a_superset_of_b(&t1_ty, &other) { + // t2 unifies to the more general form provided by t1 + let entry = TypeSolutionEntry::new_inferred(Ref(t1)); + self.update_type(t2, entry, span); } else { - let generalization = other.generalize(&self.ctx).safely_upcast(); - if sum_tys.contains(&generalization) { - self.ctx.update_type(t2, generalization); - } else { - self.push_error(span.with_item(self.unify_err(Sum(sum_tys.clone()), generalization))); - } + self.push_error(span.with_item(self.unify_err(t1_ty.clone(), other))); } }, // literals can unify to each other if they're equal (Literal(l1), Literal(l2)) if l1 == l2 => (), + // if they're not equal, their unification is the sum of both (Literal(l1), Literal(l2)) if l1 != l2 => { // update t1 to a sum type of both, // and update t2 to reference t1 let sum = Sum([Literal(l1), Literal(l2)].into()); - self.ctx.update_type(t1, sum); - self.ctx.update_type(t2, Ref(t1)); + let t1_entry = TypeSolutionEntry::new_inferred(sum); + let t2_entry = TypeSolutionEntry::new_inferred(Ref(t1)); + self.update_type(t1, t1_entry, span); + self.update_type(t2, t2_entry, span); }, (Literal(l1), Sum(tys)) => { // update t1 to a sum type of both, // and update t2 to reference t1 let sum = Sum([Literal(l1)].iter().chain(tys.iter()).cloned().collect()); - self.ctx.update_type(t1, sum); - self.ctx.update_type(t2, Ref(t1)); + let t1_entry = TypeSolutionEntry::new_inferred(sum); + let t2_entry = TypeSolutionEntry::new_inferred(Ref(t1)); + self.update_type(t1, t1_entry, span); + self.update_type(t2, t2_entry, span); }, + /* TODO rewrite below rules // literals can unify broader parent types // but the broader parent type gets instantiated with the literal type + // TODO(alex) this rule feels incorrect. A literal being unified to the parent type + // should upcast the lit, not downcast t1. Check after refactoring. (ty, Literal(lit)) => match (&lit, ty) { (petr_resolve::Literal::Integer(_), Integer) | (petr_resolve::Literal::Boolean(_), Boolean) - | (petr_resolve::Literal::String(_), String) => self.ctx.update_type(t1, SpecificType::Literal(lit)), + | (petr_resolve::Literal::String(_), String) => { + let entry = TypeSolutionEntry::new_inferred(SpecificType::Literal(lit)); + self.update_type(t1, entry, span); + }, (lit, ty) => self.push_error(span.with_item(self.unify_err(ty.clone(), SpecificType::Literal(lit.clone())))), }, // literals can unify broader parent types @@ -441,25 +496,28 @@ impl TypeSolution { (Literal(lit), ty) => match (&lit, ty) { (petr_resolve::Literal::Integer(_), Integer) | (petr_resolve::Literal::Boolean(_), Boolean) - | (petr_resolve::Literal::String(_), String) => self.ctx.update_type(t2, SpecificType::Literal(lit)), + | (petr_resolve::Literal::String(_), String) => self.update_type(t2, SpecificType::Literal(lit)), (lit, ty) => { self.push_error(span.with_item(self.unify_err(ty.clone(), SpecificType::Literal(lit.clone())))); }, }, - (other, Sum(sum_tys)) => { - // `other` must be a member of the Sum type - if !sum_tys.contains(&other) { - self.push_error(span.with_item(self.unify_err(other.clone(), SpecificType::sum(sum_tys.clone())))); + */ + (other, ref t2_ty @ Sum(_)) => { + // if `other` is a superset of `t2`, then `t2` unifies to `other` as it is more + // general + if self.a_superset_of_b(&other, &t2_ty) { + let entry = TypeSolutionEntry::new_inferred(other); + self.update_type(t2, entry, span); } - // unify both types to the other type - self.ctx.update_type(t2, other); }, // instantiate the infer type with the known type - (Infer(_, _), known) => { - self.ctx.update_type(t1, known); + (Infer(_, _), _known) => { + let entry = TypeSolutionEntry::new_inferred(Ref(t2)); + self.update_type(t1, entry, span); }, - (known, Infer(_, _)) => { - self.ctx.update_type(t2, known); + (_known, Infer(_, _)) => { + let entry = TypeSolutionEntry::new_inferred(Ref(t1)); + self.update_type(t2, entry, span); }, // lastly, if no unification rule exists for these two types, it is a mismatch (a, b) => { @@ -476,28 +534,30 @@ impl TypeSolution { t2: TypeVariable, span: Span, ) { - let ty1 = self.ctx.types.get(t1); - let ty2 = self.ctx.types.get(t2); + let ty1 = self.get_latest_type(t1); + let ty2 = self.get_latest_type(t2); use SpecificType::*; match (ty1, ty2) { (a, b) if a == b => (), (ErrorRecovery, _) | (_, ErrorRecovery) => (), - (Ref(a), _) => self.apply_satisfies_constraint(*a, t2, span), - (_, Ref(b)) => self.apply_satisfies_constraint(t1, *b, span), + (Ref(a), _) => self.apply_satisfies_constraint(a, t2, span), + (_, Ref(b)) => self.apply_satisfies_constraint(t1, b, span), // if t1 is a fully instantiated type, then t2 can be updated to be a reference to t1 (Unit | Integer | Boolean | UserDefined { .. } | String | Arrow(..) | List(..) | Literal(_) | Sum(_), Infer(_, _)) => { - self.ctx.update_type(t2, Ref(t1)); + let entry = TypeSolutionEntry::new_inferred(Ref(t1)); + self.update_type(t2, entry, span); }, // the "parent" infer type will not instantiate to the "child" type (Infer(_, _), Unit | Integer | Boolean | UserDefined { .. } | String | Arrow(..) | List(..) | Literal(_) | Sum(_)) => (), (Sum(a_tys), Sum(b_tys)) => { // calculate the intersection of these types, update t2 to the intersection let intersection = a_tys.iter().filter(|a_ty| b_tys.contains(a_ty)).cloned().collect(); - self.ctx.update_type(t2, SpecificType::sum(intersection)); + let entry = TypeSolutionEntry::new_inferred(SpecificType::sum(intersection)); + self.update_type(t2, entry, span); }, // if `ty1` is a generalized version of the sum type, // then it satisfies the sum type - (ty1, other) if ty1.is_superset_of(other, &self.ctx) => (), + (ty1, other) if self.a_superset_of_b(&ty1, &other) => (), (Literal(l1), Literal(l2)) if l1 == l2 => (), // Literals can satisfy broader parent types (ty, Literal(lit)) => match (lit, ty) { @@ -518,7 +578,7 @@ impl TypeSolution { /// Gets the latest version of a type available. First checks solved types, /// and if it doesn't exist, gets it from the unsolved types. - fn get_latest_type( + pub fn get_latest_type( &self, t1: TypeVariable, ) -> SpecificType { @@ -527,6 +587,145 @@ impl TypeSolution { .map(|entry| entry.ty.clone()) .unwrap_or_else(|| self.unsolved_types.get(t1).clone()) } + + /// To reference an error recovery type, you must provide an error. + /// This holds the invariant that error recovery types are only generated when + /// an error occurs. + pub fn error_recovery( + &mut self, + err: TypeError, + ) -> TypeVariable { + self.push_error(err); + self.error_recovery + } + + /// If `a` is a generalized form of `b`, return true + /// A generalized form is a type that is a superset of the sum types. + /// For example, `String` is a generalized form of `Sum(Literal("a") | Literal("B"))` + fn a_superset_of_b( + &self, + a: &SpecificType, + b: &SpecificType, + ) -> bool { + use SpecificType::*; + let generalized_b = self.generalize(&b).safely_upcast(); + match (a, b) { + // If `a` is the generalized form of `b`, then `b` satisfies the constraint. + (a, b) if a == b || *a == generalized_b => true, + // If `a` is a sum type which contains `b` OR the generalized form of `b`, then `b` + // satisfies the constraint. + (Sum(a_tys), b) if a_tys.contains(b) || a_tys.contains(&generalized_b) => true, + // if both `a` and `b` are sum types, then `a` must be a superset of `b`: + // - every element in `b` must either: + // - be a member of `a` + // - generalize to a member of `a` + (Sum(a_tys), Sum(b_tys)) => { + // if a_tys is a superset of b_tys, + // every element OR its generalized version is contained in a_tys + for b_ty in b_tys { + let b_ty_generalized = self.generalize(b_ty).safely_upcast(); + if !(a_tys.contains(b_ty) || a_tys.contains(&b_ty_generalized)) { + return false; + } + } + + true + }, + _otherwise => false, + } + } + + pub fn generalize( + &self, + b: &SpecificType, + ) -> GeneralType { + match b { + SpecificType::Unit => GeneralType::Unit, + SpecificType::Integer => GeneralType::Integer, + SpecificType::Boolean => GeneralType::Boolean, + SpecificType::String => GeneralType::String, + SpecificType::Ref(ty) => self.generalize(&self.get_latest_type(*ty)), + SpecificType::UserDefined { + name, + variants, + constant_literal_types, + } => GeneralType::UserDefined { + name: *name, + variants: variants + .iter() + .map(|variant| { + let generalized_fields = variant.fields.iter().map(|field| self.generalize(field)).collect::>(); + + GeneralizedTypeVariant { + fields: generalized_fields.into_boxed_slice(), + } + }) + .collect(), + constant_literal_types: constant_literal_types.clone(), + }, + SpecificType::Arrow(tys) => GeneralType::Arrow(tys.clone()), + SpecificType::ErrorRecovery => GeneralType::ErrorRecovery, + SpecificType::List(ty) => { + let ty = self.generalize(ty); + GeneralType::List(Box::new(ty)) + }, + SpecificType::Infer(u, s) => GeneralType::Infer(*u, *s), + SpecificType::Literal(l) => match l { + Literal::Integer(_) => GeneralType::Integer, + Literal::Boolean(_) => GeneralType::Boolean, + Literal::String(_) => GeneralType::String, + }, + SpecificType::Sum(tys) => { + // generalize all types, fold if possible + let all_generalized: BTreeSet<_> = tys.iter().map(|ty| self.generalize(ty)).collect(); + if all_generalized.len() == 1 { + // in this case, all specific types generalized to the same type + all_generalized.into_iter().next().expect("invariant") + } else { + GeneralType::Sum(all_generalized.into_iter().collect()) + } + }, + } + } + + #[cfg(test)] + fn pretty_print(&self) -> String { + let mut pretty = String::new(); + for (ty, entry) in self.solution.iter().filter(|(id, _)| ![self.unit, self.error_recovery].contains(id)) { + pretty.push_str(&format!("{}: {}\n", ty, self.pretty_print_type(&entry.ty))); + } + pretty + } + + pub fn get_main_function(&self) -> Option<(&FunctionId, &Function)> { + self.functions.iter().find(|(_, func)| &*self.interner.get(func.name.id) == "main") + } + + pub fn get_monomorphized_function( + &self, + id: &FunctionSignature, + ) -> &Function { + self.monomorphized_functions.get(id).expect("invariant: should exist") + } + + pub fn expr_ty( + &self, + expr: &TypedExpr, + ) -> TypeVariable { + use TypedExprKind::*; + match &expr.kind { + FunctionCall { ty, .. } => *ty, + Literal { ty, .. } => *ty, + List { ty, .. } => *ty, + Unit => self.unit, + Variable { ty, .. } => *ty, + Intrinsic { ty, .. } => *ty, + ErrorRecovery(..) => self.error_recovery, + ExprWithBindings { expression, .. } => self.expr_ty(expression), + TypeConstructor { ty, .. } => *ty, + If { then_branch, .. } => self.expr_ty(then_branch), + } + } } /// This is an information-rich type -- it tracks effects and data types. It is used for @@ -567,16 +766,16 @@ pub struct GeneralizedTypeVariant { } impl SpecificType { - fn generalize( + fn generalize_inner( &self, - ctx: &TypeContext, + types: &IndexMap, ) -> GeneralType { match self { SpecificType::Unit => GeneralType::Unit, SpecificType::Integer => GeneralType::Integer, SpecificType::Boolean => GeneralType::Boolean, SpecificType::String => GeneralType::String, - SpecificType::Ref(ty) => ctx.types.get(*ty).generalize(ctx), + SpecificType::Ref(ty) => types.get(*ty).generalize(types), SpecificType::UserDefined { name, variants, @@ -586,7 +785,7 @@ impl SpecificType { variants: variants .iter() .map(|variant| { - let generalized_fields = variant.fields.iter().map(|field| field.generalize(ctx)).collect::>(); + let generalized_fields = variant.fields.iter().map(|field| field.generalize(types)).collect::>(); GeneralizedTypeVariant { fields: generalized_fields.into_boxed_slice(), @@ -598,7 +797,7 @@ impl SpecificType { SpecificType::Arrow(tys) => GeneralType::Arrow(tys.clone()), SpecificType::ErrorRecovery => GeneralType::ErrorRecovery, SpecificType::List(ty) => { - let ty = ty.generalize(ctx); + let ty = ty.generalize(types); GeneralType::List(Box::new(ty)) }, SpecificType::Infer(u, s) => GeneralType::Infer(*u, *s), @@ -609,7 +808,7 @@ impl SpecificType { }, SpecificType::Sum(tys) => { // generalize all types, fold if possible - let all_generalized: BTreeSet<_> = tys.iter().map(|ty| ty.generalize(ctx)).collect(); + let all_generalized: BTreeSet<_> = tys.iter().map(|ty| ty.generalize(types)).collect(); if all_generalized.len() == 1 { // in this case, all specific types generalized to the same type all_generalized.into_iter().next().expect("invariant") @@ -620,45 +819,6 @@ impl SpecificType { } } - /// If `self` is a generalized form of `b`, return true - /// A generalized form is a type that is a superset of the sum types. - /// For example, `String` is a generalized form of `Sum(Literal("a") | Literal("B"))` - fn is_superset_of( - &self, - b: &SpecificType, - ctx: &TypeContext, - ) -> bool { - use SpecificType::*; - dbg!(&self); - dbg!(&b); - let generalized_b = b.generalize(ctx).safely_upcast(); - dbg!(&generalized_b); - match (self, b) { - // If `a` is the generalized form of `b`, then `b` satisfies the constraint. - (a, b) if a == b || *a == generalized_b => true, - // If `a` is a sum type which contains `b` OR the generalized form of `b`, then `b` - // satisfies the constraint. - (Sum(a_tys), b) if a_tys.contains(b) || a_tys.contains(&generalized_b) => true, - // if both `a` and `b` are sum types, then `a` must be a superset of `b`: - // - every element in `b` must either: - // - be a member of `a` - // - generalize to a member of `a` - (Sum(a_tys), Sum(b_tys)) => { - // if a_tys is a superset of b_tys, - // every element OR its generalized version is contained in a_tys - for b_ty in b_tys { - let b_ty_generalized = b_ty.generalize(ctx).safely_upcast(); - if !(a_tys.contains(b_ty) || a_tys.contains(&b_ty_generalized)) { - return false; - } - } - - true - }, - _otherwise => false, - } - } - /// Use this to construct `[SpecificType::Sum]` types -- /// it will attempt to collapse the sum into a single type if possible fn sum(tys: BTreeSet) -> SpecificType { @@ -676,33 +836,36 @@ pub struct TypeVariant { } pub trait Type { - fn as_specific_ty( - &self, - ctx: &TypeContext, - ) -> SpecificType; + fn as_specific_ty(&self) -> SpecificType; fn generalize( &self, - ctx: &TypeContext, - ) -> GeneralType { - self.as_specific_ty(ctx).generalize(ctx) - } + types: &IndexMap, + ) -> GeneralType; } impl Type for SpecificType { - fn as_specific_ty( - &self, - _ctx: &TypeContext, - ) -> SpecificType { + fn as_specific_ty(&self) -> SpecificType { self.clone() } + + fn generalize( + &self, + types: &IndexMap, + ) -> GeneralType { + self.generalize_inner(&types) + } } impl Type for GeneralType { - fn as_specific_ty( + fn generalize( &self, - _ctx: &TypeContext, - ) -> SpecificType { + _: &IndexMap, + ) -> Self { + self.clone() + } + + fn as_specific_ty(&self) -> SpecificType { match self { GeneralType::Unit => SpecificType::Unit, GeneralType::Integer => SpecificType::Integer, @@ -717,7 +880,7 @@ impl Type for GeneralType { variants: variants .iter() .map(|variant| { - let fields = variant.fields.iter().map(|field| field.as_specific_ty(_ctx)).collect::>(); + let fields = variant.fields.iter().map(|field| field.as_specific_ty()).collect::>(); TypeVariant { fields: fields.into_boxed_slice(), @@ -728,10 +891,10 @@ impl Type for GeneralType { }, GeneralType::Arrow(tys) => SpecificType::Arrow(tys.clone()), GeneralType::ErrorRecovery => SpecificType::ErrorRecovery, - GeneralType::List(ty) => SpecificType::List(Box::new(ty.as_specific_ty(_ctx))), + GeneralType::List(ty) => SpecificType::List(Box::new(ty.as_specific_ty())), GeneralType::Infer(u, s) => SpecificType::Infer(*u, *s), GeneralType::Sum(tys) => { - let tys = tys.iter().map(|ty| ty.as_specific_ty(_ctx)).collect(); + let tys = tys.iter().map(|ty| ty.as_specific_ty()).collect(); SpecificType::Sum(tys) }, } @@ -743,7 +906,7 @@ impl TypeChecker { &mut self, ty: &T, ) -> TypeVariable { - let ty = ty.as_specific_ty(&self.ctx); + let ty = ty.as_specific_ty(); // TODO: check if type already exists and return that ID instead self.ctx.types.insert(ty) } @@ -808,7 +971,7 @@ impl TypeChecker { None } - fn fully_type_check(mut self) -> Result> { + pub fn fully_type_check(&mut self) { for (id, decl) in self.resolved.types() { let ty = self.fresh_ty_var(decl.name.span); let variants = decl @@ -835,7 +998,7 @@ impl TypeChecker { } for (id, func) in self.resolved.functions() { - let typed_function = func.type_check(&mut self); + let typed_function = func.type_check(self); let ty = self.arrow_type([typed_function.params.iter().map(|(_, b)| *b).collect(), vec![typed_function.return_ty]].concat()); self.type_map.insert(id.into(), ty); @@ -850,16 +1013,13 @@ impl TypeChecker { args: vec![], span: func.name.span, }; - call.type_check(&mut self); + call.type_check(self); } // before applying existing constraints, it is likely that many duplicate constraints // exist. We can safely remove any duplicate constraints to avoid excessive error // reporting. self.deduplicate_constraints(); - - // we have now collected our constraints and can solve for them - self.into_solution() } pub fn get_main_function(&self) -> Option<(FunctionId, Function)> { @@ -870,9 +1030,17 @@ impl TypeChecker { /// - unification tries to collapse two types into one /// - satisfaction tries to make one type satisfy the constraints of another, although type /// constraints don't exist in the language yet - fn into_solution(self) -> Result> { + pub fn into_solution(self) -> Result> { let constraints = self.ctx.constraints.clone(); - let mut solution = TypeSolution::new(self.ctx.types.clone(), self.resolved.interner); + let mut solution = TypeSolution::new( + self.ctx.types.clone(), + self.ctx.error_recovery, + self.ctx.unit_ty, + self.typed_functions, + self.monomorphized_functions, + self.resolved.interner, + self.errors, + ); for TypeConstraint { kind, span } in constraints .iter() .filter(|c| if let TypeConstraintKind::Axiom(_) = c.kind { true } else { false }) @@ -885,27 +1053,25 @@ impl TypeChecker { solution.insert_solution(*axiomatic_variable, TypeSolutionEntry::new_axiomatic(ty), *span); } - /* // now apply the constraints for constraint in constraints.iter().filter(|c| !matches!(c.kind, TypeConstraintKind::Axiom(_))) { match &constraint.kind { TypeConstraintKind::Unify(t1, t2) => { - self.apply_unify_constraint(*t1, *t2, constraint.span); + solution.apply_unify_constraint(*t1, *t2, constraint.span); }, TypeConstraintKind::Satisfies(t1, t2) => { - self.apply_satisfies_constraint(*t1, *t2, constraint.span); + solution.apply_satisfies_constraint(*t1, *t2, constraint.span); }, TypeConstraintKind::Axiom(_) => unreachable!(), } } - */ solution.into_result() } - pub fn new(resolved: QueryableResolvedItems) -> Result> { + pub fn new(resolved: QueryableResolvedItems) -> Self { let ctx = TypeContext::default(); - let type_checker = TypeChecker { + TypeChecker { ctx, type_map: Default::default(), errors: Default::default(), @@ -913,9 +1079,7 @@ impl TypeChecker { resolved, variable_scope: Default::default(), monomorphized_functions: Default::default(), - }; - - type_checker.fully_type_check() + } } pub fn insert_variable( @@ -1094,17 +1258,6 @@ impl TypeChecker { self.ctx.bool_ty } - /// To reference an error recovery type, you must provide an error. - /// This holds the invariant that error recovery types are only generated when - /// an error occurs. - pub fn error_recovery( - &mut self, - err: TypeError, - ) -> TypeVariable { - self.push_error(err); - self.ctx.error_recovery - } - pub fn errors(&self) -> &[TypeError] { &self.errors } @@ -1152,10 +1305,10 @@ impl TypeChecker { let mut constraints = ConstraintDeduplicator::default(); let mut errs = vec![]; for constraint in &self.ctx.constraints { - //println!("on constraint: {:?}", constraint); let (mut tys, kind) = match &constraint.kind { TypeConstraintKind::Unify(t1, t2) => (vec![*t1, *t2], Kind::Unify), TypeConstraintKind::Satisfies(t1, t2) => (vec![*t1, *t2], Kind::Satisfies), + TypeConstraintKind::Axiom(t1) => (vec![*t1], Kind::Axiom), }; // resolve all `Ref` types to get a resolved type variable @@ -1178,12 +1331,15 @@ impl TypeChecker { constraints.insert((kind, tys), *constraint); } - for err in errs { - self.push_error(err); - } - self.ctx.constraints = constraints.into_values(); } + + fn push_error( + &mut self, + e: TypeError, + ) { + self.errors.push(e); + } } /// the `key` type is what we use to deduplicate constraints @@ -1642,7 +1798,7 @@ impl TypeCheck for petr_resolve::FunctionCall { let concrete_arg_types: Vec<_> = args .iter() - .map(|(_, _, ty)| ctx.look_up_variable(*ty).generalize(&ctx.ctx).clone()) + .map(|(_, _, ty)| ctx.look_up_variable(*ty).generalize(&ctx.ctx.types).clone()) .collect(); let signature: FunctionSignature = (self.function, concrete_arg_types.clone().into_boxed_slice()); @@ -1685,6 +1841,7 @@ impl TypeCheck for petr_resolve::FunctionCall { ); ctx.monomorphized_functions.insert(signature, monomorphized_func_decl); + // if there are any variable exprs in the body, update those ref types TypedExprKind::FunctionCall { func: self.function, @@ -1736,7 +1893,7 @@ mod pretty_printing { use crate::*; #[cfg(test)] - pub fn pretty_print_type_checker(type_checker: TypeChecker) -> String { + pub fn pretty_print_type_checker(type_checker: &TypeChecker) -> String { let mut s = String::new(); for (id, ty) in &type_checker.type_map { let text = match id { @@ -1756,7 +1913,7 @@ mod pretty_printing { }; s.push_str(&text); s.push_str(": "); - s.push_str(&pretty_print_ty(ty, &type_checker)); + s.push_str(&pretty_print_ty(ty, &type_checker.ctx.types, &type_checker.resolved.interner)); s.push('\n'); match id { @@ -1777,12 +1934,16 @@ mod pretty_printing { for func in type_checker.monomorphized_functions.values() { let func_name = type_checker.resolved.interner.get(func.name.id); - let arg_types = func.params.iter().map(|(_, ty)| pretty_print_ty(ty, &type_checker)).collect::>(); + let arg_types = func + .params + .iter() + .map(|(_, ty)| pretty_print_ty(ty, &type_checker.ctx.types, &type_checker.resolved.interner)) + .collect::>(); s.push_str(&format!( "\nfn {}({:?}) -> {}", func_name, arg_types, - pretty_print_ty(&func.return_ty, &type_checker) + pretty_print_ty(&func.return_ty, &type_checker.ctx.types, &type_checker.resolved.interner) )); } @@ -1792,7 +1953,7 @@ mod pretty_printing { if !type_checker.errors.is_empty() { s.push_str("\n__ERRORS__\n"); - for error in type_checker.errors { + for error in &type_checker.errors { s.push_str(&format!("{:?}\n", error)); } } @@ -1932,10 +2093,22 @@ mod tests { errs.into_iter().for_each(|err| eprintln!("{:?}", render_error(&source_map, err))); panic!("unresolved symbols in test"); } - let type_checker = TypeChecker::new(resolved); - let res = pretty_print_type_checker(type_checker); + let mut type_checker = TypeChecker::new(resolved); + type_checker.fully_type_check(); + let mut res = pretty_print_type_checker(&type_checker); + + let solved_constraints = match type_checker.into_solution() { + Ok(solution) => solution.pretty_print(), + Err(errs) => { + res.push_str(&"__ERRORS__\n"); + errs.into_iter().map(|err| format!("{:?}", err)).collect::>().join("\n") + }, + }; + + res.push('\n'); + res.push_str(&solved_constraints); - expect.assert_eq(&res); + expect.assert_eq(&res.trim()); } #[test] diff --git a/petr-vm/src/tests.rs b/petr-vm/src/tests.rs index f1b05a068e70eadab5ac1f34a1c77312d7b93e98..b91eebbe42e00d016cb5e9b235505fa6c4d859ab 100644 GIT binary patch delta 128 zcmeyO-(|4Dgjt|CKc}=LGe1wkRv|aBBspWUKeMR0LN%8H6!>Rr6l*Hj+9?$4z{OmP ziZn_p3sU1#i;5tU5cQMWm@_B;W0IQ8$|A=L(q9YFQ)@k0g+-kMX4U3!mbtn pW}1Q`Oy}eWY~qt=Fegs-W>T6g&mzQFJK2s!od?;p&ABXdg#jh=CrbbT From 6645091c52b39618c33c5dde5e627e655ab6b2f6 Mon Sep 17 00:00:00 2001 From: sezna Date: Fri, 2 Aug 2024 20:49:57 -0700 Subject: [PATCH 08/10] progress on axiomatic types --- petr-typecheck/src/lib.rs | 58 +++++++++++++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 12 deletions(-) diff --git a/petr-typecheck/src/lib.rs b/petr-typecheck/src/lib.rs index 7fa8283..4f720a3 100644 --- a/petr-typecheck/src/lib.rs +++ b/petr-typecheck/src/lib.rs @@ -5,7 +5,6 @@ //! - Satisfies //! - UnifyEffects //! - SatisfiesEffects -#![allow(warnings)] mod error; @@ -16,7 +15,7 @@ use std::{ use error::TypeConstraintError; pub use petr_bind::FunctionId; -use petr_resolve::{Expr, ExprKind, QueryableResolvedItems}; +use petr_resolve::{Expr, ExprKind, FunctionCall, QueryableResolvedItems}; pub use petr_resolve::{Intrinsic as ResolvedIntrinsic, IntrinsicName, Literal}; use petr_utils::{idx_map_key, Identifier, IndexMap, Span, SpannedItem, SymbolId, SymbolInterner, TypeId}; @@ -90,6 +89,16 @@ impl TypeConstraint { span, } } + + fn axiom( + t1: TypeVariable, + span: Span, + ) -> Self { + Self { + kind: TypeConstraintKind::Axiom(t1), + span, + } + } } #[derive(Clone, Copy, Debug)] @@ -161,6 +170,14 @@ impl TypeContext { self.constraints.push(TypeConstraint::satisfies(ty1, ty2, span)); } + fn axiom( + &mut self, + ty1: TypeVariable, + span: Span, + ) { + self.constraints.push(TypeConstraint::axiom(ty1, span)); + } + fn new_variable( &mut self, span: Span, @@ -477,6 +494,11 @@ impl TypeSolution { self.update_type(t1, t1_entry, span); self.update_type(t2, t2_entry, span); }, + (a, b) if self.a_superset_of_b(&a, &b) => { + // if `a` is a superset of `b`, then `b` unifies to `a` as it is more general + let entry = TypeSolutionEntry::new_inferred(Ref(t1)); + self.update_type(t2, entry, span); + }, /* TODO rewrite below rules // literals can unify broader parent types // but the broader parent type gets instantiated with the literal type @@ -690,11 +712,18 @@ impl TypeSolution { #[cfg(test)] fn pretty_print(&self) -> String { - let mut pretty = String::new(); + let mut pretty = "__SOLVED TYPES__\n".to_string(); + + let mut num_entries = 0; for (ty, entry) in self.solution.iter().filter(|(id, _)| ![self.unit, self.error_recovery].contains(id)) { - pretty.push_str(&format!("{}: {}\n", ty, self.pretty_print_type(&entry.ty))); + pretty.push_str(&format!("{}: {}\n", Into::::into(*ty), self.pretty_print_type(&entry.ty))); + num_entries += 1; + } + if num_entries == 0 { + Default::default() + } else { + pretty } - pretty } pub fn get_main_function(&self) -> Option<(&FunctionId, &Function)> { @@ -1179,6 +1208,14 @@ impl TypeChecker { self.ctx.satisfies(ty1, ty2, span); } + fn axiom( + &mut self, + ty: TypeVariable, + span: Span, + ) { + self.ctx.axiom(ty, span); + } + fn get_untyped_function( &self, function: FunctionId, @@ -1747,9 +1784,12 @@ impl TypeCheck for petr_resolve::Function { ) -> Self::Output { ctx.with_type_scope(|ctx| { let params = self.params.iter().map(|(name, ty)| (*name, ctx.to_type_var(ty))).collect::>(); + // declared parameters are axiomatic, they won't be updated by any inference for (name, ty) in ¶ms { ctx.insert_variable(*name, *ty); + // TODO get span for type annotation instead of just the name of the parameter + ctx.axiom(*ty, name.span); } // unify types within the body with the parameter @@ -1767,7 +1807,7 @@ impl TypeCheck for petr_resolve::Function { } } -impl TypeCheck for petr_resolve::FunctionCall { +impl TypeCheck for FunctionCall { type Output = TypedExprKind; fn type_check( @@ -1951,12 +1991,6 @@ mod pretty_printing { s.push('\n'); } - if !type_checker.errors.is_empty() { - s.push_str("\n__ERRORS__\n"); - for error in &type_checker.errors { - s.push_str(&format!("{:?}\n", error)); - } - } s } From 8077e83e0408a4ea5c731367efbc72e152e887ee Mon Sep 17 00:00:00 2001 From: sezna Date: Fri, 2 Aug 2024 21:13:09 -0700 Subject: [PATCH 09/10] All tests passing with new type solution system --- petr-typecheck/src/error.rs | 2 +- petr-typecheck/src/lib.rs | 159 +++++++++++++++++++++++++----------- 2 files changed, 111 insertions(+), 50 deletions(-) diff --git a/petr-typecheck/src/error.rs b/petr-typecheck/src/error.rs index f4f31d0..b39db21 100644 --- a/petr-typecheck/src/error.rs +++ b/petr-typecheck/src/error.rs @@ -7,7 +7,7 @@ pub enum TypeConstraintError { UnificationFailure(String, String), #[error("type `{0}` does not satisfy the constraints of type {1}")] FailedToSatisfy(String, String), - #[error("type `{1}` is not a subtype of sum type `{0:?}`")] + #[error("type `{1}` is not a subtype of sum type `{}`", .0.join(" | "))] NotSubtype(Vec, String), #[error("Function {function} takes {expected:?} arguments, but got {got:?} arguments.")] ArgumentCountMismatch { function: String, expected: usize, got: usize }, diff --git a/petr-typecheck/src/lib.rs b/petr-typecheck/src/lib.rs index 4f720a3..964d84f 100644 --- a/petr-typecheck/src/lib.rs +++ b/petr-typecheck/src/lib.rs @@ -457,8 +457,23 @@ impl TypeSolution { // represents `t1` // TODO remove clone if self.a_superset_of_b(&a, &b) { + let entry = TypeSolutionEntry::new_inferred(Ref(t1)); + self.update_type(t2, entry, span); } else { - self.push_error(span.with_item(self.unify_err(a, b))); + // the union of the two sets is the new type + let a_tys = match a { + Sum(tys) => tys, + _ => unreachable!(), + }; + let b_tys = match b { + Sum(tys) => tys, + _ => unreachable!(), + }; + let union = a_tys.iter().chain(b_tys.iter()).cloned().collect(); + let entry = TypeSolutionEntry::new_inferred(Sum(union)); + self.update_type(t1, entry, span); + let entry = TypeSolutionEntry::new_inferred(Ref(t1)); + self.update_type(t2, entry, span); } }, // If `t2` is a non-sum type, and `t1` is a sum type, then `t1` must contain either @@ -470,7 +485,14 @@ impl TypeSolution { let entry = TypeSolutionEntry::new_inferred(Ref(t1)); self.update_type(t2, entry, span); } else { - self.push_error(span.with_item(self.unify_err(t1_ty.clone(), other))); + // add `other` to `t1` + let mut tys = match t1_ty { + Sum(tys) => tys.clone(), + _ => unreachable!(), + }; + tys.insert(other); + let entry = TypeSolutionEntry::new_inferred(Sum(tys)); + self.update_type(t1, entry, span); } }, // literals can unify to each other if they're equal @@ -1608,7 +1630,7 @@ impl TypeCheck for Expr { } => { let condition = condition.type_check(ctx); let condition_ty = ctx.expr_ty(&condition); - ctx.unify(condition_ty, ctx.bool(), condition.span()); + ctx.unify(ctx.bool(), condition_ty, condition.span()); let then_branch = then_branch.type_check(ctx); let then_ty = ctx.expr_ty(&then_branch); @@ -2155,7 +2177,9 @@ mod tests { fn foo: (int → int) variable x: int - "#]], + + __SOLVED TYPES__ + 5: int"#]], ); } @@ -2169,7 +2193,9 @@ mod tests { fn foo: (infer t5 → infer t5) variable x: infer t5 - "#]], + + __SOLVED TYPES__ + 6: infer t5"#]], ); } @@ -2192,7 +2218,9 @@ mod tests { fn foo: (MyType → MyType) variable x: MyType - "#]], + + __SOLVED TYPES__ + 8: MyType"#]], ); } @@ -2226,7 +2254,11 @@ mod tests { __MONOMORPHIZED FUNCTIONS__ fn firstVariant(["MyType"]) -> MyComposedType - "#]], + + __SOLVED TYPES__ + 14: int + 17: infer t16 + 23: MyType"#]], ); } @@ -2242,9 +2274,7 @@ mod tests { literal: 5 fn bar: bool - literal: 5 - - "#]], + literal: 5"#]], ); } @@ -2260,9 +2290,7 @@ mod tests { literal: 5 fn bar: bool - literal: true - - "#]], + literal: true"#]], ); } @@ -2283,8 +2311,7 @@ mod tests { intrinsic: @puts(function call to functionid0 with args: ) __MONOMORPHIZED FUNCTIONS__ - fn string_literal([]) -> string - "#]], + fn string_literal([]) -> string"#]], ); } @@ -2298,7 +2325,9 @@ mod tests { fn my_func: unit intrinsic: @puts(literal: "test") - "#]], + + __SOLVED TYPES__ + 5: string"#]], ); } @@ -2312,10 +2341,9 @@ mod tests { fn my_func: unit intrinsic: @puts(literal: true) - __ERRORS__ - SpannedItem UnificationFailure("string", "true") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(52), length: 4 } }] - "#]], + + SpannedItem UnificationFailure("string", "true") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(52), length: 4 } }]"#]], ); } @@ -2329,7 +2357,9 @@ mod tests { fn my_func: bool intrinsic: @puts(literal: "test") - "#]], + + __SOLVED TYPES__ + 5: string"#]], ); } @@ -2351,10 +2381,9 @@ mod tests { __MONOMORPHIZED FUNCTIONS__ fn bool_literal([]) -> bool - __ERRORS__ - SpannedItem UnificationFailure("string", "bool") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(110), length: 14 } }] - "#]], + + SpannedItem UnificationFailure("string", "bool") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(110), length: 14 } }]"#]], ); } @@ -2385,7 +2414,10 @@ mod tests { __MONOMORPHIZED FUNCTIONS__ fn bool_literal(["int", "int"]) -> bool fn bool_literal(["bool", "bool"]) -> bool - "#]], + + __SOLVED TYPES__ + 6: infer t5 + 8: infer t7"#]], ); } #[test] @@ -2398,7 +2430,10 @@ mod tests { fn my_list: infer t8 list: [literal: 1, literal: true, ] - "#]], + + __SOLVED TYPES__ + 5: (1 | true) + 6: 1"#]], ); } @@ -2417,10 +2452,9 @@ mod tests { fn add_five: (int → int) error recovery Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(113), length: 8 } } - __ERRORS__ - SpannedItem ArgumentCountMismatch { function: "add", expected: 2, got: 1 } [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(113), length: 8 } }] - "#]], + + SpannedItem ArgumentCountMismatch { function: "add", expected: 2, got: 1 } [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(113), length: 8 } }]"#]], ); } @@ -2451,7 +2485,10 @@ fn main() returns 'int ~hi(1, 2)"#, __MONOMORPHIZED FUNCTIONS__ fn hi(["int", "int"]) -> int fn main([]) -> int - "#]], + + __SOLVED TYPES__ + 5: int + 6: int"#]], ) } @@ -2472,10 +2509,9 @@ fn main() returns 'int ~hi(1, 2)"#, __MONOMORPHIZED FUNCTIONS__ fn hi(["int"]) -> int fn main([]) -> int - __ERRORS__ - SpannedItem UnificationFailure("int", "bool") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(61), length: 2 } }] - "#]], + + SpannedItem UnificationFailure("bool", "int") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(61), length: 2 } }]"#]], ) } @@ -2496,10 +2532,9 @@ fn main() returns 'int ~hi(1, 2)"#, __MONOMORPHIZED FUNCTIONS__ fn hi([]) -> int fn main([]) -> int - __ERRORS__ - SpannedItem UnificationFailure("unit", "1") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(33), length: 46 } }] - "#]], + + SpannedItem UnificationFailure("1", "unit") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(33), length: 46 } }]"#]], ) } @@ -2521,7 +2556,10 @@ fn main() returns 'int ~hi(1, 2)"#, __MONOMORPHIZED FUNCTIONS__ fn hi([]) -> unit fn main([]) -> unit - "#]], + + __SOLVED TYPES__ + 5: bool + 6: string"#]], ) } @@ -2546,10 +2584,9 @@ fn main() returns 'int ~hi(1, 2)"#, __MONOMORPHIZED FUNCTIONS__ fn OneOrTwo(["int"]) -> OneOrTwo fn main([]) -> OneOrTwo - __ERRORS__ - SpannedItem NotSubtype(["1", "2"], "10") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(104), length: 0 } }] - "#]], + + SpannedItem NotSubtype(["1", "2"], "10") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(104), length: 0 } }]"#]], ) } @@ -2574,10 +2611,9 @@ fn main() returns 'int ~hi(1, 2)"#, __MONOMORPHIZED FUNCTIONS__ fn AOrB(["string"]) -> AOrB fn main([]) -> AOrB - __ERRORS__ - SpannedItem NotSubtype(["\"A\"", "\"B\""], "\"c\"") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(97), length: 0 } }] - "#]], + + SpannedItem NotSubtype(["\"A\"", "\"B\""], "\"c\"") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(97), length: 0 } }]"#]], ) } @@ -2673,7 +2709,11 @@ fn main() returns 'int ~hi(1, 2)"#, fn test(["int"]) -> (1 | 2 | 3) fn test_(["int"]) -> (1 | 2) fn main([]) -> int - "#]], + + __SOLVED TYPES__ + 5: (1 | 2 | 3) + 9: (1 | 2) + 11: (1 | 2)"#]], ) } @@ -2703,7 +2743,10 @@ fn main() returns 'int ~hi(1, 2)"#, fn test(["string"]) -> (int | string) fn test_(["int"]) -> (int | string) fn main([]) -> int - "#]], + + __SOLVED TYPES__ + 5: (int | string) + 9: int"#]], ) } @@ -2732,10 +2775,9 @@ fn main() returns 'int ~hi(1, 2)"#, fn test(["(int | string)"]) -> (int | string) fn test_(["bool"]) -> (int | string) fn main([]) -> int - __ERRORS__ - SpannedItem NotSubtype(["int", "string"], "bool") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(129), length: 0 } }] - "#]], + + SpannedItem NotSubtype(["int", "string"], "bool") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(129), length: 0 } }]"#]], ) } @@ -2748,7 +2790,9 @@ fn main() returns 'int ~hi(1, 2)"#, fn test: ((int | string) → (int | string)) variable a: (int | string) - "#]], + + __SOLVED TYPES__ + 5: (int | string)"#]], ) } @@ -2761,7 +2805,24 @@ fn main() returns 'int ~hi(1, 2)"#, fn test: ((int | string) → (int | bool | string)) variable a: (int | string) - "#]], + + __SOLVED TYPES__ + 5: (int | string)"#]], ) } + + #[test] + fn if_exp_basic() { + check("fn main() returns 'int if true then 1 else 0", expect![[r#" + fn main: int + if literal: true then literal: 1 else literal: 0 + + __MONOMORPHIZED FUNCTIONS__ + fn main([]) -> int + + __SOLVED TYPES__ + 5: bool + 6: (0 | 1) + 7: 1"#]]); + } } From b942dbf68df10457dadab7f250266613380c89ac Mon Sep 17 00:00:00 2001 From: sezna Date: Sat, 3 Aug 2024 06:19:34 -0700 Subject: [PATCH 10/10] refactor typechecking into modules --- petr-ir/src/lib.rs | 5 +- petr-typecheck/src/constraint_generation.rs | 948 +++++++ petr-typecheck/src/lib.rs | 2791 +------------------ petr-typecheck/src/pretty_printing.rs | 173 ++ petr-typecheck/src/solution.rs | 537 ++++ petr-typecheck/src/tests.rs | 710 +++++ petr-typecheck/src/typed_ast.rs | 248 ++ petr-typecheck/src/types.rs | 238 ++ 8 files changed, 2870 insertions(+), 2780 deletions(-) create mode 100644 petr-typecheck/src/constraint_generation.rs create mode 100644 petr-typecheck/src/pretty_printing.rs create mode 100644 petr-typecheck/src/solution.rs create mode 100644 petr-typecheck/src/tests.rs create mode 100644 petr-typecheck/src/typed_ast.rs create mode 100644 petr-typecheck/src/types.rs diff --git a/petr-ir/src/lib.rs b/petr-ir/src/lib.rs index c643b38..d296587 100644 --- a/petr-ir/src/lib.rs +++ b/petr-ir/src/lib.rs @@ -56,10 +56,7 @@ impl Lowerer { pub fn new(type_solution: petr_typecheck::TypeSolution) -> Result { // if there is an entry point, set that // set entry point to func named main - let entry_point = match type_solution.get_main_function() { - Some((a, b)) => Some((a.clone(), b.clone())), - None => None, - }; + let entry_point = type_solution.get_main_function().map(|(a, b)| (*a, b.clone())); let mut lowerer = Self { data_section: IndexMap::default(), diff --git a/petr-typecheck/src/constraint_generation.rs b/petr-typecheck/src/constraint_generation.rs new file mode 100644 index 0000000..0c6920b --- /dev/null +++ b/petr-typecheck/src/constraint_generation.rs @@ -0,0 +1,948 @@ +use std::{ + collections::{BTreeMap, BTreeSet}, + rc::Rc, +}; + +use petr_bind::FunctionId; +use petr_resolve::{Expr, FunctionCall, QueryableResolvedItems}; +use petr_utils::{Identifier, IndexMap, Span, SpannedItem, SymbolId}; + +use crate::{ + error::TypeConstraintError, + solution::{TypeSolution, TypeSolutionEntry}, + typed_ast::{TypedExpr, TypedExprKind}, + types::{GeneralType, SpecificType, Type, TypeVariant}, + Function, TypeError, TypeOrFunctionId, TypeVariable, +}; + +#[derive(Clone, Copy, Debug)] +pub struct TypeConstraint { + kind: TypeConstraintKind, + /// The span from which this type constraint originated + span: Span, +} +impl TypeConstraint { + fn unify( + t1: TypeVariable, + t2: TypeVariable, + span: Span, + ) -> Self { + Self { + kind: TypeConstraintKind::Unify(t1, t2), + span, + } + } + + fn satisfies( + t1: TypeVariable, + t2: TypeVariable, + span: Span, + ) -> Self { + Self { + kind: TypeConstraintKind::Satisfies(t1, t2), + span, + } + } + + fn axiom( + t1: TypeVariable, + span: Span, + ) -> Self { + Self { + kind: TypeConstraintKind::Axiom(t1), + span, + } + } +} + +#[derive(Clone, Copy, Debug)] +pub enum TypeConstraintKind { + Unify(TypeVariable, TypeVariable), + // constraint that lhs is a "subtype" or satisfies the typeclass constraints of "rhs" + Satisfies(TypeVariable, TypeVariable), + // If a type variable is constrained to be an axiom, it means that the type variable + // cannot be updated by the inference engine. It effectively fixes the type, or pins the type. + Axiom(TypeVariable), +} + +#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq)] +enum TypeConstraintKindValue { + Unify, + Satisfies, + Axiom, +} + +pub struct TypeContext { + types: IndexMap, + constraints: Vec, + // known primitive type IDs + unit_ty: TypeVariable, + string_ty: TypeVariable, + int_ty: TypeVariable, + bool_ty: TypeVariable, + error_recovery: TypeVariable, +} + +impl Default for TypeContext { + fn default() -> Self { + let mut types = IndexMap::default(); + // instantiate basic primitive types + let unit_ty = types.insert(SpecificType::Unit); + let string_ty = types.insert(SpecificType::String); + let bool_ty = types.insert(SpecificType::Boolean); + let int_ty = types.insert(SpecificType::Integer); + let error_recovery = types.insert(SpecificType::ErrorRecovery); + // insert primitive types + TypeContext { + types, + constraints: Default::default(), + bool_ty, + unit_ty, + string_ty, + int_ty, + error_recovery, + } + } +} + +impl TypeContext { + fn unify( + &mut self, + ty1: TypeVariable, + ty2: TypeVariable, + span: Span, + ) { + self.constraints.push(TypeConstraint::unify(ty1, ty2, span)); + } + + fn satisfies( + &mut self, + ty1: TypeVariable, + ty2: TypeVariable, + span: Span, + ) { + self.constraints.push(TypeConstraint::satisfies(ty1, ty2, span)); + } + + fn axiom( + &mut self, + ty1: TypeVariable, + span: Span, + ) { + self.constraints.push(TypeConstraint::axiom(ty1, span)); + } + + fn new_variable( + &mut self, + span: Span, + ) -> TypeVariable { + // infer is special -- it knows its own id, mostly for printing + // and disambiguating + let infer_id = self.types.len(); + self.types.insert(SpecificType::Infer(infer_id, span)) + } + + /// Update a type variable with a new SpecificType + fn update_type( + &mut self, + t1: TypeVariable, + known: SpecificType, + ) { + *self.types.get_mut(t1) = known; + } + + pub fn types(&self) -> &IndexMap { + &self.types + } +} + +pub type FunctionSignature = (FunctionId, Box<[GeneralType]>); + +pub struct TypeChecker { + ctx: TypeContext, + type_map: BTreeMap, + monomorphized_functions: BTreeMap, + typed_functions: BTreeMap, + errors: Vec, + resolved: QueryableResolvedItems, + variable_scope: Vec>, +} + +pub trait TypeCheck { + type Output; + fn type_check( + &self, + ctx: &mut TypeChecker, + ) -> Self::Output; +} + +impl TypeChecker { + pub fn insert_type( + &mut self, + ty: &T, + ) -> TypeVariable { + let ty = ty.as_specific_ty(); + // TODO: check if type already exists and return that ID instead + self.ctx.types.insert(ty) + } + + pub fn look_up_variable( + &self, + ty: TypeVariable, + ) -> &SpecificType { + self.ctx.types.get(ty) + } + + pub fn get_symbol( + &self, + id: SymbolId, + ) -> Rc { + self.resolved.interner.get(id).clone() + } + + pub(crate) fn with_type_scope( + &mut self, + f: impl FnOnce(&mut Self) -> T, + ) -> T { + self.variable_scope.push(Default::default()); + let res = f(self); + self.variable_scope.pop(); + res + } + + fn generic_type( + &mut self, + id: &Identifier, + ) -> TypeVariable { + for scope in self.variable_scope.iter().rev() { + if let Some(ty) = scope.get(id) { + return *ty; + } + } + let fresh_ty = self.fresh_ty_var(id.span); + match self.variable_scope.last_mut() { + Some(entry) => { + entry.insert(*id, fresh_ty); + }, + None => { + self.errors.push(id.span.with_item(TypeConstraintError::Internal( + "attempted to insert generic type into variable scope when no variable scope existed".into(), + ))); + self.ctx.update_type(fresh_ty, SpecificType::ErrorRecovery); + }, + }; + fresh_ty + } + + pub(crate) fn find_variable( + &self, + id: Identifier, + ) -> Option { + for scope in self.variable_scope.iter().rev() { + if let Some(ty) = scope.get(&id) { + return Some(*ty); + } + } + None + } + + pub fn fully_type_check(&mut self) { + for (id, decl) in self.resolved.types() { + let ty = self.fresh_ty_var(decl.name.span); + let variants = decl + .variants + .iter() + .map(|variant| { + self.with_type_scope(|ctx| { + let fields = variant.fields.iter().map(|field| ctx.to_petr_type(&field.ty)).collect::>(); + TypeVariant { + fields: fields.into_boxed_slice(), + } + }) + }) + .collect::>(); + self.ctx.update_type( + ty, + SpecificType::UserDefined { + name: decl.name, + variants, + constant_literal_types: decl.constant_literal_types, + }, + ); + self.type_map.insert(id.into(), ty); + } + + for (id, func) in self.resolved.functions() { + let typed_function = func.type_check(self); + + let ty = self.arrow_type([typed_function.params.iter().map(|(_, b)| *b).collect(), vec![typed_function.return_ty]].concat()); + self.type_map.insert(id.into(), ty); + self.typed_functions.insert(id, typed_function); + } + // type check the main func with no params + let main_func = self.get_main_function(); + // construct a function call for the main function, if one exists + if let Some((id, func)) = main_func { + let call = petr_resolve::FunctionCall { + function: id, + args: vec![], + span: func.name.span, + }; + call.type_check(self); + } + + // before applying existing constraints, it is likely that many duplicate constraints + // exist. We can safely remove any duplicate constraints to avoid excessive error + // reporting. + self.deduplicate_constraints(); + } + + pub fn get_main_function(&self) -> Option<(FunctionId, Function)> { + self.functions().find(|(_, func)| &*self.get_symbol(func.name.id) == "main") + } + + /// iterate through each constraint and transform the underlying types to satisfy them + /// - unification tries to collapse two types into one + /// - satisfaction tries to make one type satisfy the constraints of another, although type + /// constraints don't exist in the language yet + pub fn into_solution(self) -> Result> { + let constraints = self.ctx.constraints.clone(); + let mut solution = TypeSolution::new( + self.ctx.types.clone(), + self.ctx.error_recovery, + self.ctx.unit_ty, + self.typed_functions, + self.monomorphized_functions, + self.resolved.interner, + self.errors, + ); + for TypeConstraint { kind, span } in constraints.iter().filter(|c| matches!(c.kind, TypeConstraintKind::Axiom(_))) { + let TypeConstraintKind::Axiom(axiomatic_variable) = kind else { + unreachable!("above filter ensures that all constraints are axioms here") + }; + // first, pin all axiomatic type variables in the solution + let ty = self.ctx.types.get(*axiomatic_variable).clone(); + solution.insert_solution(*axiomatic_variable, TypeSolutionEntry::new_axiomatic(ty), *span); + } + + // now apply the constraints + for constraint in constraints.iter().filter(|c| !matches!(c.kind, TypeConstraintKind::Axiom(_))) { + match &constraint.kind { + TypeConstraintKind::Unify(t1, t2) => { + solution.apply_unify_constraint(*t1, *t2, constraint.span); + }, + TypeConstraintKind::Satisfies(t1, t2) => { + solution.apply_satisfies_constraint(*t1, *t2, constraint.span); + }, + TypeConstraintKind::Axiom(_) => unreachable!(), + } + } + + solution.into_result() + } + + pub fn new(resolved: QueryableResolvedItems) -> Self { + let ctx = TypeContext::default(); + TypeChecker { + ctx, + type_map: Default::default(), + errors: Default::default(), + typed_functions: Default::default(), + resolved, + variable_scope: Default::default(), + monomorphized_functions: Default::default(), + } + } + + pub fn insert_variable( + &mut self, + id: Identifier, + ty: TypeVariable, + ) { + self.variable_scope + .last_mut() + .expect("inserted variable when no scope existed") + .insert(id, ty); + } + + pub fn fresh_ty_var( + &mut self, + span: Span, + ) -> TypeVariable { + self.ctx.new_variable(span) + } + + fn arrow_type( + &mut self, + tys: Vec, + ) -> TypeVariable { + assert!(!tys.is_empty(), "arrow_type: tys is empty"); + + if tys.len() == 1 { + return tys[0]; + } + + let ty = SpecificType::Arrow(tys); + self.ctx.types.insert(ty) + } + + pub fn to_petr_type( + &mut self, + ty: &petr_resolve::Type, + ) -> SpecificType { + match ty { + petr_resolve::Type::Integer => SpecificType::Integer, + petr_resolve::Type::Bool => SpecificType::Boolean, + petr_resolve::Type::Unit => SpecificType::Unit, + petr_resolve::Type::String => SpecificType::String, + petr_resolve::Type::ErrorRecovery(_) => { + // unifies to anything, fresh var + SpecificType::ErrorRecovery + }, + petr_resolve::Type::Named(ty_id) => SpecificType::Ref(*self.type_map.get(&ty_id.into()).expect("type did not exist in type map")), + petr_resolve::Type::Generic(generic_name) => { + // TODO don't create an ID and then reference it -- this is messy + let id = self.generic_type(generic_name); + SpecificType::Ref(id) + }, + petr_resolve::Type::Sum(tys) => SpecificType::Sum(tys.iter().map(|ty| self.to_petr_type(ty)).collect()), + petr_resolve::Type::Literal(l) => SpecificType::Literal(l.clone()), + } + } + + pub fn to_type_var( + &mut self, + ty: &petr_resolve::Type, + ) -> TypeVariable { + let petr_ty = self.to_petr_type(ty); + self.ctx.types.insert(petr_ty) + } + + pub fn get_type( + &self, + key: impl Into, + ) -> &TypeVariable { + self.type_map.get(&key.into()).expect("type did not exist in type map") + } + + pub(crate) fn convert_literal_to_type( + &mut self, + literal: &petr_resolve::Literal, + ) -> TypeVariable { + let ty = SpecificType::Literal(literal.clone()); + self.ctx.types.insert(ty) + } + + pub fn unify( + &mut self, + ty1: TypeVariable, + ty2: TypeVariable, + span: Span, + ) { + self.ctx.unify(ty1, ty2, span); + } + + pub fn satisfies( + &mut self, + ty1: TypeVariable, + ty2: TypeVariable, + span: Span, + ) { + self.ctx.satisfies(ty1, ty2, span); + } + + pub fn axiom( + &mut self, + ty: TypeVariable, + span: Span, + ) { + self.ctx.axiom(ty, span); + } + + fn get_untyped_function( + &self, + function: FunctionId, + ) -> &petr_resolve::Function { + self.resolved.get_function(function) + } + + pub fn get_function( + &mut self, + id: &FunctionId, + ) -> Function { + if let Some(func) = self.typed_functions.get(id) { + return func.clone(); + } + + // if the function hasn't been type checked yet, type check it + let func = self.get_untyped_function(*id).clone(); + let type_checked = func.type_check(self); + self.typed_functions.insert(*id, type_checked.clone()); + type_checked + } + + pub fn get_monomorphized_function( + &self, + id: &FunctionSignature, + ) -> &Function { + self.monomorphized_functions.get(id).expect("invariant: should exist") + } + + // TODO unideal clone + pub fn functions(&self) -> impl Iterator { + self.typed_functions.iter().map(|(a, b)| (*a, b.clone())).collect::>().into_iter() + } + + pub fn expr_ty( + &self, + expr: &TypedExpr, + ) -> TypeVariable { + use TypedExprKind::*; + match &expr.kind { + FunctionCall { ty, .. } => *ty, + Literal { ty, .. } => *ty, + List { ty, .. } => *ty, + Unit => self.unit(), + Variable { ty, .. } => *ty, + Intrinsic { ty, .. } => *ty, + ErrorRecovery(..) => self.ctx.error_recovery, + ExprWithBindings { expression, .. } => self.expr_ty(expression), + TypeConstructor { ty, .. } => *ty, + If { then_branch, .. } => self.expr_ty(then_branch), + } + } + + /// Given a concrete [`SpecificType`], unify it with the return type of the given expression. + pub fn unify_expr_return( + &mut self, + ty: TypeVariable, + expr: &TypedExpr, + ) { + let expr_ty = self.expr_ty(expr); + self.unify(ty, expr_ty, expr.span()); + } + + pub fn string(&self) -> TypeVariable { + self.ctx.string_ty + } + + pub fn unit(&self) -> TypeVariable { + self.ctx.unit_ty + } + + pub fn int(&self) -> TypeVariable { + self.ctx.int_ty + } + + pub fn bool(&self) -> TypeVariable { + self.ctx.bool_ty + } + + pub fn errors(&self) -> &[TypeError] { + &self.errors + } + + pub fn satisfy_expr_return( + &mut self, + ty: TypeVariable, + expr: &TypedExpr, + ) { + let expr_ty = self.expr_ty(expr); + self.satisfies(ty, expr_ty, expr.span()); + } + + pub fn ctx(&self) -> &TypeContext { + &self.ctx + } + + /// terms: + /// ### resolved type variable + /// + /// a type variable that is not a `Ref`. To get the resolved type of + /// a type variable, you must follow the chain of `Ref`s until you reach a non-Ref type. + /// + /// ### constraint kind strength: + /// The following is the hierarchy of constraints in terms of strength, from strongest (1) to + /// weakest: + /// 1. Unify(t1, t2) (t2 _must_ be coerceable to exactly equal t1) + /// 2. Satisfies (t2 must be a subset of t1. For all cases where t2 can unify to t1, t2 + /// satisfies t1 as a constraint) + /// + /// ### constraint strength + /// A constraint `a` is _stronger than_ a constraint `b` iff: + /// - `a` is higher than `b` in terms of constraint kind strength `a` is a more specific constraint than `b` + /// - e.g. Unify(Literal(5), x) is stronger than Unify(Int, x) because the former is more specific + /// - e.g. Unify(a, b) is stronger than Satisfies(a, b) + /// + /// + /// ### duplicated constraint: + /// A constraint `a` is _duplicated by_ constraint `b` iff: + /// - `a` and `b` are the same constraint kind, and the resolved type variables are the same + /// - `a` is a stronger constraint than `b` + /// + fn deduplicate_constraints(&mut self) { + use TypeConstraintKindValue as Kind; + let mut constraints = ConstraintDeduplicator::default(); + let mut errs = vec![]; + for constraint in &self.ctx.constraints { + let (mut tys, kind) = match &constraint.kind { + TypeConstraintKind::Unify(t1, t2) => (vec![*t1, *t2], Kind::Unify), + TypeConstraintKind::Satisfies(t1, t2) => (vec![*t1, *t2], Kind::Satisfies), + TypeConstraintKind::Axiom(t1) => (vec![*t1], Kind::Axiom), + }; + + // resolve all `Ref` types to get a resolved type variable + 'outer: for ty_var in tys.iter_mut() { + // track what we have seen, in case a circular reference is present + let mut seen_vars = BTreeSet::new(); + seen_vars.insert(*ty_var); + let mut ty = self.ctx.types.get(*ty_var); + while let SpecificType::Ref(t) = ty { + if seen_vars.contains(t) { + // circular reference + errs.push(constraint.span.with_item(TypeConstraintError::CircularType)); + continue 'outer; + } + *ty_var = *t; + ty = self.ctx.types.get(*t); + } + } + + constraints.insert((kind, tys), *constraint); + } + + self.ctx.constraints = constraints.into_values(); + } + + pub fn push_error( + &mut self, + e: TypeError, + ) { + self.errors.push(e); + } + + pub fn monomorphized_functions(&self) -> &BTreeMap { + &self.monomorphized_functions + } + + pub(crate) fn insert_monomorphized_function( + &mut self, + signature: (FunctionId, Box<[GeneralType]>), + monomorphized_func_decl: Function, + ) { + self.monomorphized_functions.insert(signature, monomorphized_func_decl); + } + + pub fn type_map(&self) -> &BTreeMap { + &self.type_map + } + + pub fn resolved(&self) -> &QueryableResolvedItems { + &self.resolved + } + + pub fn typed_functions(&self) -> &BTreeMap { + &self.typed_functions + } +} + +/// the `key` type is what we use to deduplicate constraints +#[derive(Default)] +struct ConstraintDeduplicator { + constraints: BTreeMap<(TypeConstraintKindValue, Vec), TypeConstraint>, +} + +impl ConstraintDeduplicator { + fn insert( + &mut self, + key: (TypeConstraintKindValue, Vec), + constraint: TypeConstraint, + ) { + self.constraints.insert(key, constraint); + } + + fn into_values(self) -> Vec { + self.constraints.into_values().collect() + } +} + +pub fn unify_basic_math_op( + lhs: &Expr, + rhs: &Expr, + ctx: &mut TypeChecker, +) -> (TypedExpr, TypedExpr) { + let lhs = lhs.type_check(ctx); + let rhs = rhs.type_check(ctx); + let lhs_ty = ctx.expr_ty(&lhs); + let rhs_ty = ctx.expr_ty(&rhs); + let int_ty = ctx.int(); + ctx.unify(int_ty, lhs_ty, lhs.span()); + ctx.unify(int_ty, rhs_ty, rhs.span()); + (lhs, rhs) +} + +impl TypeCheck for petr_resolve::Function { + type Output = Function; + + fn type_check( + &self, + ctx: &mut TypeChecker, + ) -> Self::Output { + ctx.with_type_scope(|ctx| { + let params = self.params.iter().map(|(name, ty)| (*name, ctx.to_type_var(ty))).collect::>(); + // declared parameters are axiomatic, they won't be updated by any inference + + for (name, ty) in ¶ms { + ctx.insert_variable(*name, *ty); + // TODO get span for type annotation instead of just the name of the parameter + ctx.axiom(*ty, name.span); + } + + // unify types within the body with the parameter + let body = self.body.type_check(ctx); + + let declared_return_type = ctx.to_type_var(&self.return_type); + + Function { + name: self.name, + params, + return_ty: declared_return_type, + body, + } + }) + } +} + +impl TypeCheck for FunctionCall { + type Output = TypedExprKind; + + fn type_check( + &self, + ctx: &mut TypeChecker, + ) -> Self::Output { + let func_decl = ctx.get_function(&self.function).clone(); + + if self.args.len() != func_decl.params.len() { + // TODO: support partial application + ctx.push_error(self.span().with_item(TypeConstraintError::ArgumentCountMismatch { + expected: func_decl.params.len(), + got: self.args.len(), + function: ctx.get_symbol(func_decl.name.id).to_string(), + })); + return TypedExprKind::ErrorRecovery(self.span()); + } + + let mut args: Vec<(Identifier, TypedExpr, TypeVariable)> = Vec::with_capacity(self.args.len()); + + // unify all of the arg types with the param types + for (arg, (name, param_ty)) in self.args.iter().zip(func_decl.params.iter()) { + let arg = arg.type_check(ctx); + let arg_ty = ctx.expr_ty(&arg); + ctx.satisfies(*param_ty, arg_ty, arg.span()); + args.push((*name, arg, arg_ty)); + } + + let concrete_arg_types: Vec<_> = args + .iter() + .map(|(_, _, ty)| ctx.look_up_variable(*ty).generalize(ctx.ctx().types()).clone()) + .collect(); + + let signature: FunctionSignature = (self.function, concrete_arg_types.clone().into_boxed_slice()); + // now that we know the argument types, check if this signature has been monomorphized + // already + if ctx.monomorphized_functions().contains_key(&signature) { + return TypedExprKind::FunctionCall { + func: self.function, + args: args.into_iter().map(|(name, expr, _)| (name, expr)).collect(), + ty: func_decl.return_ty, + }; + } + + // unify declared return type with body return type + let declared_return_type = func_decl.return_ty; + + ctx.satisfy_expr_return(declared_return_type, &func_decl.body); + + // to create a monomorphized func decl, we don't actually have to update all of the types + // throughout the entire definition. We only need to update the parameter types. + let mut monomorphized_func_decl = Function { + name: func_decl.name, + params: func_decl.params.clone(), + return_ty: declared_return_type, + body: func_decl.body.clone(), + }; + + // update the parameter types to be the concrete types + for (param, concrete_ty) in monomorphized_func_decl.params.iter_mut().zip(concrete_arg_types.iter()) { + let param_ty = ctx.insert_type(concrete_ty); + param.1 = param_ty; + } + + // if there are any variable exprs in the body, update those ref types + let mut num_replacements = 0; + replace_var_reference_types( + &mut monomorphized_func_decl.body.kind, + &monomorphized_func_decl.params, + &mut num_replacements, + ); + + ctx.insert_monomorphized_function(signature, monomorphized_func_decl); + // if there are any variable exprs in the body, update those ref types + + TypedExprKind::FunctionCall { + func: self.function, + args: args.into_iter().map(|(name, expr, _)| (name, expr)).collect(), + ty: declared_return_type, + } + } +} + +impl TypeCheck for SpannedItem { + type Output = TypedExpr; + + fn type_check( + &self, + ctx: &mut TypeChecker, + ) -> Self::Output { + use petr_resolve::IntrinsicName::*; + let kind = match self.item().intrinsic { + Puts => { + if self.item().args.len() != 1 { + todo!("puts arg len check"); + } + // puts takes a single string and returns unit + let arg = self.item().args[0].type_check(ctx); + ctx.unify_expr_return(ctx.string(), &arg); + TypedExprKind::Intrinsic { + intrinsic: crate::Intrinsic::Puts(Box::new(arg)), + ty: ctx.unit(), + } + }, + Add => { + if self.item().args.len() != 2 { + todo!("add arg len check"); + } + let (lhs, rhs) = unify_basic_math_op(&self.item().args[0], &self.item().args[1], ctx); + TypedExprKind::Intrinsic { + intrinsic: crate::Intrinsic::Add(Box::new(lhs), Box::new(rhs)), + ty: ctx.int(), + } + }, + Subtract => { + if self.item().args.len() != 2 { + todo!("sub arg len check"); + } + let (lhs, rhs) = unify_basic_math_op(&self.item().args[0], &self.item().args[1], ctx); + + TypedExprKind::Intrinsic { + intrinsic: crate::Intrinsic::Subtract(Box::new(lhs), Box::new(rhs)), + ty: ctx.int(), + } + }, + Multiply => { + if self.item().args.len() != 2 { + todo!("mult arg len check"); + } + + let (lhs, rhs) = unify_basic_math_op(&self.item().args[0], &self.item().args[1], ctx); + TypedExprKind::Intrinsic { + intrinsic: crate::Intrinsic::Multiply(Box::new(lhs), Box::new(rhs)), + ty: ctx.int(), + } + }, + + Divide => { + if self.item().args.len() != 2 { + todo!("Divide arg len check"); + } + + let (lhs, rhs) = unify_basic_math_op(&self.item().args[0], &self.item().args[1], ctx); + TypedExprKind::Intrinsic { + intrinsic: crate::Intrinsic::Divide(Box::new(lhs), Box::new(rhs)), + ty: ctx.int(), + } + }, + Malloc => { + // malloc takes one integer (the number of bytes to allocate) + // and returns a pointer to the allocated memory + // will return `0` if the allocation fails + // in the future, this might change to _words_ of allocation, + // depending on the compilation target + if self.item().args.len() != 1 { + todo!("malloc arg len check"); + } + let arg = self.item().args[0].type_check(ctx); + let arg_ty = ctx.expr_ty(&arg); + let int_ty = ctx.int(); + ctx.unify(int_ty, arg_ty, arg.span()); + TypedExprKind::Intrinsic { + intrinsic: crate::Intrinsic::Malloc(Box::new(arg)), + ty: int_ty, + } + }, + SizeOf => { + if self.item().args.len() != 1 { + todo!("size_of arg len check"); + } + + let arg = self.item().args[0].type_check(ctx); + + TypedExprKind::Intrinsic { + intrinsic: crate::Intrinsic::SizeOf(Box::new(arg)), + ty: ctx.int(), + } + }, + Equals => { + if self.item().args.len() != 2 { + todo!("equal arg len check"); + } + + let lhs = self.item().args[0].type_check(ctx); + let rhs = self.item().args[1].type_check(ctx); + ctx.unify(ctx.expr_ty(&lhs), ctx.expr_ty(&rhs), self.span()); + TypedExprKind::Intrinsic { + intrinsic: crate::Intrinsic::Equals(Box::new(lhs), Box::new(rhs)), + ty: ctx.bool(), + } + }, + }; + + TypedExpr { kind, span: self.span() } + } +} + +fn replace_var_reference_types( + expr: &mut TypedExprKind, + params: &Vec<(Identifier, TypeVariable)>, + num_replacements: &mut usize, +) { + match expr { + TypedExprKind::Variable { ref mut ty, name } => { + if let Some((_param_name, ty_var)) = params.iter().find(|(param_name, _)| param_name.id == name.id) { + *num_replacements += 1; + *ty = *ty_var; + } + }, + TypedExprKind::FunctionCall { args, .. } => { + for (_, arg) in args { + replace_var_reference_types(&mut arg.kind, params, num_replacements); + } + }, + TypedExprKind::Intrinsic { intrinsic, .. } => { + use crate::Intrinsic::*; + match intrinsic { + // intrinsics which take one arg, grouped for convenience + Puts(a) | Malloc(a) | SizeOf(a) => { + replace_var_reference_types(&mut a.kind, params, num_replacements); + }, + // intrinsics which take two args, grouped for convenience + Add(a, b) | Subtract(a, b) | Multiply(a, b) | Divide(a, b) | Equals(a, b) => { + replace_var_reference_types(&mut a.kind, params, num_replacements); + replace_var_reference_types(&mut b.kind, params, num_replacements); + }, + } + }, + // TODO other expr kinds like bindings + _ => (), + } +} diff --git a/petr-typecheck/src/lib.rs b/petr-typecheck/src/lib.rs index 964d84f..a2fab05 100644 --- a/petr-typecheck/src/lib.rs +++ b/petr-typecheck/src/lib.rs @@ -1,23 +1,28 @@ //! TODO: //! - Effectual Types //! - Formalize constraints: -//! - Unify -//! - Satisfies //! - UnifyEffects //! - SatisfiesEffects mod error; -use std::{ - collections::{BTreeMap, BTreeSet}, - rc::Rc, -}; - -use error::TypeConstraintError; +pub use constraint_generation::{unify_basic_math_op, FunctionSignature, TypeCheck, TypeChecker}; +pub use error::TypeConstraintError; pub use petr_bind::FunctionId; -use petr_resolve::{Expr, ExprKind, FunctionCall, QueryableResolvedItems}; +use petr_resolve::QueryableResolvedItems; pub use petr_resolve::{Intrinsic as ResolvedIntrinsic, IntrinsicName, Literal}; -use petr_utils::{idx_map_key, Identifier, IndexMap, Span, SpannedItem, SymbolId, SymbolInterner, TypeId}; +use petr_utils::{idx_map_key, IndexMap, SpannedItem, TypeId}; +pub use solution::TypeSolution; +pub use typed_ast::*; +pub use types::*; + +mod constraint_generation; +mod pretty_printing; +mod solution; +#[cfg(test)] +mod tests; +mod typed_ast; +mod types; pub type TypeError = SpannedItem; pub type TResult = Result; @@ -60,2769 +65,3 @@ impl From<&FunctionId> for TypeOrFunctionId { } idx_map_key!(TypeVariable); - -#[derive(Clone, Copy, Debug)] -pub struct TypeConstraint { - kind: TypeConstraintKind, - /// The span from which this type constraint originated - span: Span, -} -impl TypeConstraint { - fn unify( - t1: TypeVariable, - t2: TypeVariable, - span: Span, - ) -> Self { - Self { - kind: TypeConstraintKind::Unify(t1, t2), - span, - } - } - - fn satisfies( - t1: TypeVariable, - t2: TypeVariable, - span: Span, - ) -> Self { - Self { - kind: TypeConstraintKind::Satisfies(t1, t2), - span, - } - } - - fn axiom( - t1: TypeVariable, - span: Span, - ) -> Self { - Self { - kind: TypeConstraintKind::Axiom(t1), - span, - } - } -} - -#[derive(Clone, Copy, Debug)] -pub enum TypeConstraintKind { - Unify(TypeVariable, TypeVariable), - // constraint that lhs is a "subtype" or satisfies the typeclass constraints of "rhs" - Satisfies(TypeVariable, TypeVariable), - // If a type variable is constrained to be an axiom, it means that the type variable - // cannot be updated by the inference engine. It effectively fixes the type, or pins the type. - Axiom(TypeVariable), -} - -#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq)] -enum TypeConstraintKindValue { - Unify, - Satisfies, - Axiom, -} - -pub struct TypeContext { - types: IndexMap, - constraints: Vec, - // known primitive type IDs - unit_ty: TypeVariable, - string_ty: TypeVariable, - int_ty: TypeVariable, - bool_ty: TypeVariable, - error_recovery: TypeVariable, -} - -impl Default for TypeContext { - fn default() -> Self { - let mut types = IndexMap::default(); - // instantiate basic primitive types - let unit_ty = types.insert(SpecificType::Unit); - let string_ty = types.insert(SpecificType::String); - let bool_ty = types.insert(SpecificType::Boolean); - let int_ty = types.insert(SpecificType::Integer); - let error_recovery = types.insert(SpecificType::ErrorRecovery); - // insert primitive types - TypeContext { - types, - constraints: Default::default(), - bool_ty, - unit_ty, - string_ty, - int_ty, - error_recovery, - } - } -} - -impl TypeContext { - fn unify( - &mut self, - ty1: TypeVariable, - ty2: TypeVariable, - span: Span, - ) { - self.constraints.push(TypeConstraint::unify(ty1, ty2, span)); - } - - fn satisfies( - &mut self, - ty1: TypeVariable, - ty2: TypeVariable, - span: Span, - ) { - self.constraints.push(TypeConstraint::satisfies(ty1, ty2, span)); - } - - fn axiom( - &mut self, - ty1: TypeVariable, - span: Span, - ) { - self.constraints.push(TypeConstraint::axiom(ty1, span)); - } - - fn new_variable( - &mut self, - span: Span, - ) -> TypeVariable { - // infer is special -- it knows its own id, mostly for printing - // and disambiguating - let infer_id = self.types.len(); - self.types.insert(SpecificType::Infer(infer_id, span)) - } - - /// Update a type variable with a new SpecificType - fn update_type( - &mut self, - t1: TypeVariable, - known: SpecificType, - ) { - *self.types.get_mut(t1) = known; - } - - pub fn types(&self) -> &IndexMap { - &self.types - } -} - -pub type FunctionSignature = (FunctionId, Box<[GeneralType]>); - -pub struct TypeChecker { - ctx: TypeContext, - type_map: BTreeMap, - monomorphized_functions: BTreeMap, - typed_functions: BTreeMap, - errors: Vec, - resolved: QueryableResolvedItems, - variable_scope: Vec>, -} - -#[derive(Clone, PartialEq, Debug, Eq, PartialOrd, Ord)] -/// A type which is general, and has no constraints applied to it. -/// This is a generalization of [`SpecificType`]. -/// This is more useful for IR generation, since functions are monomorphized -/// based on general types. -pub enum GeneralType { - Unit, - Integer, - Boolean, - String, - UserDefined { - name: Identifier, - // TODO these should be boxed slices, as their size is not changed - variants: Box<[GeneralizedTypeVariant]>, - constant_literal_types: Vec, - }, - Arrow(Vec), - ErrorRecovery, - List(Box), - Infer(usize, Span), - Sum(BTreeSet), -} - -impl GeneralType { - /// Because [`GeneralType`]'s type info is less detailed (specific) than [`SpecificType`], - /// we can losslessly cast any [`GeneralType`] into an instance of [`SpecificType`]. - pub fn safely_upcast(&self) -> SpecificType { - match self { - GeneralType::Unit => SpecificType::Unit, - GeneralType::Integer => SpecificType::Integer, - GeneralType::Boolean => SpecificType::Boolean, - GeneralType::String => SpecificType::String, - GeneralType::ErrorRecovery => SpecificType::ErrorRecovery, - _ => todo!(), - } - } -} - -/// Represents the result of the type-checking stage for an individual type variable. -pub struct TypeSolutionEntry { - source: TypeSolutionSource, - ty: SpecificType, -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum TypeSolutionSource { - Inherent, - Axiomatic, - Inferred, -} - -impl TypeSolutionEntry { - pub fn new_axiomatic(ty: SpecificType) -> Self { - Self { - source: TypeSolutionSource::Axiomatic, - ty, - } - } - - pub fn new_inherent(ty: SpecificType) -> Self { - Self { - source: TypeSolutionSource::Inherent, - ty, - } - } - - pub fn new_inferred(ty: SpecificType) -> Self { - Self { - source: TypeSolutionSource::Inferred, - ty, - } - } - - pub fn is_axiomatic(&self) -> bool { - self.source == TypeSolutionSource::Axiomatic - } -} - -pub struct TypeSolution { - solution: BTreeMap, - unsolved_types: IndexMap, - errors: Vec, - interner: SymbolInterner, - error_recovery: TypeVariable, - unit: TypeVariable, - functions: BTreeMap, - monomorphized_functions: BTreeMap, -} - -impl TypeSolution { - pub fn new( - unsolved_types: IndexMap, - error_recovery: TypeVariable, - unit: TypeVariable, - functions: BTreeMap, - monomorphized_functions: BTreeMap, - interner: SymbolInterner, - preexisting_errors: Vec, - ) -> Self { - let solution = vec![ - (unit, TypeSolutionEntry::new_inherent(SpecificType::Unit)), - (error_recovery, TypeSolutionEntry::new_inherent(SpecificType::ErrorRecovery)), - ] - .into_iter() - .collect(); - Self { - solution, - unsolved_types, - errors: preexisting_errors, - interner, - functions, - monomorphized_functions, - unit, - error_recovery, - } - } - - fn push_error( - &mut self, - e: TypeError, - ) { - self.errors.push(e); - } - - pub fn insert_solution( - &mut self, - ty: TypeVariable, - entry: TypeSolutionEntry, - span: Span, - ) { - if self.solution.contains_key(&ty) { - self.update_type(ty, entry, span); - return; - } - self.solution.insert(ty, entry); - } - - fn pretty_print_type( - &self, - ty: &SpecificType, - ) -> String { - pretty_printing::pretty_print_petr_type(&ty, &self.unsolved_types, &self.interner) - } - - fn unify_err( - &self, - clone_1: SpecificType, - clone_2: SpecificType, - ) -> TypeConstraintError { - let pretty_printed_b = self.pretty_print_type(&clone_2); - match clone_1 { - SpecificType::Sum(tys) => { - let tys = tys.iter().map(|ty| self.pretty_print_type(ty)).collect::>(); - TypeConstraintError::NotSubtype(tys, pretty_printed_b) - }, - _ => { - let pretty_printed_a = self.pretty_print_type(&clone_1); - TypeConstraintError::UnificationFailure(pretty_printed_a, pretty_printed_b) - }, - } - } - - fn satisfy_err( - &self, - clone_1: SpecificType, - clone_2: SpecificType, - ) -> TypeConstraintError { - let pretty_printed_b = self.pretty_print_type(&clone_2); - match clone_1 { - SpecificType::Sum(tys) => { - let tys = tys.iter().map(|ty| self.pretty_print_type(&ty)).collect::>(); - TypeConstraintError::NotSubtype(tys, pretty_printed_b) - }, - _ => { - let pretty_printed_a = self.pretty_print_type(&clone_1); - TypeConstraintError::FailedToSatisfy(pretty_printed_a, pretty_printed_b) - }, - } - } - - pub fn update_type( - &mut self, - ty: TypeVariable, - entry: TypeSolutionEntry, - span: Span, - ) { - match self.solution.get_mut(&ty) { - Some(e) => { - if e.is_axiomatic() { - let pretty_printed_preexisting = pretty_printing::pretty_print_petr_type(&entry.ty, &self.unsolved_types, &self.interner); - let pretty_printed_ty = pretty_printing::pretty_print_petr_type(&entry.ty, &self.unsolved_types, &self.interner); - self.errors - .push(span.with_item(TypeConstraintError::InvalidTypeUpdate(pretty_printed_preexisting, pretty_printed_ty))); - return; - } - *e = entry; - }, - None => { - self.solution.insert(ty, entry); - }, - } - } - - fn into_result(self) -> Result>> { - if self.errors.is_empty() { - Ok(self) - } else { - Err(self.errors) - } - } - - /// Attempt to unify two types, returning an error if they cannot be unified - /// The more specific of the two types will instantiate the more general of the two types. - /// - /// TODO: The unify constraint should attempt to upcast `t2` as `t1` if possible, but will never - /// downcast `t1` as `t2`. This is not currently how it works and needs investigation. - fn apply_unify_constraint( - &mut self, - t1: TypeVariable, - t2: TypeVariable, - span: Span, - ) { - let ty1 = self.get_latest_type(t1).clone(); - let ty2 = self.get_latest_type(t2).clone(); - use SpecificType::*; - match (ty1, ty2) { - (a, b) if a == b => (), - (ErrorRecovery, _) | (_, ErrorRecovery) => (), - (Ref(a), _) => self.apply_unify_constraint(a, t2, span), - (_, Ref(b)) => self.apply_unify_constraint(t1, b, span), - (Infer(id, _), Infer(id2, _)) if id != id2 => { - // if two different inferred types are unified, replace the second with a reference - // to the first - let entry = TypeSolutionEntry::new_inferred(Ref(t1)); - self.update_type(t2, entry, span); - }, - (a @ Sum(_), b @ Sum(_)) => { - // the unification of two sum types is the union of the two types if and only if - // `t2` is a total subset of `t1` - // `t1` remains unchanged, as we are trying to coerce `t2` into something that - // represents `t1` - // TODO remove clone - if self.a_superset_of_b(&a, &b) { - let entry = TypeSolutionEntry::new_inferred(Ref(t1)); - self.update_type(t2, entry, span); - } else { - // the union of the two sets is the new type - let a_tys = match a { - Sum(tys) => tys, - _ => unreachable!(), - }; - let b_tys = match b { - Sum(tys) => tys, - _ => unreachable!(), - }; - let union = a_tys.iter().chain(b_tys.iter()).cloned().collect(); - let entry = TypeSolutionEntry::new_inferred(Sum(union)); - self.update_type(t1, entry, span); - let entry = TypeSolutionEntry::new_inferred(Ref(t1)); - self.update_type(t2, entry, span); - } - }, - // If `t2` is a non-sum type, and `t1` is a sum type, then `t1` must contain either - // exactly the same specific type OR the generalization of that type - // If the latter, then the specific type must be updated to its generalization - (ref t1_ty @ Sum(_), other) => { - if self.a_superset_of_b(&t1_ty, &other) { - // t2 unifies to the more general form provided by t1 - let entry = TypeSolutionEntry::new_inferred(Ref(t1)); - self.update_type(t2, entry, span); - } else { - // add `other` to `t1` - let mut tys = match t1_ty { - Sum(tys) => tys.clone(), - _ => unreachable!(), - }; - tys.insert(other); - let entry = TypeSolutionEntry::new_inferred(Sum(tys)); - self.update_type(t1, entry, span); - } - }, - // literals can unify to each other if they're equal - (Literal(l1), Literal(l2)) if l1 == l2 => (), - // if they're not equal, their unification is the sum of both - (Literal(l1), Literal(l2)) if l1 != l2 => { - // update t1 to a sum type of both, - // and update t2 to reference t1 - let sum = Sum([Literal(l1), Literal(l2)].into()); - let t1_entry = TypeSolutionEntry::new_inferred(sum); - let t2_entry = TypeSolutionEntry::new_inferred(Ref(t1)); - self.update_type(t1, t1_entry, span); - self.update_type(t2, t2_entry, span); - }, - (Literal(l1), Sum(tys)) => { - // update t1 to a sum type of both, - // and update t2 to reference t1 - let sum = Sum([Literal(l1)].iter().chain(tys.iter()).cloned().collect()); - let t1_entry = TypeSolutionEntry::new_inferred(sum); - let t2_entry = TypeSolutionEntry::new_inferred(Ref(t1)); - self.update_type(t1, t1_entry, span); - self.update_type(t2, t2_entry, span); - }, - (a, b) if self.a_superset_of_b(&a, &b) => { - // if `a` is a superset of `b`, then `b` unifies to `a` as it is more general - let entry = TypeSolutionEntry::new_inferred(Ref(t1)); - self.update_type(t2, entry, span); - }, - /* TODO rewrite below rules - // literals can unify broader parent types - // but the broader parent type gets instantiated with the literal type - // TODO(alex) this rule feels incorrect. A literal being unified to the parent type - // should upcast the lit, not downcast t1. Check after refactoring. - (ty, Literal(lit)) => match (&lit, ty) { - (petr_resolve::Literal::Integer(_), Integer) - | (petr_resolve::Literal::Boolean(_), Boolean) - | (petr_resolve::Literal::String(_), String) => { - let entry = TypeSolutionEntry::new_inferred(SpecificType::Literal(lit)); - self.update_type(t1, entry, span); - }, - (lit, ty) => self.push_error(span.with_item(self.unify_err(ty.clone(), SpecificType::Literal(lit.clone())))), - }, - // literals can unify broader parent types - // but the broader parent type gets instantiated with the literal type - (Literal(lit), ty) => match (&lit, ty) { - (petr_resolve::Literal::Integer(_), Integer) - | (petr_resolve::Literal::Boolean(_), Boolean) - | (petr_resolve::Literal::String(_), String) => self.update_type(t2, SpecificType::Literal(lit)), - (lit, ty) => { - self.push_error(span.with_item(self.unify_err(ty.clone(), SpecificType::Literal(lit.clone())))); - }, - }, - */ - (other, ref t2_ty @ Sum(_)) => { - // if `other` is a superset of `t2`, then `t2` unifies to `other` as it is more - // general - if self.a_superset_of_b(&other, &t2_ty) { - let entry = TypeSolutionEntry::new_inferred(other); - self.update_type(t2, entry, span); - } - }, - // instantiate the infer type with the known type - (Infer(_, _), _known) => { - let entry = TypeSolutionEntry::new_inferred(Ref(t2)); - self.update_type(t1, entry, span); - }, - (_known, Infer(_, _)) => { - let entry = TypeSolutionEntry::new_inferred(Ref(t1)); - self.update_type(t2, entry, span); - }, - // lastly, if no unification rule exists for these two types, it is a mismatch - (a, b) => { - self.push_error(span.with_item(self.unify_err(a, b))); - }, - } - } - - // This function will need to be rewritten when type constraints and bounded polymorphism are - // implemented. - fn apply_satisfies_constraint( - &mut self, - t1: TypeVariable, - t2: TypeVariable, - span: Span, - ) { - let ty1 = self.get_latest_type(t1); - let ty2 = self.get_latest_type(t2); - use SpecificType::*; - match (ty1, ty2) { - (a, b) if a == b => (), - (ErrorRecovery, _) | (_, ErrorRecovery) => (), - (Ref(a), _) => self.apply_satisfies_constraint(a, t2, span), - (_, Ref(b)) => self.apply_satisfies_constraint(t1, b, span), - // if t1 is a fully instantiated type, then t2 can be updated to be a reference to t1 - (Unit | Integer | Boolean | UserDefined { .. } | String | Arrow(..) | List(..) | Literal(_) | Sum(_), Infer(_, _)) => { - let entry = TypeSolutionEntry::new_inferred(Ref(t1)); - self.update_type(t2, entry, span); - }, - // the "parent" infer type will not instantiate to the "child" type - (Infer(_, _), Unit | Integer | Boolean | UserDefined { .. } | String | Arrow(..) | List(..) | Literal(_) | Sum(_)) => (), - (Sum(a_tys), Sum(b_tys)) => { - // calculate the intersection of these types, update t2 to the intersection - let intersection = a_tys.iter().filter(|a_ty| b_tys.contains(a_ty)).cloned().collect(); - let entry = TypeSolutionEntry::new_inferred(SpecificType::sum(intersection)); - self.update_type(t2, entry, span); - }, - // if `ty1` is a generalized version of the sum type, - // then it satisfies the sum type - (ty1, other) if self.a_superset_of_b(&ty1, &other) => (), - (Literal(l1), Literal(l2)) if l1 == l2 => (), - // Literals can satisfy broader parent types - (ty, Literal(lit)) => match (lit, ty) { - (petr_resolve::Literal::Integer(_), Integer) => (), - (petr_resolve::Literal::Boolean(_), Boolean) => (), - (petr_resolve::Literal::String(_), String) => (), - (lit, ty) => { - self.push_error(span.with_item(self.satisfy_err(ty.clone(), SpecificType::Literal(lit.clone())))); - }, - }, - // if we are trying to satisfy an inferred type with no bounds, this is ok - (Infer(..), _) => (), - (a, b) => { - self.push_error(span.with_item(self.satisfy_err(a.clone(), b.clone()))); - }, - } - } - - /// Gets the latest version of a type available. First checks solved types, - /// and if it doesn't exist, gets it from the unsolved types. - pub fn get_latest_type( - &self, - t1: TypeVariable, - ) -> SpecificType { - self.solution - .get(&t1) - .map(|entry| entry.ty.clone()) - .unwrap_or_else(|| self.unsolved_types.get(t1).clone()) - } - - /// To reference an error recovery type, you must provide an error. - /// This holds the invariant that error recovery types are only generated when - /// an error occurs. - pub fn error_recovery( - &mut self, - err: TypeError, - ) -> TypeVariable { - self.push_error(err); - self.error_recovery - } - - /// If `a` is a generalized form of `b`, return true - /// A generalized form is a type that is a superset of the sum types. - /// For example, `String` is a generalized form of `Sum(Literal("a") | Literal("B"))` - fn a_superset_of_b( - &self, - a: &SpecificType, - b: &SpecificType, - ) -> bool { - use SpecificType::*; - let generalized_b = self.generalize(&b).safely_upcast(); - match (a, b) { - // If `a` is the generalized form of `b`, then `b` satisfies the constraint. - (a, b) if a == b || *a == generalized_b => true, - // If `a` is a sum type which contains `b` OR the generalized form of `b`, then `b` - // satisfies the constraint. - (Sum(a_tys), b) if a_tys.contains(b) || a_tys.contains(&generalized_b) => true, - // if both `a` and `b` are sum types, then `a` must be a superset of `b`: - // - every element in `b` must either: - // - be a member of `a` - // - generalize to a member of `a` - (Sum(a_tys), Sum(b_tys)) => { - // if a_tys is a superset of b_tys, - // every element OR its generalized version is contained in a_tys - for b_ty in b_tys { - let b_ty_generalized = self.generalize(b_ty).safely_upcast(); - if !(a_tys.contains(b_ty) || a_tys.contains(&b_ty_generalized)) { - return false; - } - } - - true - }, - _otherwise => false, - } - } - - pub fn generalize( - &self, - b: &SpecificType, - ) -> GeneralType { - match b { - SpecificType::Unit => GeneralType::Unit, - SpecificType::Integer => GeneralType::Integer, - SpecificType::Boolean => GeneralType::Boolean, - SpecificType::String => GeneralType::String, - SpecificType::Ref(ty) => self.generalize(&self.get_latest_type(*ty)), - SpecificType::UserDefined { - name, - variants, - constant_literal_types, - } => GeneralType::UserDefined { - name: *name, - variants: variants - .iter() - .map(|variant| { - let generalized_fields = variant.fields.iter().map(|field| self.generalize(field)).collect::>(); - - GeneralizedTypeVariant { - fields: generalized_fields.into_boxed_slice(), - } - }) - .collect(), - constant_literal_types: constant_literal_types.clone(), - }, - SpecificType::Arrow(tys) => GeneralType::Arrow(tys.clone()), - SpecificType::ErrorRecovery => GeneralType::ErrorRecovery, - SpecificType::List(ty) => { - let ty = self.generalize(ty); - GeneralType::List(Box::new(ty)) - }, - SpecificType::Infer(u, s) => GeneralType::Infer(*u, *s), - SpecificType::Literal(l) => match l { - Literal::Integer(_) => GeneralType::Integer, - Literal::Boolean(_) => GeneralType::Boolean, - Literal::String(_) => GeneralType::String, - }, - SpecificType::Sum(tys) => { - // generalize all types, fold if possible - let all_generalized: BTreeSet<_> = tys.iter().map(|ty| self.generalize(ty)).collect(); - if all_generalized.len() == 1 { - // in this case, all specific types generalized to the same type - all_generalized.into_iter().next().expect("invariant") - } else { - GeneralType::Sum(all_generalized.into_iter().collect()) - } - }, - } - } - - #[cfg(test)] - fn pretty_print(&self) -> String { - let mut pretty = "__SOLVED TYPES__\n".to_string(); - - let mut num_entries = 0; - for (ty, entry) in self.solution.iter().filter(|(id, _)| ![self.unit, self.error_recovery].contains(id)) { - pretty.push_str(&format!("{}: {}\n", Into::::into(*ty), self.pretty_print_type(&entry.ty))); - num_entries += 1; - } - if num_entries == 0 { - Default::default() - } else { - pretty - } - } - - pub fn get_main_function(&self) -> Option<(&FunctionId, &Function)> { - self.functions.iter().find(|(_, func)| &*self.interner.get(func.name.id) == "main") - } - - pub fn get_monomorphized_function( - &self, - id: &FunctionSignature, - ) -> &Function { - self.monomorphized_functions.get(id).expect("invariant: should exist") - } - - pub fn expr_ty( - &self, - expr: &TypedExpr, - ) -> TypeVariable { - use TypedExprKind::*; - match &expr.kind { - FunctionCall { ty, .. } => *ty, - Literal { ty, .. } => *ty, - List { ty, .. } => *ty, - Unit => self.unit, - Variable { ty, .. } => *ty, - Intrinsic { ty, .. } => *ty, - ErrorRecovery(..) => self.error_recovery, - ExprWithBindings { expression, .. } => self.expr_ty(expression), - TypeConstructor { ty, .. } => *ty, - If { then_branch, .. } => self.expr_ty(then_branch), - } - } -} - -/// This is an information-rich type -- it tracks effects and data types. It is used for -/// the type-checking stage to provide rich information to the user. -/// Types are generalized into instances of [`GeneralType`] for monomorphization and -/// code generation. -#[derive(Clone, PartialEq, Debug, Eq, PartialOrd, Ord)] -pub enum SpecificType { - Unit, - Integer, - Boolean, - /// a static length string known at compile time - String, - /// A reference to another type - Ref(TypeVariable), - /// A user-defined type - UserDefined { - name: Identifier, - // TODO these should be boxed slices, as their size is not changed - variants: Vec, - constant_literal_types: Vec, - }, - Arrow(Vec), - ErrorRecovery, - // TODO make this petr type instead of typevariable - List(Box), - /// the usize is just an identifier for use in rendering the type - /// the span is the location of the inference, for error reporting if the inference is never - /// resolved - Infer(usize, Span), - Sum(BTreeSet), - Literal(Literal), -} - -#[derive(Clone, PartialEq, Debug, Eq, PartialOrd, Ord)] -pub struct GeneralizedTypeVariant { - pub fields: Box<[GeneralType]>, -} - -impl SpecificType { - fn generalize_inner( - &self, - types: &IndexMap, - ) -> GeneralType { - match self { - SpecificType::Unit => GeneralType::Unit, - SpecificType::Integer => GeneralType::Integer, - SpecificType::Boolean => GeneralType::Boolean, - SpecificType::String => GeneralType::String, - SpecificType::Ref(ty) => types.get(*ty).generalize(types), - SpecificType::UserDefined { - name, - variants, - constant_literal_types, - } => GeneralType::UserDefined { - name: *name, - variants: variants - .iter() - .map(|variant| { - let generalized_fields = variant.fields.iter().map(|field| field.generalize(types)).collect::>(); - - GeneralizedTypeVariant { - fields: generalized_fields.into_boxed_slice(), - } - }) - .collect(), - constant_literal_types: constant_literal_types.clone(), - }, - SpecificType::Arrow(tys) => GeneralType::Arrow(tys.clone()), - SpecificType::ErrorRecovery => GeneralType::ErrorRecovery, - SpecificType::List(ty) => { - let ty = ty.generalize(types); - GeneralType::List(Box::new(ty)) - }, - SpecificType::Infer(u, s) => GeneralType::Infer(*u, *s), - SpecificType::Literal(l) => match l { - Literal::Integer(_) => GeneralType::Integer, - Literal::Boolean(_) => GeneralType::Boolean, - Literal::String(_) => GeneralType::String, - }, - SpecificType::Sum(tys) => { - // generalize all types, fold if possible - let all_generalized: BTreeSet<_> = tys.iter().map(|ty| ty.generalize(types)).collect(); - if all_generalized.len() == 1 { - // in this case, all specific types generalized to the same type - all_generalized.into_iter().next().expect("invariant") - } else { - GeneralType::Sum(all_generalized.into_iter().collect()) - } - }, - } - } - - /// Use this to construct `[SpecificType::Sum]` types -- - /// it will attempt to collapse the sum into a single type if possible - fn sum(tys: BTreeSet) -> SpecificType { - if tys.len() == 1 { - tys.into_iter().next().expect("invariant") - } else { - SpecificType::Sum(tys) - } - } -} - -#[derive(Clone, PartialEq, Debug, Eq, PartialOrd, Ord)] -pub struct TypeVariant { - pub fields: Box<[SpecificType]>, -} - -pub trait Type { - fn as_specific_ty(&self) -> SpecificType; - - fn generalize( - &self, - types: &IndexMap, - ) -> GeneralType; -} - -impl Type for SpecificType { - fn as_specific_ty(&self) -> SpecificType { - self.clone() - } - - fn generalize( - &self, - types: &IndexMap, - ) -> GeneralType { - self.generalize_inner(&types) - } -} - -impl Type for GeneralType { - fn generalize( - &self, - _: &IndexMap, - ) -> Self { - self.clone() - } - - fn as_specific_ty(&self) -> SpecificType { - match self { - GeneralType::Unit => SpecificType::Unit, - GeneralType::Integer => SpecificType::Integer, - GeneralType::Boolean => SpecificType::Boolean, - GeneralType::String => SpecificType::String, - GeneralType::UserDefined { - name, - variants, - constant_literal_types, - } => SpecificType::UserDefined { - name: *name, - variants: variants - .iter() - .map(|variant| { - let fields = variant.fields.iter().map(|field| field.as_specific_ty()).collect::>(); - - TypeVariant { - fields: fields.into_boxed_slice(), - } - }) - .collect(), - constant_literal_types: constant_literal_types.clone(), - }, - GeneralType::Arrow(tys) => SpecificType::Arrow(tys.clone()), - GeneralType::ErrorRecovery => SpecificType::ErrorRecovery, - GeneralType::List(ty) => SpecificType::List(Box::new(ty.as_specific_ty())), - GeneralType::Infer(u, s) => SpecificType::Infer(*u, *s), - GeneralType::Sum(tys) => { - let tys = tys.iter().map(|ty| ty.as_specific_ty()).collect(); - SpecificType::Sum(tys) - }, - } - } -} - -impl TypeChecker { - pub fn insert_type( - &mut self, - ty: &T, - ) -> TypeVariable { - let ty = ty.as_specific_ty(); - // TODO: check if type already exists and return that ID instead - self.ctx.types.insert(ty) - } - - pub fn look_up_variable( - &self, - ty: TypeVariable, - ) -> &SpecificType { - self.ctx.types.get(ty) - } - - pub fn get_symbol( - &self, - id: SymbolId, - ) -> Rc { - self.resolved.interner.get(id).clone() - } - - fn with_type_scope( - &mut self, - f: impl FnOnce(&mut Self) -> T, - ) -> T { - self.variable_scope.push(Default::default()); - let res = f(self); - self.variable_scope.pop(); - res - } - - fn generic_type( - &mut self, - id: &Identifier, - ) -> TypeVariable { - for scope in self.variable_scope.iter().rev() { - if let Some(ty) = scope.get(id) { - return *ty; - } - } - let fresh_ty = self.fresh_ty_var(id.span); - match self.variable_scope.last_mut() { - Some(entry) => { - entry.insert(*id, fresh_ty); - }, - None => { - self.errors.push(id.span.with_item(TypeConstraintError::Internal( - "attempted to insert generic type into variable scope when no variable scope existed".into(), - ))); - self.ctx.update_type(fresh_ty, SpecificType::ErrorRecovery); - }, - }; - fresh_ty - } - - fn find_variable( - &self, - id: Identifier, - ) -> Option { - for scope in self.variable_scope.iter().rev() { - if let Some(ty) = scope.get(&id) { - return Some(*ty); - } - } - None - } - - pub fn fully_type_check(&mut self) { - for (id, decl) in self.resolved.types() { - let ty = self.fresh_ty_var(decl.name.span); - let variants = decl - .variants - .iter() - .map(|variant| { - self.with_type_scope(|ctx| { - let fields = variant.fields.iter().map(|field| ctx.to_petr_type(&field.ty)).collect::>(); - TypeVariant { - fields: fields.into_boxed_slice(), - } - }) - }) - .collect::>(); - self.ctx.update_type( - ty, - SpecificType::UserDefined { - name: decl.name, - variants, - constant_literal_types: decl.constant_literal_types, - }, - ); - self.type_map.insert(id.into(), ty); - } - - for (id, func) in self.resolved.functions() { - let typed_function = func.type_check(self); - - let ty = self.arrow_type([typed_function.params.iter().map(|(_, b)| *b).collect(), vec![typed_function.return_ty]].concat()); - self.type_map.insert(id.into(), ty); - self.typed_functions.insert(id, typed_function); - } - // type check the main func with no params - let main_func = self.get_main_function(); - // construct a function call for the main function, if one exists - if let Some((id, func)) = main_func { - let call = petr_resolve::FunctionCall { - function: id, - args: vec![], - span: func.name.span, - }; - call.type_check(self); - } - - // before applying existing constraints, it is likely that many duplicate constraints - // exist. We can safely remove any duplicate constraints to avoid excessive error - // reporting. - self.deduplicate_constraints(); - } - - pub fn get_main_function(&self) -> Option<(FunctionId, Function)> { - self.functions().find(|(_, func)| &*self.get_symbol(func.name.id) == "main") - } - - /// iterate through each constraint and transform the underlying types to satisfy them - /// - unification tries to collapse two types into one - /// - satisfaction tries to make one type satisfy the constraints of another, although type - /// constraints don't exist in the language yet - pub fn into_solution(self) -> Result> { - let constraints = self.ctx.constraints.clone(); - let mut solution = TypeSolution::new( - self.ctx.types.clone(), - self.ctx.error_recovery, - self.ctx.unit_ty, - self.typed_functions, - self.monomorphized_functions, - self.resolved.interner, - self.errors, - ); - for TypeConstraint { kind, span } in constraints - .iter() - .filter(|c| if let TypeConstraintKind::Axiom(_) = c.kind { true } else { false }) - { - let TypeConstraintKind::Axiom(axiomatic_variable) = kind else { - unreachable!("above filter ensures that all constraints are axioms here") - }; - // first, pin all axiomatic type variables in the solution - let ty = self.ctx.types.get(*axiomatic_variable).clone(); - solution.insert_solution(*axiomatic_variable, TypeSolutionEntry::new_axiomatic(ty), *span); - } - - // now apply the constraints - for constraint in constraints.iter().filter(|c| !matches!(c.kind, TypeConstraintKind::Axiom(_))) { - match &constraint.kind { - TypeConstraintKind::Unify(t1, t2) => { - solution.apply_unify_constraint(*t1, *t2, constraint.span); - }, - TypeConstraintKind::Satisfies(t1, t2) => { - solution.apply_satisfies_constraint(*t1, *t2, constraint.span); - }, - TypeConstraintKind::Axiom(_) => unreachable!(), - } - } - - solution.into_result() - } - - pub fn new(resolved: QueryableResolvedItems) -> Self { - let ctx = TypeContext::default(); - TypeChecker { - ctx, - type_map: Default::default(), - errors: Default::default(), - typed_functions: Default::default(), - resolved, - variable_scope: Default::default(), - monomorphized_functions: Default::default(), - } - } - - pub fn insert_variable( - &mut self, - id: Identifier, - ty: TypeVariable, - ) { - self.variable_scope - .last_mut() - .expect("inserted variable when no scope existed") - .insert(id, ty); - } - - pub fn fresh_ty_var( - &mut self, - span: Span, - ) -> TypeVariable { - self.ctx.new_variable(span) - } - - fn arrow_type( - &mut self, - tys: Vec, - ) -> TypeVariable { - assert!(!tys.is_empty(), "arrow_type: tys is empty"); - - if tys.len() == 1 { - return tys[0]; - } - - let ty = SpecificType::Arrow(tys); - self.ctx.types.insert(ty) - } - - pub fn to_petr_type( - &mut self, - ty: &petr_resolve::Type, - ) -> SpecificType { - match ty { - petr_resolve::Type::Integer => SpecificType::Integer, - petr_resolve::Type::Bool => SpecificType::Boolean, - petr_resolve::Type::Unit => SpecificType::Unit, - petr_resolve::Type::String => SpecificType::String, - petr_resolve::Type::ErrorRecovery(_) => { - // unifies to anything, fresh var - SpecificType::ErrorRecovery - }, - petr_resolve::Type::Named(ty_id) => SpecificType::Ref(*self.type_map.get(&ty_id.into()).expect("type did not exist in type map")), - petr_resolve::Type::Generic(generic_name) => { - // TODO don't create an ID and then reference it -- this is messy - let id = self.generic_type(generic_name); - SpecificType::Ref(id) - }, - petr_resolve::Type::Sum(tys) => SpecificType::Sum(tys.iter().map(|ty| self.to_petr_type(ty)).collect()), - petr_resolve::Type::Literal(l) => SpecificType::Literal(l.clone()), - } - } - - pub fn to_type_var( - &mut self, - ty: &petr_resolve::Type, - ) -> TypeVariable { - let petr_ty = self.to_petr_type(ty); - self.ctx.types.insert(petr_ty) - } - - pub fn get_type( - &self, - key: impl Into, - ) -> &TypeVariable { - self.type_map.get(&key.into()).expect("type did not exist in type map") - } - - fn convert_literal_to_type( - &mut self, - literal: &petr_resolve::Literal, - ) -> TypeVariable { - let ty = SpecificType::Literal(literal.clone()); - self.ctx.types.insert(ty) - } - - pub fn unify( - &mut self, - ty1: TypeVariable, - ty2: TypeVariable, - span: Span, - ) { - self.ctx.unify(ty1, ty2, span); - } - - pub fn satisfies( - &mut self, - ty1: TypeVariable, - ty2: TypeVariable, - span: Span, - ) { - self.ctx.satisfies(ty1, ty2, span); - } - - fn axiom( - &mut self, - ty: TypeVariable, - span: Span, - ) { - self.ctx.axiom(ty, span); - } - - fn get_untyped_function( - &self, - function: FunctionId, - ) -> &petr_resolve::Function { - self.resolved.get_function(function) - } - - pub fn get_function( - &mut self, - id: &FunctionId, - ) -> Function { - if let Some(func) = self.typed_functions.get(id) { - return func.clone(); - } - - // if the function hasn't been type checked yet, type check it - let func = self.get_untyped_function(*id).clone(); - let type_checked = func.type_check(self); - self.typed_functions.insert(*id, type_checked.clone()); - type_checked - } - - pub fn get_monomorphized_function( - &self, - id: &FunctionSignature, - ) -> &Function { - self.monomorphized_functions.get(id).expect("invariant: should exist") - } - - // TODO unideal clone - pub fn functions(&self) -> impl Iterator { - self.typed_functions.iter().map(|(a, b)| (*a, b.clone())).collect::>().into_iter() - } - - pub fn expr_ty( - &self, - expr: &TypedExpr, - ) -> TypeVariable { - use TypedExprKind::*; - match &expr.kind { - FunctionCall { ty, .. } => *ty, - Literal { ty, .. } => *ty, - List { ty, .. } => *ty, - Unit => self.unit(), - Variable { ty, .. } => *ty, - Intrinsic { ty, .. } => *ty, - ErrorRecovery(..) => self.ctx.error_recovery, - ExprWithBindings { expression, .. } => self.expr_ty(expression), - TypeConstructor { ty, .. } => *ty, - If { then_branch, .. } => self.expr_ty(then_branch), - } - } - - /// Given a concrete [`SpecificType`], unify it with the return type of the given expression. - pub fn unify_expr_return( - &mut self, - ty: TypeVariable, - expr: &TypedExpr, - ) { - let expr_ty = self.expr_ty(expr); - self.unify(ty, expr_ty, expr.span()); - } - - pub fn string(&self) -> TypeVariable { - self.ctx.string_ty - } - - pub fn unit(&self) -> TypeVariable { - self.ctx.unit_ty - } - - pub fn int(&self) -> TypeVariable { - self.ctx.int_ty - } - - pub fn bool(&self) -> TypeVariable { - self.ctx.bool_ty - } - - pub fn errors(&self) -> &[TypeError] { - &self.errors - } - - fn satisfy_expr_return( - &mut self, - ty: TypeVariable, - expr: &TypedExpr, - ) { - let expr_ty = self.expr_ty(expr); - self.satisfies(ty, expr_ty, expr.span()); - } - - pub fn ctx(&self) -> &TypeContext { - &self.ctx - } - - /// terms: - /// ### resolved type variable - /// - /// a type variable that is not a `Ref`. To get the resolved type of - /// a type variable, you must follow the chain of `Ref`s until you reach a non-Ref type. - /// - /// ### constraint kind strength: - /// The following is the hierarchy of constraints in terms of strength, from strongest (1) to - /// weakest: - /// 1. Unify(t1, t2) (t2 _must_ be coerceable to exactly equal t1) - /// 2. Satisfies (t2 must be a subset of t1. For all cases where t2 can unify to t1, t2 - /// satisfies t1 as a constraint) - /// - /// ### constraint strength - /// A constraint `a` is _stronger than_ a constraint `b` iff: - /// - `a` is higher than `b` in terms of constraint kind strength `a` is a more specific constraint than `b` - /// - e.g. Unify(Literal(5), x) is stronger than Unify(Int, x) because the former is more specific - /// - e.g. Unify(a, b) is stronger than Satisfies(a, b) - /// - /// - /// ### duplicated constraint: - /// A constraint `a` is _duplicated by_ constraint `b` iff: - /// - `a` and `b` are the same constraint kind, and the resolved type variables are the same - /// - `a` is a stronger constraint than `b` - /// - fn deduplicate_constraints(&mut self) { - use TypeConstraintKindValue as Kind; - let mut constraints = ConstraintDeduplicator::default(); - let mut errs = vec![]; - for constraint in &self.ctx.constraints { - let (mut tys, kind) = match &constraint.kind { - TypeConstraintKind::Unify(t1, t2) => (vec![*t1, *t2], Kind::Unify), - TypeConstraintKind::Satisfies(t1, t2) => (vec![*t1, *t2], Kind::Satisfies), - TypeConstraintKind::Axiom(t1) => (vec![*t1], Kind::Axiom), - }; - - // resolve all `Ref` types to get a resolved type variable - 'outer: for ty_var in tys.iter_mut() { - // track what we have seen, in case a circular reference is present - let mut seen_vars = BTreeSet::new(); - seen_vars.insert(*ty_var); - let mut ty = self.ctx.types.get(*ty_var); - while let SpecificType::Ref(t) = ty { - if seen_vars.contains(t) { - // circular reference - errs.push(constraint.span.with_item(TypeConstraintError::CircularType)); - continue 'outer; - } - *ty_var = *t; - ty = self.ctx.types.get(*t); - } - } - - constraints.insert((kind, tys), *constraint); - } - - self.ctx.constraints = constraints.into_values(); - } - - fn push_error( - &mut self, - e: TypeError, - ) { - self.errors.push(e); - } -} - -/// the `key` type is what we use to deduplicate constraints -#[derive(Default)] -struct ConstraintDeduplicator { - constraints: BTreeMap<(TypeConstraintKindValue, Vec), TypeConstraint>, -} - -impl ConstraintDeduplicator { - fn insert( - &mut self, - key: (TypeConstraintKindValue, Vec), - constraint: TypeConstraint, - ) { - self.constraints.insert(key, constraint); - } - - fn into_values(self) -> Vec { - self.constraints.into_values().collect() - } -} - -#[derive(Clone)] -pub enum Intrinsic { - Puts(Box), - Add(Box, Box), - Multiply(Box, Box), - Divide(Box, Box), - Subtract(Box, Box), - Malloc(Box), - SizeOf(Box), - Equals(Box, Box), -} - -impl std::fmt::Debug for Intrinsic { - fn fmt( - &self, - f: &mut std::fmt::Formatter<'_>, - ) -> std::fmt::Result { - match self { - Intrinsic::Puts(expr) => write!(f, "@puts({:?})", expr), - Intrinsic::Add(lhs, rhs) => write!(f, "@add({:?}, {:?})", lhs, rhs), - Intrinsic::Multiply(lhs, rhs) => write!(f, "@multiply({:?}, {:?})", lhs, rhs), - Intrinsic::Divide(lhs, rhs) => write!(f, "@divide({:?}, {:?})", lhs, rhs), - Intrinsic::Subtract(lhs, rhs) => write!(f, "@subtract({:?}, {:?})", lhs, rhs), - Intrinsic::Malloc(size) => write!(f, "@malloc({:?})", size), - Intrinsic::SizeOf(expr) => write!(f, "@sizeof({:?})", expr), - Intrinsic::Equals(lhs, rhs) => write!(f, "@equal({:?}, {:?})", lhs, rhs), - } - } -} - -#[derive(Clone)] -pub struct TypedExpr { - pub kind: TypedExprKind, - span: Span, -} - -impl TypedExpr { - pub fn span(&self) -> Span { - self.span - } -} - -#[derive(Clone, Debug)] -pub enum TypedExprKind { - FunctionCall { - func: FunctionId, - args: Vec<(Identifier, TypedExpr)>, - ty: TypeVariable, - }, - Literal { - value: Literal, - ty: TypeVariable, - }, - List { - elements: Vec, - ty: TypeVariable, - }, - Unit, - Variable { - ty: TypeVariable, - name: Identifier, - }, - Intrinsic { - ty: TypeVariable, - intrinsic: Intrinsic, - }, - ErrorRecovery(Span), - ExprWithBindings { - bindings: Vec<(Identifier, TypedExpr)>, - expression: Box, - }, - TypeConstructor { - ty: TypeVariable, - args: Box<[TypedExpr]>, - }, - If { - condition: Box, - then_branch: Box, - else_branch: Box, - }, -} - -impl std::fmt::Debug for TypedExpr { - fn fmt( - &self, - f: &mut std::fmt::Formatter<'_>, - ) -> std::fmt::Result { - use TypedExprKind::*; - match &self.kind { - FunctionCall { func, args, .. } => { - write!(f, "function call to {} with args: ", func)?; - for (name, arg) in args { - write!(f, "{}: {:?}, ", name.id, arg)?; - } - Ok(()) - }, - Literal { value, .. } => write!(f, "literal: {}", value), - List { elements, .. } => { - write!(f, "list: [")?; - for elem in elements { - write!(f, "{:?}, ", elem)?; - } - write!(f, "]") - }, - Unit => write!(f, "unit"), - Variable { name, .. } => write!(f, "variable: {}", name.id), - Intrinsic { intrinsic, .. } => write!(f, "intrinsic: {:?}", intrinsic), - ErrorRecovery(span) => { - write!(f, "error recovery {span:?}") - }, - ExprWithBindings { bindings, expression } => { - write!(f, "bindings: ")?; - for (name, expr) in bindings { - write!(f, "{}: {:?}, ", name.id, expr)?; - } - write!(f, "expression: {:?}", expression) - }, - TypeConstructor { ty, .. } => write!(f, "type constructor: {:?}", ty), - If { - condition, - then_branch, - else_branch, - } => { - write!(f, "if {:?} then {:?} else {:?}", condition, then_branch, else_branch) - }, - } - } -} - -impl TypeCheck for Expr { - type Output = TypedExpr; - - fn type_check( - &self, - ctx: &mut TypeChecker, - ) -> Self::Output { - let kind = match &self.kind { - ExprKind::Literal(lit) => { - let ty = ctx.convert_literal_to_type(lit); - TypedExprKind::Literal { value: lit.clone(), ty } - }, - ExprKind::List(exprs) => { - if exprs.is_empty() { - let ty = ctx.unit(); - TypedExprKind::List { elements: vec![], ty } - } else { - let type_checked_exprs = exprs.iter().map(|expr| expr.type_check(ctx)).collect::>(); - // unify the type of the first expr against everything else in the list - let first_ty = ctx.expr_ty(&type_checked_exprs[0]); - for expr in type_checked_exprs.iter().skip(1) { - let second_ty = ctx.expr_ty(expr); - ctx.unify(first_ty, second_ty, expr.span()); - } - let first_ty = ctx.ctx.types.get(first_ty).clone(); - TypedExprKind::List { - elements: type_checked_exprs, - ty: ctx.insert_type::(&SpecificType::List(Box::new(first_ty))), - } - } - }, - ExprKind::FunctionCall(call) => (*call).type_check(ctx), - ExprKind::Unit => TypedExprKind::Unit, - ExprKind::ErrorRecovery => TypedExprKind::ErrorRecovery(self.span), - ExprKind::Variable { name, ty } => { - // look up variable in scope - // find its expr return type - let var_ty = ctx.find_variable(*name).expect("variable not found in scope"); - let ty = ctx.to_type_var(ty); - - ctx.unify(var_ty, ty, name.span()); - - TypedExprKind::Variable { ty, name: *name } - }, - ExprKind::Intrinsic(intrinsic) => return self.span.with_item(intrinsic.clone()).type_check(ctx), - ExprKind::TypeConstructor(parent_type_id, args) => { - // This ExprKind only shows up in the body of type constructor functions, and - // is basically a noop. The surrounding function decl will handle type checking for - // the type constructor. - let args = args.iter().map(|arg| arg.type_check(ctx)).collect::>(); - let ty = ctx.get_type(*parent_type_id); - TypedExprKind::TypeConstructor { - ty: *ty, - args: args.into_boxed_slice(), - } - }, - ExprKind::ExpressionWithBindings { bindings, expression } => { - // for each binding, type check the rhs - ctx.with_type_scope(|ctx| { - let mut type_checked_bindings = Vec::with_capacity(bindings.len()); - for binding in bindings { - let binding_ty = binding.expression.type_check(ctx); - let binding_expr_return_ty = ctx.expr_ty(&binding_ty); - ctx.insert_variable(binding.name, binding_expr_return_ty); - type_checked_bindings.push((binding.name, binding_ty)); - } - - TypedExprKind::ExprWithBindings { - bindings: type_checked_bindings, - expression: Box::new(expression.type_check(ctx)), - } - }) - }, - ExprKind::If { - condition, - then_branch, - else_branch, - } => { - let condition = condition.type_check(ctx); - let condition_ty = ctx.expr_ty(&condition); - ctx.unify(ctx.bool(), condition_ty, condition.span()); - - let then_branch = then_branch.type_check(ctx); - let then_ty = ctx.expr_ty(&then_branch); - - let else_branch = else_branch.type_check(ctx); - let else_ty = ctx.expr_ty(&else_branch); - - ctx.unify(then_ty, else_ty, else_branch.span()); - - TypedExprKind::If { - condition: Box::new(condition), - then_branch: Box::new(then_branch), - else_branch: Box::new(else_branch), - } - }, - }; - - TypedExpr { kind, span: self.span } - } -} - -fn unify_basic_math_op( - lhs: &Expr, - rhs: &Expr, - ctx: &mut TypeChecker, -) -> (TypedExpr, TypedExpr) { - let lhs = lhs.type_check(ctx); - let rhs = rhs.type_check(ctx); - let lhs_ty = ctx.expr_ty(&lhs); - let rhs_ty = ctx.expr_ty(&rhs); - let int_ty = ctx.int(); - ctx.unify(int_ty, lhs_ty, lhs.span()); - ctx.unify(int_ty, rhs_ty, rhs.span()); - (lhs, rhs) -} - -impl TypeCheck for SpannedItem { - type Output = TypedExpr; - - fn type_check( - &self, - ctx: &mut TypeChecker, - ) -> Self::Output { - use petr_resolve::IntrinsicName::*; - let kind = match self.item().intrinsic { - Puts => { - if self.item().args.len() != 1 { - todo!("puts arg len check"); - } - // puts takes a single string and returns unit - let arg = self.item().args[0].type_check(ctx); - ctx.unify_expr_return(ctx.string(), &arg); - TypedExprKind::Intrinsic { - intrinsic: Intrinsic::Puts(Box::new(arg)), - ty: ctx.unit(), - } - }, - Add => { - if self.item().args.len() != 2 { - todo!("add arg len check"); - } - let (lhs, rhs) = unify_basic_math_op(&self.item().args[0], &self.item().args[1], ctx); - TypedExprKind::Intrinsic { - intrinsic: Intrinsic::Add(Box::new(lhs), Box::new(rhs)), - ty: ctx.int(), - } - }, - Subtract => { - if self.item().args.len() != 2 { - todo!("sub arg len check"); - } - let (lhs, rhs) = unify_basic_math_op(&self.item().args[0], &self.item().args[1], ctx); - - TypedExprKind::Intrinsic { - intrinsic: Intrinsic::Subtract(Box::new(lhs), Box::new(rhs)), - ty: ctx.int(), - } - }, - Multiply => { - if self.item().args.len() != 2 { - todo!("mult arg len check"); - } - - let (lhs, rhs) = unify_basic_math_op(&self.item().args[0], &self.item().args[1], ctx); - TypedExprKind::Intrinsic { - intrinsic: Intrinsic::Multiply(Box::new(lhs), Box::new(rhs)), - ty: ctx.int(), - } - }, - - Divide => { - if self.item().args.len() != 2 { - todo!("Divide arg len check"); - } - - let (lhs, rhs) = unify_basic_math_op(&self.item().args[0], &self.item().args[1], ctx); - TypedExprKind::Intrinsic { - intrinsic: Intrinsic::Divide(Box::new(lhs), Box::new(rhs)), - ty: ctx.int(), - } - }, - Malloc => { - // malloc takes one integer (the number of bytes to allocate) - // and returns a pointer to the allocated memory - // will return `0` if the allocation fails - // in the future, this might change to _words_ of allocation, - // depending on the compilation target - if self.item().args.len() != 1 { - todo!("malloc arg len check"); - } - let arg = self.item().args[0].type_check(ctx); - let arg_ty = ctx.expr_ty(&arg); - let int_ty = ctx.int(); - ctx.unify(int_ty, arg_ty, arg.span()); - TypedExprKind::Intrinsic { - intrinsic: Intrinsic::Malloc(Box::new(arg)), - ty: int_ty, - } - }, - SizeOf => { - if self.item().args.len() != 1 { - todo!("size_of arg len check"); - } - - let arg = self.item().args[0].type_check(ctx); - - TypedExprKind::Intrinsic { - intrinsic: Intrinsic::SizeOf(Box::new(arg)), - ty: ctx.int(), - } - }, - Equals => { - if self.item().args.len() != 2 { - todo!("equal arg len check"); - } - - let lhs = self.item().args[0].type_check(ctx); - let rhs = self.item().args[1].type_check(ctx); - ctx.unify(ctx.expr_ty(&lhs), ctx.expr_ty(&rhs), self.span()); - TypedExprKind::Intrinsic { - intrinsic: Intrinsic::Equals(Box::new(lhs), Box::new(rhs)), - ty: ctx.bool(), - } - }, - }; - - TypedExpr { kind, span: self.span() } - } -} - -trait TypeCheck { - type Output; - fn type_check( - &self, - ctx: &mut TypeChecker, - ) -> Self::Output; -} - -#[derive(Clone, Debug)] -pub struct Function { - pub name: Identifier, - pub params: Vec<(Identifier, TypeVariable)>, - pub body: TypedExpr, - pub return_ty: TypeVariable, -} - -impl TypeCheck for petr_resolve::Function { - type Output = Function; - - fn type_check( - &self, - ctx: &mut TypeChecker, - ) -> Self::Output { - ctx.with_type_scope(|ctx| { - let params = self.params.iter().map(|(name, ty)| (*name, ctx.to_type_var(ty))).collect::>(); - // declared parameters are axiomatic, they won't be updated by any inference - - for (name, ty) in ¶ms { - ctx.insert_variable(*name, *ty); - // TODO get span for type annotation instead of just the name of the parameter - ctx.axiom(*ty, name.span); - } - - // unify types within the body with the parameter - let body = self.body.type_check(ctx); - - let declared_return_type = ctx.to_type_var(&self.return_type); - - Function { - name: self.name, - params, - return_ty: declared_return_type, - body, - } - }) - } -} - -impl TypeCheck for FunctionCall { - type Output = TypedExprKind; - - fn type_check( - &self, - ctx: &mut TypeChecker, - ) -> Self::Output { - let func_decl = ctx.get_function(&self.function).clone(); - - if self.args.len() != func_decl.params.len() { - // TODO: support partial application - ctx.push_error(self.span().with_item(TypeConstraintError::ArgumentCountMismatch { - expected: func_decl.params.len(), - got: self.args.len(), - function: ctx.get_symbol(func_decl.name.id).to_string(), - })); - return TypedExprKind::ErrorRecovery(self.span()); - } - - let mut args: Vec<(Identifier, TypedExpr, TypeVariable)> = Vec::with_capacity(self.args.len()); - - // unify all of the arg types with the param types - for (arg, (name, param_ty)) in self.args.iter().zip(func_decl.params.iter()) { - let arg = arg.type_check(ctx); - let arg_ty = ctx.expr_ty(&arg); - ctx.satisfies(*param_ty, arg_ty, arg.span()); - args.push((*name, arg, arg_ty)); - } - - let concrete_arg_types: Vec<_> = args - .iter() - .map(|(_, _, ty)| ctx.look_up_variable(*ty).generalize(&ctx.ctx.types).clone()) - .collect(); - - let signature: FunctionSignature = (self.function, concrete_arg_types.clone().into_boxed_slice()); - // now that we know the argument types, check if this signature has been monomorphized - // already - if ctx.monomorphized_functions.contains_key(&signature) { - return TypedExprKind::FunctionCall { - func: self.function, - args: args.into_iter().map(|(name, expr, _)| (name, expr)).collect(), - ty: func_decl.return_ty, - }; - } - - // unify declared return type with body return type - let declared_return_type = func_decl.return_ty; - - ctx.satisfy_expr_return(declared_return_type, &func_decl.body); - - // to create a monomorphized func decl, we don't actually have to update all of the types - // throughout the entire definition. We only need to update the parameter types. - let mut monomorphized_func_decl = Function { - name: func_decl.name, - params: func_decl.params.clone(), - return_ty: declared_return_type, - body: func_decl.body.clone(), - }; - - // update the parameter types to be the concrete types - for (param, concrete_ty) in monomorphized_func_decl.params.iter_mut().zip(concrete_arg_types.iter()) { - let param_ty = ctx.insert_type(concrete_ty); - param.1 = param_ty; - } - - // if there are any variable exprs in the body, update those ref types - let mut num_replacements = 0; - replace_var_reference_types( - &mut monomorphized_func_decl.body.kind, - &monomorphized_func_decl.params, - &mut num_replacements, - ); - - ctx.monomorphized_functions.insert(signature, monomorphized_func_decl); - // if there are any variable exprs in the body, update those ref types - - TypedExprKind::FunctionCall { - func: self.function, - args: args.into_iter().map(|(name, expr, _)| (name, expr)).collect(), - ty: declared_return_type, - } - } -} - -fn replace_var_reference_types( - expr: &mut TypedExprKind, - params: &Vec<(Identifier, TypeVariable)>, - num_replacements: &mut usize, -) { - match expr { - TypedExprKind::Variable { ref mut ty, name } => { - if let Some((_param_name, ty_var)) = params.iter().find(|(param_name, _)| param_name.id == name.id) { - *num_replacements += 1; - *ty = *ty_var; - } - }, - TypedExprKind::FunctionCall { args, .. } => { - for (_, arg) in args { - replace_var_reference_types(&mut arg.kind, params, num_replacements); - } - }, - TypedExprKind::Intrinsic { intrinsic, .. } => { - use Intrinsic::*; - match intrinsic { - // intrinsics which take one arg, grouped for convenience - Puts(a) | Malloc(a) | SizeOf(a) => { - replace_var_reference_types(&mut a.kind, params, num_replacements); - }, - // intrinsics which take two args, grouped for convenience - Add(a, b) | Subtract(a, b) | Multiply(a, b) | Divide(a, b) | Equals(a, b) => { - replace_var_reference_types(&mut a.kind, params, num_replacements); - replace_var_reference_types(&mut b.kind, params, num_replacements); - }, - } - }, - // TODO other expr kinds like bindings - _ => (), - } -} - -mod pretty_printing { - use petr_utils::SymbolInterner; - - use crate::*; - - #[cfg(test)] - pub fn pretty_print_type_checker(type_checker: &TypeChecker) -> String { - let mut s = String::new(); - for (id, ty) in &type_checker.type_map { - let text = match id { - TypeOrFunctionId::TypeId(id) => { - let ty = type_checker.resolved.get_type(*id); - - let name = type_checker.resolved.interner.get(ty.name.id); - format!("type {}", name) - }, - TypeOrFunctionId::FunctionId(id) => { - let func = type_checker.resolved.get_function(*id); - - let name = type_checker.resolved.interner.get(func.name.id); - - format!("fn {}", name) - }, - }; - s.push_str(&text); - s.push_str(": "); - s.push_str(&pretty_print_ty(ty, &type_checker.ctx.types, &type_checker.resolved.interner)); - - s.push('\n'); - match id { - TypeOrFunctionId::TypeId(_) => (), - TypeOrFunctionId::FunctionId(func) => { - let func = type_checker.typed_functions.get(func).unwrap(); - let body = &func.body; - s.push_str(&pretty_print_typed_expr(body, &type_checker)); - s.push('\n'); - }, - } - s.push('\n'); - } - - if !type_checker.monomorphized_functions.is_empty() { - s.push_str("__MONOMORPHIZED FUNCTIONS__"); - } - - for func in type_checker.monomorphized_functions.values() { - let func_name = type_checker.resolved.interner.get(func.name.id); - let arg_types = func - .params - .iter() - .map(|(_, ty)| pretty_print_ty(ty, &type_checker.ctx.types, &type_checker.resolved.interner)) - .collect::>(); - s.push_str(&format!( - "\nfn {}({:?}) -> {}", - func_name, - arg_types, - pretty_print_ty(&func.return_ty, &type_checker.ctx.types, &type_checker.resolved.interner) - )); - } - - if !type_checker.monomorphized_functions.is_empty() { - s.push('\n'); - } - - s - } - - pub fn pretty_print_ty( - ty: &TypeVariable, - types: &IndexMap, - interner: &SymbolInterner, - ) -> String { - let mut ty = types.get(*ty); - while let SpecificType::Ref(t) = ty { - ty = types.get(*t); - } - pretty_print_petr_type(ty, types, interner) - } - - pub fn pretty_print_petr_type( - ty: &SpecificType, - types: &IndexMap, - interner: &SymbolInterner, - ) -> String { - match ty { - SpecificType::Unit => "unit".to_string(), - SpecificType::Integer => "int".to_string(), - SpecificType::Boolean => "bool".to_string(), - SpecificType::String => "string".to_string(), - SpecificType::Ref(ty) => pretty_print_ty(ty, types, interner), - SpecificType::UserDefined { name, .. } => { - let name = interner.get(name.id); - name.to_string() - }, - SpecificType::Arrow(tys) => { - let mut s = String::new(); - s.push('('); - for (ix, ty) in tys.iter().enumerate() { - let is_last = ix == tys.len() - 1; - - s.push_str(&pretty_print_ty(ty, types, interner)); - if !is_last { - s.push_str(" → "); - } - } - s.push(')'); - s - }, - SpecificType::ErrorRecovery => "error recovery".to_string(), - SpecificType::List(ty) => format!("[{}]", pretty_print_petr_type(ty, types, interner)), - SpecificType::Infer(id, _) => format!("infer t{id}"), - SpecificType::Sum(tys) => { - let mut s = String::new(); - s.push('('); - for (ix, ty) in tys.iter().enumerate() { - let is_last = ix == tys.len() - 1; - // print the petr ty - s.push_str(&pretty_print_petr_type(ty, types, interner)); - if !is_last { - s.push_str(" | "); - } - } - s.push(')'); - s - }, - SpecificType::Literal(l) => format!("{}", l), - } - } - - #[cfg(test)] - pub fn pretty_print_typed_expr( - typed_expr: &TypedExpr, - type_checker: &TypeChecker, - ) -> String { - let interner = &type_checker.resolved.interner; - let types = &type_checker.ctx.types; - match &typed_expr.kind { - TypedExprKind::ExprWithBindings { bindings, expression } => { - let mut s = String::new(); - for (name, expr) in bindings { - let ident = interner.get(name.id); - let ty = type_checker.expr_ty(expr); - let ty = pretty_print_ty(&ty, types, interner); - s.push_str(&format!("{ident}: {:?} ({}),\n", expr, ty)); - } - let expr_ty = type_checker.expr_ty(expression); - let expr_ty = pretty_print_ty(&expr_ty, types, interner); - s.push_str(&format!("{:?} ({})", pretty_print_typed_expr(expression, &type_checker), expr_ty)); - s - }, - TypedExprKind::Variable { name, ty } => { - let name = interner.get(name.id); - let ty = pretty_print_ty(ty, types, interner); - format!("variable {name}: {ty}") - }, - - TypedExprKind::FunctionCall { func, args, ty } => { - let mut s = String::new(); - s.push_str(&format!("function call to {} with args: ", func)); - for (name, arg) in args { - let name = interner.get(name.id); - let arg_ty = type_checker.expr_ty(arg); - let arg_ty = pretty_print_ty(&arg_ty, types, interner); - s.push_str(&format!("{name}: {}, ", arg_ty)); - } - let ty = pretty_print_ty(ty, types, interner); - s.push_str(&format!("returns {ty}")); - s - }, - TypedExprKind::TypeConstructor { ty, .. } => format!("type constructor: {}", pretty_print_ty(ty, types, interner)), - _otherwise => format!("{:?}", typed_expr), - } - } -} - -#[cfg(test)] -mod tests { - use expect_test::{expect, Expect}; - use petr_resolve::resolve_symbols; - use petr_utils::render_error; - - use super::*; - use crate::pretty_printing::*; - - fn check( - input: impl Into, - expect: Expect, - ) { - let input = input.into(); - let parser = petr_parse::Parser::new(vec![("test", input)]); - let (ast, errs, interner, source_map) = parser.into_result(); - if !errs.is_empty() { - errs.into_iter().for_each(|err| eprintln!("{:?}", render_error(&source_map, err))); - panic!("test failed: code didn't parse"); - } - let (errs, resolved) = resolve_symbols(ast, interner, Default::default()); - if !errs.is_empty() { - errs.into_iter().for_each(|err| eprintln!("{:?}", render_error(&source_map, err))); - panic!("unresolved symbols in test"); - } - let mut type_checker = TypeChecker::new(resolved); - type_checker.fully_type_check(); - let mut res = pretty_print_type_checker(&type_checker); - - let solved_constraints = match type_checker.into_solution() { - Ok(solution) => solution.pretty_print(), - Err(errs) => { - res.push_str(&"__ERRORS__\n"); - errs.into_iter().map(|err| format!("{:?}", err)).collect::>().join("\n") - }, - }; - - res.push('\n'); - res.push_str(&solved_constraints); - - expect.assert_eq(&res.trim()); - } - - #[test] - fn identity_resolution_concrete_type() { - check( - r#" - fn foo(x in 'int) returns 'int x - "#, - expect![[r#" - fn foo: (int → int) - variable x: int - - - __SOLVED TYPES__ - 5: int"#]], - ); - } - - #[test] - fn identity_resolution_generic() { - check( - r#" - fn foo(x in 'A) returns 'A x - "#, - expect![[r#" - fn foo: (infer t5 → infer t5) - variable x: infer t5 - - - __SOLVED TYPES__ - 6: infer t5"#]], - ); - } - - #[test] - fn identity_resolution_custom_type() { - check( - r#" - type MyType = A | B - fn foo(x in 'MyType) returns 'MyType x - "#, - expect![[r#" - type MyType: MyType - - fn A: MyType - type constructor: MyType - - fn B: MyType - type constructor: MyType - - fn foo: (MyType → MyType) - variable x: MyType - - - __SOLVED TYPES__ - 8: MyType"#]], - ); - } - - #[test] - fn identity_resolution_two_custom_types() { - check( - r#" - type MyType = A | B - type MyComposedType = firstVariant someField 'MyType | secondVariant someField 'int someField2 'MyType someField3 'GenericType - fn foo(x in 'MyType) returns 'MyComposedType ~firstVariant(x) - "#, - expect![[r#" - type MyType: MyType - - type MyComposedType: MyComposedType - - fn A: MyType - type constructor: MyType - - fn B: MyType - type constructor: MyType - - fn firstVariant: (MyType → MyComposedType) - type constructor: MyComposedType - - fn secondVariant: (int → MyType → infer t16 → MyComposedType) - type constructor: MyComposedType - - fn foo: (MyType → MyComposedType) - function call to functionid2 with args: someField: MyType, returns MyComposedType - - __MONOMORPHIZED FUNCTIONS__ - fn firstVariant(["MyType"]) -> MyComposedType - - __SOLVED TYPES__ - 14: int - 17: infer t16 - 23: MyType"#]], - ); - } - - #[test] - fn literal_unification_fail() { - check( - r#" - fn foo() returns 'int 5 - fn bar() returns 'bool 5 - "#, - expect![[r#" - fn foo: int - literal: 5 - - fn bar: bool - literal: 5"#]], - ); - } - - #[test] - fn literal_unification_success() { - check( - r#" - fn foo() returns 'int 5 - fn bar() returns 'bool true - "#, - expect![[r#" - fn foo: int - literal: 5 - - fn bar: bool - literal: true"#]], - ); - } - - #[test] - fn pass_zero_arity_func_to_intrinsic() { - check( - r#" - fn string_literal() returns 'string - "This is a string literal." - - fn my_func() returns 'unit - @puts(~string_literal)"#, - expect![[r#" - fn string_literal: string - literal: "This is a string literal." - - fn my_func: unit - intrinsic: @puts(function call to functionid0 with args: ) - - __MONOMORPHIZED FUNCTIONS__ - fn string_literal([]) -> string"#]], - ); - } - - #[test] - fn pass_literal_string_to_intrinsic() { - check( - r#" - fn my_func() returns 'unit - @puts("test")"#, - expect![[r#" - fn my_func: unit - intrinsic: @puts(literal: "test") - - - __SOLVED TYPES__ - 5: string"#]], - ); - } - - #[test] - fn pass_wrong_type_literal_to_intrinsic() { - check( - r#" - fn my_func() returns 'unit - @puts(true)"#, - expect![[r#" - fn my_func: unit - intrinsic: @puts(literal: true) - - __ERRORS__ - - SpannedItem UnificationFailure("string", "true") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(52), length: 4 } }]"#]], - ); - } - - #[test] - fn intrinsic_and_return_ty_dont_match() { - check( - r#" - fn my_func() returns 'bool - @puts("test")"#, - expect![[r#" - fn my_func: bool - intrinsic: @puts(literal: "test") - - - __SOLVED TYPES__ - 5: string"#]], - ); - } - - #[test] - fn pass_wrong_type_fn_call_to_intrinsic() { - check( - r#" - fn bool_literal() returns 'bool - true - - fn my_func() returns 'unit - @puts(~bool_literal)"#, - expect![[r#" - fn bool_literal: bool - literal: true - - fn my_func: unit - intrinsic: @puts(function call to functionid0 with args: ) - - __MONOMORPHIZED FUNCTIONS__ - fn bool_literal([]) -> bool - __ERRORS__ - - SpannedItem UnificationFailure("string", "bool") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(110), length: 14 } }]"#]], - ); - } - - #[test] - fn multiple_calls_to_fn_dont_unify_params_themselves() { - check( - r#" - fn bool_literal(a in 'A, b in 'B) returns 'bool - true - - fn my_func() returns 'bool - ~bool_literal(1, 2) - - {- should not unify the parameter types of bool_literal -} - fn my_second_func() returns 'bool - ~bool_literal(true, false) - "#, - expect![[r#" - fn bool_literal: (infer t5 → infer t7 → bool) - literal: true - - fn my_func: bool - function call to functionid0 with args: a: 1, b: 2, returns bool - - fn my_second_func: bool - function call to functionid0 with args: a: true, b: false, returns bool - - __MONOMORPHIZED FUNCTIONS__ - fn bool_literal(["int", "int"]) -> bool - fn bool_literal(["bool", "bool"]) -> bool - - __SOLVED TYPES__ - 6: infer t5 - 8: infer t7"#]], - ); - } - #[test] - fn list_different_types_type_err() { - check( - r#" - fn my_list() returns 'list [ 1, true ] - "#, - expect![[r#" - fn my_list: infer t8 - list: [literal: 1, literal: true, ] - - - __SOLVED TYPES__ - 5: (1 | true) - 6: 1"#]], - ); - } - - #[test] - fn incorrect_number_of_args() { - check( - r#" - fn add(a in 'int, b in 'int) returns 'int a - - fn add_five(a in 'int) returns 'int ~add(5) - "#, - expect![[r#" - fn add: (int → int → int) - variable a: int - - fn add_five: (int → int) - error recovery Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(113), length: 8 } } - - __ERRORS__ - - SpannedItem ArgumentCountMismatch { function: "add", expected: 2, got: 1 } [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(113), length: 8 } }]"#]], - ); - } - - #[test] - fn infer_let_bindings() { - check( - r#" - fn hi(x in 'int, y in 'int) returns 'int - let a = x; - b = y; - c = 20; - d = 30; - e = 42; - a -fn main() returns 'int ~hi(1, 2)"#, - expect![[r#" - fn hi: (int → int → int) - a: variable: symbolid2 (int), - b: variable: symbolid4 (int), - c: literal: 20 (20), - d: literal: 30 (30), - e: literal: 42 (42), - "variable a: int" (int) - - fn main: int - function call to functionid0 with args: x: 1, y: 2, returns int - - __MONOMORPHIZED FUNCTIONS__ - fn hi(["int", "int"]) -> int - fn main([]) -> int - - __SOLVED TYPES__ - 5: int - 6: int"#]], - ) - } - - #[test] - fn if_rejects_non_bool_condition() { - check( - r#" - fn hi(x in 'int) returns 'int - if x then 1 else 2 - fn main() returns 'int ~hi(1)"#, - expect![[r#" - fn hi: (int → int) - if variable: symbolid2 then literal: 1 else literal: 2 - - fn main: int - function call to functionid0 with args: x: 1, returns int - - __MONOMORPHIZED FUNCTIONS__ - fn hi(["int"]) -> int - fn main([]) -> int - __ERRORS__ - - SpannedItem UnificationFailure("bool", "int") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(61), length: 2 } }]"#]], - ) - } - - #[test] - fn if_rejects_non_unit_missing_else() { - check( - r#" - fn hi() returns 'int - if true then 1 - fn main() returns 'int ~hi()"#, - expect![[r#" - fn hi: int - if literal: true then literal: 1 else unit - - fn main: int - function call to functionid0 with args: returns int - - __MONOMORPHIZED FUNCTIONS__ - fn hi([]) -> int - fn main([]) -> int - __ERRORS__ - - SpannedItem UnificationFailure("1", "unit") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(33), length: 46 } }]"#]], - ) - } - - #[test] - fn if_allows_unit_missing_else() { - check( - r#" - fn hi() returns 'unit - if true then @puts "hi" - - fn main() returns 'unit ~hi()"#, - expect![[r#" - fn hi: unit - if literal: true then intrinsic: @puts(literal: "hi") else unit - - fn main: unit - function call to functionid0 with args: returns unit - - __MONOMORPHIZED FUNCTIONS__ - fn hi([]) -> unit - fn main([]) -> unit - - __SOLVED TYPES__ - 5: bool - 6: string"#]], - ) - } - - #[test] - fn disallow_incorrect_constant_int() { - check( - r#" - type OneOrTwo = 1 | 2 - - fn main() returns 'OneOrTwo - ~OneOrTwo 10 - "#, - expect![[r#" - type OneOrTwo: OneOrTwo - - fn OneOrTwo: ((1 | 2) → OneOrTwo) - type constructor: OneOrTwo - - fn main: OneOrTwo - function call to functionid0 with args: OneOrTwo: 10, returns OneOrTwo - - __MONOMORPHIZED FUNCTIONS__ - fn OneOrTwo(["int"]) -> OneOrTwo - fn main([]) -> OneOrTwo - __ERRORS__ - - SpannedItem NotSubtype(["1", "2"], "10") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(104), length: 0 } }]"#]], - ) - } - - #[test] - fn disallow_incorrect_constant_string() { - check( - r#" - type AOrB = "A" | "B" - - fn main() returns 'AOrB - ~AOrB "c" - "#, - expect![[r#" - type AOrB: AOrB - - fn AOrB: (("A" | "B") → AOrB) - type constructor: AOrB - - fn main: AOrB - function call to functionid0 with args: AOrB: "c", returns AOrB - - __MONOMORPHIZED FUNCTIONS__ - fn AOrB(["string"]) -> AOrB - fn main([]) -> AOrB - __ERRORS__ - - SpannedItem NotSubtype(["\"A\"", "\"B\""], "\"c\"") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(97), length: 0 } }]"#]], - ) - } - - // TODO remove ignore before merging - #[ignore] - #[test] - fn disallow_incorrect_constant_bool() { - check( - r#" - type AlwaysTrue = true - - fn main() returns 'AlwaysTrue - ~AlwaysTrue false - "#, - expect![[r#" - type AlwaysTrue: AlwaysTrue - - fn AlwaysTrue: ((true) → AlwaysTrue) - type constructor: AlwaysTrue - - fn main: AlwaysTrue - function call to functionid0 with args: AlwaysTrue: false, returns AlwaysTrue - - __MONOMORPHIZED FUNCTIONS__ - fn AlwaysTrue(["bool"]) -> AlwaysTrue - fn main([]) -> AlwaysTrue - __ERRORS__ - SpannedItem FailedToSatisfy("false", "(true)") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(100), length: 0 } }] - "#]], - ) - } - - // TODO remove ignore before merging - #[ignore] - #[test] - fn disallow_wrong_sum_type_in_add() { - check( - r#" - type IntBelowFive = 1 | 2 | 3 | 4 | 5 - {- reject an `add` which may return an int above five -} - fn add(a in 'IntBelowFive, b in 'IntBelowFive) returns 'IntBelowFive @add(a, b) -"#, - expect![[r#""#]], - ) - } - - #[ignore] - #[test] - fn allow_wrong_sum_type_in_add() { - check( - r#" - type IntBelowFive = 1 | 2 | 3 | 4 | 5 - {- reject an `add` which may return an int above five -} - fn add(a in 'IntBelowFive, b in 'IntBelowFive) returns 'int @add(a, b) -"#, - expect![[r#""#]], - ) - } - - #[test] - fn sum_type_unifies_to_superset() { - check( - r"fn test(a in 'sum 1 | 2 | 3) returns 'sum 1 | 2 | 3 a - fn test_(a in 'sum 1 | 2) returns 'sum 1 | 2 a - fn main() returns 'int - {- should be of specific type lit 2 -} - let x = 2; - {- should be of specific type 'sum 1 | 2 -} - y = ~test_(x); - {- should be of specific type 'sum 1 | 2 | 3 -} - z = ~test(y); - {- should also be of specific type 'sum 1 | 2 | 3 -} - zz = ~test(x) - - {- and should generalize to 'int with no problems -} - zz - ", - expect![[r#" - fn test: ((1 | 2 | 3) → (1 | 2 | 3)) - variable a: (1 | 2 | 3) - - fn test_: ((1 | 2) → (1 | 2)) - variable a: (1 | 2) - - fn main: int - x: literal: 2 (2), - y: function call to functionid1 with args: symbolid1: variable: symbolid5, ((1 | 2)), - z: function call to functionid0 with args: symbolid1: variable: symbolid6, ((1 | 2 | 3)), - zz: function call to functionid0 with args: symbolid1: variable: symbolid5, ((1 | 2 | 3)), - "variable zz: (1 | 2 | 3)" ((1 | 2 | 3)) - - __MONOMORPHIZED FUNCTIONS__ - fn test(["int"]) -> (1 | 2 | 3) - fn test_(["int"]) -> (1 | 2) - fn main([]) -> int - - __SOLVED TYPES__ - 5: (1 | 2 | 3) - 9: (1 | 2) - 11: (1 | 2)"#]], - ) - } - - #[test] - fn specific_type_generalizes() { - check( - r#"fn test(a in 'sum 'int | 'string) returns 'sum 'int | 'string a - fn test_(a in 'int) returns 'sum 'int | 'string a - fn main() returns 'int - let x = ~test_(5); - y = ~test("a string"); - 42 - "#, - expect![[r#" - fn test: ((int | string) → (int | string)) - variable a: (int | string) - - fn test_: (int → (int | string)) - variable a: int - - fn main: int - x: function call to functionid1 with args: symbolid1: literal: 5, ((int | string)), - y: function call to functionid0 with args: symbolid1: literal: "a string", ((int | string)), - "literal: 42" (42) - - __MONOMORPHIZED FUNCTIONS__ - fn test(["string"]) -> (int | string) - fn test_(["int"]) -> (int | string) - fn main([]) -> int - - __SOLVED TYPES__ - 5: (int | string) - 9: int"#]], - ) - } - - #[test] - fn disallow_bad_generalization() { - check( - r#"fn test(a in 'sum 'int | 'string) returns 'sum 'int | 'string a - fn test_(a in 'bool) returns 'sum 'int | 'string a - fn main() returns 'int - {- we are passing 'bool into 'int | 'string so this should fail to satisfy constraints -} - let y = ~test(~test_(true)); - 42 - "#, - expect![[r#" - fn test: ((int | string) → (int | string)) - variable a: (int | string) - - fn test_: (bool → (int | string)) - variable a: bool - - fn main: int - y: function call to functionid0 with args: symbolid1: function call to functionid1 with args: symbolid1: literal: true, , ((int | string)), - "literal: 42" (42) - - __MONOMORPHIZED FUNCTIONS__ - fn test(["(int | string)"]) -> (int | string) - fn test_(["bool"]) -> (int | string) - fn main([]) -> int - __ERRORS__ - - SpannedItem NotSubtype(["int", "string"], "bool") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(129), length: 0 } }]"#]], - ) - } - - #[test] - fn order_of_sum_type_doesnt_matter() { - check( - r#"fn test(a in 'sum 'int | 'string) returns 'sum 'string | 'int a - "#, - expect![[r#" - fn test: ((int | string) → (int | string)) - variable a: (int | string) - - - __SOLVED TYPES__ - 5: (int | string)"#]], - ) - } - - #[test] - fn can_return_superset() { - check( - r#"fn test(a in 'sum 'int | 'string) returns 'sum 'string | 'int | 'bool a - "#, - expect![[r#" - fn test: ((int | string) → (int | bool | string)) - variable a: (int | string) - - - __SOLVED TYPES__ - 5: (int | string)"#]], - ) - } - - #[test] - fn if_exp_basic() { - check("fn main() returns 'int if true then 1 else 0", expect![[r#" - fn main: int - if literal: true then literal: 1 else literal: 0 - - __MONOMORPHIZED FUNCTIONS__ - fn main([]) -> int - - __SOLVED TYPES__ - 5: bool - 6: (0 | 1) - 7: 1"#]]); - } -} diff --git a/petr-typecheck/src/pretty_printing.rs b/petr-typecheck/src/pretty_printing.rs new file mode 100644 index 0000000..d0d82c6 --- /dev/null +++ b/petr-typecheck/src/pretty_printing.rs @@ -0,0 +1,173 @@ +use petr_utils::SymbolInterner; +use types::SpecificType; + +use crate::*; + +#[cfg(test)] +pub fn pretty_print_type_checker(type_checker: &TypeChecker) -> String { + let mut s = String::new(); + for (id, ty) in type_checker.type_map() { + let text = match id { + TypeOrFunctionId::TypeId(id) => { + let ty = type_checker.resolved().get_type(*id); + + let name = type_checker.resolved().interner.get(ty.name.id); + format!("type {}", name) + }, + TypeOrFunctionId::FunctionId(id) => { + let func = type_checker.resolved().get_function(*id); + + let name = type_checker.resolved().interner.get(func.name.id); + + format!("fn {}", name) + }, + }; + s.push_str(&text); + s.push_str(": "); + s.push_str(&pretty_print_ty(ty, type_checker.ctx().types(), &type_checker.resolved().interner)); + + s.push('\n'); + match id { + TypeOrFunctionId::TypeId(_) => (), + TypeOrFunctionId::FunctionId(func) => { + let func = type_checker.typed_functions().get(func).unwrap(); + let body = &func.body; + s.push_str(&pretty_print_typed_expr(body, type_checker)); + s.push('\n'); + }, + } + s.push('\n'); + } + + if !type_checker.monomorphized_functions().is_empty() { + s.push_str("__MONOMORPHIZED FUNCTIONS__"); + } + + for func in type_checker.monomorphized_functions().values() { + let func_name = type_checker.resolved().interner.get(func.name.id); + let arg_types = func + .params + .iter() + .map(|(_, ty)| pretty_print_ty(ty, type_checker.ctx().types(), &type_checker.resolved().interner)) + .collect::>(); + s.push_str(&format!( + "\nfn {}({:?}) -> {}", + func_name, + arg_types, + pretty_print_ty(&func.return_ty, type_checker.ctx().types(), &type_checker.resolved().interner) + )); + } + + if !type_checker.monomorphized_functions().is_empty() { + s.push('\n'); + } + + s +} + +pub fn pretty_print_ty( + ty: &TypeVariable, + types: &IndexMap, + interner: &SymbolInterner, +) -> String { + let mut ty = types.get(*ty); + while let SpecificType::Ref(t) = ty { + ty = types.get(*t); + } + pretty_print_petr_type(ty, types, interner) +} + +pub fn pretty_print_petr_type( + ty: &SpecificType, + types: &IndexMap, + interner: &SymbolInterner, +) -> String { + match ty { + SpecificType::Unit => "unit".to_string(), + SpecificType::Integer => "int".to_string(), + SpecificType::Boolean => "bool".to_string(), + SpecificType::String => "string".to_string(), + SpecificType::Ref(ty) => pretty_print_ty(ty, types, interner), + SpecificType::UserDefined { name, .. } => { + let name = interner.get(name.id); + name.to_string() + }, + SpecificType::Arrow(tys) => { + let mut s = String::new(); + s.push('('); + for (ix, ty) in tys.iter().enumerate() { + let is_last = ix == tys.len() - 1; + + s.push_str(&pretty_print_ty(ty, types, interner)); + if !is_last { + s.push_str(" → "); + } + } + s.push(')'); + s + }, + SpecificType::ErrorRecovery => "error recovery".to_string(), + SpecificType::List(ty) => format!("[{}]", pretty_print_petr_type(ty, types, interner)), + SpecificType::Infer(id, _) => format!("infer t{id}"), + SpecificType::Sum(tys) => { + let mut s = String::new(); + s.push('('); + for (ix, ty) in tys.iter().enumerate() { + let is_last = ix == tys.len() - 1; + // print the petr ty + s.push_str(&pretty_print_petr_type(ty, types, interner)); + if !is_last { + s.push_str(" | "); + } + } + s.push(')'); + s + }, + SpecificType::Literal(l) => format!("{}", l), + } +} + +#[cfg(test)] +pub fn pretty_print_typed_expr( + typed_expr: &TypedExpr, + type_checker: &TypeChecker, +) -> String { + let interner = &type_checker.resolved().interner; + let types = &type_checker.ctx().types(); + match &typed_expr.kind { + TypedExprKind::ExprWithBindings { bindings, expression } => { + let mut s = String::new(); + for (name, expr) in bindings { + let ident = interner.get(name.id); + let ty = type_checker.expr_ty(expr); + let ty = pretty_print_ty(&ty, types, interner); + s.push_str(&format!("{ident}: {:?} ({}),\n", expr, ty)); + } + let expr_ty = type_checker.expr_ty(expression); + let expr_ty = pretty_print_ty(&expr_ty, types, interner); + s.push_str(&format!("{:?} ({})", pretty_print_typed_expr(expression, type_checker), expr_ty)); + s + }, + TypedExprKind::Variable { name, ty } => { + let name = interner.get(name.id); + let ty = pretty_print_ty(ty, types, interner); + format!("variable {name}: {ty}") + }, + + TypedExprKind::FunctionCall { func, args, ty } => { + let mut s = String::new(); + s.push_str(&format!("function call to {} with args: ", func)); + for (name, arg) in args { + let name = interner.get(name.id); + let arg_ty = type_checker.expr_ty(arg); + let arg_ty = pretty_print_ty(&arg_ty, types, interner); + s.push_str(&format!("{name}: {}, ", arg_ty)); + } + let ty = pretty_print_ty(ty, types, interner); + s.push_str(&format!("returns {ty}")); + s + }, + TypedExprKind::TypeConstructor { ty, .. } => format!("type constructor: {}", pretty_print_ty(ty, types, interner)), + _otherwise => format!("{:?}", typed_expr), + } +} diff --git a/petr-typecheck/src/solution.rs b/petr-typecheck/src/solution.rs new file mode 100644 index 0000000..8477847 --- /dev/null +++ b/petr-typecheck/src/solution.rs @@ -0,0 +1,537 @@ +use std::collections::{BTreeMap, BTreeSet}; + +use petr_bind::FunctionId; +use petr_resolve::Literal; +use petr_utils::{IndexMap, Span, SpannedItem, SymbolInterner}; + +use crate::{ + constraint_generation::FunctionSignature, + error::TypeConstraintError, + pretty_printing, + typed_ast::{TypedExpr, TypedExprKind}, + types::{GeneralType, GeneralizedTypeVariant, SpecificType}, + Function, TypeError, TypeVariable, +}; + +/// Represents the result of the type-checking stage for an individual type variable. +pub struct TypeSolutionEntry { + source: TypeSolutionSource, + ty: SpecificType, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum TypeSolutionSource { + Inherent, + Axiomatic, + Inferred, +} + +impl TypeSolutionEntry { + pub fn new_axiomatic(ty: SpecificType) -> Self { + Self { + source: TypeSolutionSource::Axiomatic, + ty, + } + } + + pub fn new_inherent(ty: SpecificType) -> Self { + Self { + source: TypeSolutionSource::Inherent, + ty, + } + } + + pub fn new_inferred(ty: SpecificType) -> Self { + Self { + source: TypeSolutionSource::Inferred, + ty, + } + } + + pub fn is_axiomatic(&self) -> bool { + self.source == TypeSolutionSource::Axiomatic + } +} + +pub struct TypeSolution { + solution: BTreeMap, + unsolved_types: IndexMap, + errors: Vec, + interner: SymbolInterner, + error_recovery: TypeVariable, + unit: TypeVariable, + functions: BTreeMap, + monomorphized_functions: BTreeMap, +} + +impl TypeSolution { + pub fn new( + unsolved_types: IndexMap, + error_recovery: TypeVariable, + unit: TypeVariable, + functions: BTreeMap, + monomorphized_functions: BTreeMap, + interner: SymbolInterner, + preexisting_errors: Vec, + ) -> Self { + let solution = vec![ + (unit, TypeSolutionEntry::new_inherent(SpecificType::Unit)), + (error_recovery, TypeSolutionEntry::new_inherent(SpecificType::ErrorRecovery)), + ] + .into_iter() + .collect(); + Self { + solution, + unsolved_types, + errors: preexisting_errors, + interner, + functions, + monomorphized_functions, + unit, + error_recovery, + } + } + + fn push_error( + &mut self, + e: TypeError, + ) { + self.errors.push(e); + } + + pub fn insert_solution( + &mut self, + ty: TypeVariable, + entry: TypeSolutionEntry, + span: Span, + ) { + if self.solution.contains_key(&ty) { + self.update_type(ty, entry, span); + return; + } + self.solution.insert(ty, entry); + } + + fn pretty_print_type( + &self, + ty: &SpecificType, + ) -> String { + pretty_printing::pretty_print_petr_type(ty, &self.unsolved_types, &self.interner) + } + + fn unify_err( + &self, + clone_1: SpecificType, + clone_2: SpecificType, + ) -> TypeConstraintError { + let pretty_printed_b = self.pretty_print_type(&clone_2); + match clone_1 { + SpecificType::Sum(tys) => { + let tys = tys.iter().map(|ty| self.pretty_print_type(ty)).collect::>(); + TypeConstraintError::NotSubtype(tys, pretty_printed_b) + }, + _ => { + let pretty_printed_a = self.pretty_print_type(&clone_1); + TypeConstraintError::UnificationFailure(pretty_printed_a, pretty_printed_b) + }, + } + } + + fn satisfy_err( + &self, + clone_1: SpecificType, + clone_2: SpecificType, + ) -> TypeConstraintError { + let pretty_printed_b = self.pretty_print_type(&clone_2); + match clone_1 { + SpecificType::Sum(tys) => { + let tys = tys.iter().map(|ty| self.pretty_print_type(ty)).collect::>(); + TypeConstraintError::NotSubtype(tys, pretty_printed_b) + }, + _ => { + let pretty_printed_a = self.pretty_print_type(&clone_1); + TypeConstraintError::FailedToSatisfy(pretty_printed_a, pretty_printed_b) + }, + } + } + + pub fn update_type( + &mut self, + ty: TypeVariable, + entry: TypeSolutionEntry, + span: Span, + ) { + match self.solution.get_mut(&ty) { + Some(e) => { + if e.is_axiomatic() { + let pretty_printed_preexisting = pretty_printing::pretty_print_petr_type(&entry.ty, &self.unsolved_types, &self.interner); + let pretty_printed_ty = pretty_printing::pretty_print_petr_type(&entry.ty, &self.unsolved_types, &self.interner); + self.errors + .push(span.with_item(TypeConstraintError::InvalidTypeUpdate(pretty_printed_preexisting, pretty_printed_ty))); + return; + } + *e = entry; + }, + None => { + self.solution.insert(ty, entry); + }, + } + } + + pub(crate) fn into_result(self) -> Result>> { + if self.errors.is_empty() { + Ok(self) + } else { + Err(self.errors) + } + } + + /// Attempt to unify two types, returning an error if they cannot be unified + pub(crate) fn apply_unify_constraint( + &mut self, + t1: TypeVariable, + t2: TypeVariable, + span: Span, + ) { + let ty1 = self.get_latest_type(t1).clone(); + let ty2 = self.get_latest_type(t2).clone(); + use SpecificType::*; + match (ty1, ty2) { + (a, b) if a == b => (), + (ErrorRecovery, _) | (_, ErrorRecovery) => (), + (Ref(a), _) => self.apply_unify_constraint(a, t2, span), + (_, Ref(b)) => self.apply_unify_constraint(t1, b, span), + (Infer(id, _), Infer(id2, _)) if id != id2 => { + // if two different inferred types are unified, replace the second with a reference + // to the first + let entry = TypeSolutionEntry::new_inferred(Ref(t1)); + self.update_type(t2, entry, span); + }, + (a @ Sum(_), b @ Sum(_)) => { + // the unification of two sum types is the union of the two types if and only if + // `t2` is a total subset of `t1` + // `t1` remains unchanged, as we are trying to coerce `t2` into something that + // represents `t1` + // TODO remove clone + if self.a_superset_of_b(&a, &b) { + let entry = TypeSolutionEntry::new_inferred(Ref(t1)); + self.update_type(t2, entry, span); + } else { + // the union of the two sets is the new type + let a_tys = match a { + Sum(tys) => tys, + _ => unreachable!(), + }; + let b_tys = match b { + Sum(tys) => tys, + _ => unreachable!(), + }; + let union = a_tys.iter().chain(b_tys.iter()).cloned().collect(); + let entry = TypeSolutionEntry::new_inferred(Sum(union)); + self.update_type(t1, entry, span); + let entry = TypeSolutionEntry::new_inferred(Ref(t1)); + self.update_type(t2, entry, span); + } + }, + // If `t2` is a non-sum type, and `t1` is a sum type, then `t1` must contain either + // exactly the same specific type OR the generalization of that type + // If the latter, then the specific type must be updated to its generalization + (ref t1_ty @ Sum(_), other) => { + if self.a_superset_of_b(t1_ty, &other) { + // t2 unifies to the more general form provided by t1 + let entry = TypeSolutionEntry::new_inferred(Ref(t1)); + self.update_type(t2, entry, span); + } else { + // add `other` to `t1` + let mut tys = match t1_ty { + Sum(tys) => tys.clone(), + _ => unreachable!(), + }; + tys.insert(other); + let entry = TypeSolutionEntry::new_inferred(Sum(tys)); + self.update_type(t1, entry, span); + } + }, + // literals can unify to each other if they're equal + (Literal(l1), Literal(l2)) if l1 == l2 => (), + // if they're not equal, their unification is the sum of both + (Literal(l1), Literal(l2)) if l1 != l2 => { + // update t1 to a sum type of both, + // and update t2 to reference t1 + let sum = Sum([Literal(l1), Literal(l2)].into()); + let t1_entry = TypeSolutionEntry::new_inferred(sum); + let t2_entry = TypeSolutionEntry::new_inferred(Ref(t1)); + self.update_type(t1, t1_entry, span); + self.update_type(t2, t2_entry, span); + }, + (Literal(l1), Sum(tys)) => { + // update t1 to a sum type of both, + // and update t2 to reference t1 + let sum = Sum([Literal(l1)].iter().chain(tys.iter()).cloned().collect()); + let t1_entry = TypeSolutionEntry::new_inferred(sum); + let t2_entry = TypeSolutionEntry::new_inferred(Ref(t1)); + self.update_type(t1, t1_entry, span); + self.update_type(t2, t2_entry, span); + }, + (a, b) if self.a_superset_of_b(&a, &b) => { + // if `a` is a superset of `b`, then `b` unifies to `a` as it is more general + let entry = TypeSolutionEntry::new_inferred(Ref(t1)); + self.update_type(t2, entry, span); + }, + /* TODO rewrite below rules + // literals can unify broader parent types + // but the broader parent type gets instantiated with the literal type + // TODO(alex) this rule feels incorrect. A literal being unified to the parent type + // should upcast the lit, not downcast t1. Check after refactoring. + (ty, Literal(lit)) => match (&lit, ty) { + (petr_resolve::Literal::Integer(_), Integer) + | (petr_resolve::Literal::Boolean(_), Boolean) + | (petr_resolve::Literal::String(_), String) => { + let entry = TypeSolutionEntry::new_inferred(SpecificType::Literal(lit)); + self.update_type(t1, entry, span); + }, + (lit, ty) => self.push_error(span.with_item(self.unify_err(ty.clone(), SpecificType::Literal(lit.clone())))), + }, + // literals can unify broader parent types + // but the broader parent type gets instantiated with the literal type + (Literal(lit), ty) => match (&lit, ty) { + (petr_resolve::Literal::Integer(_), Integer) + | (petr_resolve::Literal::Boolean(_), Boolean) + | (petr_resolve::Literal::String(_), String) => self.update_type(t2, SpecificType::Literal(lit)), + (lit, ty) => { + self.push_error(span.with_item(self.unify_err(ty.clone(), SpecificType::Literal(lit.clone())))); + }, + }, + */ + (other, ref t2_ty @ Sum(_)) => { + // if `other` is a superset of `t2`, then `t2` unifies to `other` as it is more + // general + if self.a_superset_of_b(&other, t2_ty) { + let entry = TypeSolutionEntry::new_inferred(other); + self.update_type(t2, entry, span); + } + }, + // instantiate the infer type with the known type + (Infer(_, _), _known) => { + let entry = TypeSolutionEntry::new_inferred(Ref(t2)); + self.update_type(t1, entry, span); + }, + (_known, Infer(_, _)) => { + let entry = TypeSolutionEntry::new_inferred(Ref(t1)); + self.update_type(t2, entry, span); + }, + // lastly, if no unification rule exists for these two types, it is a mismatch + (a, b) => { + self.push_error(span.with_item(self.unify_err(a, b))); + }, + } + } + + // This function will need to be rewritten when type constraints and bounded polymorphism are + // implemented. + pub(crate) fn apply_satisfies_constraint( + &mut self, + t1: TypeVariable, + t2: TypeVariable, + span: Span, + ) { + let ty1 = self.get_latest_type(t1); + let ty2 = self.get_latest_type(t2); + use SpecificType::*; + match (ty1, ty2) { + (a, b) if a == b => (), + (ErrorRecovery, _) | (_, ErrorRecovery) => (), + (Ref(a), _) => self.apply_satisfies_constraint(a, t2, span), + (_, Ref(b)) => self.apply_satisfies_constraint(t1, b, span), + // if t1 is a fully instantiated type, then t2 can be updated to be a reference to t1 + (Unit | Integer | Boolean | UserDefined { .. } | String | Arrow(..) | List(..) | Literal(_) | Sum(_), Infer(_, _)) => { + let entry = TypeSolutionEntry::new_inferred(Ref(t1)); + self.update_type(t2, entry, span); + }, + // the "parent" infer type will not instantiate to the "child" type + (Infer(_, _), Unit | Integer | Boolean | UserDefined { .. } | String | Arrow(..) | List(..) | Literal(_) | Sum(_)) => (), + (Sum(a_tys), Sum(b_tys)) => { + // calculate the intersection of these types, update t2 to the intersection + let intersection = a_tys.iter().filter(|a_ty| b_tys.contains(a_ty)).cloned().collect(); + let entry = TypeSolutionEntry::new_inferred(SpecificType::sum(intersection)); + self.update_type(t2, entry, span); + }, + // if `ty1` is a generalized version of the sum type, + // then it satisfies the sum type + (ty1, other) if self.a_superset_of_b(&ty1, &other) => (), + (Literal(l1), Literal(l2)) if l1 == l2 => (), + // Literals can satisfy broader parent types + (ty, Literal(lit)) => match (lit, ty) { + (petr_resolve::Literal::Integer(_), Integer) => (), + (petr_resolve::Literal::Boolean(_), Boolean) => (), + (petr_resolve::Literal::String(_), String) => (), + (lit, ty) => { + self.push_error(span.with_item(self.satisfy_err(ty.clone(), SpecificType::Literal(lit.clone())))); + }, + }, + // if we are trying to satisfy an inferred type with no bounds, this is ok + (Infer(..), _) => (), + (a, b) => { + self.push_error(span.with_item(self.satisfy_err(a.clone(), b.clone()))); + }, + } + } + + /// Gets the latest version of a type available. First checks solved types, + /// and if it doesn't exist, gets it from the unsolved types. + pub fn get_latest_type( + &self, + t1: TypeVariable, + ) -> SpecificType { + self.solution + .get(&t1) + .map(|entry| entry.ty.clone()) + .unwrap_or_else(|| self.unsolved_types.get(t1).clone()) + } + + /// To reference an error recovery type, you must provide an error. + /// This holds the invariant that error recovery types are only generated when + /// an error occurs. + pub fn error_recovery( + &mut self, + err: TypeError, + ) -> TypeVariable { + self.push_error(err); + self.error_recovery + } + + /// If `a` is a generalized form of `b`, return true + /// A generalized form is a type that is a superset of the sum types. + /// For example, `String` is a generalized form of `Sum(Literal("a") | Literal("B"))` + fn a_superset_of_b( + &self, + a: &SpecificType, + b: &SpecificType, + ) -> bool { + use SpecificType::*; + let generalized_b = self.generalize(b).safely_upcast(); + match (a, b) { + // If `a` is the generalized form of `b`, then `b` satisfies the constraint. + (a, b) if a == b || *a == generalized_b => true, + // If `a` is a sum type which contains `b` OR the generalized form of `b`, then `b` + // satisfies the constraint. + (Sum(a_tys), b) if a_tys.contains(b) || a_tys.contains(&generalized_b) => true, + // if both `a` and `b` are sum types, then `a` must be a superset of `b`: + // - every element in `b` must either: + // - be a member of `a` + // - generalize to a member of `a` + (Sum(a_tys), Sum(b_tys)) => { + // if a_tys is a superset of b_tys, + // every element OR its generalized version is contained in a_tys + for b_ty in b_tys { + let b_ty_generalized = self.generalize(b_ty).safely_upcast(); + if !(a_tys.contains(b_ty) || a_tys.contains(&b_ty_generalized)) { + return false; + } + } + + true + }, + _otherwise => false, + } + } + + pub fn generalize( + &self, + b: &SpecificType, + ) -> GeneralType { + match b { + SpecificType::Unit => GeneralType::Unit, + SpecificType::Integer => GeneralType::Integer, + SpecificType::Boolean => GeneralType::Boolean, + SpecificType::String => GeneralType::String, + SpecificType::Ref(ty) => self.generalize(&self.get_latest_type(*ty)), + SpecificType::UserDefined { + name, + variants, + constant_literal_types, + } => GeneralType::UserDefined { + name: *name, + variants: variants + .iter() + .map(|variant| { + let generalized_fields = variant.fields.iter().map(|field| self.generalize(field)).collect::>(); + + GeneralizedTypeVariant { + fields: generalized_fields.into_boxed_slice(), + } + }) + .collect(), + constant_literal_types: constant_literal_types.clone(), + }, + SpecificType::Arrow(tys) => GeneralType::Arrow(tys.clone()), + SpecificType::ErrorRecovery => GeneralType::ErrorRecovery, + SpecificType::List(ty) => { + let ty = self.generalize(ty); + GeneralType::List(Box::new(ty)) + }, + SpecificType::Infer(u, s) => GeneralType::Infer(*u, *s), + SpecificType::Literal(l) => match l { + Literal::Integer(_) => GeneralType::Integer, + Literal::Boolean(_) => GeneralType::Boolean, + Literal::String(_) => GeneralType::String, + }, + SpecificType::Sum(tys) => { + // generalize all types, fold if possible + let all_generalized: BTreeSet<_> = tys.iter().map(|ty| self.generalize(ty)).collect(); + if all_generalized.len() == 1 { + // in this case, all specific types generalized to the same type + all_generalized.into_iter().next().expect("invariant") + } else { + GeneralType::Sum(all_generalized.into_iter().collect()) + } + }, + } + } + + #[cfg(test)] + pub fn pretty_print(&self) -> String { + let mut pretty = "__SOLVED TYPES__\n".to_string(); + + let mut num_entries = 0; + for (ty, entry) in self.solution.iter().filter(|(id, _)| ![self.unit, self.error_recovery].contains(id)) { + pretty.push_str(&format!("{}: {}\n", Into::::into(*ty), self.pretty_print_type(&entry.ty))); + num_entries += 1; + } + if num_entries == 0 { + Default::default() + } else { + pretty + } + } + + pub fn get_main_function(&self) -> Option<(&FunctionId, &Function)> { + self.functions.iter().find(|(_, func)| &*self.interner.get(func.name.id) == "main") + } + + pub fn get_monomorphized_function( + &self, + id: &FunctionSignature, + ) -> &Function { + self.monomorphized_functions.get(id).expect("invariant: should exist") + } + + pub fn expr_ty( + &self, + expr: &TypedExpr, + ) -> TypeVariable { + use TypedExprKind::*; + match &expr.kind { + FunctionCall { ty, .. } => *ty, + Literal { ty, .. } => *ty, + List { ty, .. } => *ty, + Unit => self.unit, + Variable { ty, .. } => *ty, + Intrinsic { ty, .. } => *ty, + ErrorRecovery(..) => self.error_recovery, + ExprWithBindings { expression, .. } => self.expr_ty(expression), + TypeConstructor { ty, .. } => *ty, + If { then_branch, .. } => self.expr_ty(then_branch), + } + } +} diff --git a/petr-typecheck/src/tests.rs b/petr-typecheck/src/tests.rs new file mode 100644 index 0000000..1a17ca4 --- /dev/null +++ b/petr-typecheck/src/tests.rs @@ -0,0 +1,710 @@ +use expect_test::{expect, Expect}; +use petr_resolve::resolve_symbols; +use petr_utils::render_error; + +use crate::{constraint_generation::TypeChecker, pretty_printing::*}; + +fn check( + input: impl Into, + expect: Expect, +) { + let input = input.into(); + let parser = petr_parse::Parser::new(vec![("test", input)]); + let (ast, errs, interner, source_map) = parser.into_result(); + if !errs.is_empty() { + errs.into_iter().for_each(|err| eprintln!("{:?}", render_error(&source_map, err))); + panic!("test failed: code didn't parse"); + } + let (errs, resolved) = resolve_symbols(ast, interner, Default::default()); + if !errs.is_empty() { + errs.into_iter().for_each(|err| eprintln!("{:?}", render_error(&source_map, err))); + panic!("unresolved symbols in test"); + } + let mut type_checker = TypeChecker::new(resolved); + type_checker.fully_type_check(); + let mut res = pretty_print_type_checker(&type_checker); + + let solved_constraints = match type_checker.into_solution() { + Ok(solution) => solution.pretty_print(), + Err(errs) => { + res.push_str("__ERRORS__\n"); + errs.into_iter().map(|err| format!("{:?}", err)).collect::>().join("\n") + }, + }; + + res.push('\n'); + res.push_str(&solved_constraints); + + expect.assert_eq(res.trim()); +} + +#[test] +fn identity_resolution_concrete_type() { + check( + r#" + fn foo(x in 'int) returns 'int x + "#, + expect![[r#" + fn foo: (int → int) + variable x: int + + + __SOLVED TYPES__ + 5: int"#]], + ); +} + +#[test] +fn identity_resolution_generic() { + check( + r#" + fn foo(x in 'A) returns 'A x + "#, + expect![[r#" + fn foo: (infer t5 → infer t5) + variable x: infer t5 + + + __SOLVED TYPES__ + 6: infer t5"#]], + ); +} + +#[test] +fn identity_resolution_custom_type() { + check( + r#" + type MyType = A | B + fn foo(x in 'MyType) returns 'MyType x + "#, + expect![[r#" + type MyType: MyType + + fn A: MyType + type constructor: MyType + + fn B: MyType + type constructor: MyType + + fn foo: (MyType → MyType) + variable x: MyType + + + __SOLVED TYPES__ + 8: MyType"#]], + ); +} + +#[test] +fn identity_resolution_two_custom_types() { + check( + r#" + type MyType = A | B + type MyComposedType = firstVariant someField 'MyType | secondVariant someField 'int someField2 'MyType someField3 'GenericType + fn foo(x in 'MyType) returns 'MyComposedType ~firstVariant(x) + "#, + expect![[r#" + type MyType: MyType + + type MyComposedType: MyComposedType + + fn A: MyType + type constructor: MyType + + fn B: MyType + type constructor: MyType + + fn firstVariant: (MyType → MyComposedType) + type constructor: MyComposedType + + fn secondVariant: (int → MyType → infer t16 → MyComposedType) + type constructor: MyComposedType + + fn foo: (MyType → MyComposedType) + function call to functionid2 with args: someField: MyType, returns MyComposedType + + __MONOMORPHIZED FUNCTIONS__ + fn firstVariant(["MyType"]) -> MyComposedType + + __SOLVED TYPES__ + 14: int + 17: infer t16 + 23: MyType"#]], + ); +} + +#[test] +fn literal_unification_fail() { + check( + r#" + fn foo() returns 'int 5 + fn bar() returns 'bool 5 + "#, + expect![[r#" + fn foo: int + literal: 5 + + fn bar: bool + literal: 5"#]], + ); +} + +#[test] +fn literal_unification_success() { + check( + r#" + fn foo() returns 'int 5 + fn bar() returns 'bool true + "#, + expect![[r#" + fn foo: int + literal: 5 + + fn bar: bool + literal: true"#]], + ); +} + +#[test] +fn pass_zero_arity_func_to_intrinsic() { + check( + r#" + fn string_literal() returns 'string + "This is a string literal." + + fn my_func() returns 'unit + @puts(~string_literal)"#, + expect![[r#" + fn string_literal: string + literal: "This is a string literal." + + fn my_func: unit + intrinsic: @puts(function call to functionid0 with args: ) + + __MONOMORPHIZED FUNCTIONS__ + fn string_literal([]) -> string"#]], + ); +} + +#[test] +fn pass_literal_string_to_intrinsic() { + check( + r#" + fn my_func() returns 'unit + @puts("test")"#, + expect![[r#" + fn my_func: unit + intrinsic: @puts(literal: "test") + + + __SOLVED TYPES__ + 5: string"#]], + ); +} + +#[test] +fn pass_wrong_type_literal_to_intrinsic() { + check( + r#" + fn my_func() returns 'unit + @puts(true)"#, + expect![[r#" + fn my_func: unit + intrinsic: @puts(literal: true) + + __ERRORS__ + + SpannedItem UnificationFailure("string", "true") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(52), length: 4 } }]"#]], + ); +} + +#[test] +fn intrinsic_and_return_ty_dont_match() { + check( + r#" + fn my_func() returns 'bool + @puts("test")"#, + expect![[r#" + fn my_func: bool + intrinsic: @puts(literal: "test") + + + __SOLVED TYPES__ + 5: string"#]], + ); +} + +#[test] +fn pass_wrong_type_fn_call_to_intrinsic() { + check( + r#" + fn bool_literal() returns 'bool + true + + fn my_func() returns 'unit + @puts(~bool_literal)"#, + expect![[r#" + fn bool_literal: bool + literal: true + + fn my_func: unit + intrinsic: @puts(function call to functionid0 with args: ) + + __MONOMORPHIZED FUNCTIONS__ + fn bool_literal([]) -> bool + __ERRORS__ + + SpannedItem UnificationFailure("string", "bool") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(110), length: 14 } }]"#]], + ); +} + +#[test] +fn multiple_calls_to_fn_dont_unify_params_themselves() { + check( + r#" + fn bool_literal(a in 'A, b in 'B) returns 'bool + true + + fn my_func() returns 'bool + ~bool_literal(1, 2) + + {- should not unify the parameter types of bool_literal -} + fn my_second_func() returns 'bool + ~bool_literal(true, false) + "#, + expect![[r#" + fn bool_literal: (infer t5 → infer t7 → bool) + literal: true + + fn my_func: bool + function call to functionid0 with args: a: 1, b: 2, returns bool + + fn my_second_func: bool + function call to functionid0 with args: a: true, b: false, returns bool + + __MONOMORPHIZED FUNCTIONS__ + fn bool_literal(["int", "int"]) -> bool + fn bool_literal(["bool", "bool"]) -> bool + + __SOLVED TYPES__ + 6: infer t5 + 8: infer t7"#]], + ); +} +#[test] +fn list_different_types_type_err() { + check( + r#" + fn my_list() returns 'list [ 1, true ] + "#, + expect![[r#" + fn my_list: infer t8 + list: [literal: 1, literal: true, ] + + + __SOLVED TYPES__ + 5: (1 | true) + 6: 1"#]], + ); +} + +#[test] +fn incorrect_number_of_args() { + check( + r#" + fn add(a in 'int, b in 'int) returns 'int a + + fn add_five(a in 'int) returns 'int ~add(5) + "#, + expect![[r#" + fn add: (int → int → int) + variable a: int + + fn add_five: (int → int) + error recovery Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(113), length: 8 } } + + __ERRORS__ + + SpannedItem ArgumentCountMismatch { function: "add", expected: 2, got: 1 } [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(113), length: 8 } }]"#]], + ); +} + +#[test] +fn infer_let_bindings() { + check( + r#" + fn hi(x in 'int, y in 'int) returns 'int + let a = x; + b = y; + c = 20; + d = 30; + e = 42; + a +fn main() returns 'int ~hi(1, 2)"#, + expect![[r#" + fn hi: (int → int → int) + a: variable: symbolid2 (int), + b: variable: symbolid4 (int), + c: literal: 20 (20), + d: literal: 30 (30), + e: literal: 42 (42), + "variable a: int" (int) + + fn main: int + function call to functionid0 with args: x: 1, y: 2, returns int + + __MONOMORPHIZED FUNCTIONS__ + fn hi(["int", "int"]) -> int + fn main([]) -> int + + __SOLVED TYPES__ + 5: int + 6: int"#]], + ) +} + +#[test] +fn if_rejects_non_bool_condition() { + check( + r#" + fn hi(x in 'int) returns 'int + if x then 1 else 2 + fn main() returns 'int ~hi(1)"#, + expect![[r#" + fn hi: (int → int) + if variable: symbolid2 then literal: 1 else literal: 2 + + fn main: int + function call to functionid0 with args: x: 1, returns int + + __MONOMORPHIZED FUNCTIONS__ + fn hi(["int"]) -> int + fn main([]) -> int + __ERRORS__ + + SpannedItem UnificationFailure("bool", "int") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(61), length: 2 } }]"#]], + ) +} + +#[test] +fn if_rejects_non_unit_missing_else() { + check( + r#" + fn hi() returns 'int + if true then 1 + fn main() returns 'int ~hi()"#, + expect![[r#" + fn hi: int + if literal: true then literal: 1 else unit + + fn main: int + function call to functionid0 with args: returns int + + __MONOMORPHIZED FUNCTIONS__ + fn hi([]) -> int + fn main([]) -> int + __ERRORS__ + + SpannedItem UnificationFailure("1", "unit") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(33), length: 46 } }]"#]], + ) +} + +#[test] +fn if_allows_unit_missing_else() { + check( + r#" + fn hi() returns 'unit + if true then @puts "hi" + + fn main() returns 'unit ~hi()"#, + expect![[r#" + fn hi: unit + if literal: true then intrinsic: @puts(literal: "hi") else unit + + fn main: unit + function call to functionid0 with args: returns unit + + __MONOMORPHIZED FUNCTIONS__ + fn hi([]) -> unit + fn main([]) -> unit + + __SOLVED TYPES__ + 5: bool + 6: string"#]], + ) +} + +#[test] +fn disallow_incorrect_constant_int() { + check( + r#" + type OneOrTwo = 1 | 2 + + fn main() returns 'OneOrTwo + ~OneOrTwo 10 + "#, + expect![[r#" + type OneOrTwo: OneOrTwo + + fn OneOrTwo: ((1 | 2) → OneOrTwo) + type constructor: OneOrTwo + + fn main: OneOrTwo + function call to functionid0 with args: OneOrTwo: 10, returns OneOrTwo + + __MONOMORPHIZED FUNCTIONS__ + fn OneOrTwo(["int"]) -> OneOrTwo + fn main([]) -> OneOrTwo + __ERRORS__ + + SpannedItem NotSubtype(["1", "2"], "10") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(104), length: 0 } }]"#]], + ) +} + +#[test] +fn disallow_incorrect_constant_string() { + check( + r#" + type AOrB = "A" | "B" + + fn main() returns 'AOrB + ~AOrB "c" + "#, + expect![[r#" + type AOrB: AOrB + + fn AOrB: (("A" | "B") → AOrB) + type constructor: AOrB + + fn main: AOrB + function call to functionid0 with args: AOrB: "c", returns AOrB + + __MONOMORPHIZED FUNCTIONS__ + fn AOrB(["string"]) -> AOrB + fn main([]) -> AOrB + __ERRORS__ + + SpannedItem NotSubtype(["\"A\"", "\"B\""], "\"c\"") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(97), length: 0 } }]"#]], + ) +} + +#[test] +fn disallow_incorrect_constant_bool() { + check( + r#" + type AlwaysTrue = true + + fn main() returns 'AlwaysTrue + ~AlwaysTrue false + "#, + expect![[r#" + type AlwaysTrue: AlwaysTrue + + fn AlwaysTrue: ((true) → AlwaysTrue) + type constructor: AlwaysTrue + + fn main: AlwaysTrue + function call to functionid0 with args: AlwaysTrue: false, returns AlwaysTrue + + __MONOMORPHIZED FUNCTIONS__ + fn AlwaysTrue(["bool"]) -> AlwaysTrue + fn main([]) -> AlwaysTrue + __ERRORS__ + + SpannedItem NotSubtype(["true"], "false") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(100), length: 0 } }]"#]], + ) +} + +#[test] +fn disallow_wrong_sum_type_in_add() { + check( + r#" + type IntBelowFive = 1 | 2 | 3 | 4 | 5 + {- reject an `add` which may return an int above five -} + fn add(a in 'IntBelowFive, b in 'IntBelowFive) returns 'IntBelowFive @add(a, b) +"#, + expect![[r#" + type IntBelowFive: IntBelowFive + + fn IntBelowFive: ((1 | 2 | 3 | 4 | 5) → IntBelowFive) + type constructor: IntBelowFive + + fn add: (IntBelowFive → IntBelowFive → IntBelowFive) + intrinsic: @add(variable: symbolid3, variable: symbolid4) + + __ERRORS__ + + SpannedItem UnificationFailure("int", "IntBelowFive") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(208), length: 2 } }]"#]], + ) +} + +// TODO: decide if user-defined types should be able to safely upcast, or if only +// anonymous types should be able to do so +#[ignore] +#[test] +fn allow_wrong_sum_type_in_add() { + check( + r#" + type IntBelowFive = 1 | 2 | 3 | 4 | 5 + {- reject an `add` which may return an int above five -} + fn add(a in 'IntBelowFive, b in 'IntBelowFive) returns 'int @add(a, b) +"#, + expect![[r#""#]], + ) +} + +#[test] +fn sum_type_unifies_to_superset() { + check( + r"fn test(a in 'sum 1 | 2 | 3) returns 'sum 1 | 2 | 3 a + fn test_(a in 'sum 1 | 2) returns 'sum 1 | 2 a + fn main() returns 'int + {- should be of specific type lit 2 -} + let x = 2; + {- should be of specific type 'sum 1 | 2 -} + y = ~test_(x); + {- should be of specific type 'sum 1 | 2 | 3 -} + z = ~test(y); + {- should also be of specific type 'sum 1 | 2 | 3 -} + zz = ~test(x) + + {- and should generalize to 'int with no problems -} + zz + ", + expect![[r#" + fn test: ((1 | 2 | 3) → (1 | 2 | 3)) + variable a: (1 | 2 | 3) + + fn test_: ((1 | 2) → (1 | 2)) + variable a: (1 | 2) + + fn main: int + x: literal: 2 (2), + y: function call to functionid1 with args: symbolid1: variable: symbolid5, ((1 | 2)), + z: function call to functionid0 with args: symbolid1: variable: symbolid6, ((1 | 2 | 3)), + zz: function call to functionid0 with args: symbolid1: variable: symbolid5, ((1 | 2 | 3)), + "variable zz: (1 | 2 | 3)" ((1 | 2 | 3)) + + __MONOMORPHIZED FUNCTIONS__ + fn test(["int"]) -> (1 | 2 | 3) + fn test_(["int"]) -> (1 | 2) + fn main([]) -> int + + __SOLVED TYPES__ + 5: (1 | 2 | 3) + 9: (1 | 2) + 11: (1 | 2)"#]], + ) +} + +#[test] +fn specific_type_generalizes() { + check( + r#"fn test(a in 'sum 'int | 'string) returns 'sum 'int | 'string a + fn test_(a in 'int) returns 'sum 'int | 'string a + fn main() returns 'int + let x = ~test_(5); + y = ~test("a string"); + 42 + "#, + expect![[r#" + fn test: ((int | string) → (int | string)) + variable a: (int | string) + + fn test_: (int → (int | string)) + variable a: int + + fn main: int + x: function call to functionid1 with args: symbolid1: literal: 5, ((int | string)), + y: function call to functionid0 with args: symbolid1: literal: "a string", ((int | string)), + "literal: 42" (42) + + __MONOMORPHIZED FUNCTIONS__ + fn test(["string"]) -> (int | string) + fn test_(["int"]) -> (int | string) + fn main([]) -> int + + __SOLVED TYPES__ + 5: (int | string) + 9: int"#]], + ) +} + +#[test] +fn disallow_bad_generalization() { + check( + r#"fn test(a in 'sum 'int | 'string) returns 'sum 'int | 'string a + fn test_(a in 'bool) returns 'sum 'int | 'string a + fn main() returns 'int + {- we are passing 'bool into 'int | 'string so this should fail to satisfy constraints -} + let y = ~test(~test_(true)); + 42 + "#, + expect![[r#" + fn test: ((int | string) → (int | string)) + variable a: (int | string) + + fn test_: (bool → (int | string)) + variable a: bool + + fn main: int + y: function call to functionid0 with args: symbolid1: function call to functionid1 with args: symbolid1: literal: true, , ((int | string)), + "literal: 42" (42) + + __MONOMORPHIZED FUNCTIONS__ + fn test(["(int | string)"]) -> (int | string) + fn test_(["bool"]) -> (int | string) + fn main([]) -> int + __ERRORS__ + + SpannedItem NotSubtype(["int", "string"], "bool") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(129), length: 0 } }]"#]], + ) +} + +#[test] +fn order_of_sum_type_doesnt_matter() { + check( + r#"fn test(a in 'sum 'int | 'string) returns 'sum 'string | 'int a + "#, + expect![[r#" + fn test: ((int | string) → (int | string)) + variable a: (int | string) + + + __SOLVED TYPES__ + 5: (int | string)"#]], + ) +} + +#[test] +fn can_return_superset() { + check( + r#"fn test(a in 'sum 'int | 'string) returns 'sum 'string | 'int | 'bool a + "#, + expect![[r#" + fn test: ((int | string) → (int | bool | string)) + variable a: (int | string) + + + __SOLVED TYPES__ + 5: (int | string)"#]], + ) +} + +#[test] +fn if_exp_basic() { + check( + "fn main() returns 'int if true then 1 else 0", + expect![[r#" + fn main: int + if literal: true then literal: 1 else literal: 0 + + __MONOMORPHIZED FUNCTIONS__ + fn main([]) -> int + + __SOLVED TYPES__ + 5: bool + 6: (0 | 1) + 7: 1"#]], + ); +} diff --git a/petr-typecheck/src/typed_ast.rs b/petr-typecheck/src/typed_ast.rs new file mode 100644 index 0000000..e0f9bc3 --- /dev/null +++ b/petr-typecheck/src/typed_ast.rs @@ -0,0 +1,248 @@ +use petr_bind::FunctionId; +use petr_resolve::{Expr, ExprKind, Literal}; +use petr_utils::{Identifier, Span}; + +use crate::{ + constraint_generation::{TypeCheck, TypeChecker}, + types::SpecificType, + TypeVariable, +}; + +#[derive(Clone)] +pub enum Intrinsic { + Puts(Box), + Add(Box, Box), + Multiply(Box, Box), + Divide(Box, Box), + Subtract(Box, Box), + Malloc(Box), + SizeOf(Box), + Equals(Box, Box), +} + +impl std::fmt::Debug for Intrinsic { + fn fmt( + &self, + f: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + match self { + Intrinsic::Puts(expr) => write!(f, "@puts({:?})", expr), + Intrinsic::Add(lhs, rhs) => write!(f, "@add({:?}, {:?})", lhs, rhs), + Intrinsic::Multiply(lhs, rhs) => write!(f, "@multiply({:?}, {:?})", lhs, rhs), + Intrinsic::Divide(lhs, rhs) => write!(f, "@divide({:?}, {:?})", lhs, rhs), + Intrinsic::Subtract(lhs, rhs) => write!(f, "@subtract({:?}, {:?})", lhs, rhs), + Intrinsic::Malloc(size) => write!(f, "@malloc({:?})", size), + Intrinsic::SizeOf(expr) => write!(f, "@sizeof({:?})", expr), + Intrinsic::Equals(lhs, rhs) => write!(f, "@equal({:?}, {:?})", lhs, rhs), + } + } +} + +#[derive(Clone)] +pub struct TypedExpr { + pub kind: TypedExprKind, + pub span: Span, +} + +impl TypedExpr { + pub fn span(&self) -> Span { + self.span + } +} + +#[derive(Clone, Debug)] +pub enum TypedExprKind { + FunctionCall { + func: FunctionId, + args: Vec<(Identifier, TypedExpr)>, + ty: TypeVariable, + }, + Literal { + value: Literal, + ty: TypeVariable, + }, + List { + elements: Vec, + ty: TypeVariable, + }, + Unit, + Variable { + ty: TypeVariable, + name: Identifier, + }, + Intrinsic { + ty: TypeVariable, + intrinsic: Intrinsic, + }, + ErrorRecovery(Span), + ExprWithBindings { + bindings: Vec<(Identifier, TypedExpr)>, + expression: Box, + }, + TypeConstructor { + ty: TypeVariable, + args: Box<[TypedExpr]>, + }, + If { + condition: Box, + then_branch: Box, + else_branch: Box, + }, +} + +impl std::fmt::Debug for TypedExpr { + fn fmt( + &self, + f: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + use TypedExprKind::*; + match &self.kind { + FunctionCall { func, args, .. } => { + write!(f, "function call to {} with args: ", func)?; + for (name, arg) in args { + write!(f, "{}: {:?}, ", name.id, arg)?; + } + Ok(()) + }, + Literal { value, .. } => write!(f, "literal: {}", value), + List { elements, .. } => { + write!(f, "list: [")?; + for elem in elements { + write!(f, "{:?}, ", elem)?; + } + write!(f, "]") + }, + Unit => write!(f, "unit"), + Variable { name, .. } => write!(f, "variable: {}", name.id), + Intrinsic { intrinsic, .. } => write!(f, "intrinsic: {:?}", intrinsic), + ErrorRecovery(span) => { + write!(f, "error recovery {span:?}") + }, + ExprWithBindings { bindings, expression } => { + write!(f, "bindings: ")?; + for (name, expr) in bindings { + write!(f, "{}: {:?}, ", name.id, expr)?; + } + write!(f, "expression: {:?}", expression) + }, + TypeConstructor { ty, .. } => write!(f, "type constructor: {:?}", ty), + If { + condition, + then_branch, + else_branch, + } => { + write!(f, "if {:?} then {:?} else {:?}", condition, then_branch, else_branch) + }, + } + } +} + +impl TypeCheck for Expr { + type Output = TypedExpr; + + fn type_check( + &self, + ctx: &mut TypeChecker, + ) -> Self::Output { + let kind = match &self.kind { + ExprKind::Literal(lit) => { + let ty = ctx.convert_literal_to_type(lit); + TypedExprKind::Literal { value: lit.clone(), ty } + }, + ExprKind::List(exprs) => { + if exprs.is_empty() { + let ty = ctx.unit(); + TypedExprKind::List { elements: vec![], ty } + } else { + let type_checked_exprs = exprs.iter().map(|expr| expr.type_check(ctx)).collect::>(); + // unify the type of the first expr against everything else in the list + let first_ty = ctx.expr_ty(&type_checked_exprs[0]); + for expr in type_checked_exprs.iter().skip(1) { + let second_ty = ctx.expr_ty(expr); + ctx.unify(first_ty, second_ty, expr.span()); + } + let first_ty = ctx.ctx().types().get(first_ty).clone(); + TypedExprKind::List { + elements: type_checked_exprs, + ty: ctx.insert_type::(&SpecificType::List(Box::new(first_ty))), + } + } + }, + ExprKind::FunctionCall(call) => (*call).type_check(ctx), + ExprKind::Unit => TypedExprKind::Unit, + ExprKind::ErrorRecovery => TypedExprKind::ErrorRecovery(self.span), + ExprKind::Variable { name, ty } => { + // look up variable in scope + // find its expr return type + let var_ty = ctx.find_variable(*name).expect("variable not found in scope"); + let ty = ctx.to_type_var(ty); + + ctx.unify(var_ty, ty, name.span()); + + TypedExprKind::Variable { ty, name: *name } + }, + ExprKind::Intrinsic(intrinsic) => return self.span.with_item(intrinsic.clone()).type_check(ctx), + ExprKind::TypeConstructor(parent_type_id, args) => { + // This ExprKind only shows up in the body of type constructor functions, and + // is basically a noop. The surrounding function decl will handle type checking for + // the type constructor. + let args = args.iter().map(|arg| arg.type_check(ctx)).collect::>(); + let ty = ctx.get_type(*parent_type_id); + TypedExprKind::TypeConstructor { + ty: *ty, + args: args.into_boxed_slice(), + } + }, + ExprKind::ExpressionWithBindings { bindings, expression } => { + // for each binding, type check the rhs + ctx.with_type_scope(|ctx| { + let mut type_checked_bindings = Vec::with_capacity(bindings.len()); + for binding in bindings { + let binding_ty = binding.expression.type_check(ctx); + let binding_expr_return_ty = ctx.expr_ty(&binding_ty); + ctx.insert_variable(binding.name, binding_expr_return_ty); + type_checked_bindings.push((binding.name, binding_ty)); + } + + TypedExprKind::ExprWithBindings { + bindings: type_checked_bindings, + expression: Box::new(expression.type_check(ctx)), + } + }) + }, + ExprKind::If { + condition, + then_branch, + else_branch, + } => { + let condition = condition.type_check(ctx); + let condition_ty = ctx.expr_ty(&condition); + ctx.unify(ctx.bool(), condition_ty, condition.span()); + + let then_branch = then_branch.type_check(ctx); + let then_ty = ctx.expr_ty(&then_branch); + + let else_branch = else_branch.type_check(ctx); + let else_ty = ctx.expr_ty(&else_branch); + + ctx.unify(then_ty, else_ty, else_branch.span()); + + TypedExprKind::If { + condition: Box::new(condition), + then_branch: Box::new(then_branch), + else_branch: Box::new(else_branch), + } + }, + }; + + TypedExpr { kind, span: self.span } + } +} + +#[derive(Clone, Debug)] +pub struct Function { + pub name: Identifier, + pub params: Vec<(Identifier, TypeVariable)>, + pub body: TypedExpr, + pub return_ty: TypeVariable, +} diff --git a/petr-typecheck/src/types.rs b/petr-typecheck/src/types.rs new file mode 100644 index 0000000..a16f699 --- /dev/null +++ b/petr-typecheck/src/types.rs @@ -0,0 +1,238 @@ +use std::collections::BTreeSet; + +use petr_resolve::Literal; +use petr_utils::{Identifier, IndexMap, Span}; + +use crate::TypeVariable; + +#[derive(Clone, PartialEq, Debug, Eq, PartialOrd, Ord)] +/// A type which is general, and has no constraints applied to it. +/// This is a generalization of [`SpecificType`]. +/// This is more useful for IR generation, since functions are monomorphized +/// based on general types. +pub enum GeneralType { + Unit, + Integer, + Boolean, + String, + UserDefined { + name: Identifier, + // TODO these should be boxed slices, as their size is not changed + variants: Box<[GeneralizedTypeVariant]>, + constant_literal_types: Vec, + }, + Arrow(Vec), + ErrorRecovery, + List(Box), + Infer(usize, Span), + Sum(BTreeSet), +} + +impl GeneralType { + /// Because [`GeneralType`]'s type info is less detailed (specific) than [`SpecificType`], + /// we can losslessly cast any [`GeneralType`] into an instance of [`SpecificType`]. + pub fn safely_upcast(&self) -> SpecificType { + match self { + GeneralType::Unit => SpecificType::Unit, + GeneralType::Integer => SpecificType::Integer, + GeneralType::Boolean => SpecificType::Boolean, + GeneralType::String => SpecificType::String, + GeneralType::ErrorRecovery => SpecificType::ErrorRecovery, + GeneralType::Infer(u, s) => SpecificType::Infer(*u, *s), + GeneralType::UserDefined { + name, + variants, + constant_literal_types, + } => SpecificType::UserDefined { + name: *name, + variants: variants + .iter() + .map(|variant| { + let fields = variant.fields.iter().map(|field| field.safely_upcast()).collect::>(); + + TypeVariant { + fields: fields.into_boxed_slice(), + } + }) + .collect(), + constant_literal_types: constant_literal_types.clone(), + }, + GeneralType::Arrow(tys) => SpecificType::Arrow(tys.clone()), + GeneralType::List(ty) => SpecificType::List(Box::new(ty.safely_upcast())), + GeneralType::Sum(tys) => SpecificType::Sum(tys.iter().map(|ty| ty.safely_upcast()).collect()), + } + } +} +/// This is an information-rich type -- it tracks effects and data types. It is used for +/// the type-checking stage to provide rich information to the user. +/// Types are generalized into instances of [`GeneralType`] for monomorphization and +/// code generation. +#[derive(Clone, PartialEq, Debug, Eq, PartialOrd, Ord)] +pub enum SpecificType { + Unit, + Integer, + Boolean, + /// a static length string known at compile time + String, + /// A reference to another type + Ref(TypeVariable), + /// A user-defined type + UserDefined { + name: Identifier, + // TODO these should be boxed slices, as their size is not changed + variants: Vec, + constant_literal_types: Vec, + }, + Arrow(Vec), + ErrorRecovery, + // TODO make this petr type instead of typevariable + List(Box), + /// the usize is just an identifier for use in rendering the type + /// the span is the location of the inference, for error reporting if the inference is never + /// resolved + Infer(usize, Span), + Sum(BTreeSet), + Literal(Literal), +} + +#[derive(Clone, PartialEq, Debug, Eq, PartialOrd, Ord)] +pub struct GeneralizedTypeVariant { + pub fields: Box<[GeneralType]>, +} + +impl SpecificType { + fn generalize_inner( + &self, + types: &IndexMap, + ) -> GeneralType { + match self { + SpecificType::Unit => GeneralType::Unit, + SpecificType::Integer => GeneralType::Integer, + SpecificType::Boolean => GeneralType::Boolean, + SpecificType::String => GeneralType::String, + SpecificType::Ref(ty) => types.get(*ty).generalize(types), + SpecificType::UserDefined { + name, + variants, + constant_literal_types, + } => GeneralType::UserDefined { + name: *name, + variants: variants + .iter() + .map(|variant| { + let generalized_fields = variant.fields.iter().map(|field| field.generalize(types)).collect::>(); + + GeneralizedTypeVariant { + fields: generalized_fields.into_boxed_slice(), + } + }) + .collect(), + constant_literal_types: constant_literal_types.clone(), + }, + SpecificType::Arrow(tys) => GeneralType::Arrow(tys.clone()), + SpecificType::ErrorRecovery => GeneralType::ErrorRecovery, + SpecificType::List(ty) => { + let ty = ty.generalize(types); + GeneralType::List(Box::new(ty)) + }, + SpecificType::Infer(u, s) => GeneralType::Infer(*u, *s), + SpecificType::Literal(l) => match l { + Literal::Integer(_) => GeneralType::Integer, + Literal::Boolean(_) => GeneralType::Boolean, + Literal::String(_) => GeneralType::String, + }, + SpecificType::Sum(tys) => { + // generalize all types, fold if possible + let all_generalized: BTreeSet<_> = tys.iter().map(|ty| ty.generalize(types)).collect(); + if all_generalized.len() == 1 { + // in this case, all specific types generalized to the same type + all_generalized.into_iter().next().expect("invariant") + } else { + GeneralType::Sum(all_generalized.into_iter().collect()) + } + }, + } + } + + /// Use this to construct `[SpecificType::Sum]` types -- + /// it will attempt to collapse the sum into a single type if possible + pub(crate) fn sum(tys: BTreeSet) -> SpecificType { + if tys.len() == 1 { + tys.into_iter().next().expect("invariant") + } else { + SpecificType::Sum(tys) + } + } +} + +#[derive(Clone, PartialEq, Debug, Eq, PartialOrd, Ord)] +pub struct TypeVariant { + pub fields: Box<[SpecificType]>, +} + +pub trait Type { + fn as_specific_ty(&self) -> SpecificType; + + fn generalize( + &self, + types: &IndexMap, + ) -> GeneralType; +} + +impl Type for SpecificType { + fn as_specific_ty(&self) -> SpecificType { + self.clone() + } + + fn generalize( + &self, + types: &IndexMap, + ) -> GeneralType { + self.generalize_inner(types) + } +} + +impl Type for GeneralType { + fn generalize( + &self, + _: &IndexMap, + ) -> Self { + self.clone() + } + + fn as_specific_ty(&self) -> SpecificType { + match self { + GeneralType::Unit => SpecificType::Unit, + GeneralType::Integer => SpecificType::Integer, + GeneralType::Boolean => SpecificType::Boolean, + GeneralType::String => SpecificType::String, + GeneralType::UserDefined { + name, + variants, + constant_literal_types, + } => SpecificType::UserDefined { + name: *name, + variants: variants + .iter() + .map(|variant| { + let fields = variant.fields.iter().map(|field| field.as_specific_ty()).collect::>(); + + TypeVariant { + fields: fields.into_boxed_slice(), + } + }) + .collect(), + constant_literal_types: constant_literal_types.clone(), + }, + GeneralType::Arrow(tys) => SpecificType::Arrow(tys.clone()), + GeneralType::ErrorRecovery => SpecificType::ErrorRecovery, + GeneralType::List(ty) => SpecificType::List(Box::new(ty.as_specific_ty())), + GeneralType::Infer(u, s) => SpecificType::Infer(*u, *s), + GeneralType::Sum(tys) => { + let tys = tys.iter().map(|ty| ty.as_specific_ty()).collect(); + SpecificType::Sum(tys) + }, + } + } +} +