diff --git a/rflx/expr.py b/rflx/expr.py index 14d805a1b..3eb8e7d97 100644 --- a/rflx/expr.py +++ b/rflx/expr.py @@ -736,7 +736,7 @@ class MathAssExpr(AssExpr): def __init__(self, *terms: Expr, location: Optional[Location] = None) -> None: super().__init__(*terms, location=location) common_type = rty.common_type([t.type_ for t in terms]) - self.type_ = common_type if common_type != rty.UNDEFINED else rty.AnyInteger() + self.type_ = common_type if common_type != rty.UNDEFINED else rty.BASE_INTEGER def _check_type_subexpr(self) -> RecordFluxError: error = RecordFluxError() diff --git a/rflx/typing_.py b/rflx/typing_.py index ef73e6a67..459fdbc2e 100644 --- a/rflx/typing_.py +++ b/rflx/typing_.py @@ -89,19 +89,11 @@ def __str__(self) -> str: @attr.s(frozen=True) class AnyInteger(Any): DESCRIPTIVE_NAME: ClassVar[str] = "integer type" - bounds: Optional[Bounds] = attr.ib(default=None) - - def __str__(self) -> str: - bounds = f" ({self.bounds})" if self.bounds else "" - return f"{self.DESCRIPTIVE_NAME}{bounds}" + bounds: Bounds = attr.ib() + @abstractmethod def is_compatible(self, other: Type) -> bool: - return other == Any() or isinstance(other, AnyInteger) - - def common_type(self, other: Type) -> Type: - if other == Any() or isinstance(other, AnyInteger): - return self - return Undefined() + raise NotImplementedError @attr.s(frozen=True) @@ -109,6 +101,13 @@ class UniversalInteger(AnyInteger): DESCRIPTIVE_NAME: ClassVar[str] = "type universal integer" bounds: Bounds = attr.ib() + def __str__(self) -> str: + bounds = f" ({self.bounds})" if self.bounds else "" + return f"{self.DESCRIPTIVE_NAME}{bounds}" + + def is_compatible(self, other: Type) -> bool: + return other == Any() or isinstance(other, AnyInteger) + def common_type(self, other: Type) -> Type: if isinstance(other, UniversalInteger) and self.bounds != other.bounds: return UniversalInteger( @@ -132,6 +131,9 @@ def __str__(self) -> str: bounds = f" ({self.bounds})" if self.bounds else "" return f'{self.DESCRIPTIVE_NAME} "{self.identifier}"{bounds}' + def is_compatible(self, other: Type) -> bool: + return other == Any() or isinstance(other, AnyInteger) + def is_compatible_strong(self, other: Type) -> bool: return ( self == other @@ -148,7 +150,7 @@ def common_type(self, other: Type) -> Type: if isinstance(other, Integer) and ( self.identifier != other.identifier or self.bounds != other.bounds ): - return AnyInteger() + return BASE_INTEGER if isinstance(other, AnyInteger): return other if other == Any() or self == other: diff --git a/tests/unit/expr_test.py b/tests/unit/expr_test.py index 5e76e3e7f..454564e13 100644 --- a/tests/unit/expr_test.py +++ b/tests/unit/expr_test.py @@ -131,9 +131,9 @@ def test_not_type() -> None: def test_not_type_error() -> None: assert_type_error( - Not(Variable("X", type_=rty.AnyInteger(), location=Location((10, 20)))), + Not(Variable("X", type_=INT_TY, location=Location((10, 20)))), r'^:10:20: error: expected enumeration type "__BUILTINS__::Boolean"\n' - r":10:20: error: found integer type$", + r':10:20: error: found integer type "I" \(10 \.\. 100\)$', ) @@ -615,10 +615,6 @@ def test_number_hashable() -> None: @pytest.mark.parametrize("operation", [Add, Mul, Sub, Div, Pow]) def test_math_expr_type(operation: Callable[[Expr, Expr], Expr]) -> None: - assert_type( - operation(Variable("X", type_=rty.AnyInteger()), Variable("Y", type_=rty.AnyInteger())), - rty.AnyInteger(), - ) assert_type( operation(Variable("X", type_=INT_TY), Variable("Y", type_=INT_TY)), INT_TY, @@ -941,11 +937,11 @@ def test_attribute() -> None: @pytest.mark.parametrize( ("attribute", "expr", "expected"), [ - (Size, Variable("X", type_=rty.AnyInteger()), rty.UNIVERSAL_INTEGER), - (Length, Variable("X", type_=rty.AnyInteger()), rty.UNIVERSAL_INTEGER), - (First, Variable("X", type_=rty.AnyInteger()), rty.UNIVERSAL_INTEGER), - (Last, Variable("X", type_=rty.AnyInteger()), rty.UNIVERSAL_INTEGER), - (ValidChecksum, Variable("X", type_=rty.AnyInteger()), rty.BOOLEAN), + (Size, Variable("X", type_=INT_TY), rty.UNIVERSAL_INTEGER), + (Length, Variable("X", type_=INT_TY), rty.UNIVERSAL_INTEGER), + (First, Variable("X", type_=INT_TY), rty.UNIVERSAL_INTEGER), + (Last, Variable("X", type_=INT_TY), rty.UNIVERSAL_INTEGER), + (ValidChecksum, Variable("X", type_=INT_TY), rty.BOOLEAN), (Valid, Variable("X", type_=rty.Message("A")), rty.BOOLEAN), ( Present, @@ -1160,19 +1156,22 @@ def test_aggregate_precedence() -> None: @pytest.mark.parametrize("relation", [Less, LessEqual, Equal, GreaterEqual, Greater, NotEqual]) def test_relation_integer_type(relation: Callable[[Expr, Expr], Expr]) -> None: assert_type( - relation(Variable("X", type_=rty.AnyInteger()), Variable("Y", type_=rty.AnyInteger())), + relation(Variable("X", type_=INT_TY), Variable("Y", type_=INT_TY)), rty.BOOLEAN, ) @pytest.mark.parametrize("relation", [Less, LessEqual, Equal, GreaterEqual, Greater, NotEqual]) def test_relation_integer_type_error(relation: Callable[[Expr, Expr], Expr]) -> None: + integer_type = ( + r'integer type "I" \(10 \.\. 100\)' if relation in [Equal, NotEqual] else r"integer type" + ) assert_type_error( relation( - Variable("X", type_=rty.AnyInteger()), + Variable("X", type_=INT_TY), Variable("True", type_=rty.BOOLEAN, location=Location((10, 30))), ), - r"^:10:30: error: expected integer type\n" + rf"^:10:30: error: expected {integer_type}\n" r':10:30: error: found enumeration type "__BUILTINS__::Boolean"$', ) @@ -1181,8 +1180,8 @@ def test_relation_integer_type_error(relation: Callable[[Expr, Expr], Expr]) -> def test_relation_composite_type(relation: Callable[[Expr, Expr], Expr]) -> None: assert_type( relation( - Variable("X", type_=rty.AnyInteger()), - Variable("Y", type_=rty.Sequence("A", rty.AnyInteger())), + Variable("X", type_=INT_TY), + Variable("Y", type_=rty.Sequence("A", INT_TY)), ), rty.BOOLEAN, ) @@ -1192,20 +1191,20 @@ def test_relation_composite_type(relation: Callable[[Expr, Expr], Expr]) -> None def test_relation_composite_type_error(relation: Callable[[Expr, Expr], Expr]) -> None: assert_type_error( relation( - Variable("X", type_=rty.AnyInteger(), location=Location((10, 20))), + Variable("X", type_=INT_TY, location=Location((10, 20))), Variable("True", type_=rty.BOOLEAN, location=Location((10, 30))), ), r"^:10:30: error: expected aggregate" - r" with element integer type\n" + r' with element integer type "I" \(10 \.\. 100\)\n' r':10:30: error: found enumeration type "__BUILTINS__::Boolean"$', ) assert_type_error( relation( - Variable("X", type_=rty.AnyInteger(), location=Location((10, 20))), + Variable("X", type_=INT_TY, location=Location((10, 20))), Variable("Y", type_=rty.Sequence("A", rty.BOOLEAN), location=Location((10, 30))), ), r"^:10:30: error: expected aggregate" - r" with element integer type\n" + r' with element integer type "I" \(10 \.\. 100\)\n' r':10:30: error: found sequence type "A"' r' with element enumeration type "__BUILTINS__::Boolean"$', ) @@ -1470,13 +1469,16 @@ def test_value_range_type_error() -> None: assert_type_error( ValueRange( Variable("X", type_=rty.BOOLEAN, location=Location((10, 30))), - Variable("Y", type_=rty.Sequence("A", rty.AnyInteger()), location=Location((10, 40))), + Variable("Y", type_=rty.Sequence("A", INT_TY), location=Location((10, 40))), location=Location((10, 20)), ), - r"^:10:30: error: expected integer type\n" + r"^" + r":10:30: error: expected integer type\n" r':10:30: error: found enumeration type "__BUILTINS__::Boolean"\n' r":10:40: error: expected integer type\n" - r':10:40: error: found sequence type "A" with element integer type$', + r':10:40: error: found sequence type "A"' + r' with element integer type "I" \(10 \.\. 100\)' + r"$", ) @@ -1525,11 +1527,12 @@ def test_quantified_expression_type(expr: Callable[[str, Expr, Expr], Expr]) -> [ ( Variable("Y", type_=rty.BOOLEAN, location=Location((10, 30))), - Variable("Z", type_=rty.Sequence("A", rty.AnyInteger()), location=Location((10, 40))), + Variable("Z", type_=rty.Sequence("A", INT_TY), location=Location((10, 40))), r"^:10:30: error: expected composite type\n" r':10:30: error: found enumeration type "__BUILTINS__::Boolean"\n' r':10:40: error: expected enumeration type "__BUILTINS__::Boolean"\n' - r':10:40: error: found sequence type "A" with element integer type$', + r':10:40: error: found sequence type "A"' + r' with element integer type "I" \(10 \.\. 100\)$', ), ( Variable("Y", type_=rty.BOOLEAN, location=Location((10, 30))), @@ -1940,17 +1943,17 @@ def test_call_type_error() -> None: "X", rty.BOOLEAN, [ - Variable("Y", type_=rty.AnyInteger(), location=Location((10, 30))), + Variable("Y", type_=INT_TY, location=Location((10, 30))), Variable("Z", type_=rty.BOOLEAN, location=Location((10, 40))), ], argument_types=[ rty.BOOLEAN, - rty.AnyInteger(), + INT_TY, ], ), r'^:10:30: error: expected enumeration type "__BUILTINS__::Boolean"\n' - r":10:30: error: found integer type\n" - r":10:40: error: expected integer type\n" + r':10:30: error: found integer type "I" \(10 \.\. 100\)\n' + r':10:40: error: expected integer type "I" \(10 \.\. 100\)\n' r':10:40: error: found enumeration type "__BUILTINS__::Boolean"$', ) diff --git a/tests/unit/generator/session_test.py b/tests/unit/generator/session_test.py index d2e440244..049241785 100644 --- a/tests/unit/generator/session_test.py +++ b/tests/unit/generator/session_test.py @@ -1105,11 +1105,11 @@ def _update_str(self) -> None: ir.ObjFieldAccess( "Z", "Z", - rty.Message("B", {("Z",)}, {}, {ID("Z"): rty.Aggregate(rty.AnyInteger())}), + rty.Message("B", {("Z",)}, {}, {ID("Z"): rty.Aggregate(INT_TY)}), origin=ir.ConstructedOrigin("", Location((10, 20))), ), FatalError, - r'unexpected type \(aggregate with element integer type\) for "Z.Z"' + r'unexpected type \(aggregate with element integer type "I" \(1 \.\. 100\)\) for "Z.Z"' r' in assignment of "X"', ), ( @@ -1322,14 +1322,15 @@ def _update_str(self) -> None: "E", ir.ObjVar( "L", - rty.Sequence("A", rty.AnyInteger()), + rty.Sequence("A", INT_TY), origin=ir.ConstructedOrigin("", Location((10, 20))), ), ir.ComplexExpr([], ir.ObjVar("E", INT_TY)), ir.ComplexBoolExpr([], ir.Greater(ir.IntVar("E", INT_TY), ir.IntVal(0))), ), RecordFluxError, - r"iterating over sequence of integer type in list comprehension not yet supported", + r'iterating over sequence of integer type "I" \(1 \.\. 100\) in list comprehension' + r" not yet supported", ), ( INT_TY, @@ -1337,7 +1338,7 @@ def _update_str(self) -> None: "E", ir.ObjVar( "L", - rty.Sequence("A", rty.AnyInteger()), + rty.Sequence("A", INT_TY), origin=ir.ConstructedOrigin("", Location((10, 20))), ), ir.ComplexExpr([], ir.ObjVar("E", INT_TY)), @@ -1347,7 +1348,8 @@ def _update_str(self) -> None: ), ), RecordFluxError, - r"iterating over sequence of integer type in list comprehension not yet supported", + r'iterating over sequence of integer type "I" \(1 \.\. 100\) in list comprehension' + r" not yet supported", ), ( rty.Sequence("A", INT_TY), diff --git a/tests/unit/typing__test.py b/tests/unit/typing__test.py index 57732d1a4..871eded8f 100644 --- a/tests/unit/typing__test.py +++ b/tests/unit/typing__test.py @@ -9,9 +9,9 @@ from rflx.rapidflux import Location, RecordFluxError from rflx.rapidflux.ty import Bounds from rflx.typing_ import ( + BASE_INTEGER, Aggregate, Any, - AnyInteger, Channel, Enumeration, Integer, @@ -63,20 +63,20 @@ def test_enumeration_is_compatible(enumeration: Type, other: Type, expected: boo @pytest.mark.parametrize( ("base_integer", "other", "expected"), [ - (AnyInteger(), Any(), AnyInteger()), - (AnyInteger(), AnyInteger(), AnyInteger()), + (BASE_INTEGER, Any(), BASE_INTEGER), + (BASE_INTEGER, BASE_INTEGER, BASE_INTEGER), ( - AnyInteger(), + BASE_INTEGER, Integer("A", Bounds(10, 100)), - AnyInteger(), + BASE_INTEGER, ), ( - AnyInteger(), + BASE_INTEGER, UniversalInteger(Bounds(10, 100)), - AnyInteger(), + BASE_INTEGER, ), - (AnyInteger(), Undefined(), Undefined()), - (AnyInteger(), ENUMERATION_B, Undefined()), + (BASE_INTEGER, Undefined(), Undefined()), + (BASE_INTEGER, ENUMERATION_B, Undefined()), ], ) def test_base_integer_common_type(base_integer: Type, other: Type, expected: Type) -> None: @@ -87,20 +87,20 @@ def test_base_integer_common_type(base_integer: Type, other: Type, expected: Typ @pytest.mark.parametrize( ("base_integer", "other", "expected"), [ - (AnyInteger(), Any(), True), - (AnyInteger(), AnyInteger(), True), + (BASE_INTEGER, Any(), True), + (BASE_INTEGER, BASE_INTEGER, True), ( - AnyInteger(), + BASE_INTEGER, Integer("A", Bounds(10, 100)), True, ), ( - AnyInteger(), + BASE_INTEGER, UniversalInteger(Bounds(10, 100)), True, ), - (AnyInteger(), Undefined(), False), - (AnyInteger(), ENUMERATION_B, False), + (BASE_INTEGER, Undefined(), False), + (BASE_INTEGER, ENUMERATION_B, False), ], ) def test_base_integer_is_compatible(base_integer: Type, other: Type, expected: bool) -> None: @@ -112,7 +112,7 @@ def test_base_integer_is_compatible(base_integer: Type, other: Type, expected: b ("universal_integer", "other", "expected"), [ (UniversalInteger(Bounds(10, 100)), Any(), UniversalInteger(Bounds(10, 100))), - (UniversalInteger(Bounds(10, 100)), AnyInteger(), AnyInteger()), + (UniversalInteger(Bounds(10, 100)), BASE_INTEGER, BASE_INTEGER), ( UniversalInteger(Bounds(10, 100)), UniversalInteger(Bounds(10, 100)), @@ -145,7 +145,7 @@ def test_universal_integer_common_type( ("universal_integer", "other", "expected"), [ (UniversalInteger(Bounds(10, 100)), Any(), True), - (UniversalInteger(Bounds(10, 100)), AnyInteger(), True), + (UniversalInteger(Bounds(10, 100)), BASE_INTEGER, True), (UniversalInteger(Bounds(10, 100)), UniversalInteger(Bounds(10, 100)), True), ( UniversalInteger(Bounds(10, 100)), @@ -175,8 +175,8 @@ def test_universal_integer_is_compatible( ), ( Integer("A", Bounds(10, 100)), - AnyInteger(), - AnyInteger(), + BASE_INTEGER, + BASE_INTEGER, ), ( Integer("A", Bounds(10, 100)), @@ -191,7 +191,7 @@ def test_universal_integer_is_compatible( ( Integer("A", Bounds(10, 100)), Integer("B", Bounds(10, 100)), - AnyInteger(), + BASE_INTEGER, ), ( Integer("A", Bounds(10, 100)), @@ -219,7 +219,7 @@ def test_integer_common_type(integer: Type, other: Type, expected: Type) -> None ("integer", "other", "expected"), [ (Integer("A", Bounds(10, 100)), Any(), True), - (Integer("A", Bounds(10, 100)), AnyInteger(), True), + (Integer("A", Bounds(10, 100)), BASE_INTEGER, True), (Integer("A", Bounds(10, 100)), Integer("A", Bounds(10, 100)), True), (Integer("A", Bounds(10, 100)), UniversalInteger(Bounds(10, 100)), True), ( @@ -250,7 +250,7 @@ def test_integer_is_compatible(integer: Type, other: Type, expected: bool) -> No ("integer", "other", "expected"), [ (Integer("A", Bounds(10, 100)), Any(), True), - (Integer("A", Bounds(10, 100)), AnyInteger(), False), + (Integer("A", Bounds(10, 100)), BASE_INTEGER, False), (Integer("A", Bounds(10, 100)), Integer("A", Bounds(10, 100)), True), (Integer("A", Bounds(10, 100)), UniversalInteger(Bounds(10, 100)), True), ( @@ -293,12 +293,12 @@ def test_integer_is_compatible_strong(integer: Type, other: Type, expected: bool ( Aggregate(Integer("A", Bounds(10, 100))), Aggregate(Integer("B", Bounds(10, 100))), - Aggregate(AnyInteger()), + Aggregate(BASE_INTEGER), ), ( Aggregate(Integer("A", Bounds(10, 100))), Aggregate(Integer("A", Bounds(20, 200))), - Aggregate(AnyInteger()), + Aggregate(BASE_INTEGER), ), ( Aggregate(UniversalInteger(Bounds(10, 100))), @@ -576,7 +576,7 @@ def test_channel_is_compatible(channel: Type, other: Type, expected: bool) -> No Aggregate(UniversalInteger(Bounds(20, 100))), Aggregate(Integer("B", Bounds(20, 200))), ], - Aggregate(AnyInteger()), + Aggregate(BASE_INTEGER), ), ( [ @@ -616,10 +616,13 @@ def test_check_type(actual: Type, expected: Type) -> None: r':10:20: error: found message type "A"$', ), ( - AnyInteger(), + BASE_INTEGER, Message("A"), - r'^:10:20: error: expected message type "A"\n' - r":10:20: error: found integer type$", + r"^" + r':10:20: error: expected message type "A"\n' + r':10:20: error: found integer type "__BUILTINS__::Base_Integer"' + r" \(0 \.\. 9223372036854775807\)" + r"$", ), ( Undefined(), @@ -659,10 +662,13 @@ def test_check_type_instance( r':10:20: error: found message type "M"$', ), ( - AnyInteger(), + BASE_INTEGER, (Sequence, Message), - r"^:10:20: error: expected sequence type or message type\n" - r":10:20: error: found integer type$", + r"^" + r":10:20: error: expected sequence type or message type\n" + r':10:20: error: found integer type "__BUILTINS__::Base_Integer"' + r" \(0 \.\. 9223372036854775807\)" + r"$", ), ( Undefined(),