Skip to content

Commit

Permalink
Refactor handling of expression conversions in IR
Browse files Browse the repository at this point in the history
Ref. eng/recordflux/RecordFlux!1685
  • Loading branch information
treiher committed Oct 21, 2024
1 parent 696622b commit 4e67acf
Showing 1 changed file with 72 additions and 77 deletions.
149 changes: 72 additions & 77 deletions rflx/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -2200,111 +2200,106 @@ def add_conversions(statements: Sequence[Stmt]) -> list[Stmt]:


@singledispatch
def _convert_expression(
expression: Expr, # noqa: ARG001
target_type: ty.Type, # noqa: ARG001
) -> Expr:
raise NotImplementedError
def _convert_expression(expression: Expr, target_type: ty.Type) -> Expr:
if target_type.is_compatible_strong(expression.type_) or not isinstance(
target_type,
(ty.Integer, ty.Enumeration),
):
return expression


@_convert_expression.register(MsgAgg)
@_convert_expression.register(DeltaMsgAgg)
def _(
expression: MsgAgg | DeltaMsgAgg,
target_type: ty.Type, # noqa: ARG001
) -> Expr:
field_values: dict[ID, Expr] = {
f: _convert_expression(v, expression.type_.types[f])
for f, v in expression.field_values.items()
}

return expression.__class__(
expression.identifier,
field_values,
expression.type_,
expression.origin,
)
assert False


@_convert_expression.register(BinaryIntExpr)
@_convert_expression.register(IntExpr)
@_convert_expression.register(Expr)
def _(
expression: BinaryIntExpr | IntExpr | Expr,
target_type: ty.Type,
) -> Expr:
result: Expr
def _(expression: BinaryIntExpr, target_type: ty.Type) -> Expr:
assert isinstance(target_type, ty.Integer)

if (
target_type.is_compatible_strong(expression.type_)
and not isinstance(expression, BinaryIntExpr)
or not isinstance(target_type, (ty.Integer, ty.Enumeration))
):
return expression

if isinstance(expression, BinaryIntExpr):
assert isinstance(target_type, ty.Integer)
left = (
return expression.__class__(
(
expression.left
if target_type.is_compatible_strong(expression.left.type_)
else IntConversion(
target_type,
expression.left,
expression.left.origin,
)
)
right = (
),
(
expression.right
if target_type.is_compatible_strong(expression.right.type_)
else IntConversion(
target_type,
expression.right,
expression.right.origin,
)
)
result = expression.__class__(
left,
right,
origin=expression.origin,
)
),
origin=expression.origin,
)

elif isinstance(expression, IntExpr):
assert isinstance(target_type, ty.Integer)
result = IntConversion(
target_type,
expression,
expression.origin,
)

elif isinstance(expression, CaseExpr):
assert isinstance(target_type, ty.Integer)
@_convert_expression.register(IntExpr)
def _(expression: IntExpr, target_type: ty.Type) -> Expr:
if target_type.is_compatible_strong(expression.type_) or not isinstance(
target_type,
ty.Integer,
):
return expression

return IntConversion(
target_type,
expression,
expression.origin,
)


@_convert_expression.register(CaseExpr)
def _(expression: CaseExpr, target_type: ty.Type) -> Expr:
if target_type.is_compatible_strong(expression.type_) or not isinstance(
target_type,
ty.Integer,
):
return expression

choices = []
choices = []

for k, v in expression.choices:
assert isinstance(v, IntExpr)
choices.append(
for k, v in expression.choices:
assert isinstance(v, IntExpr)
choices.append(
(
k,
(
k,
(
v
if target_type.is_compatible_strong(v.type_)
else IntConversion(target_type, v, v.origin)
),
v
if target_type.is_compatible_strong(v.type_)
else IntConversion(target_type, v, v.origin)
),
)

result = expression.__class__(
expression.expression,
choices,
target_type,
origin=expression.origin,
),
)

else:
assert False
return expression.__class__(
expression.expression,
choices,
target_type,
origin=expression.origin,
)

return result

@_convert_expression.register(MsgAgg)
@_convert_expression.register(DeltaMsgAgg)
def _(
expression: MsgAgg | DeltaMsgAgg,
target_type: ty.Type, # noqa: ARG001
) -> Expr:
field_values: dict[ID, Expr] = {
f: _convert_expression(v, expression.type_.types[f])
for f, v in expression.field_values.items()
}

return expression.__class__(
expression.identifier,
field_values,
expression.type_,
expression.origin,
)


def add_checks(statements: Sequence[Stmt], variable_id: Generator[ID, None, None]) -> list[Stmt]:
Expand Down

0 comments on commit 4e67acf

Please sign in to comment.