Skip to content

Commit

Permalink
Fix rejection of invalid parameter types and return types
Browse files Browse the repository at this point in the history
Ref. eng/recordflux/RecordFlux#977
  • Loading branch information
treiher committed Sep 13, 2024
1 parent c26f0c9 commit 49f8edf
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 61 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Fixed

- Rejection of invalid parameter types and return types in function declarations (eng/recordflux/RecordFlux#977)

## [0.24.0] - 2024-09-12

### Added
Expand Down Expand Up @@ -586,6 +592,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [0.1.0] - 2019-05-14

[Unreleased]: https://github.com/AdaCore/RecordFlux/compare/v0.24.0...HEAD
[0.24.0]: https://github.com/AdaCore/RecordFlux/compare/v0.23.0...v0.24.0
[0.23.0]: https://github.com/AdaCore/RecordFlux/compare/v0.22.0...v0.23.0
[0.22.0]: https://github.com/AdaCore/RecordFlux/compare/v0.21.0...v0.22.0
Expand Down
92 changes: 91 additions & 1 deletion rflx/model/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,13 +734,103 @@ def undefined_type(type_identifier: StrID, location: Location | None) -> None:
self.package,
)
if type_identifier in self.types:
a.type_ = self.types[type_identifier].type_
argument_type = self.types[type_identifier]
a.type_ = argument_type.type_
self._validate_function_parameter_type(type_identifier)
else:
a.type_ = rty.Any()
undefined_type(a.type_identifier, d.location)

return_type_id = type_decl.internal_type_identifier(
d.return_type,
self.package,
)
if return_type_id in self.types:
self._validate_function_return_type(
return_type_id,
)

visible_declarations[k] = d

def _validate_function_parameter_type(self, type_identifier: ID) -> None:
parameter_type = self.types[type_identifier]
if isinstance(parameter_type, Message) and not parameter_type.is_definite:
assert parameter_type.location
self.error.extend(
[
ErrorEntry(
"only definite messages can be used as function parameters",
Severity.ERROR,
type_identifier.location,
),
ErrorEntry(
"message type defined here",
Severity.NOTE,
parameter_type.location,
),
],
)
if (
not isinstance(parameter_type, (type_decl.Scalar, Message))
and parameter_type.identifier != rty.OPAQUE.identifier
):
assert type_identifier.location
self.error.extend(
[
ErrorEntry(
"invalid parameter type",
Severity.ERROR,
type_identifier.location,
[
Annotation(
"only scalars, definite messages and Opaque are allowed",
Severity.HELP,
type_identifier.location,
),
],
generate_default_annotation=False,
),
],
)

def _validate_function_return_type(self, type_identifier: ID) -> None:
return_type = self.types[type_identifier]
if isinstance(return_type, Message) and not return_type.is_definite:
assert return_type.location
self.error.extend(
[
ErrorEntry(
"only a definite message can be used as return type",
Severity.ERROR,
type_identifier.location,
),
ErrorEntry(
"message type defined here",
Severity.NOTE,
return_type.location,
),
],
)
if not isinstance(return_type, (type_decl.Scalar, Message)):
assert type_identifier.location
self.error.extend(
[
ErrorEntry(
"invalid return type",
Severity.ERROR,
type_identifier.location,
[
Annotation(
"only scalars and definite messages are allowed",
Severity.HELP,
type_identifier.location,
),
],
generate_default_annotation=False,
),
],
)

def _validate_actions(
self,
actions: Sequence[stmt.Statement],
Expand Down
164 changes: 104 additions & 60 deletions tests/unit/model/state_machine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,110 @@ def test_reset_incompatible() -> None:
)


@pytest.mark.parametrize(
("parameter_type", "expected_error"),
[
(
"TLV::Message",
r"<stdin>:1:2: error: only definite messages can be used as function parameters\n"
r"<stdin>:1:1: note: message type defined here",
),
(
"TLV::Messages",
r"<stdin>:1:2: error: invalid parameter type\n"
r"<stdin>:1:2: help: only scalars, definite messages and Opaque are allowed",
),
],
)
def test_function_declaration_invalid_parameter_type(
parameter_type: str,
expected_error: str,
) -> None:
assert_state_machine_model_error(
states=[
State(
"Start",
transitions=[Transition(target=ID("null"))],
exception_transition=(Transition(target=ID("null"))),
declarations=[],
actions=[
stmt.VariableAssignment(
"Result",
expr.Call("Function", rty.BOOLEAN, [expr.Variable("M")]),
),
],
),
],
declarations=[
decl.VariableDeclaration("M", parameter_type),
decl.VariableDeclaration("Result", "Boolean"),
],
parameters=[
decl.FunctionDeclaration(
"Function",
[decl.Argument("M", ID(parameter_type, location=Location((1, 2))))],
"Boolean",
),
],
types=[BOOLEAN, models.tlv_message(), models.tlv_messages()],
regex=rf"^{expected_error}$",
)


@pytest.mark.parametrize(
("return_type", "expected_error", "needs_exception_transition"),
[
(
"TLV::Message",
r"<stdin>:1:2: error: only a definite message can be used as return type\n"
r"<stdin>:1:1: note: message type defined here",
True,
),
(
"TLV::Messages",
r"<stdin>:1:2: error: invalid return type\n"
r"<stdin>:1:2: help: only scalars and definite messages are allowed",
False,
),
],
)
def test_function_declaration_invalid_return_type(
return_type: str,
expected_error: str,
needs_exception_transition: bool,
) -> None:
assert_state_machine_model_error(
states=[
State(
"Start",
transitions=[Transition(target=ID("null"))],
exception_transition=(
Transition(target=ID("null")) if needs_exception_transition else None
),
declarations=[],
actions=[
stmt.VariableAssignment(
"Result",
expr.Call("Function", rty.BOOLEAN, []),
),
],
),
],
declarations=[
decl.VariableDeclaration("Result", return_type),
],
parameters=[
decl.FunctionDeclaration(
"Function",
[],
ID(return_type, location=Location((1, 2))),
),
],
types=[BOOLEAN, models.tlv_message(), models.tlv_messages()],
regex=rf"^{expected_error}$",
)


def test_call_to_undeclared_function() -> None:
assert_state_machine_model_error(
states=[
Expand Down Expand Up @@ -1653,66 +1757,6 @@ def test_comprehension() -> None:
)


def test_assignment_opaque_function_undef_parameter() -> None:
assert_state_machine_model_error(
states=[
State(
"Start",
transitions=[Transition(target=ID("null"))],
exception_transition=Transition(target=ID("null")),
actions=[
stmt.VariableAssignment(
"Data",
expr.Opaque(
expr.Call(
"Sub",
rty.OPAQUE,
[expr.Variable("UndefData", location=Location((10, 20)))],
),
),
),
],
),
],
declarations=[
decl.VariableDeclaration("Data", "Opaque"),
],
parameters=[
decl.FunctionDeclaration("Sub", [decl.Argument("Param", "Opaque")], "TLV::Message"),
],
types=[BOOLEAN, OPAQUE, models.tlv_message()],
regex=r'^<stdin>:10:20: error: undefined variable "UndefData"$',
)


def test_assignment_opaque_function_result() -> None:
StateMachine(
identifier="P::S",
states=[
State(
"Start",
transitions=[Transition(target=ID("null"))],
exception_transition=Transition(target=ID("null")),
actions=[
stmt.VariableAssignment(
"Data",
expr.Opaque(
expr.Call("Sub", rty.OPAQUE, [expr.Variable("Data")]),
),
),
],
),
],
declarations=[
decl.VariableDeclaration("Data", "Opaque"),
],
parameters=[
decl.FunctionDeclaration("Sub", [decl.Argument("Param", "Opaque")], "TLV::Message"),
],
types=[BOOLEAN, OPAQUE, models.tlv_message()],
)


def test_message_field_assignment_with_invalid_field_name() -> None:
assert_state_machine_model_error(
states=[
Expand Down

0 comments on commit 49f8edf

Please sign in to comment.