Skip to content

Commit

Permalink
Make Call expressions have a mandatory type
Browse files Browse the repository at this point in the history
When creating a Call expression (mostly during code generation), a type
must now be provided. In some rare cases (unit tests and conversion back
from Ada expressions), the Undefined type is used.

Ref. eng/recordflux/RecordFlux#1365
  • Loading branch information
kanigsson committed Apr 9, 2024
1 parent d08b28a commit fca894c
Show file tree
Hide file tree
Showing 17 changed files with 230 additions and 85 deletions.
4 changes: 2 additions & 2 deletions rflx/ada.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from typing_extensions import Self

from rflx import expression as expr
from rflx import expression as expr, typing_ as rty
from rflx.common import Base, file_name, indent, indent_next, unique
from rflx.contract import invariant
from rflx.identifier import ID, StrID
Expand Down Expand Up @@ -611,7 +611,7 @@ def _representation(self) -> str:

def rflx_expr(self) -> expr.Call:
assert not self.named_arguments
return expr.Call(self.identifier, [a.rflx_expr() for a in self.arguments])
return expr.Call(self.identifier, rty.UNDEFINED, [a.rflx_expr() for a in self.arguments])


class Slice(Name):
Expand Down
4 changes: 2 additions & 2 deletions rflx/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2035,9 +2035,9 @@ class Call(Name):
def __init__( # noqa: PLR0913
self,
identifier: StrID,
type_: rty.Type,
args: Optional[Sequence[Expr]] = None,
immutable: bool = False,
type_: rty.Type = rty.UNDEFINED,
argument_types: Optional[Sequence[rty.Type]] = None,
location: Optional[Location] = None,
) -> None:
Expand Down Expand Up @@ -2158,9 +2158,9 @@ def substituted(
assert isinstance(expr, Call)
return expr.__class__(
expr.identifier,
expr.type_,
[a.substituted(func) for a in expr.args],
expr.immutable,
expr.type_,
expr.argument_types,
expr.location,
)
Expand Down
72 changes: 55 additions & 17 deletions rflx/generator/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Callable
from typing import Optional

from rflx import expression as expr, model
from rflx import expression as expr, model, typing_ as rty
from rflx.ada import (
TRUE,
Add,
Expand Down Expand Up @@ -66,17 +66,24 @@ class Debug(enum.Enum):
EXTERNAL = enum.auto()


def type_to_id(type_: rty.NamedType) -> ID:
if type_.identifier.parent == BUILTINS_PACKAGE:
return const.TYPES * type_.identifier.name

return type_.identifier


def substitution(
message: model.Message,
prefix: str,
embedded: bool = False,
public: bool = False,
target_type: ID = const.TYPES_BASE_INT,
target_type: rty.NamedType = rty.BASE_INTEGER,
) -> Callable[[expr.Expr], expr.Expr]:
facts = substitution_facts(message, prefix, embedded, public, target_type)

def type_conversion(expression: expr.Expr) -> expr.Expr:
return expr.Call(target_type, [expression])
return expr.Call(type_to_id(target_type), target_type, [expression])

def func( # noqa: PLR0912
expression: expr.Expr,
Expand Down Expand Up @@ -119,6 +126,7 @@ def byte_aggregate(aggregate: expr.Aggregate) -> expr.Aggregate:
expr.ValueRange(
expr.Call(
const.TYPES_TO_INDEX,
rty.INDEX,
[
expr.Selected(
expr.Indexed(
Expand All @@ -131,6 +139,7 @@ def byte_aggregate(aggregate: expr.Aggregate) -> expr.Aggregate:
),
expr.Call(
const.TYPES_TO_INDEX,
rty.INDEX,
[
expr.Selected(
expr.Indexed(
Expand All @@ -147,6 +156,7 @@ def byte_aggregate(aggregate: expr.Aggregate) -> expr.Aggregate:
)
equal_call = expr.Call(
"Equal",
rty.BOOLEAN,
[expr.Variable("Ctx"), expr.Variable(field.affixed_name), aggregate],
)
return equal_call if isinstance(expression, expr.Equal) else expr.Not(equal_call)
Expand All @@ -168,12 +178,18 @@ def byte_aggregate(aggregate: expr.Aggregate) -> expr.Aggregate:
if boolean_literal and other:
return expression.__class__(
other,
type_conversion(expr.Call("To_Base_Integer", [boolean_literal])),
type_conversion(
expr.Call("To_Base_Integer", rty.BASE_INTEGER, [boolean_literal]),
),
)

def field_value(field: model.Field) -> expr.Expr:
if public:
return expr.Call(f"Get_{field.name}", [expr.Variable("Ctx")])
return expr.Call(
f"Get_{field.name}",
message.field_types[field].type_,
[expr.Variable("Ctx")],
)
return expr.Selected(
expr.Indexed(
expr.Variable(ID("Ctx") * "Cursors" if not embedded else "Cursors"),
Expand Down Expand Up @@ -212,19 +228,24 @@ def substitution_facts(
prefix: str,
embedded: bool = False,
public: bool = False,
target_type: ID = const.TYPES_BASE_INT,
target_type: rty.NamedType = rty.BASE_INTEGER,
) -> dict[expr.Name, expr.Expr]:
def prefixed(name: str) -> expr.Expr:
return expr.Variable(ID("Ctx") * name) if not embedded else expr.Variable(name)

first = prefixed("First")
last = expr.Call("Written_Last", [expr.Variable("Ctx")]) if public else prefixed("Written_Last")
last = (
expr.Call("Written_Last", rty.BIT_LENGTH, [expr.Variable("Ctx")])
if public
else prefixed("Written_Last")
)
cursors = prefixed("Cursors")

def field_first(field: model.Field) -> expr.Expr:
if public:
return expr.Call(
"Field_First",
rty.BIT_INDEX,
[expr.Variable("Ctx"), expr.Variable(field.affixed_name)],
)
return expr.Selected(expr.Indexed(cursors, expr.Variable(field.affixed_name)), "First")
Expand All @@ -233,6 +254,7 @@ def field_last(field: model.Field) -> expr.Expr:
if public:
return expr.Call(
"Field_Last",
rty.BIT_LENGTH,
[expr.Variable("Ctx"), expr.Variable(field.affixed_name)],
)
return expr.Selected(expr.Indexed(cursors, expr.Variable(field.affixed_name)), "Last")
Expand All @@ -241,6 +263,7 @@ def field_size(field: model.Field) -> expr.Expr:
if public:
return expr.Call(
"Field_Size",
rty.BIT_LENGTH,
[expr.Variable("Ctx"), expr.Variable(field.affixed_name)],
)
return expr.Add(
Expand All @@ -254,8 +277,16 @@ def field_size(field: model.Field) -> expr.Expr:
def parameter_value(parameter: model.Field, parameter_type: model.Type) -> expr.Expr:
if isinstance(parameter_type, model.Enumeration):
if embedded:
return expr.Call("To_Base_Integer", [expr.Variable(parameter.name)])
return expr.Call("To_Base_Integer", [expr.Variable("Ctx" * parameter.identifier)])
return expr.Call(
"To_Base_Integer",
rty.BASE_INTEGER,
[expr.Variable(parameter.name)],
)
return expr.Call(
"To_Base_Integer",
rty.BASE_INTEGER,
[expr.Variable("Ctx" * parameter.identifier)],
)
if isinstance(parameter_type, model.Scalar):
if embedded:
return expr.Variable(parameter.name)
Expand All @@ -268,15 +299,16 @@ def field_value(field: model.Field, field_type: model.Type) -> expr.Expr:
if public:
return expr.Call(
"To_Base_Integer",
[expr.Call(f"Get_{field.name}", [expr.Variable("Ctx")])],
rty.BASE_INTEGER,
[expr.Call(f"Get_{field.name}", field_type.type_, [expr.Variable("Ctx")])],
)
return expr.Selected(
expr.Indexed(cursors, expr.Variable(field.affixed_name)),
"Value",
)
if isinstance(field_type, model.Scalar):
if public:
return expr.Call(f"Get_{field.name}", [expr.Variable("Ctx")])
return expr.Call(f"Get_{field.name}", field_type.type_, [expr.Variable("Ctx")])
return expr.Selected(
expr.Indexed(cursors, expr.Variable(field.affixed_name)),
"Value",
Expand All @@ -287,7 +319,7 @@ def field_value(field: model.Field, field_type: model.Type) -> expr.Expr:
assert False, f'unexpected type "{type(field_type).__name__}"'

def type_conversion(expression: expr.Expr) -> expr.Expr:
return expr.Call(target_type, [expression])
return expr.Call(type_to_id(target_type), target_type, [expression])

return {
expr.First("Message"): type_conversion(first),
Expand All @@ -305,14 +337,20 @@ def type_conversion(expression: expr.Expr) -> expr.Expr:
for f, t in message.field_types.items()
},
**{
expr.Literal(l): type_conversion(expr.Call("To_Base_Integer", [expr.Variable(l)]))
expr.Literal(l): type_conversion(
expr.Call("To_Base_Integer", rty.BASE_INTEGER, [expr.Variable(l)]),
)
for t in message.types.values()
if isinstance(t, model.Enumeration) and t != model.BOOLEAN
for l in t.literals
},
**{
expr.Literal(t.package * l): type_conversion(
expr.Call("To_Base_Integer", [expr.Variable(prefix * t.package * l)]),
expr.Call(
"To_Base_Integer",
rty.BASE_INTEGER,
[expr.Variable(prefix * t.package * l)],
),
)
for t in message.types.values()
if isinstance(t, model.Enumeration) and t != model.BOOLEAN
Expand Down Expand Up @@ -348,14 +386,14 @@ def link_property(link: model.Link, unique: bool) -> Expr:
field_type.size
if isinstance(field_type, model.Scalar)
else link.size.substituted(
substitution(message, prefix, embedded, target_type=const.TYPES_BIT_LENGTH),
substitution(message, prefix, embedded, target_type=rty.BIT_LENGTH),
).simplified()
)
first = (
prefixed("First")
if link.source == model.INITIAL
else link.first.substituted(
substitution(message, prefix, embedded, target_type=const.TYPES_BIT_INDEX),
substitution(message, prefix, embedded, target_type=rty.BIT_INDEX),
)
.substituted(
mapping={
Expand Down Expand Up @@ -947,7 +985,7 @@ def substituted(expression: expr.Expr) -> Expr:
substitution(
message,
prefix,
target_type=const.TYPES_BIT_LENGTH,
target_type=rty.BIT_LENGTH,
embedded=True,
),
)
Expand Down
26 changes: 21 additions & 5 deletions rflx/generator/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pathlib import Path
from typing import Optional

from rflx import __version__, expression as expr
from rflx import __version__, expression as expr, typing_ as rty
from rflx.ada import (
FALSE,
TRUE,
Expand Down Expand Up @@ -1103,7 +1103,11 @@ def _create_contains_function(
if isinstance(t, Enumeration) and t.always_valid:
condition = expr.AndThen(
expr.Selected(
expr.Call(pdu_identifier * f"Get_{f.name}", [expr.Variable("Ctx")]),
expr.Call(
pdu_identifier * f"Get_{f.name}",
t.type_,
[expr.Variable("Ctx")],
),
"Known",
),
condition,
Expand All @@ -1113,11 +1117,19 @@ def _create_contains_function(
mapping={
expr.Variable(f.name): (
expr.Selected(
expr.Call(pdu_identifier * f"Get_{f.name}", [expr.Variable("Ctx")]),
expr.Call(
pdu_identifier * f"Get_{f.name}",
t.type_,
[expr.Variable("Ctx")],
),
"Enum",
)
if isinstance(t, Enumeration) and t.always_valid
else expr.Call(pdu_identifier * f"Get_{f.name}", [expr.Variable("Ctx")])
else expr.Call(
pdu_identifier * f"Get_{f.name}",
t.type_,
[expr.Variable("Ctx")],
)
)
for f, t in condition_fields.items()
},
Expand Down Expand Up @@ -1591,14 +1603,15 @@ def _refinement_conditions(
pdu_identifier = self._prefix * refinement.pdu.identifier

conditions: list[expr.Expr] = [
expr.Call(pdu_identifier * "Has_Buffer", [expr.Variable(pdu_context)]),
expr.Call(pdu_identifier * "Has_Buffer", rty.BOOLEAN, [expr.Variable(pdu_context)]),
]

if null_sdu:
conditions.extend(
[
expr.Call(
pdu_identifier * "Well_Formed",
rty.BOOLEAN,
[
expr.Variable(pdu_context),
expr.Variable(pdu_identifier * refinement.field.affixed_name),
Expand All @@ -1607,6 +1620,7 @@ def _refinement_conditions(
expr.Not(
expr.Call(
pdu_identifier * "Present",
rty.BOOLEAN,
[
expr.Variable(pdu_context),
expr.Variable(pdu_identifier * refinement.field.affixed_name),
Expand All @@ -1619,6 +1633,7 @@ def _refinement_conditions(
conditions.append(
expr.Call(
pdu_identifier * "Present",
rty.BOOLEAN,
[
expr.Variable(pdu_context),
expr.Variable(pdu_identifier * refinement.field.affixed_name),
Expand All @@ -1630,6 +1645,7 @@ def _refinement_conditions(
[
expr.Call(
pdu_identifier * "Valid",
rty.BOOLEAN,
[
expr.Variable(pdu_context),
expr.Variable(pdu_identifier * f.affixed_name),
Expand Down
Loading

0 comments on commit fca894c

Please sign in to comment.