diff --git a/petr-typecheck/src/lib.rs b/petr-typecheck/src/lib.rs index 4fc4b1e..d1fa01c 100644 --- a/petr-typecheck/src/lib.rs +++ b/petr-typecheck/src/lib.rs @@ -323,8 +323,15 @@ impl SpecificType { use SpecificType::*; let generalized_b = b.generalize(ctx).safely_upcast(); match (self, b) { + // If `a` is the generalized form of `b`, then `b` satisfies the constraint. (a, b) if a == b || *a == generalized_b => true, + // If `a` is a sum type which contains `b` OR the generalized form of `b`, then `b` + // satisfies the constraint. (Sum(a_tys), b) if a_tys.contains(b) || a_tys.contains(&generalized_b) => true, + // if both `a` and `b` are sum types, then `a` must be a superset of `b`: + // - every element in `b` must either: + // - be a member of `a` + // - generalize to a member of `a` (Sum(a_tys), Sum(b_tys)) => { // if a_tys is a superset of b_tys, // every element OR its generalized version is contained in a_tys @@ -337,7 +344,17 @@ impl SpecificType { true }, - _ => todo!(), + _otherwise => false, + } + } + + /// Use this to construct `[SpecificType::Sum]` types -- + /// it will attempt to collapse the sum into a single type if possible + fn sum(tys: BTreeSet) -> SpecificType { + if tys.len() == 1 { + tys.into_iter().next().expect("invariant") + } else { + SpecificType::Sum(tys) } } } @@ -638,7 +655,7 @@ impl TypeChecker { (other, Sum(sum_tys)) => { // `other` must be a member of the Sum type if !sum_tys.contains(&other) { - self.push_error(span.with_item(self.unify_err(other.clone(), SpecificType::Sum(sum_tys.iter().cloned().collect())))); + self.push_error(span.with_item(self.unify_err(other.clone(), SpecificType::sum(sum_tys.clone())))); } // unify both types to the other type self.ctx.update_type(t2, other); @@ -682,17 +699,11 @@ impl TypeChecker { (Sum(a_tys), Sum(b_tys)) => { // calculate the intersection of these types, update t2 to the intersection let intersection = a_tys.iter().filter(|a_ty| b_tys.contains(a_ty)).cloned().collect(); - self.ctx.update_type(t2, Sum(intersection)); - }, - (ty1 @ Sum(sum_tys), other) => { - if - // if `other` is a generalized version of the sum type, - // then it satisfies the sum type - ty1.is_superset_of(other, &self.ctx) { - } else { - self.push_error(span.with_item(self.satisfy_err(SpecificType::Sum(sum_tys.iter().cloned().collect()), other.clone()))); - } + self.ctx.update_type(t2, SpecificType::sum(intersection)); }, + // if `ty1` is a generalized version of the sum type, + // then it satisfies the sum type + (ty1, other) if ty1.is_superset_of(other, &self.ctx) => (), (Literal(l1), Literal(l2)) if l1 == l2 => (), // Literals can satisfy broader parent types (ty, Literal(lit)) => match (lit, ty) { @@ -1553,8 +1564,12 @@ mod pretty_printing { )); } + if !type_checker.monomorphized_functions.is_empty() { + s.push('\n'); + } + if !type_checker.errors.is_empty() { - s.push_str("\n\n__ERRORS__\n"); + s.push_str("\n__ERRORS__\n"); for error in type_checker.errors { s.push_str(&format!("{:?}\n", error)); } @@ -1778,7 +1793,8 @@ mod tests { function call to functionid2 with args: someField: MyType, returns MyComposedType __MONOMORPHIZED FUNCTIONS__ - fn firstVariant(["MyType"]) -> MyComposedType"#]], + fn firstVariant(["MyType"]) -> MyComposedType + "#]], ); } @@ -1835,7 +1851,8 @@ mod tests { intrinsic: @puts(function call to functionid0 with args: ) __MONOMORPHIZED FUNCTIONS__ - fn string_literal([]) -> string"#]], + fn string_literal([]) -> string + "#]], ); } @@ -1902,6 +1919,7 @@ mod tests { __MONOMORPHIZED FUNCTIONS__ fn bool_literal([]) -> bool + __ERRORS__ SpannedItem UnificationFailure("string", "bool") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(110), length: 14 } }] "#]], @@ -1934,7 +1952,8 @@ mod tests { __MONOMORPHIZED FUNCTIONS__ fn bool_literal(["int", "int"]) -> bool - fn bool_literal(["bool", "bool"]) -> bool"#]], + fn bool_literal(["bool", "bool"]) -> bool + "#]], ); } #[test] @@ -1999,7 +2018,8 @@ fn main() returns 'int ~hi(1, 2)"#, __MONOMORPHIZED FUNCTIONS__ fn hi(["int", "int"]) -> int - fn main([]) -> int"#]], + fn main([]) -> int + "#]], ) } @@ -2020,6 +2040,7 @@ fn main() returns 'int ~hi(1, 2)"#, __MONOMORPHIZED FUNCTIONS__ fn hi(["int"]) -> int fn main([]) -> int + __ERRORS__ SpannedItem UnificationFailure("int", "bool") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(61), length: 2 } }] "#]], @@ -2043,6 +2064,7 @@ fn main() returns 'int ~hi(1, 2)"#, __MONOMORPHIZED FUNCTIONS__ fn hi([]) -> int fn main([]) -> int + __ERRORS__ SpannedItem UnificationFailure("unit", "1") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(33), length: 46 } }] "#]], @@ -2066,7 +2088,8 @@ fn main() returns 'int ~hi(1, 2)"#, __MONOMORPHIZED FUNCTIONS__ fn hi([]) -> unit - fn main([]) -> unit"#]], + fn main([]) -> unit + "#]], ) } @@ -2091,8 +2114,9 @@ fn main() returns 'int ~hi(1, 2)"#, __MONOMORPHIZED FUNCTIONS__ fn OneOrTwo(["int"]) -> OneOrTwo fn main([]) -> OneOrTwo + __ERRORS__ - SpannedItem FailedToSatisfy("10", "(1 | 2)") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(104), length: 0 } }] + SpannedItem NotSubtype(["1", "2"], "10") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(104), length: 0 } }] "#]], ) } @@ -2118,8 +2142,9 @@ fn main() returns 'int ~hi(1, 2)"#, __MONOMORPHIZED FUNCTIONS__ fn AOrB(["string"]) -> AOrB fn main([]) -> AOrB + __ERRORS__ - SpannedItem FailedToSatisfy("\"c\"", "(\"A\" | \"B\")") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(97), length: 0 } }] + SpannedItem NotSubtype(["\"A\"", "\"B\""], "\"c\"") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(97), length: 0 } }] "#]], ) } @@ -2198,7 +2223,25 @@ fn main() returns 'int ~hi(1, 2)"#, {- and should generalize to 'int with no problems -} zz ", - expect![[r#""#]], + expect![[r#" + fn test: ((1 | 2 | 3) → (1 | 2 | 3)) + variable a: (1 | 2 | 3) + + fn test_: ((1 | 2) → (1 | 2)) + variable a: (1 | 2) + + fn main: int + x: literal: 2 (2), + y: function call to functionid1 with args: symbolid1: variable: symbolid5, ((1 | 2)), + z: function call to functionid0 with args: symbolid1: variable: symbolid6, ((1 | 2 | 3)), + zz: function call to functionid0 with args: symbolid1: variable: symbolid5, ((1 | 2 | 3)), + "variable zz: (1 | 2 | 3)" ((1 | 2 | 3)) + + __MONOMORPHIZED FUNCTIONS__ + fn test(["int"]) -> (1 | 2 | 3) + fn test_(["int"]) -> (1 | 2) + fn main([]) -> int + "#]], ) } @@ -2212,7 +2255,23 @@ fn main() returns 'int ~hi(1, 2)"#, y = ~test("a string"); 42 "#, - expect![[r#""#]], + expect![[r#" + fn test: ((int | string) → (int | string)) + variable a: (int | string) + + fn test_: (int → (int | string)) + variable a: int + + fn main: int + x: function call to functionid1 with args: symbolid1: literal: 5, ((int | string)), + y: function call to functionid0 with args: symbolid1: literal: "a string", ((int | string)), + "literal: 42" (42) + + __MONOMORPHIZED FUNCTIONS__ + fn test(["string"]) -> (int | string) + fn test_(["int"]) -> (int | string) + fn main([]) -> int + "#]], ) } @@ -2226,7 +2285,25 @@ fn main() returns 'int ~hi(1, 2)"#, let y = ~test(~test_(true)); 42 "#, - expect![[r#""#]], + expect![[r#" + fn test: ((int | string) → (int | string)) + variable a: (int | string) + + fn test_: (bool → (int | string)) + variable a: bool + + fn main: int + y: function call to functionid0 with args: symbolid1: function call to functionid1 with args: symbolid1: literal: true, , ((int | string)), + "literal: 42" (42) + + __MONOMORPHIZED FUNCTIONS__ + fn test(["(int | string)"]) -> (int | string) + fn test_(["bool"]) -> (int | string) + fn main([]) -> int + + __ERRORS__ + SpannedItem NotSubtype(["int", "string"], "bool") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(129), length: 0 } }] + "#]], ) } @@ -2235,7 +2312,11 @@ fn main() returns 'int ~hi(1, 2)"#, check( r#"fn test(a in 'sum 'int | 'string) returns 'sum 'string | 'int a "#, - expect![[r#""#]], + expect![[r#" + fn test: ((int | string) → (int | string)) + variable a: (int | string) + + "#]], ) } @@ -2244,7 +2325,11 @@ fn main() returns 'int ~hi(1, 2)"#, check( r#"fn test(a in 'sum 'int | 'string) returns 'sum 'string | 'int | 'bool a "#, - expect![[r#""#]], + expect![[r#" + fn test: ((int | string) → (int | bool | string)) + variable a: (int | string) + + "#]], ) } }