Skip to content

Commit

Permalink
Make AnyInteger abstract
Browse files Browse the repository at this point in the history
Ref. eng/recordflux/RecordFlux#1672
  • Loading branch information
treiher committed Jul 10, 2024
1 parent 89fb782 commit dc9093c
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 79 deletions.
2 changes: 1 addition & 1 deletion rflx/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ class MathAssExpr(AssExpr):
def __init__(self, *terms: Expr, location: Optional[Location] = None) -> None:
super().__init__(*terms, location=location)
common_type = rty.common_type([t.type_ for t in terms])
self.type_ = common_type if common_type != rty.UNDEFINED else rty.AnyInteger()
self.type_ = common_type if common_type != rty.UNDEFINED else rty.BASE_INTEGER

def _check_type_subexpr(self) -> RecordFluxError:
error = RecordFluxError()
Expand Down
26 changes: 14 additions & 12 deletions rflx/typing_.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,26 +89,25 @@ def __str__(self) -> str:
@attr.s(frozen=True)
class AnyInteger(Any):
DESCRIPTIVE_NAME: ClassVar[str] = "integer type"
bounds: Optional[Bounds] = attr.ib(default=None)

def __str__(self) -> str:
bounds = f" ({self.bounds})" if self.bounds else ""
return f"{self.DESCRIPTIVE_NAME}{bounds}"
bounds: Bounds = attr.ib()

@abstractmethod
def is_compatible(self, other: Type) -> bool:
return other == Any() or isinstance(other, AnyInteger)

def common_type(self, other: Type) -> Type:
if other == Any() or isinstance(other, AnyInteger):
return self
return Undefined()
raise NotImplementedError


@attr.s(frozen=True)
class UniversalInteger(AnyInteger):
DESCRIPTIVE_NAME: ClassVar[str] = "type universal integer"
bounds: Bounds = attr.ib()

def __str__(self) -> str:
bounds = f" ({self.bounds})" if self.bounds else ""
return f"{self.DESCRIPTIVE_NAME}{bounds}"

def is_compatible(self, other: Type) -> bool:
return other == Any() or isinstance(other, AnyInteger)

def common_type(self, other: Type) -> Type:
if isinstance(other, UniversalInteger) and self.bounds != other.bounds:
return UniversalInteger(
Expand All @@ -132,6 +131,9 @@ def __str__(self) -> str:
bounds = f" ({self.bounds})" if self.bounds else ""
return f'{self.DESCRIPTIVE_NAME} "{self.identifier}"{bounds}'

def is_compatible(self, other: Type) -> bool:
return other == Any() or isinstance(other, AnyInteger)

def is_compatible_strong(self, other: Type) -> bool:
return (
self == other
Expand All @@ -148,7 +150,7 @@ def common_type(self, other: Type) -> Type:
if isinstance(other, Integer) and (
self.identifier != other.identifier or self.bounds != other.bounds
):
return AnyInteger()
return BASE_INTEGER
if isinstance(other, AnyInteger):
return other
if other == Any() or self == other:
Expand Down
61 changes: 32 additions & 29 deletions tests/unit/expr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ def test_not_type() -> None:

def test_not_type_error() -> None:
assert_type_error(
Not(Variable("X", type_=rty.AnyInteger(), location=Location((10, 20)))),
Not(Variable("X", type_=INT_TY, location=Location((10, 20)))),
r'^<stdin>:10:20: error: expected enumeration type "__BUILTINS__::Boolean"\n'
r"<stdin>:10:20: error: found integer type$",
r'<stdin>:10:20: error: found integer type "I" \(10 \.\. 100\)$',
)


Expand Down Expand Up @@ -615,10 +615,6 @@ def test_number_hashable() -> None:

@pytest.mark.parametrize("operation", [Add, Mul, Sub, Div, Pow])
def test_math_expr_type(operation: Callable[[Expr, Expr], Expr]) -> None:
assert_type(
operation(Variable("X", type_=rty.AnyInteger()), Variable("Y", type_=rty.AnyInteger())),
rty.AnyInteger(),
)
assert_type(
operation(Variable("X", type_=INT_TY), Variable("Y", type_=INT_TY)),
INT_TY,
Expand Down Expand Up @@ -941,11 +937,11 @@ def test_attribute() -> None:
@pytest.mark.parametrize(
("attribute", "expr", "expected"),
[
(Size, Variable("X", type_=rty.AnyInteger()), rty.UNIVERSAL_INTEGER),
(Length, Variable("X", type_=rty.AnyInteger()), rty.UNIVERSAL_INTEGER),
(First, Variable("X", type_=rty.AnyInteger()), rty.UNIVERSAL_INTEGER),
(Last, Variable("X", type_=rty.AnyInteger()), rty.UNIVERSAL_INTEGER),
(ValidChecksum, Variable("X", type_=rty.AnyInteger()), rty.BOOLEAN),
(Size, Variable("X", type_=INT_TY), rty.UNIVERSAL_INTEGER),
(Length, Variable("X", type_=INT_TY), rty.UNIVERSAL_INTEGER),
(First, Variable("X", type_=INT_TY), rty.UNIVERSAL_INTEGER),
(Last, Variable("X", type_=INT_TY), rty.UNIVERSAL_INTEGER),
(ValidChecksum, Variable("X", type_=INT_TY), rty.BOOLEAN),
(Valid, Variable("X", type_=rty.Message("A")), rty.BOOLEAN),
(
Present,
Expand Down Expand Up @@ -1160,19 +1156,22 @@ def test_aggregate_precedence() -> None:
@pytest.mark.parametrize("relation", [Less, LessEqual, Equal, GreaterEqual, Greater, NotEqual])
def test_relation_integer_type(relation: Callable[[Expr, Expr], Expr]) -> None:
assert_type(
relation(Variable("X", type_=rty.AnyInteger()), Variable("Y", type_=rty.AnyInteger())),
relation(Variable("X", type_=INT_TY), Variable("Y", type_=INT_TY)),
rty.BOOLEAN,
)


@pytest.mark.parametrize("relation", [Less, LessEqual, Equal, GreaterEqual, Greater, NotEqual])
def test_relation_integer_type_error(relation: Callable[[Expr, Expr], Expr]) -> None:
integer_type = (
r'integer type "I" \(10 \.\. 100\)' if relation in [Equal, NotEqual] else r"integer type"
)
assert_type_error(
relation(
Variable("X", type_=rty.AnyInteger()),
Variable("X", type_=INT_TY),
Variable("True", type_=rty.BOOLEAN, location=Location((10, 30))),
),
r"^<stdin>:10:30: error: expected integer type\n"
rf"^<stdin>:10:30: error: expected {integer_type}\n"
r'<stdin>:10:30: error: found enumeration type "__BUILTINS__::Boolean"$',
)

Expand All @@ -1181,8 +1180,8 @@ def test_relation_integer_type_error(relation: Callable[[Expr, Expr], Expr]) ->
def test_relation_composite_type(relation: Callable[[Expr, Expr], Expr]) -> None:
assert_type(
relation(
Variable("X", type_=rty.AnyInteger()),
Variable("Y", type_=rty.Sequence("A", rty.AnyInteger())),
Variable("X", type_=INT_TY),
Variable("Y", type_=rty.Sequence("A", INT_TY)),
),
rty.BOOLEAN,
)
Expand All @@ -1192,20 +1191,20 @@ def test_relation_composite_type(relation: Callable[[Expr, Expr], Expr]) -> None
def test_relation_composite_type_error(relation: Callable[[Expr, Expr], Expr]) -> None:
assert_type_error(
relation(
Variable("X", type_=rty.AnyInteger(), location=Location((10, 20))),
Variable("X", type_=INT_TY, location=Location((10, 20))),
Variable("True", type_=rty.BOOLEAN, location=Location((10, 30))),
),
r"^<stdin>:10:30: error: expected aggregate"
r" with element integer type\n"
r' with element integer type "I" \(10 \.\. 100\)\n'
r'<stdin>:10:30: error: found enumeration type "__BUILTINS__::Boolean"$',
)
assert_type_error(
relation(
Variable("X", type_=rty.AnyInteger(), location=Location((10, 20))),
Variable("X", type_=INT_TY, location=Location((10, 20))),
Variable("Y", type_=rty.Sequence("A", rty.BOOLEAN), location=Location((10, 30))),
),
r"^<stdin>:10:30: error: expected aggregate"
r" with element integer type\n"
r' with element integer type "I" \(10 \.\. 100\)\n'
r'<stdin>:10:30: error: found sequence type "A"'
r' with element enumeration type "__BUILTINS__::Boolean"$',
)
Expand Down Expand Up @@ -1470,13 +1469,16 @@ def test_value_range_type_error() -> None:
assert_type_error(
ValueRange(
Variable("X", type_=rty.BOOLEAN, location=Location((10, 30))),
Variable("Y", type_=rty.Sequence("A", rty.AnyInteger()), location=Location((10, 40))),
Variable("Y", type_=rty.Sequence("A", INT_TY), location=Location((10, 40))),
location=Location((10, 20)),
),
r"^<stdin>:10:30: error: expected integer type\n"
r"^"
r"<stdin>:10:30: error: expected integer type\n"
r'<stdin>:10:30: error: found enumeration type "__BUILTINS__::Boolean"\n'
r"<stdin>:10:40: error: expected integer type\n"
r'<stdin>:10:40: error: found sequence type "A" with element integer type$',
r'<stdin>:10:40: error: found sequence type "A"'
r' with element integer type "I" \(10 \.\. 100\)'
r"$",
)


Expand Down Expand Up @@ -1525,11 +1527,12 @@ def test_quantified_expression_type(expr: Callable[[str, Expr, Expr], Expr]) ->
[
(
Variable("Y", type_=rty.BOOLEAN, location=Location((10, 30))),
Variable("Z", type_=rty.Sequence("A", rty.AnyInteger()), location=Location((10, 40))),
Variable("Z", type_=rty.Sequence("A", INT_TY), location=Location((10, 40))),
r"^<stdin>:10:30: error: expected composite type\n"
r'<stdin>:10:30: error: found enumeration type "__BUILTINS__::Boolean"\n'
r'<stdin>:10:40: error: expected enumeration type "__BUILTINS__::Boolean"\n'
r'<stdin>:10:40: error: found sequence type "A" with element integer type$',
r'<stdin>:10:40: error: found sequence type "A"'
r' with element integer type "I" \(10 \.\. 100\)$',
),
(
Variable("Y", type_=rty.BOOLEAN, location=Location((10, 30))),
Expand Down Expand Up @@ -1940,17 +1943,17 @@ def test_call_type_error() -> None:
"X",
rty.BOOLEAN,
[
Variable("Y", type_=rty.AnyInteger(), location=Location((10, 30))),
Variable("Y", type_=INT_TY, location=Location((10, 30))),
Variable("Z", type_=rty.BOOLEAN, location=Location((10, 40))),
],
argument_types=[
rty.BOOLEAN,
rty.AnyInteger(),
INT_TY,
],
),
r'^<stdin>:10:30: error: expected enumeration type "__BUILTINS__::Boolean"\n'
r"<stdin>:10:30: error: found integer type\n"
r"<stdin>:10:40: error: expected integer type\n"
r'<stdin>:10:30: error: found integer type "I" \(10 \.\. 100\)\n'
r'<stdin>:10:40: error: expected integer type "I" \(10 \.\. 100\)\n'
r'<stdin>:10:40: error: found enumeration type "__BUILTINS__::Boolean"$',
)

Expand Down
14 changes: 8 additions & 6 deletions tests/unit/generator/session_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,11 +1105,11 @@ def _update_str(self) -> None:
ir.ObjFieldAccess(
"Z",
"Z",
rty.Message("B", {("Z",)}, {}, {ID("Z"): rty.Aggregate(rty.AnyInteger())}),
rty.Message("B", {("Z",)}, {}, {ID("Z"): rty.Aggregate(INT_TY)}),
origin=ir.ConstructedOrigin("", Location((10, 20))),
),
FatalError,
r'unexpected type \(aggregate with element integer type\) for "Z.Z"'
r'unexpected type \(aggregate with element integer type "I" \(1 \.\. 100\)\) for "Z.Z"'
r' in assignment of "X"',
),
(
Expand Down Expand Up @@ -1322,22 +1322,23 @@ def _update_str(self) -> None:
"E",
ir.ObjVar(
"L",
rty.Sequence("A", rty.AnyInteger()),
rty.Sequence("A", INT_TY),
origin=ir.ConstructedOrigin("", Location((10, 20))),
),
ir.ComplexExpr([], ir.ObjVar("E", INT_TY)),
ir.ComplexBoolExpr([], ir.Greater(ir.IntVar("E", INT_TY), ir.IntVal(0))),
),
RecordFluxError,
r"iterating over sequence of integer type in list comprehension not yet supported",
r'iterating over sequence of integer type "I" \(1 \.\. 100\) in list comprehension'
r" not yet supported",
),
(
INT_TY,
ir.Find(
"E",
ir.ObjVar(
"L",
rty.Sequence("A", rty.AnyInteger()),
rty.Sequence("A", INT_TY),
origin=ir.ConstructedOrigin("", Location((10, 20))),
),
ir.ComplexExpr([], ir.ObjVar("E", INT_TY)),
Expand All @@ -1347,7 +1348,8 @@ def _update_str(self) -> None:
),
),
RecordFluxError,
r"iterating over sequence of integer type in list comprehension not yet supported",
r'iterating over sequence of integer type "I" \(1 \.\. 100\) in list comprehension'
r" not yet supported",
),
(
rty.Sequence("A", INT_TY),
Expand Down
Loading

0 comments on commit dc9093c

Please sign in to comment.