Skip to content

Commit

Permalink
fix up satisfies constraint
Browse files Browse the repository at this point in the history
  • Loading branch information
sezna committed Jul 31, 2024
1 parent 05df7ae commit 2c96d0d
Showing 1 changed file with 110 additions and 25 deletions.
135 changes: 110 additions & 25 deletions petr-typecheck/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,15 @@ impl SpecificType {
use SpecificType::*;
let generalized_b = b.generalize(ctx).safely_upcast();
match (self, b) {
// If `a` is the generalized form of `b`, then `b` satisfies the constraint.
(a, b) if a == b || *a == generalized_b => true,
// If `a` is a sum type which contains `b` OR the generalized form of `b`, then `b`
// satisfies the constraint.
(Sum(a_tys), b) if a_tys.contains(b) || a_tys.contains(&generalized_b) => true,
// if both `a` and `b` are sum types, then `a` must be a superset of `b`:
// - every element in `b` must either:
// - be a member of `a`
// - generalize to a member of `a`
(Sum(a_tys), Sum(b_tys)) => {
// if a_tys is a superset of b_tys,
// every element OR its generalized version is contained in a_tys
Expand All @@ -337,7 +344,17 @@ impl SpecificType {

true
},
_ => todo!(),
_otherwise => false,
}
}

/// Use this to construct `[SpecificType::Sum]` types --
/// it will attempt to collapse the sum into a single type if possible
fn sum(tys: BTreeSet<SpecificType>) -> SpecificType {
if tys.len() == 1 {
tys.into_iter().next().expect("invariant")
} else {
SpecificType::Sum(tys)
}
}
}
Expand Down Expand Up @@ -638,7 +655,7 @@ impl TypeChecker {
(other, Sum(sum_tys)) => {
// `other` must be a member of the Sum type
if !sum_tys.contains(&other) {
self.push_error(span.with_item(self.unify_err(other.clone(), SpecificType::Sum(sum_tys.iter().cloned().collect()))));
self.push_error(span.with_item(self.unify_err(other.clone(), SpecificType::sum(sum_tys.clone()))));
}
// unify both types to the other type
self.ctx.update_type(t2, other);
Expand Down Expand Up @@ -682,17 +699,11 @@ impl TypeChecker {
(Sum(a_tys), Sum(b_tys)) => {
// calculate the intersection of these types, update t2 to the intersection
let intersection = a_tys.iter().filter(|a_ty| b_tys.contains(a_ty)).cloned().collect();
self.ctx.update_type(t2, Sum(intersection));
},
(ty1 @ Sum(sum_tys), other) => {
if
// if `other` is a generalized version of the sum type,
// then it satisfies the sum type
ty1.is_superset_of(other, &self.ctx) {
} else {
self.push_error(span.with_item(self.satisfy_err(SpecificType::Sum(sum_tys.iter().cloned().collect()), other.clone())));
}
self.ctx.update_type(t2, SpecificType::sum(intersection));
},
// if `ty1` is a generalized version of the sum type,
// then it satisfies the sum type
(ty1, other) if ty1.is_superset_of(other, &self.ctx) => (),
(Literal(l1), Literal(l2)) if l1 == l2 => (),
// Literals can satisfy broader parent types
(ty, Literal(lit)) => match (lit, ty) {
Expand Down Expand Up @@ -1553,8 +1564,12 @@ mod pretty_printing {
));
}

if !type_checker.monomorphized_functions.is_empty() {
s.push('\n');
}

if !type_checker.errors.is_empty() {
s.push_str("\n\n__ERRORS__\n");
s.push_str("\n__ERRORS__\n");
for error in type_checker.errors {
s.push_str(&format!("{:?}\n", error));
}
Expand Down Expand Up @@ -1778,7 +1793,8 @@ mod tests {
function call to functionid2 with args: someField: MyType, returns MyComposedType
__MONOMORPHIZED FUNCTIONS__
fn firstVariant(["MyType"]) -> MyComposedType"#]],
fn firstVariant(["MyType"]) -> MyComposedType
"#]],
);
}

Expand Down Expand Up @@ -1835,7 +1851,8 @@ mod tests {
intrinsic: @puts(function call to functionid0 with args: )
__MONOMORPHIZED FUNCTIONS__
fn string_literal([]) -> string"#]],
fn string_literal([]) -> string
"#]],
);
}

Expand Down Expand Up @@ -1902,6 +1919,7 @@ mod tests {
__MONOMORPHIZED FUNCTIONS__
fn bool_literal([]) -> bool
__ERRORS__
SpannedItem UnificationFailure("string", "bool") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(110), length: 14 } }]
"#]],
Expand Down Expand Up @@ -1934,7 +1952,8 @@ mod tests {
__MONOMORPHIZED FUNCTIONS__
fn bool_literal(["int", "int"]) -> bool
fn bool_literal(["bool", "bool"]) -> bool"#]],
fn bool_literal(["bool", "bool"]) -> bool
"#]],
);
}
#[test]
Expand Down Expand Up @@ -1999,7 +2018,8 @@ fn main() returns 'int ~hi(1, 2)"#,
__MONOMORPHIZED FUNCTIONS__
fn hi(["int", "int"]) -> int
fn main([]) -> int"#]],
fn main([]) -> int
"#]],
)
}

Expand All @@ -2020,6 +2040,7 @@ fn main() returns 'int ~hi(1, 2)"#,
__MONOMORPHIZED FUNCTIONS__
fn hi(["int"]) -> int
fn main([]) -> int
__ERRORS__
SpannedItem UnificationFailure("int", "bool") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(61), length: 2 } }]
"#]],
Expand All @@ -2043,6 +2064,7 @@ fn main() returns 'int ~hi(1, 2)"#,
__MONOMORPHIZED FUNCTIONS__
fn hi([]) -> int
fn main([]) -> int
__ERRORS__
SpannedItem UnificationFailure("unit", "1") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(33), length: 46 } }]
"#]],
Expand All @@ -2066,7 +2088,8 @@ fn main() returns 'int ~hi(1, 2)"#,
__MONOMORPHIZED FUNCTIONS__
fn hi([]) -> unit
fn main([]) -> unit"#]],
fn main([]) -> unit
"#]],
)
}

Expand All @@ -2091,8 +2114,9 @@ fn main() returns 'int ~hi(1, 2)"#,
__MONOMORPHIZED FUNCTIONS__
fn OneOrTwo(["int"]) -> OneOrTwo
fn main([]) -> OneOrTwo
__ERRORS__
SpannedItem FailedToSatisfy("10", "(1 | 2)") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(104), length: 0 } }]
SpannedItem NotSubtype(["1", "2"], "10") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(104), length: 0 } }]
"#]],
)
}
Expand All @@ -2118,8 +2142,9 @@ fn main() returns 'int ~hi(1, 2)"#,
__MONOMORPHIZED FUNCTIONS__
fn AOrB(["string"]) -> AOrB
fn main([]) -> AOrB
__ERRORS__
SpannedItem FailedToSatisfy("\"c\"", "(\"A\" | \"B\")") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(97), length: 0 } }]
SpannedItem NotSubtype(["\"A\"", "\"B\""], "\"c\"") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(97), length: 0 } }]
"#]],
)
}
Expand Down Expand Up @@ -2198,7 +2223,25 @@ fn main() returns 'int ~hi(1, 2)"#,
{- and should generalize to 'int with no problems -}
zz
",
expect![[r#""#]],
expect![[r#"
fn test: ((1 | 2 | 3) → (1 | 2 | 3))
variable a: (1 | 2 | 3)
fn test_: ((1 | 2) → (1 | 2))
variable a: (1 | 2)
fn main: int
x: literal: 2 (2),
y: function call to functionid1 with args: symbolid1: variable: symbolid5, ((1 | 2)),
z: function call to functionid0 with args: symbolid1: variable: symbolid6, ((1 | 2 | 3)),
zz: function call to functionid0 with args: symbolid1: variable: symbolid5, ((1 | 2 | 3)),
"variable zz: (1 | 2 | 3)" ((1 | 2 | 3))
__MONOMORPHIZED FUNCTIONS__
fn test(["int"]) -> (1 | 2 | 3)
fn test_(["int"]) -> (1 | 2)
fn main([]) -> int
"#]],
)
}

Expand All @@ -2212,7 +2255,23 @@ fn main() returns 'int ~hi(1, 2)"#,
y = ~test("a string");
42
"#,
expect![[r#""#]],
expect![[r#"
fn test: ((int | string) → (int | string))
variable a: (int | string)
fn test_: (int → (int | string))
variable a: int
fn main: int
x: function call to functionid1 with args: symbolid1: literal: 5, ((int | string)),
y: function call to functionid0 with args: symbolid1: literal: "a string", ((int | string)),
"literal: 42" (42)
__MONOMORPHIZED FUNCTIONS__
fn test(["string"]) -> (int | string)
fn test_(["int"]) -> (int | string)
fn main([]) -> int
"#]],
)
}

Expand All @@ -2226,7 +2285,25 @@ fn main() returns 'int ~hi(1, 2)"#,
let y = ~test(~test_(true));
42
"#,
expect![[r#""#]],
expect![[r#"
fn test: ((int | string) → (int | string))
variable a: (int | string)
fn test_: (bool → (int | string))
variable a: bool
fn main: int
y: function call to functionid0 with args: symbolid1: function call to functionid1 with args: symbolid1: literal: true, , ((int | string)),
"literal: 42" (42)
__MONOMORPHIZED FUNCTIONS__
fn test(["(int | string)"]) -> (int | string)
fn test_(["bool"]) -> (int | string)
fn main([]) -> int
__ERRORS__
SpannedItem NotSubtype(["int", "string"], "bool") [Span { source: SourceId(0), span: SourceSpan { offset: SourceOffset(129), length: 0 } }]
"#]],
)
}

Expand All @@ -2235,7 +2312,11 @@ fn main() returns 'int ~hi(1, 2)"#,
check(
r#"fn test(a in 'sum 'int | 'string) returns 'sum 'string | 'int a
"#,
expect![[r#""#]],
expect![[r#"
fn test: ((int | string) → (int | string))
variable a: (int | string)
"#]],
)
}

Expand All @@ -2244,7 +2325,11 @@ fn main() returns 'int ~hi(1, 2)"#,
check(
r#"fn test(a in 'sum 'int | 'string) returns 'sum 'string | 'int | 'bool a
"#,
expect![[r#""#]],
expect![[r#"
fn test: ((int | string) → (int | bool | string))
variable a: (int | string)
"#]],
)
}
}

0 comments on commit 2c96d0d

Please sign in to comment.