From fca894c7fc0e3d628aa2f49f89903aaf14f4f299 Mon Sep 17 00:00:00 2001 From: Johannes Kanig Date: Fri, 5 Apr 2024 16:14:10 +0900 Subject: [PATCH] Make Call expressions have a mandatory type When creating a Call expression (mostly during code generation), a type must now be provided. In some rare cases (unit tests and conversion back from Ada expressions), the Undefined type is used. Ref. eng/recordflux/RecordFlux#1365 --- rflx/ada.py | 4 +- rflx/expression.py | 4 +- rflx/generator/common.py | 72 ++++++++++++++----- rflx/generator/generator.py | 26 +++++-- rflx/generator/message.py | 21 +++++- rflx/generator/parser.py | 3 +- rflx/generator/serializer.py | 4 +- rflx/model/session.py | 4 +- rflx/specification/parser.py | 1 + rflx/typing_.py | 12 ++++ tests/integration/specification_model_test.py | 7 +- tests/property/strategies.py | 6 +- tests/unit/ada_test.py | 3 +- tests/unit/expression_test.py | 51 +++++++------ tests/unit/generator/generator_test.py | 28 ++++++-- tests/unit/model/session_test.py | 48 +++++++++---- tests/unit/specification/grammar_test.py | 21 ++++-- 17 files changed, 230 insertions(+), 85 deletions(-) diff --git a/rflx/ada.py b/rflx/ada.py index a83665e4d..e53128b84 100644 --- a/rflx/ada.py +++ b/rflx/ada.py @@ -12,7 +12,7 @@ from typing_extensions import Self -from rflx import expression as expr +from rflx import expression as expr, typing_ as rty from rflx.common import Base, file_name, indent, indent_next, unique from rflx.contract import invariant from rflx.identifier import ID, StrID @@ -611,7 +611,7 @@ def _representation(self) -> str: def rflx_expr(self) -> expr.Call: assert not self.named_arguments - return expr.Call(self.identifier, [a.rflx_expr() for a in self.arguments]) + return expr.Call(self.identifier, rty.UNDEFINED, [a.rflx_expr() for a in self.arguments]) class Slice(Name): diff --git a/rflx/expression.py b/rflx/expression.py index 5b9affdf7..c5b64d3c5 100644 --- a/rflx/expression.py +++ b/rflx/expression.py @@ -2035,9 +2035,9 @@ class Call(Name): def __init__( # noqa: PLR0913 self, identifier: StrID, + type_: rty.Type, args: Optional[Sequence[Expr]] = None, immutable: bool = False, - type_: rty.Type = rty.UNDEFINED, argument_types: Optional[Sequence[rty.Type]] = None, location: Optional[Location] = None, ) -> None: @@ -2158,9 +2158,9 @@ def substituted( assert isinstance(expr, Call) return expr.__class__( expr.identifier, + expr.type_, [a.substituted(func) for a in expr.args], expr.immutable, - expr.type_, expr.argument_types, expr.location, ) diff --git a/rflx/generator/common.py b/rflx/generator/common.py index d60dd6171..cfe161ff1 100644 --- a/rflx/generator/common.py +++ b/rflx/generator/common.py @@ -6,7 +6,7 @@ from collections.abc import Callable from typing import Optional -from rflx import expression as expr, model +from rflx import expression as expr, model, typing_ as rty from rflx.ada import ( TRUE, Add, @@ -66,17 +66,24 @@ class Debug(enum.Enum): EXTERNAL = enum.auto() +def type_to_id(type_: rty.NamedType) -> ID: + if type_.identifier.parent == BUILTINS_PACKAGE: + return const.TYPES * type_.identifier.name + + return type_.identifier + + def substitution( message: model.Message, prefix: str, embedded: bool = False, public: bool = False, - target_type: ID = const.TYPES_BASE_INT, + target_type: rty.NamedType = rty.BASE_INTEGER, ) -> Callable[[expr.Expr], expr.Expr]: facts = substitution_facts(message, prefix, embedded, public, target_type) def type_conversion(expression: expr.Expr) -> expr.Expr: - return expr.Call(target_type, [expression]) + return expr.Call(type_to_id(target_type), target_type, [expression]) def func( # noqa: PLR0912 expression: expr.Expr, @@ -119,6 +126,7 @@ def byte_aggregate(aggregate: expr.Aggregate) -> expr.Aggregate: expr.ValueRange( expr.Call( const.TYPES_TO_INDEX, + rty.INDEX, [ expr.Selected( expr.Indexed( @@ -131,6 +139,7 @@ def byte_aggregate(aggregate: expr.Aggregate) -> expr.Aggregate: ), expr.Call( const.TYPES_TO_INDEX, + rty.INDEX, [ expr.Selected( expr.Indexed( @@ -147,6 +156,7 @@ def byte_aggregate(aggregate: expr.Aggregate) -> expr.Aggregate: ) equal_call = expr.Call( "Equal", + rty.BOOLEAN, [expr.Variable("Ctx"), expr.Variable(field.affixed_name), aggregate], ) return equal_call if isinstance(expression, expr.Equal) else expr.Not(equal_call) @@ -168,12 +178,18 @@ def byte_aggregate(aggregate: expr.Aggregate) -> expr.Aggregate: if boolean_literal and other: return expression.__class__( other, - type_conversion(expr.Call("To_Base_Integer", [boolean_literal])), + type_conversion( + expr.Call("To_Base_Integer", rty.BASE_INTEGER, [boolean_literal]), + ), ) def field_value(field: model.Field) -> expr.Expr: if public: - return expr.Call(f"Get_{field.name}", [expr.Variable("Ctx")]) + return expr.Call( + f"Get_{field.name}", + message.field_types[field].type_, + [expr.Variable("Ctx")], + ) return expr.Selected( expr.Indexed( expr.Variable(ID("Ctx") * "Cursors" if not embedded else "Cursors"), @@ -212,19 +228,24 @@ def substitution_facts( prefix: str, embedded: bool = False, public: bool = False, - target_type: ID = const.TYPES_BASE_INT, + target_type: rty.NamedType = rty.BASE_INTEGER, ) -> dict[expr.Name, expr.Expr]: def prefixed(name: str) -> expr.Expr: return expr.Variable(ID("Ctx") * name) if not embedded else expr.Variable(name) first = prefixed("First") - last = expr.Call("Written_Last", [expr.Variable("Ctx")]) if public else prefixed("Written_Last") + last = ( + expr.Call("Written_Last", rty.BIT_LENGTH, [expr.Variable("Ctx")]) + if public + else prefixed("Written_Last") + ) cursors = prefixed("Cursors") def field_first(field: model.Field) -> expr.Expr: if public: return expr.Call( "Field_First", + rty.BIT_INDEX, [expr.Variable("Ctx"), expr.Variable(field.affixed_name)], ) return expr.Selected(expr.Indexed(cursors, expr.Variable(field.affixed_name)), "First") @@ -233,6 +254,7 @@ def field_last(field: model.Field) -> expr.Expr: if public: return expr.Call( "Field_Last", + rty.BIT_LENGTH, [expr.Variable("Ctx"), expr.Variable(field.affixed_name)], ) return expr.Selected(expr.Indexed(cursors, expr.Variable(field.affixed_name)), "Last") @@ -241,6 +263,7 @@ def field_size(field: model.Field) -> expr.Expr: if public: return expr.Call( "Field_Size", + rty.BIT_LENGTH, [expr.Variable("Ctx"), expr.Variable(field.affixed_name)], ) return expr.Add( @@ -254,8 +277,16 @@ def field_size(field: model.Field) -> expr.Expr: def parameter_value(parameter: model.Field, parameter_type: model.Type) -> expr.Expr: if isinstance(parameter_type, model.Enumeration): if embedded: - return expr.Call("To_Base_Integer", [expr.Variable(parameter.name)]) - return expr.Call("To_Base_Integer", [expr.Variable("Ctx" * parameter.identifier)]) + return expr.Call( + "To_Base_Integer", + rty.BASE_INTEGER, + [expr.Variable(parameter.name)], + ) + return expr.Call( + "To_Base_Integer", + rty.BASE_INTEGER, + [expr.Variable("Ctx" * parameter.identifier)], + ) if isinstance(parameter_type, model.Scalar): if embedded: return expr.Variable(parameter.name) @@ -268,7 +299,8 @@ def field_value(field: model.Field, field_type: model.Type) -> expr.Expr: if public: return expr.Call( "To_Base_Integer", - [expr.Call(f"Get_{field.name}", [expr.Variable("Ctx")])], + rty.BASE_INTEGER, + [expr.Call(f"Get_{field.name}", field_type.type_, [expr.Variable("Ctx")])], ) return expr.Selected( expr.Indexed(cursors, expr.Variable(field.affixed_name)), @@ -276,7 +308,7 @@ def field_value(field: model.Field, field_type: model.Type) -> expr.Expr: ) if isinstance(field_type, model.Scalar): if public: - return expr.Call(f"Get_{field.name}", [expr.Variable("Ctx")]) + return expr.Call(f"Get_{field.name}", field_type.type_, [expr.Variable("Ctx")]) return expr.Selected( expr.Indexed(cursors, expr.Variable(field.affixed_name)), "Value", @@ -287,7 +319,7 @@ def field_value(field: model.Field, field_type: model.Type) -> expr.Expr: assert False, f'unexpected type "{type(field_type).__name__}"' def type_conversion(expression: expr.Expr) -> expr.Expr: - return expr.Call(target_type, [expression]) + return expr.Call(type_to_id(target_type), target_type, [expression]) return { expr.First("Message"): type_conversion(first), @@ -305,14 +337,20 @@ def type_conversion(expression: expr.Expr) -> expr.Expr: for f, t in message.field_types.items() }, **{ - expr.Literal(l): type_conversion(expr.Call("To_Base_Integer", [expr.Variable(l)])) + expr.Literal(l): type_conversion( + expr.Call("To_Base_Integer", rty.BASE_INTEGER, [expr.Variable(l)]), + ) for t in message.types.values() if isinstance(t, model.Enumeration) and t != model.BOOLEAN for l in t.literals }, **{ expr.Literal(t.package * l): type_conversion( - expr.Call("To_Base_Integer", [expr.Variable(prefix * t.package * l)]), + expr.Call( + "To_Base_Integer", + rty.BASE_INTEGER, + [expr.Variable(prefix * t.package * l)], + ), ) for t in message.types.values() if isinstance(t, model.Enumeration) and t != model.BOOLEAN @@ -348,14 +386,14 @@ def link_property(link: model.Link, unique: bool) -> Expr: field_type.size if isinstance(field_type, model.Scalar) else link.size.substituted( - substitution(message, prefix, embedded, target_type=const.TYPES_BIT_LENGTH), + substitution(message, prefix, embedded, target_type=rty.BIT_LENGTH), ).simplified() ) first = ( prefixed("First") if link.source == model.INITIAL else link.first.substituted( - substitution(message, prefix, embedded, target_type=const.TYPES_BIT_INDEX), + substitution(message, prefix, embedded, target_type=rty.BIT_INDEX), ) .substituted( mapping={ @@ -947,7 +985,7 @@ def substituted(expression: expr.Expr) -> Expr: substitution( message, prefix, - target_type=const.TYPES_BIT_LENGTH, + target_type=rty.BIT_LENGTH, embedded=True, ), ) diff --git a/rflx/generator/generator.py b/rflx/generator/generator.py index 84d8b74b1..3f95853b8 100644 --- a/rflx/generator/generator.py +++ b/rflx/generator/generator.py @@ -9,7 +9,7 @@ from pathlib import Path from typing import Optional -from rflx import __version__, expression as expr +from rflx import __version__, expression as expr, typing_ as rty from rflx.ada import ( FALSE, TRUE, @@ -1103,7 +1103,11 @@ def _create_contains_function( if isinstance(t, Enumeration) and t.always_valid: condition = expr.AndThen( expr.Selected( - expr.Call(pdu_identifier * f"Get_{f.name}", [expr.Variable("Ctx")]), + expr.Call( + pdu_identifier * f"Get_{f.name}", + t.type_, + [expr.Variable("Ctx")], + ), "Known", ), condition, @@ -1113,11 +1117,19 @@ def _create_contains_function( mapping={ expr.Variable(f.name): ( expr.Selected( - expr.Call(pdu_identifier * f"Get_{f.name}", [expr.Variable("Ctx")]), + expr.Call( + pdu_identifier * f"Get_{f.name}", + t.type_, + [expr.Variable("Ctx")], + ), "Enum", ) if isinstance(t, Enumeration) and t.always_valid - else expr.Call(pdu_identifier * f"Get_{f.name}", [expr.Variable("Ctx")]) + else expr.Call( + pdu_identifier * f"Get_{f.name}", + t.type_, + [expr.Variable("Ctx")], + ) ) for f, t in condition_fields.items() }, @@ -1591,7 +1603,7 @@ def _refinement_conditions( pdu_identifier = self._prefix * refinement.pdu.identifier conditions: list[expr.Expr] = [ - expr.Call(pdu_identifier * "Has_Buffer", [expr.Variable(pdu_context)]), + expr.Call(pdu_identifier * "Has_Buffer", rty.BOOLEAN, [expr.Variable(pdu_context)]), ] if null_sdu: @@ -1599,6 +1611,7 @@ def _refinement_conditions( [ expr.Call( pdu_identifier * "Well_Formed", + rty.BOOLEAN, [ expr.Variable(pdu_context), expr.Variable(pdu_identifier * refinement.field.affixed_name), @@ -1607,6 +1620,7 @@ def _refinement_conditions( expr.Not( expr.Call( pdu_identifier * "Present", + rty.BOOLEAN, [ expr.Variable(pdu_context), expr.Variable(pdu_identifier * refinement.field.affixed_name), @@ -1619,6 +1633,7 @@ def _refinement_conditions( conditions.append( expr.Call( pdu_identifier * "Present", + rty.BOOLEAN, [ expr.Variable(pdu_context), expr.Variable(pdu_identifier * refinement.field.affixed_name), @@ -1630,6 +1645,7 @@ def _refinement_conditions( [ expr.Call( pdu_identifier * "Valid", + rty.BOOLEAN, [ expr.Variable(pdu_context), expr.Variable(pdu_identifier * f.affixed_name), diff --git a/rflx/generator/message.py b/rflx/generator/message.py index a3e3c1359..9cfbcb5c9 100644 --- a/rflx/generator/message.py +++ b/rflx/generator/message.py @@ -3,7 +3,7 @@ from collections import abc from typing import Union -from rflx import expression as expr +from rflx import expression as expr, typing_ as rty from rflx.ada import ( FALSE, NULL, @@ -413,6 +413,7 @@ def create_valid_predecessors_invariant_function( if l.source in composite_fields else "Valid" ), + rty.BOOLEAN, [ expr.Indexed( expr.Variable("Cursors"), @@ -634,6 +635,7 @@ def create_field_first_internal_function(message: Message, prefix: str) -> UnitP def recursive_call(fld: Field) -> expr.Expr: return expr.Call( "Field_First_Internal", + rty.BIT_INDEX, [ expr.Variable("Cursors"), expr.Variable("First"), @@ -648,6 +650,7 @@ def recursive_call(fld: Field) -> expr.Expr: def field_size_internal_call(fld: expr.Variable) -> expr.Expr: return expr.Call( "Field_Size_Internal", + rty.BIT_LENGTH, [ expr.Variable("Cursors"), expr.Variable("First"), @@ -672,6 +675,7 @@ def link_first_expr(link: Link) -> tuple[expr.Expr, expr.Expr]: expr.AndThen( expr.Call( "Well_Formed", + rty.BOOLEAN, [ expr.Indexed( expr.Variable("Cursors"), @@ -699,7 +703,7 @@ def fld_first_expr(fld: Field) -> expr.Expr: first_expr = [link_first_expr(fld) for fld in incoming] return expr.IfExpr( first_expr, - expr.Call("RFLX_Types.Unreachable"), + expr.Call("RFLX_Types.Unreachable", rty.BOOLEAN), ) assert first_node != fld return expr.Add( @@ -2114,12 +2118,18 @@ def condition(field: Field, message: Message) -> Expr: c: expr.Expr = expr.Or(*[l.condition for l in message.outgoing(field)]) c = c.substituted( mapping={ - expr.Size(field.name): expr.Call(const.TYPES_BASE_INT, [expr.Variable("Size")]), + expr.Size(field.name): expr.Call( + const.TYPES_BASE_INT, + rty.BASE_INTEGER, + [expr.Variable("Size")], + ), expr.Last(field.name): expr.Call( const.TYPES_BASE_INT, + rty.BASE_INTEGER, [ expr.Call( "Field_Last", + rty.BIT_LENGTH, [ expr.Variable("Ctx"), expr.Variable(field.affixed_name, immutable=True), @@ -3611,12 +3621,14 @@ def func(expression: expr.Expr) -> expr.Expr: if isinstance(field_type, Enumeration): return expr.Call( "To_Base_Integer", + rty.BASE_INTEGER, [expr.Variable("Struct" * expression.identifier)], ) if isinstance(field_type, Scalar): return expr.Call( const.TYPES_BASE_INT, + rty.BASE_INTEGER, [expr.Variable("Struct" * expression.identifier)], ) @@ -3656,6 +3668,7 @@ def _create_to_context_procedure(prefix: str, message: Message) -> UnitPart: lambda x: ( expr.Call( const.TYPES_BIT_LENGTH, + rty.BIT_LENGTH, [expr.Variable("Struct" * x.identifier)], ) if isinstance(x, expr.Variable) @@ -3765,6 +3778,7 @@ def substitute(expression: expr.Expr) -> expr.Expr: ): return expr.Call( f"Field_Size_{expression.prefix.identifier}", + rty.BIT_LENGTH, [expr.Variable("Struct")], ) if ( @@ -3773,6 +3787,7 @@ def substitute(expression: expr.Expr) -> expr.Expr: ): return expr.Call( const.TYPES_BIT_LENGTH, + rty.BIT_LENGTH, [ expr.Selected( expr.Variable("Struct"), diff --git a/rflx/generator/parser.py b/rflx/generator/parser.py index 8f49577f7..41009252d 100644 --- a/rflx/generator/parser.py +++ b/rflx/generator/parser.py @@ -2,7 +2,7 @@ from collections.abc import Mapping, Sequence -import rflx.expression as expr +from rflx import expression as expr, typing_ as rty from rflx.ada import ( TRUE, Add, @@ -1180,6 +1180,7 @@ def valid_message_condition(self, message: Message, well_formed: bool = False) - and isinstance(message.field_types[l.source], Composite) else "Valid" ), + rty.BOOLEAN, [ expr.Variable("Ctx"), expr.Variable(l.source.affixed_name, immutable=True), diff --git a/rflx/generator/serializer.py b/rflx/generator/serializer.py index 2ff4d9614..ac20739d1 100644 --- a/rflx/generator/serializer.py +++ b/rflx/generator/serializer.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Optional -from rflx import expression as expr +from rflx import expression as expr, typing_ as rty from rflx.ada import ( TRUE, Add, @@ -121,7 +121,7 @@ def create_valid_size_function(self, message: Message) -> UnitPart: common.substitution( message, self.prefix, - target_type=const.TYPES_BIT_LENGTH, + target_type=rty.BIT_LENGTH, ), ) .simplified() diff --git a/rflx/model/session.py b/rflx/model/session.py index 871ccd3ce..95b675898 100644 --- a/rflx/model/session.py +++ b/rflx/model/session.py @@ -1034,9 +1034,9 @@ def normalize_identifiers( if expression.identifier in functions_map: return expr.Call( ID(functions_map[expression.identifier], location=expression.identifier.location), + expression.type_, [], expression.immutable, - expression.type_, location=expression.location, ) if expression.identifier in variables_map: @@ -1050,9 +1050,9 @@ def normalize_identifiers( if isinstance(expression, expr.Call) and expression.identifier in functions_map: return expr.Call( ID(functions_map[expression.identifier], location=expression.identifier.location), + expression.type_, expression.args, expression.immutable, - expression.type_, expression.argument_types, location=expression.location, ) diff --git a/rflx/specification/parser.py b/rflx/specification/parser.py index 95fe8fed6..69e3f5262 100644 --- a/rflx/specification/parser.py +++ b/rflx/specification/parser.py @@ -624,6 +624,7 @@ def create_call(error: RecordFluxError, expression: lang.Expr, filename: Path) - assert isinstance(expression, lang.Call) return expr.Call( create_id(error, expression.f_identifier, filename), + rty.UNDEFINED, [create_expression(error, a, filename) for a in expression.f_arguments], location=node_location(expression, filename), ) diff --git a/rflx/typing_.py b/rflx/typing_.py index 5f7aa0425..bb6c5628e 100644 --- a/rflx/typing_.py +++ b/rflx/typing_.py @@ -478,3 +478,15 @@ def _undefined_type(location: Optional[Location], description: str = "") -> Reco Bounds(0, 2**const.MAX_SCALAR_SIZE - 1), location=Location((0, 0), Path(str(const.BUILTINS_PACKAGE)), (0, 0)), ) + +INDEX = Integer( + const.BUILTINS_PACKAGE * "Index", + Bounds(1, 2**31 - 1), + location=Location((0, 0), Path(str(const.BUILTINS_PACKAGE)), (0, 0)), +) + +BIT_INDEX = Integer( + const.BUILTINS_PACKAGE * "Bit_Index", + Bounds(1, BIT_LENGTH.bounds.upper), + location=Location((0, 0), Path(str(const.BUILTINS_PACKAGE)), (0, 0)), +) diff --git a/tests/integration/specification_model_test.py b/tests/integration/specification_model_test.py index e9f400854..12b9bd88e 100644 --- a/tests/integration/specification_model_test.py +++ b/tests/integration/specification_model_test.py @@ -6,7 +6,7 @@ import pytest -from rflx import expression as expr +from rflx import expression as expr, typing_ as rty from rflx.error import ERROR_CONFIG, RecordFluxError from rflx.model import ( BOOLEAN, @@ -600,7 +600,10 @@ def test_consistency_specification_parsing_generation(tmp_path: Path) -> None: "null", condition=expr.And( expr.Equal(expr.Variable("Z"), expr.TRUE), - expr.Equal(expr.Call("G", [expr.Variable("F")]), expr.TRUE), + expr.Equal( + expr.Call("G", rty.BOOLEAN, [expr.Variable("F")]), + expr.TRUE, + ), ), description="rfc1149.txt+45:4-47:8", ), diff --git a/tests/property/strategies.py b/tests/property/strategies.py index e14bebd28..c480aca3d 100644 --- a/tests/property/strategies.py +++ b/tests/property/strategies.py @@ -7,7 +7,7 @@ from hypothesis import assume, strategies as st -from rflx import error, expression as expr +from rflx import error, expression as expr, typing_ as rty from rflx.identifier import ID from rflx.model import ( BUILTIN_TYPES, @@ -378,7 +378,9 @@ def attributes(draw: Draw, elements: st.SearchStrategy[expr.Expr]) -> expr.Expr: @st.composite def calls(draw: Draw, elements: st.SearchStrategy[expr.Expr]) -> expr.Call: - return draw(st.builds(expr.Call, identifiers(), st.lists(elements, min_size=1))) + return draw( + st.builds(expr.Call, identifiers(), st.just(rty.Undefined), st.lists(elements, min_size=1)), + ) @st.composite diff --git a/tests/unit/ada_test.py b/tests/unit/ada_test.py index 901da09b2..b2df3745f 100644 --- a/tests/unit/ada_test.py +++ b/tests/unit/ada_test.py @@ -5,7 +5,7 @@ import pytest -from rflx import ada, expression as expr +from rflx import ada, expression as expr, typing_ as rty from rflx.identifier import ID from tests.utils import assert_equal @@ -199,6 +199,7 @@ def test_indexed_rflx_expr() -> None: def test_call_rflx_expr() -> None: assert ada.Call("X", [ada.Variable("Y"), ada.Variable("Z")]).rflx_expr() == expr.Call( "X", + rty.UNDEFINED, [expr.Variable("Y"), expr.Variable("Z")], ) diff --git a/tests/unit/expression_test.py b/tests/unit/expression_test.py index 3aada7da0..c0a8e57c0 100644 --- a/tests/unit/expression_test.py +++ b/tests/unit/expression_test.py @@ -809,8 +809,8 @@ def test_neg_to_ir() -> None: def test_add_str() -> None: assert str(Add(Variable("X"), Number(1))) == "X + 1" assert str(-Add(Variable("X"), Number(1))) == "-X - 1" - assert str(Add(Number(1), Call("Test", []))) == "1 + Test" - assert str(Add(Number(1), -Call("Test", []))) == "1 - Test" + assert str(Add(Number(1), Call("Test", rty.BASE_INTEGER, []))) == "1 + Test" + assert str(Add(Number(1), -Call("Test", rty.BASE_INTEGER, []))) == "1 - Test" assert str(Add()) == "0" @@ -1236,6 +1236,7 @@ def test_attribute_type(attribute: Callable[[Expr], Expr], expr: Expr, expected: Opaque( Call( "X", + rty.UNDEFINED, [Variable("Y", location=Location((10, 30)))], location=Location((10, 20)), ), @@ -1267,12 +1268,12 @@ def test_attribute_substituted() -> None: Number(-42), ) assert_equal( - First("X").substituted(lambda x: Call("Y") if x == Variable("X") else x), - First(Call("Y")), + First("X").substituted(lambda x: Call("Y", rty.BASE_INTEGER) if x == Variable("X") else x), + First(Call("Y", rty.BASE_INTEGER)), ) assert_equal( - -First("X").substituted(lambda x: Call("Y") if x == Variable("X") else x), - -First(Call("Y")), + -First("X").substituted(lambda x: Call("Y", rty.BASE_INTEGER) if x == Variable("X") else x), + -First(Call("Y", rty.BASE_INTEGER)), ) assert_equal( -First("X").substituted( @@ -1306,7 +1307,10 @@ def test_attribute_str() -> None: def test_attribute_variables() -> None: assert First("X").variables() == [Variable("X")] - assert First(Call("X", [Variable("Y")])).variables() == [Variable("X"), Variable("Y")] + assert First(Call("X", rty.BASE_INTEGER, [Variable("Y")])).variables() == [ + Variable("X"), + Variable("Y"), + ] @pytest.mark.parametrize( @@ -1325,7 +1329,7 @@ def test_attribute_z3expr(attribute: Expr, z3name: str) -> None: def test_attribute_z3expr_error() -> None: with pytest.raises(Z3TypeError): - First(Call("X")).z3expr() + First(Call("X", rty.BASE_INTEGER)).z3expr() @pytest.mark.parametrize( @@ -2150,7 +2154,7 @@ def test_expr_substituted_pre() -> None: with pytest.raises(AssertionError): Selected(Variable("X"), "F").substituted(lambda x: x, {}) # pragma: no branch with pytest.raises(AssertionError): - Call("Sub").substituted(lambda x: x, {}) # pragma: no branch + Call("Sub", rty.BASE_INTEGER).substituted(lambda x: x, {}) # pragma: no branch with pytest.raises(AssertionError): ForAllOf("X", Variable("Y"), Variable("Z")).substituted( # pragma: no branch lambda x: x, @@ -2389,8 +2393,8 @@ def test_call_type() -> None: assert_type( Call( "X", + rty.BOOLEAN, [Variable("Y", type_=rty.Integer("A"))], - type_=rty.BOOLEAN, argument_types=[rty.Integer("A")], ), rty.BOOLEAN, @@ -2399,18 +2403,23 @@ def test_call_type() -> None: def test_call_type_error() -> None: assert_type_error( - Call("X", [Variable("Y", location=Location((10, 30)))], location=Location((10, 20))), + Call( + "X", + rty.UNDEFINED, + [Variable("Y", location=Location((10, 30)))], + location=Location((10, 20)), + ), r'^:10:30: model: error: undefined variable "Y"\n' r':10:20: model: error: undefined function "X"$', ) assert_type_error( Call( "X", + rty.BOOLEAN, [ Variable("Y", type_=rty.AnyInteger(), location=Location((10, 30))), Variable("Z", type_=rty.BOOLEAN, location=Location((10, 40))), ], - type_=rty.BOOLEAN, argument_types=[ rty.BOOLEAN, rty.AnyInteger(), @@ -2424,50 +2433,52 @@ def test_call_type_error() -> None: def test_call_variables() -> None: - result = Call("Sub", [Variable("A"), Variable("B")]).variables() + result = Call("Sub", rty.BASE_INTEGER, [Variable("A"), Variable("B")]).variables() expected = [Variable("Sub"), Variable("A"), Variable("B")] assert result == expected def test_call_findall() -> None: - assert Call("X", [Variable("Y"), Variable("Z")]).findall(lambda x: isinstance(x, Variable)) == [ + assert Call("X", rty.BASE_INTEGER, [Variable("Y"), Variable("Z")]).findall( + lambda x: isinstance(x, Variable), + ) == [ Variable("Y"), Variable("Z"), ] def test_call_str() -> None: - assert str(Call("Test", [])) == "Test" + assert str(Call("Test", rty.BASE_INTEGER, [])) == "Test" def test_call_neg() -> None: - assert -Call("Test", []) == Neg(Call("Test", [])) + assert -Call("Test", rty.BASE_INTEGER, []) == Neg(Call("Test", rty.BASE_INTEGER, [])) def test_call_to_ir() -> None: assert Call( "X", + INT_TY, [Variable("Y", type_=rty.BOOLEAN), Variable("Z", type_=INT_TY)], - type_=INT_TY, ).to_ir(id_generator()) == ir.ComplexExpr( [], ir.IntCall("X", [ir.BoolVar("Y"), ir.IntVar("Z", INT_TY)], [rty.BOOLEAN, INT_TY], INT_TY), ) assert Call( "X", + rty.BOOLEAN, [Variable("Y", type_=rty.BOOLEAN), Variable("Z", type_=rty.BOOLEAN)], - type_=rty.BOOLEAN, ).to_ir(id_generator()) == ir.ComplexExpr( [], ir.BoolCall("X", [ir.BoolVar("Y"), ir.BoolVar("Z")], [rty.BOOLEAN, rty.BOOLEAN]), ) assert Call( "X", + rty.BOOLEAN, [ And(Variable("X", type_=rty.BOOLEAN), TRUE), Add(Variable("Y", type_=INT_TY), Number(1)), ], - type_=rty.BOOLEAN, ).to_ir(id_generator()) == ir.ComplexExpr( [], ir.BoolCall( @@ -2481,8 +2492,8 @@ def test_call_to_ir() -> None: ) assert Call( "X", + MSG_TY, [Variable("Y", type_=rty.BOOLEAN), Variable("Z", type_=INT_TY)], - type_=MSG_TY, ).to_ir(id_generator()) == ir.ComplexExpr( [], ir.ObjCall("X", [ir.BoolVar("Y"), ir.IntVar("Z", INT_TY)], [rty.BOOLEAN, INT_TY], MSG_TY), diff --git a/tests/unit/generator/generator_test.py b/tests/unit/generator/generator_test.py index d7f765f06..2b7b818a0 100644 --- a/tests/unit/generator/generator_test.py +++ b/tests/unit/generator/generator_test.py @@ -154,6 +154,11 @@ def test_generate_partial_update(tmp_path: Path) -> None: Generator().generate(models.ethernet_model(), Integration(), tmp_path) +def test_type_translation() -> None: + assert (common.type_to_id(rty.BASE_INTEGER)) == const.TYPES_BASE_INT + assert (common.type_to_id(rty.NamedType("P::mytype"))) == ID("P::mytype") + + @pytest.mark.parametrize("model", models.spark_test_models()) def test_equality(model: Callable[[], Model], tmp_path: Path) -> None: assert_equal_code(model(), Integration(), GENERATED_DIR, tmp_path, accept_extra_files=True) @@ -186,6 +191,7 @@ def test_substitution_relation_aggregate( expr.ValueRange( expr.Call( const.TYPES_TO_INDEX, + rty.INDEX, [ expr.Selected( expr.Indexed( @@ -198,6 +204,7 @@ def test_substitution_relation_aggregate( ), expr.Call( const.TYPES_TO_INDEX, + rty.INDEX, [ expr.Selected( expr.Indexed( @@ -215,6 +222,7 @@ def test_substitution_relation_aggregate( else: equal_call = expr.Call( "Equal", + rty.BOOLEAN, [ expr.Variable("Ctx"), expr.Variable("F_Value"), @@ -235,14 +243,22 @@ def test_substitution_relation_aggregate( ( expr.Variable("Value"), expr.TRUE, - expr.Call("RFLX_Types::Base_Integer", [expr.Variable("Value")]), - expr.Call("RFLX_Types::Base_Integer", [expr.Call("To_Base_Integer", [expr.TRUE])]), + expr.Call("RFLX_Types::Base_Integer", rty.BASE_INTEGER, [expr.Variable("Value")]), + expr.Call( + "RFLX_Types::Base_Integer", + rty.BASE_INTEGER, + [expr.Call("To_Base_Integer", rty.BASE_INTEGER, [expr.TRUE])], + ), ), ( expr.FALSE, expr.Variable("Value"), - expr.Call("RFLX_Types::Base_Integer", [expr.Variable("Value")]), - expr.Call("RFLX_Types::Base_Integer", [expr.Call("To_Base_Integer", [expr.FALSE])]), + expr.Call("RFLX_Types::Base_Integer", rty.BASE_INTEGER, [expr.Variable("Value")]), + expr.Call( + "RFLX_Types::Base_Integer", + rty.BASE_INTEGER, + [expr.Call("To_Base_Integer", rty.BASE_INTEGER, [expr.FALSE])], + ), ), ], ) @@ -264,11 +280,11 @@ def test_substitution_relation_boolean_literal( [ ( (expr.Variable("Length"), expr.Number(1)), - (expr.Call("Get_Length", [expr.Variable("Ctx")]), expr.Number(1)), + (expr.Call("Get_Length", rty.BASE_INTEGER, [expr.Variable("Ctx")]), expr.Number(1)), ), ( (expr.Number(1), expr.Variable("Length")), - (expr.Number(1), expr.Call("Get_Length", [expr.Variable("Ctx")])), + (expr.Number(1), expr.Call("Get_Length", rty.BASE_INTEGER, [expr.Variable("Ctx")])), ), ((expr.Number(1), expr.Variable("Unknown")), (expr.Number(1), expr.Variable("Unknown"))), ], diff --git a/tests/unit/model/session_test.py b/tests/unit/model/session_test.py index 24ffc0035..7d6413502 100644 --- a/tests/unit/model/session_test.py +++ b/tests/unit/model/session_test.py @@ -49,7 +49,10 @@ def test_str() -> None: "null", condition=expr.And( expr.Equal(expr.Variable("Z"), expr.TRUE), - expr.Equal(expr.Call("G", [expr.Variable("F")]), expr.TRUE), + expr.Equal( + expr.Call("G", rty.BOOLEAN, [expr.Variable("F")]), + expr.TRUE, + ), ), description="rfc1149.txt+45:4-47:8", ), @@ -133,7 +136,10 @@ def test_identifier_normalization(monkeypatch: pytest.MonkeyPatch) -> None: "null", condition=expr.And( expr.Equal(expr.Variable("z"), expr.TRUE), - expr.Equal(expr.Call("g", [expr.Variable("f")]), expr.TRUE), + expr.Equal( + expr.Call("g", rty.BOOLEAN, [expr.Variable("f")]), + expr.TRUE, + ), ), ), Transition("a"), @@ -257,6 +263,7 @@ def test_inconsistent_identifier_casing() -> None: expr.Equal( expr.Call( ID("g", location=Location((7, 7))), + rty.BOOLEAN, [expr.Variable(ID("f", location=Location((8, 8))))], ), expr.TRUE, @@ -821,6 +828,7 @@ def test_call_to_undeclared_function() -> None: "Global", expr.Call( "UndefSub", + rty.UNDEFINED, [expr.Variable("Global")], location=Location((10, 20)), ), @@ -855,6 +863,7 @@ def test_call_undeclared_variable() -> None: "Result", expr.Call( "SubProg", + rty.BOOLEAN, [expr.Variable("Undefined", location=Location((10, 20)))], ), ), @@ -884,6 +893,7 @@ def test_call_invalid_argument_type() -> None: "Result", expr.Call( "Function", + rty.BOOLEAN, [expr.Variable("Channel", location=Location((10, 20)))], ), ), @@ -919,6 +929,7 @@ def test_call_missing_arguments() -> None: "Result", expr.Call( "Function", + rty.BOOLEAN, location=Location((10, 20)), ), ), @@ -948,6 +959,7 @@ def test_call_too_many_arguments() -> None: "Result", expr.Call( "Function", + rty.BOOLEAN, [expr.TRUE, expr.Number(1)], location=Location((10, 20)), ), @@ -1154,6 +1166,7 @@ def test_undeclared_variable_in_function_call() -> None: "Result", expr.Call( "SubProg", + rty.BOOLEAN, [expr.Variable("Undefined", location=Location((10, 20)))], ), ), @@ -1629,6 +1642,7 @@ def test_assignment_opaque_function_undef_parameter() -> None: expr.Opaque( expr.Call( "Sub", + rty.OPAQUE, [expr.Variable("UndefData", location=Location((10, 20)))], ), ), @@ -1659,7 +1673,7 @@ def test_assignment_opaque_function_result() -> None: stmt.VariableAssignment( "Data", expr.Opaque( - expr.Call("Sub", [expr.Variable("Data")]), + expr.Call("Sub", rty.OPAQUE, [expr.Variable("Data")]), ), ), ], @@ -1951,7 +1965,7 @@ def test_undefined_type_in_parameters(parameters: abc.Sequence[decl.FormalDeclar transitions=[ Transition( target=ID("null"), - condition=expr.Equal(expr.Call("X", [expr.TRUE]), expr.TRUE), + condition=expr.Equal(expr.Call("X", rty.BOOLEAN, [expr.TRUE]), expr.TRUE), ), Transition( target=ID("Start"), @@ -2326,7 +2340,7 @@ def test_missing_exception_transition() -> None: actions=[ stmt.VariableAssignment( "Tag", - expr.Call("SubProg"), + expr.Call("SubProg", models.tlv_tag().type_), ), ], ), @@ -2383,15 +2397,15 @@ def test_resolving_of_function_calls() -> None: global_decl = session.declarations[ID("Global")] assert isinstance(global_decl, decl.VariableDeclaration) - assert global_decl.expression == expr.Call("Func") + assert global_decl.expression == expr.Call("Func", rty.BOOLEAN) local_decl = session.states[0].declarations[ID("Local")] assert isinstance(local_decl, decl.VariableDeclaration) - assert local_decl.expression == expr.Call("Func") + assert local_decl.expression == expr.Call("Func", rty.BOOLEAN) local_stmt = session.states[0].actions[0] assert isinstance(local_stmt, stmt.VariableAssignment) - assert local_stmt.expression == expr.Call("Func") + assert local_stmt.expression == expr.Call("Func", rty.BOOLEAN) @pytest.mark.parametrize( @@ -2905,11 +2919,11 @@ def test_state_normalization( "Msg", expr.Call( "Func", - args=[expr.Opaque("Msg2")], - type_=rty.Message( + rty.Message( "M", is_definite=True, ), + args=[expr.Opaque("Msg2")], ), ), ], @@ -2937,10 +2951,10 @@ def test_state_normalization( "Msg", expr.Call( "Func", - args=[expr.Opaque("Msg2")], type_=rty.Structure( "M", ), + args=[expr.Opaque("Msg2")], ), ), ], @@ -3060,7 +3074,7 @@ def test_message_assignment_from_function() -> None: transitions=[Transition(target=ID("null"))], exception_transition=Transition(target=ID("null")), declarations=[decl.VariableDeclaration("Msg", "Null_Msg::Message")], - actions=[stmt.VariableAssignment("Msg", expr.Call("SubProg"))], + actions=[stmt.VariableAssignment("Msg", expr.Call("SubProg", rty.BASE_INTEGER))], ), ], declarations=[], @@ -3093,7 +3107,10 @@ def test_unchecked_session_checked() -> None: "null", condition=expr.And( expr.Equal(expr.Variable("Z"), expr.TRUE), - expr.Equal(expr.Call("G", [expr.Variable("F")]), expr.TRUE), + expr.Equal( + expr.Call("G", rty.BOOLEAN, [expr.Variable("F")]), + expr.TRUE, + ), ), description="rfc1149.txt+45:4-47:8", ), @@ -3139,7 +3156,10 @@ def test_unchecked_session_checked() -> None: "null", condition=expr.And( expr.Equal(expr.Variable("Z"), expr.TRUE), - expr.Equal(expr.Call("G", [expr.Variable("F")]), expr.TRUE), + expr.Equal( + expr.Call("G", rty.BOOLEAN, [expr.Variable("F")]), + expr.TRUE, + ), ), description="rfc1149.txt+45:4-47:8", ), diff --git a/tests/unit/specification/grammar_test.py b/tests/unit/specification/grammar_test.py index 2822485bf..579271586 100644 --- a/tests/unit/specification/grammar_test.py +++ b/tests/unit/specification/grammar_test.py @@ -17,7 +17,7 @@ create_statement, diagnostics_to_error, ) -from rflx.typing_ import BOOLEAN +from rflx.typing_ import BOOLEAN, UNDEFINED from tests.utils import parse, parse_bool_expression, parse_expression, parse_math_expression @@ -292,7 +292,10 @@ def test_mathematical_expression(string: str, expected: expr.Expr) -> None: ("string", "expected"), [ ("X + Y", expr.Add(expr.Variable("X"), expr.Variable("Y"))), - ("X + Y (Z)", expr.Add(expr.Variable("X"), expr.Call("Y", [expr.Variable("Z")]))), + ( + "X + Y (Z)", + expr.Add(expr.Variable("X"), expr.Call("Y", UNDEFINED, [expr.Variable("Z")])), + ), ], ) def test_extended_mathematical_expression(string: str, expected: expr.Expr) -> None: @@ -326,7 +329,10 @@ def test_boolean_expression(string: str, expected: expr.Expr) -> None: ("string", "expected"), [ ("X and Y", expr.And(expr.Variable("X"), expr.Variable("Y"))), - ("X and Y (Z)", expr.And(expr.Variable("X"), expr.Call("Y", [expr.Variable("Z")]))), + ( + "X and Y (Z)", + expr.And(expr.Variable("X"), expr.Call("Y", UNDEFINED, [expr.Variable("Z")])), + ), ], ) def test_extended_boolean_expression(string: str, expected: expr.Expr) -> None: @@ -433,7 +439,7 @@ def test_boolean_expression_error(string: str, error: expr.Expr, extended: bool) ), ( 'X (A, "S", 42)', - expr.Call("X", [expr.Variable("A"), expr.String("S"), expr.Number(42)]), + expr.Call("X", UNDEFINED, [expr.Variable("A"), expr.String("S"), expr.Number(42)]), ), ("X::Y (A)", expr.Conversion("X::Y", expr.Variable("A"))), ("X'(Y => Z)", expr.MessageAggregate("X", {ID("Y"): expr.Variable("Z")})), @@ -521,8 +527,11 @@ def test_expression_base(string: str, expected: expr.Expr) -> None: ), ), ("X::Y (Z) = 42", expr.Equal(expr.Conversion("X::Y", expr.Variable("Z")), expr.Number(42))), - ("X (Y).Z", expr.Selected(expr.Call("X", [expr.Variable("Y")]), "Z")), - ("X (Y).Z'Size", expr.Size(expr.Selected(expr.Call("X", [expr.Variable("Y")]), "Z"))), + ("X (Y).Z", expr.Selected(expr.Call("X", UNDEFINED, [expr.Variable("Y")]), "Z")), + ( + "X (Y).Z'Size", + expr.Size(expr.Selected(expr.Call("X", UNDEFINED, [expr.Variable("Y")]), "Z")), + ), ( "G::E not in P::S (E.D).V", expr.NotIn(