diff --git a/rflx/generator/session.py b/rflx/generator/session.py index 94fd71a45..b2ba42213 100644 --- a/rflx/generator/session.py +++ b/rflx/generator/session.py @@ -5193,7 +5193,10 @@ def _convert_type( expression: ir.Expr, target_type: rty.Type, ) -> ir.Expr: - if target_type.is_compatible_strong(expression.type_): + if target_type.is_compatible_strong(expression.type_) and not isinstance( + expression, + ir.BinaryIntExpr, + ): return expression assert isinstance(target_type, (rty.Integer, rty.Enumeration)), target_type @@ -5202,18 +5205,25 @@ def _convert_type( if isinstance(expression, ir.BinaryIntExpr): assert isinstance(target_type, rty.Integer) - return expression.__class__( - ir.IntConversion( + left = ( + expression.left + if target_type.is_compatible_strong(expression.left.type_) + else ir.IntConversion( self._ada_type(target_type.identifier), expression.left, target_type, - ), - ir.IntConversion( + ) + ) + right = ( + expression.right + if target_type.is_compatible_strong(expression.right.type_) + else ir.IntConversion( self._ada_type(target_type.identifier), expression.right, target_type, - ), + ) ) + return expression.__class__(left, right) return ir.Conversion(self._ada_type(target_type.identifier), expression, target_type) diff --git a/rflx/ir.py b/rflx/ir.py index a5f1601e7..207715655 100644 --- a/rflx/ir.py +++ b/rflx/ir.py @@ -195,7 +195,11 @@ def substituted(self, mapping: Mapping[ID, ID]) -> Assign: ) def preconditions(self, variable_id: Generator[ID, None, None]) -> list[Cond]: - return self.expression.preconditions(variable_id) + return ( + self.expression.preconditions(variable_id, self.type_) + if isinstance(self.expression, BinaryIntExpr) + else self.expression.preconditions(variable_id) + ) def to_z3_expr(self) -> z3.BoolRef: target: Var @@ -856,6 +860,14 @@ class BinaryIntExpr(BinaryExpr, IntExpr): right: BasicIntExpr origin: Optional[Origin] = None + @abstractmethod + def preconditions( + self, + variable_id: Generator[ID, None, None], + target_type: rty.Type | None = None, + ) -> list[Cond]: + raise NotImplementedError + @property def type_(self) -> rty.AnyInteger: type_ = self.left.type_.common_type(self.right.type_) @@ -884,12 +896,17 @@ class Add(BinaryIntExpr): def to_z3_expr(self) -> z3.ArithRef: return self.left.to_z3_expr() + self.right.to_z3_expr() - def preconditions(self, variable_id: Generator[ID, None, None]) -> list[Cond]: + def preconditions( + self, + variable_id: Generator[ID, None, None], + target_type: rty.Type | None = None, + ) -> list[Cond]: + target_type = target_type or self.type_ v_id = next(variable_id) - v_type = to_integer(self.type_) + v_type = rty.BASE_INTEGER upper_bound = ( - self.type_.bounds.upper - if isinstance(self.type_, rty.AnyInteger) and self.type_.bounds is not None + target_type.bounds.upper + if isinstance(target_type, rty.AnyInteger) and target_type.bounds is not None else INT_MAX ) return [ @@ -935,7 +952,11 @@ class Sub(BinaryIntExpr): def to_z3_expr(self) -> z3.ArithRef: return self.left.to_z3_expr() - self.right.to_z3_expr() - def preconditions(self, _variable_id: Generator[ID, None, None]) -> list[Cond]: + def preconditions( + self, + _variable_id: Generator[ID, None, None], + _target_type: rty.Type | None = None, + ) -> list[Cond]: return [ # Left >= Right Cond(GreaterEqual(self.left, self.right)), @@ -951,12 +972,17 @@ class Mul(BinaryIntExpr): def to_z3_expr(self) -> z3.ArithRef: return self.left.to_z3_expr() * self.right.to_z3_expr() - def preconditions(self, variable_id: Generator[ID, None, None]) -> list[Cond]: + def preconditions( + self, + variable_id: Generator[ID, None, None], + target_type: rty.Type | None = None, + ) -> list[Cond]: + target_type = target_type or self.type_ v_id = next(variable_id) - v_type = to_integer(self.type_) + v_type = rty.BASE_INTEGER upper_bound = ( - self.type_.bounds.upper - if isinstance(self.type_, rty.AnyInteger) and self.type_.bounds is not None + target_type.bounds.upper + if isinstance(target_type, rty.AnyInteger) and target_type.bounds is not None else INT_MAX ) return [ @@ -994,7 +1020,11 @@ class Div(BinaryIntExpr): def to_z3_expr(self) -> z3.ArithRef: return self.left.to_z3_expr() / self.right.to_z3_expr() - def preconditions(self, _variable_id: Generator[ID, None, None]) -> list[Cond]: + def preconditions( + self, + _variable_id: Generator[ID, None, None], + _target_type: rty.Type | None = None, + ) -> list[Cond]: return [ # Right /= 0 Cond(NotEqual(self.right, IntVal(0))), @@ -1010,12 +1040,17 @@ class Pow(BinaryIntExpr): def to_z3_expr(self) -> z3.ArithRef: return self.left.to_z3_expr() ** self.right.to_z3_expr() - def preconditions(self, variable_id: Generator[ID, None, None]) -> list[Cond]: + def preconditions( + self, + variable_id: Generator[ID, None, None], + target_type: rty.Type | None = None, + ) -> list[Cond]: + target_type = target_type or self.type_ v_id = next(variable_id) - v_type = to_integer(self.type_) + v_type = rty.BASE_INTEGER upper_bound = ( - self.type_.bounds.upper - if isinstance(self.type_, rty.AnyInteger) and self.type_.bounds is not None + target_type.bounds.upper + if isinstance(target_type, rty.AnyInteger) and target_type.bounds is not None else INT_MAX ) return [ @@ -1027,7 +1062,7 @@ def preconditions(self, variable_id: Generator[ID, None, None]) -> list[Cond]: ), [ VarDecl(v_id, v_type, None, origin=self.origin), - Assign(v_id, self, to_integer(self.type_), origin=self.origin), + Assign(v_id, self, v_type, origin=self.origin), ], ), ] @@ -1042,7 +1077,11 @@ class Mod(BinaryIntExpr): def to_z3_expr(self) -> z3.ArithRef: return self.left.to_z3_expr() % self.right.to_z3_expr() - def preconditions(self, _variable_id: Generator[ID, None, None]) -> list[Cond]: + def preconditions( + self, + _variable_id: Generator[ID, None, None], + _target_type: rty.Type | None = None, + ) -> list[Cond]: return [ # Right /= 0 Cond(NotEqual(self.right, IntVal(0))), diff --git a/rflx/typing_.py b/rflx/typing_.py index 459fdbc2e..761e4b62f 100644 --- a/rflx/typing_.py +++ b/rflx/typing_.py @@ -113,10 +113,10 @@ def common_type(self, other: Type) -> Type: return UniversalInteger( Bounds.merge(self.bounds, other.bounds), ) - if isinstance(other, AnyInteger): - return other if other == Any() or self == other: return self + if isinstance(other, AnyInteger): + return BASE_INTEGER return Undefined() @@ -145,16 +145,10 @@ def is_compatible_strong(self, other: Type) -> bool: ) def common_type(self, other: Type) -> Type: - if isinstance(other, UniversalInteger): + if other == Any(): return self - if isinstance(other, Integer) and ( - self.identifier != other.identifier or self.bounds != other.bounds - ): - return BASE_INTEGER if isinstance(other, AnyInteger): - return other - if other == Any() or self == other: - return self + return BASE_INTEGER return Undefined() diff --git a/tests/feature/session_comprehension_on_message_field/generated/rflx-test-session.adb b/tests/feature/session_comprehension_on_message_field/generated/rflx-test-session.adb index 49cf12828..f2a3cf6b4 100644 --- a/tests/feature/session_comprehension_on_message_field/generated/rflx-test-session.adb +++ b/tests/feature/session_comprehension_on_message_field/generated/rflx-test-session.adb @@ -297,7 +297,7 @@ is pragma Assert (Universal.Message.Sufficient_Space (Ctx.P.Message_Ctx, Universal.Message.F_Message_Type)); Universal.Message.Set_Message_Type (Ctx.P.Message_Ctx, Universal.MT_Option_Types); pragma Assert (Universal.Message.Sufficient_Space (Ctx.P.Message_Ctx, Universal.Message.F_Length)); - Universal.Message.Set_Length (Ctx.P.Message_Ctx, Universal.Length (T_5) / Universal.Length (8)); + Universal.Message.Set_Length (Ctx.P.Message_Ctx, Universal.Length (T_5) / 8); if not Universal.Message.Valid_Length (Ctx.P.Message_Ctx, Universal.Message.F_Option_Types, Universal.Option_Types.Byte_Size (Option_Types_Ctx)) then Ctx.P.Next_State := S_Final; pragma Assert (Process_Invariant); diff --git a/tests/feature/session_comprehension_on_sequence/generated/rflx-test-session.adb b/tests/feature/session_comprehension_on_sequence/generated/rflx-test-session.adb index 0642e338f..0c48218fb 100644 --- a/tests/feature/session_comprehension_on_sequence/generated/rflx-test-session.adb +++ b/tests/feature/session_comprehension_on_sequence/generated/rflx-test-session.adb @@ -331,7 +331,7 @@ is pragma Assert (Universal.Message.Sufficient_Space (Ctx.P.Message_1_Ctx, Universal.Message.F_Message_Type)); Universal.Message.Set_Message_Type (Ctx.P.Message_1_Ctx, Universal.MT_Option_Types); pragma Assert (Universal.Message.Sufficient_Space (Ctx.P.Message_1_Ctx, Universal.Message.F_Length)); - Universal.Message.Set_Length (Ctx.P.Message_1_Ctx, Universal.Length (T_2) / Universal.Length (8)); + Universal.Message.Set_Length (Ctx.P.Message_1_Ctx, Universal.Length (T_2) / 8); if not Universal.Message.Valid_Length (Ctx.P.Message_1_Ctx, Universal.Message.F_Option_Types, Universal.Option_Types.Byte_Size (Option_Types_Ctx)) then Ctx.P.Next_State := S_Final; pragma Assert (Process_Invariant); @@ -459,7 +459,7 @@ is pragma Assert (Universal.Message.Sufficient_Space (Ctx.P.Message_2_Ctx, Universal.Message.F_Message_Type)); Universal.Message.Set_Message_Type (Ctx.P.Message_2_Ctx, Universal.MT_Options); pragma Assert (Universal.Message.Sufficient_Space (Ctx.P.Message_2_Ctx, Universal.Message.F_Length)); - Universal.Message.Set_Length (Ctx.P.Message_2_Ctx, Universal.Length (T_4) / Universal.Length (8)); + Universal.Message.Set_Length (Ctx.P.Message_2_Ctx, Universal.Length (T_4) / 8); if not Universal.Message.Valid_Length (Ctx.P.Message_2_Ctx, Universal.Message.F_Options, Universal.Options.Byte_Size (Message_Options_Ctx)) then Ctx.P.Next_State := S_Final; pragma Assert (Process_Invariant); diff --git a/tests/feature/session_functions/generated/rflx-test-session.adb b/tests/feature/session_functions/generated/rflx-test-session.adb index e24e9f376..f12eea898 100644 --- a/tests/feature/session_functions/generated/rflx-test-session.adb +++ b/tests/feature/session_functions/generated/rflx-test-session.adb @@ -118,7 +118,7 @@ is goto Finalize_Process; end if; -- tests/feature/session_functions/test.rflx:58:10 - Length := Test.Length (T_6) / Test.Length (T_7); + Length := Test.Length (T_6) / T_7; -- tests/feature/session_functions/test.rflx:60:10 declare Definite_Message : Test.Definite_Message.Structure; @@ -188,7 +188,7 @@ is -- tests/feature/session_functions/test.rflx:79:20 T_8 := RFLX.RFLX_Types.Base_Integer (Universal.Message.Size (Ctx.P.Message_Ctx)); -- tests/feature/session_functions/test.rflx:79:10 - Length := Test.Length (T_8) / Test.Length (8); + Length := Test.Length (T_8) / 8; -- tests/feature/session_functions/test.rflx:81:10 declare Definite_Message : Test.Definite_Message.Structure; diff --git a/tests/feature/session_message_creation/generated/rflx-test-session.adb b/tests/feature/session_message_creation/generated/rflx-test-session.adb index 775ac5f18..a675c421b 100644 --- a/tests/feature/session_message_creation/generated/rflx-test-session.adb +++ b/tests/feature/session_message_creation/generated/rflx-test-session.adb @@ -123,7 +123,7 @@ is pragma Assert (Universal.Message.Sufficient_Space (Ctx.P.M_S_Ctx, Universal.Message.F_Message_Type)); Universal.Message.Set_Message_Type (Ctx.P.M_S_Ctx, Universal.MT_Data); pragma Assert (Universal.Message.Sufficient_Space (Ctx.P.M_S_Ctx, Universal.Message.F_Length)); - Universal.Message.Set_Length (Ctx.P.M_S_Ctx, Universal.Length (T_11) / Universal.Length (8)); + Universal.Message.Set_Length (Ctx.P.M_S_Ctx, Universal.Length (T_11) / 8); declare function RFLX_Process_Data_Pre (Length : RFLX_Types.Length) return Boolean is (Universal.Message.Has_Buffer (Ctx.P.M_R_Ctx) diff --git a/tests/feature/session_sequence_append/generated/rflx-test-session.adb b/tests/feature/session_sequence_append/generated/rflx-test-session.adb index 17b2685fa..e09356052 100644 --- a/tests/feature/session_sequence_append/generated/rflx-test-session.adb +++ b/tests/feature/session_sequence_append/generated/rflx-test-session.adb @@ -114,7 +114,7 @@ is pragma Assert (Universal.Message.Sufficient_Space (Ctx.P.Message_Ctx, Universal.Message.F_Message_Type)); Universal.Message.Set_Message_Type (Ctx.P.Message_Ctx, Universal.MT_Options); pragma Assert (Universal.Message.Sufficient_Space (Ctx.P.Message_Ctx, Universal.Message.F_Length)); - Universal.Message.Set_Length (Ctx.P.Message_Ctx, Universal.Length (T_0) / Universal.Length (8)); + Universal.Message.Set_Length (Ctx.P.Message_Ctx, Universal.Length (T_0) / 8); if not Universal.Message.Valid_Length (Ctx.P.Message_Ctx, Universal.Message.F_Options, Universal.Options.Byte_Size (Options_Ctx)) then Ctx.P.Next_State := S_Final; pragma Assert (Process_Invariant); diff --git a/tests/feature/session_setting_of_message_fields/generated/rflx-test-session.adb b/tests/feature/session_setting_of_message_fields/generated/rflx-test-session.adb index 0957b017b..e358ca451 100644 --- a/tests/feature/session_setting_of_message_fields/generated/rflx-test-session.adb +++ b/tests/feature/session_setting_of_message_fields/generated/rflx-test-session.adb @@ -143,7 +143,7 @@ is pragma Assert (Universal.Message.Sufficient_Space (Local_Message_Ctx, Universal.Message.F_Message_Type)); Universal.Message.Set_Message_Type (Local_Message_Ctx, Universal.MT_Data); pragma Assert (Universal.Message.Sufficient_Space (Local_Message_Ctx, Universal.Message.F_Length)); - Universal.Message.Set_Length (Local_Message_Ctx, Universal.Length (T_7) / Universal.Length (8)); + Universal.Message.Set_Length (Local_Message_Ctx, Universal.Length (T_7) / 8); declare pragma Warnings (Off, "is not modified, could be declared constant"); RFLX_Ctx_P_Message_Ctx_Tmp : Universal.Message.Context := Ctx.P.Message_Ctx; diff --git a/tests/feature/session_variable_initialization/generated/rflx-test-session.adb b/tests/feature/session_variable_initialization/generated/rflx-test-session.adb index 8a88f00fa..2717d93ab 100644 --- a/tests/feature/session_variable_initialization/generated/rflx-test-session.adb +++ b/tests/feature/session_variable_initialization/generated/rflx-test-session.adb @@ -63,8 +63,8 @@ is is Local : Universal.Value := 2; T_0 : Universal.Value; - T_2 : Universal.Value; - T_3 : Universal.Value; + T_2 : RFLX.RFLX_Types.Base_Integer; + T_3 : RFLX.RFLX_Types.Base_Integer; T_1 : RFLX.RFLX_Types.Base_Integer; function Process_Invariant return Boolean is (Ctx.P.Slots.Slot_Ptr_1 = null) @@ -82,9 +82,9 @@ is end if; T_0 := Universal.Message.Get_Value (Ctx.P.Message_Ctx); -- tests/feature/session_variable_initialization/test.rflx:22:19 - T_2 := 255 - T_0; + T_2 := 255 - RFLX.RFLX_Types.Base_Integer (T_0); -- tests/feature/session_variable_initialization/test.rflx:22:19 - if not (RFLX.RFLX_Types.Base_Integer (Local) <= RFLX.RFLX_Types.Base_Integer (T_2)) then + if not (RFLX.RFLX_Types.Base_Integer (Local) <= T_2) then Ctx.P.Next_State := S_Final; pragma Assert (Process_Invariant); goto Finalize_Process; @@ -96,7 +96,7 @@ is -- tests/feature/session_variable_initialization/test.rflx:24:20 T_3 := 255 - 20; -- tests/feature/session_variable_initialization/test.rflx:24:20 - if not (RFLX.RFLX_Types.Base_Integer (Ctx.P.Uninitialized_Global) <= RFLX.RFLX_Types.Base_Integer (T_3)) then + if not (RFLX.RFLX_Types.Base_Integer (Ctx.P.Uninitialized_Global) <= T_3) then Ctx.P.Next_State := S_Final; pragma Assert (Process_Invariant); goto Finalize_Process; @@ -110,7 +110,7 @@ is pragma Assert (Universal.Message.Sufficient_Space (Ctx.P.Message_Ctx, Universal.Message.F_Message_Type)); Universal.Message.Set_Message_Type (Ctx.P.Message_Ctx, Universal.MT_Value); pragma Assert (Universal.Message.Sufficient_Space (Ctx.P.Message_Ctx, Universal.Message.F_Length)); - Universal.Message.Set_Length (Ctx.P.Message_Ctx, Universal.Length (T_1) / Universal.Length (8)); + Universal.Message.Set_Length (Ctx.P.Message_Ctx, Universal.Length (T_1) / 8); pragma Assert (Universal.Message.Sufficient_Space (Ctx.P.Message_Ctx, Universal.Message.F_Value)); Universal.Message.Set_Value (Ctx.P.Message_Ctx, Ctx.P.Global); if RFLX.RFLX_Types.Base_Integer (Local) < RFLX.RFLX_Types.Base_Integer (Ctx.P.Global) then diff --git a/tests/unit/expr_conv_test.py b/tests/unit/expr_conv_test.py index c7876932c..f83ea2bf9 100644 --- a/tests/unit/expr_conv_test.py +++ b/tests/unit/expr_conv_test.py @@ -227,14 +227,14 @@ def test_to_ir_neg() -> None: id_generator(), ) == ir.ComplexIntExpr( [ - ir.VarDecl("T_0", INT_TY), + ir.VarDecl("T_0", rty.BASE_INTEGER), ir.Assign( "T_0", ir.Add(ir.IntVar("X", INT_TY), ir.IntVar("Y", INT_TY)), - INT_TY, + rty.BASE_INTEGER, ), ], - ir.Neg(ir.IntVar("T_0", INT_TY)), + ir.Neg(ir.IntVar("T_0", rty.BASE_INTEGER)), ) @@ -260,10 +260,14 @@ def test_to_ir_add_mul( # type: ignore[misc] id_generator(), ) == ir.ComplexIntExpr( [ - ir.VarDecl("T_0", INT_TY), - ir.Assign("T_0", ir_op(ir.IntVar("Y", INT_TY), ir.IntVar("Z", INT_TY)), INT_TY), + ir.VarDecl("T_0", rty.BASE_INTEGER), + ir.Assign( + "T_0", + ir_op(ir.IntVar("Y", INT_TY), ir.IntVar("Z", INT_TY)), + rty.BASE_INTEGER, + ), ], - ir_op(ir.IntVar("X", INT_TY), ir.IntVar("T_0", INT_TY)), + ir_op(ir.IntVar("X", INT_TY), ir.IntVar("T_0", rty.BASE_INTEGER)), ) assert expr_conv.to_ir( op( @@ -276,10 +280,14 @@ def test_to_ir_add_mul( # type: ignore[misc] id_generator(), ) == ir.ComplexIntExpr( [ - ir.VarDecl("T_0", INT_TY), - ir.Assign("T_0", ir_op(ir.IntVar("X", INT_TY), ir.IntVar("Y", INT_TY)), INT_TY), + ir.VarDecl("T_0", rty.BASE_INTEGER), + ir.Assign( + "T_0", + ir_op(ir.IntVar("X", INT_TY), ir.IntVar("Y", INT_TY)), + rty.BASE_INTEGER, + ), ], - ir_op(ir.IntVar("T_0", INT_TY), ir.IntVar("Z", INT_TY)), + ir_op(ir.IntVar("T_0", rty.BASE_INTEGER), ir.IntVar("Z", INT_TY)), ) @@ -312,10 +320,14 @@ def test_to_ir_sub_div_pow_mod( # type: ignore[misc] id_generator(), ) == ir.ComplexIntExpr( [ - ir.VarDecl("T_0", INT_TY), - ir.Assign("T_0", ir_op(ir.IntVar("X", INT_TY), ir.IntVar("Y", INT_TY)), INT_TY), + ir.VarDecl("T_0", rty.BASE_INTEGER), + ir.Assign( + "T_0", + ir_op(ir.IntVar("X", INT_TY), ir.IntVar("Y", INT_TY)), + rty.BASE_INTEGER, + ), ], - ir_op(ir.IntVar("T_0", INT_TY), ir.IntVar("Z", INT_TY)), + ir_op(ir.IntVar("T_0", rty.BASE_INTEGER), ir.IntVar("Z", INT_TY)), ) @@ -575,7 +587,7 @@ def test_to_ir_comprehension() -> None: "X", Selected(Variable("M", type_=rty.Message("M")), "Y", type_=rty.Sequence("S", INT_TY)), Add(Variable("X", type_=INT_TY), Variable("Y", type_=INT_TY), Number(1)), - Less(Sub(Variable("X", type_=INT_TY), Number(1)), Number(100)), + Less(Sub(Variable("X", type_=INT_TY), Number(1)), Number(ir.INT_MAX)), ), id_generator(), ) == ir.ComplexExpr( @@ -585,17 +597,21 @@ def test_to_ir_comprehension() -> None: ir.ObjFieldAccess("M", ID("Y"), MSG_TY), ir.ComplexExpr( [ - ir.VarDecl("T_0", INT_TY), - ir.Assign("T_0", ir.Add(ir.IntVar("Y", INT_TY), ir.IntVal(1)), INT_TY), + ir.VarDecl("T_0", rty.BASE_INTEGER), + ir.Assign( + "T_0", + ir.Add(ir.IntVar("Y", INT_TY), ir.IntVal(1)), + rty.BASE_INTEGER, + ), ], - ir.Add(ir.IntVar("X", INT_TY), ir.IntVar("T_0", INT_TY)), + ir.Add(ir.IntVar("X", INT_TY), ir.IntVar("T_0", rty.BASE_INTEGER)), ), ir.ComplexBoolExpr( [ - ir.VarDecl("T_1", INT_TY), + ir.VarDecl("T_1", rty.BASE_INTEGER), ir.Assign("T_1", ir.Sub(ir.IntVar("X", INT_TY), ir.IntVal(1)), rty.BOOLEAN), ], - ir.Less(ir.IntVar("T_1", INT_TY), ir.IntVal(100)), + ir.Less(ir.IntVar("T_1", rty.BASE_INTEGER), ir.IntVal(ir.INT_MAX)), ), ), ) @@ -625,14 +641,14 @@ def test_to_ir_message_aggregate( # type: ignore[misc] id_generator(), ) == ir.ComplexExpr( [ - ir.VarDecl("T_0", INT_TY), - ir.Assign("T_0", ir.Add(ir.IntVar("Y", INT_TY), ir.IntVal(1)), INT_TY), + ir.VarDecl("T_0", rty.BASE_INTEGER), + ir.Assign("T_0", ir.Add(ir.IntVar("Y", INT_TY), ir.IntVal(1)), rty.BASE_INTEGER), ], ir_agg( "X", { ID("Y"): ir.ObjFieldAccess("M", ID("Y"), MSG_TY), - ID("Z"): ir.Add(ir.IntVar("X", INT_TY), ir.IntVar("T_0", INT_TY)), + ID("Z"): ir.Add(ir.IntVar("X", INT_TY), ir.IntVar("T_0", rty.BASE_INTEGER)), }, MSG_TY, ), diff --git a/tests/unit/expr_test.py b/tests/unit/expr_test.py index 454564e13..0a262b2d6 100644 --- a/tests/unit/expr_test.py +++ b/tests/unit/expr_test.py @@ -617,7 +617,7 @@ def test_number_hashable() -> None: def test_math_expr_type(operation: Callable[[Expr, Expr], Expr]) -> None: assert_type( operation(Variable("X", type_=INT_TY), Variable("Y", type_=INT_TY)), - INT_TY, + rty.BASE_INTEGER, ) @@ -2051,7 +2051,7 @@ def test_comprehension_type() -> None: Add(Variable("X"), Variable("Z", type_=INT_TY)), TRUE, ), - rty.Aggregate(INT_TY), + rty.Aggregate(rty.BASE_INTEGER), ) assert_type( Comprehension( diff --git a/tests/unit/ir_test.py b/tests/unit/ir_test.py index ba984a4b9..eb317175b 100644 --- a/tests/unit/ir_test.py +++ b/tests/unit/ir_test.py @@ -626,15 +626,15 @@ def test_add_to_z3_expr() -> None: def test_add_preconditions() -> None: assert ir.Add( - ir.IntVar("X", INT_TY), + ir.IntVar("X", rty.BASE_INTEGER), ir.IntVal(1), origin=expr.Add(expr.Variable("X"), expr.Number(1)), ).preconditions(id_generator()) == [ ir.Cond( ir.LessEqual(ir.IntVar("X", INT_TY), ir.IntVar("T_0", INT_TY)), [ - ir.VarDecl("T_0", INT_TY), - ir.Assign("T_0", ir.Sub(ir.IntVal(100), ir.IntVal(1)), INT_TY), + ir.VarDecl("T_0", rty.BASE_INTEGER), + ir.Assign("T_0", ir.Sub(ir.IntVal(ir.INT_MAX), ir.IntVal(1)), rty.BASE_INTEGER), ], ), ] @@ -669,8 +669,8 @@ def test_mul_preconditions() -> None: ir.Cond( ir.LessEqual(ir.IntVar("X", INT_TY), ir.IntVar("T_0", INT_TY)), [ - ir.VarDecl("T_0", INT_TY), - ir.Assign("T_0", ir.Div(ir.IntVal(100), ir.IntVal(1)), INT_TY), + ir.VarDecl("T_0", rty.BASE_INTEGER), + ir.Assign("T_0", ir.Div(ir.IntVal(ir.INT_MAX), ir.IntVal(1)), rty.BASE_INTEGER), ], ), ] @@ -703,10 +703,10 @@ def test_pow_to_z3_expr() -> None: def test_pow_preconditions() -> None: assert ir.Pow(ir.IntVar("X", INT_TY), ir.IntVal(1)).preconditions(id_generator()) == [ ir.Cond( - ir.LessEqual(ir.IntVar("T_0", INT_TY), ir.IntVal(100)), + ir.LessEqual(ir.IntVar("T_0", INT_TY), ir.IntVal(ir.INT_MAX)), [ - ir.VarDecl("T_0", INT_TY), - ir.Assign("T_0", ir.Pow(ir.IntVar("X", INT_TY), ir.IntVal(1)), INT_TY), + ir.VarDecl("T_0", rty.BASE_INTEGER), + ir.Assign("T_0", ir.Pow(ir.IntVar("X", INT_TY), ir.IntVal(1)), rty.BASE_INTEGER), ], ), ] @@ -1253,7 +1253,7 @@ def test_agg_str() -> None: def test_agg_type() -> None: - assert ir.Agg([ir.IntVar("X", INT_TY), ir.IntVal(10)]).type_ == rty.Aggregate(INT_TY) + assert ir.Agg([ir.IntVar("X", INT_TY), ir.IntVal(10)]).type_ == rty.Aggregate(rty.BASE_INTEGER) def test_agg_substituted() -> None: @@ -1462,17 +1462,17 @@ def test_add_required_checks() -> None: PROOF_MANAGER, id_generator(), ) == [ - ir.VarDecl("T_0", INT_TY), - ir.Assign("T_0", ir.Sub(ir.IntVal(100), ir.IntVal(1)), INT_TY), - ir.Check(ir.LessEqual(ir.IntVar("Y", INT_TY), ir.IntVar("T_0", INT_TY))), + ir.VarDecl("T_0", rty.BASE_INTEGER), + ir.Assign("T_0", ir.Sub(ir.IntVal(100), ir.IntVal(1)), rty.BASE_INTEGER), + ir.Check(ir.LessEqual(ir.IntVar("Y", INT_TY), ir.IntVar("T_0", rty.BASE_INTEGER))), ir.Assign("A", ir.Add(ir.IntVar("Y", INT_TY), ir.IntVal(1)), INT_TY), ir.Check(ir.NotEqual(ir.IntVar("Z", INT_TY), ir.IntVal(0))), ir.Assign("B", ir.Div(ir.IntVar("A", INT_TY), ir.IntVar("Z", INT_TY)), INT_TY), ir.Check(ir.GreaterEqual(ir.IntVar("B", INT_TY), ir.IntVal(1))), ir.Assign("X", ir.Sub(ir.IntVar("B", INT_TY), ir.IntVal(1)), INT_TY), ir.Assign("Z", ir.IntVal(0), INT_TY), - ir.VarDecl("T_1", INT_TY), - ir.Assign("T_1", ir.Sub(ir.IntVal(100), ir.IntVal(1)), INT_TY), - ir.Check(ir.LessEqual(ir.IntVar("Z", INT_TY), ir.IntVar("T_1", INT_TY))), + ir.VarDecl("T_1", rty.BASE_INTEGER), + ir.Assign("T_1", ir.Sub(ir.IntVal(100), ir.IntVal(1)), rty.BASE_INTEGER), + ir.Check(ir.LessEqual(ir.IntVar("Z", INT_TY), ir.IntVar("T_1", rty.BASE_INTEGER))), ir.Assign("C", ir.Add(ir.IntVar("Z", INT_TY), ir.IntVal(1)), INT_TY), ] diff --git a/tests/unit/model/statement_test.py b/tests/unit/model/statement_test.py index 3518c7d6e..0f9276796 100644 --- a/tests/unit/model/statement_test.py +++ b/tests/unit/model/statement_test.py @@ -16,7 +16,7 @@ def test_variable_assignment_to_ir() -> None: expr.Add(expr.Variable("Y", type_=INT_TY), expr.Number(1)), INT_TY, ).to_ir(id_generator()) == [ - ir.Assign("X", ir.Add(ir.IntVar("Y", INT_TY), ir.IntVal(1)), INT_TY), + ir.Assign("X", ir.Add(ir.IntVar("Y", INT_TY), ir.IntVal(1)), rty.BASE_INTEGER), ] assert stmt.VariableAssignment( "X", @@ -26,9 +26,13 @@ def test_variable_assignment_to_ir() -> None: ), INT_TY, ).to_ir(id_generator()) == [ - ir.VarDecl("T_0", INT_TY), + ir.VarDecl("T_0", rty.BASE_INTEGER), ir.Assign("T_0", ir.Sub(ir.IntVar("Z", INT_TY), ir.IntVal(1)), rty.BASE_INTEGER), - ir.Assign("X", ir.Add(ir.IntVar("Y", INT_TY), ir.IntVar("T_0", rty.BASE_INTEGER)), INT_TY), + ir.Assign( + "X", + ir.Add(ir.IntVar("Y", INT_TY), ir.IntVar("T_0", rty.BASE_INTEGER)), + rty.BASE_INTEGER, + ), ] @@ -46,7 +50,7 @@ def test_message_field_assignment_to_ir() -> None: ), MSG_TY, ).to_ir(id_generator()) == [ - ir.VarDecl("T_0", INT_TY), + ir.VarDecl("T_0", rty.BASE_INTEGER), ir.Assign("T_0", ir.Add(ir.IntVar("Z", INT_TY), ir.IntVal(1)), rty.BASE_INTEGER), ir.FieldAssign( "X", @@ -70,7 +74,7 @@ def test_append_to_ir() -> None: ), SEQ_TY, ).to_ir(id_generator()) == [ - ir.VarDecl("T_0", INT_TY), + ir.VarDecl("T_0", rty.BASE_INTEGER), ir.Assign("T_0", ir.Add(ir.IntVar("Z", INT_TY), ir.IntVal(1)), rty.BASE_INTEGER), ir.Append("X", ir.Add(ir.IntVar("Y", INT_TY), ir.IntVar("T_0", rty.BASE_INTEGER)), SEQ_TY), ] diff --git a/tests/unit/typing__test.py b/tests/unit/typing__test.py index 871eded8f..69e904fac 100644 --- a/tests/unit/typing__test.py +++ b/tests/unit/typing__test.py @@ -111,8 +111,16 @@ def test_base_integer_is_compatible(base_integer: Type, other: Type, expected: b @pytest.mark.parametrize( ("universal_integer", "other", "expected"), [ - (UniversalInteger(Bounds(10, 100)), Any(), UniversalInteger(Bounds(10, 100))), - (UniversalInteger(Bounds(10, 100)), BASE_INTEGER, BASE_INTEGER), + ( + UniversalInteger(Bounds(10, 100)), + Any(), + UniversalInteger(Bounds(10, 100)), + ), + ( + UniversalInteger(Bounds(10, 100)), + BASE_INTEGER, + BASE_INTEGER, + ), ( UniversalInteger(Bounds(10, 100)), UniversalInteger(Bounds(10, 100)), @@ -121,15 +129,23 @@ def test_base_integer_is_compatible(base_integer: Type, other: Type, expected: b ( UniversalInteger(Bounds(10, 100)), Integer("A", Bounds(10, 100)), - Integer("A", Bounds(10, 100)), + BASE_INTEGER, ), ( UniversalInteger(Bounds(20, 80)), Integer("A", Bounds(10, 100)), - Integer("A", Bounds(10, 100)), + BASE_INTEGER, + ), + ( + UniversalInteger(Bounds(10, 100)), + Undefined(), + Undefined(), + ), + ( + UniversalInteger(Bounds(10, 100)), + ENUMERATION_B, + Undefined(), ), - (UniversalInteger(Bounds(10, 100)), Undefined(), Undefined()), - (UniversalInteger(Bounds(10, 100)), ENUMERATION_B, Undefined()), ], ) def test_universal_integer_common_type( @@ -181,12 +197,12 @@ def test_universal_integer_is_compatible( ( Integer("A", Bounds(10, 100)), Integer("A", Bounds(10, 100)), - Integer("A", Bounds(10, 100)), + BASE_INTEGER, ), ( Integer("A", Bounds(10, 100)), UniversalInteger(Bounds(10, 100)), - Integer("A", Bounds(10, 100)), + BASE_INTEGER, ), ( Integer("A", Bounds(10, 100)), @@ -196,7 +212,7 @@ def test_universal_integer_is_compatible( ( Integer("A", Bounds(10, 100)), UniversalInteger(Bounds(0, 200)), - Integer("A", Bounds(10, 100)), + BASE_INTEGER, ), ( Integer("A", Bounds(10, 100)), @@ -550,9 +566,8 @@ def test_channel_is_compatible(channel: Type, other: Type, expected: bool) -> No [ Integer("A", Bounds(10, 100)), Integer("A", Bounds(10, 100)), - Integer("A", Bounds(10, 100)), ], - Integer("A", Bounds(10, 100)), + BASE_INTEGER, ), ( [ @@ -560,7 +575,7 @@ def test_channel_is_compatible(channel: Type, other: Type, expected: bool) -> No Integer("A", Bounds(10, 100)), UniversalInteger(Bounds(50, 100)), ], - Integer("A", Bounds(10, 100)), + BASE_INTEGER, ), ( [ @@ -568,7 +583,7 @@ def test_channel_is_compatible(channel: Type, other: Type, expected: bool) -> No Integer("A", Bounds(10, 100)), UniversalInteger(Bounds(20, 200)), ], - Integer("A", Bounds(10, 100)), + BASE_INTEGER, ), ( [