From 4e67acf7e3d44c2bed3bcdca448d766bc8622734 Mon Sep 17 00:00:00 2001 From: Tobias Reiher Date: Thu, 17 Oct 2024 17:32:16 +0200 Subject: [PATCH] Refactor handling of expression conversions in IR Ref. eng/recordflux/RecordFlux!1685 --- rflx/ir.py | 149 ++++++++++++++++++++++++++--------------------------- 1 file changed, 72 insertions(+), 77 deletions(-) diff --git a/rflx/ir.py b/rflx/ir.py index 78cf1be53..ca511ae1c 100644 --- a/rflx/ir.py +++ b/rflx/ir.py @@ -2200,51 +2200,22 @@ def add_conversions(statements: Sequence[Stmt]) -> list[Stmt]: @singledispatch -def _convert_expression( - expression: Expr, # noqa: ARG001 - target_type: ty.Type, # noqa: ARG001 -) -> Expr: - raise NotImplementedError +def _convert_expression(expression: Expr, target_type: ty.Type) -> Expr: + if target_type.is_compatible_strong(expression.type_) or not isinstance( + target_type, + (ty.Integer, ty.Enumeration), + ): + return expression - -@_convert_expression.register(MsgAgg) -@_convert_expression.register(DeltaMsgAgg) -def _( - expression: MsgAgg | DeltaMsgAgg, - target_type: ty.Type, # noqa: ARG001 -) -> Expr: - field_values: dict[ID, Expr] = { - f: _convert_expression(v, expression.type_.types[f]) - for f, v in expression.field_values.items() - } - - return expression.__class__( - expression.identifier, - field_values, - expression.type_, - expression.origin, - ) + assert False @_convert_expression.register(BinaryIntExpr) -@_convert_expression.register(IntExpr) -@_convert_expression.register(Expr) -def _( - expression: BinaryIntExpr | IntExpr | Expr, - target_type: ty.Type, -) -> Expr: - result: Expr +def _(expression: BinaryIntExpr, target_type: ty.Type) -> Expr: + assert isinstance(target_type, ty.Integer) - if ( - target_type.is_compatible_strong(expression.type_) - and not isinstance(expression, BinaryIntExpr) - or not isinstance(target_type, (ty.Integer, ty.Enumeration)) - ): - return expression - - if isinstance(expression, BinaryIntExpr): - assert isinstance(target_type, ty.Integer) - left = ( + return expression.__class__( + ( expression.left if target_type.is_compatible_strong(expression.left.type_) else IntConversion( @@ -2252,8 +2223,8 @@ def _( expression.left, expression.left.origin, ) - ) - right = ( + ), + ( expression.right if target_type.is_compatible_strong(expression.right.type_) else IntConversion( @@ -2261,50 +2232,74 @@ def _( expression.right, expression.right.origin, ) - ) - result = expression.__class__( - left, - right, - origin=expression.origin, - ) + ), + origin=expression.origin, + ) - elif isinstance(expression, IntExpr): - assert isinstance(target_type, ty.Integer) - result = IntConversion( - target_type, - expression, - expression.origin, - ) - elif isinstance(expression, CaseExpr): - assert isinstance(target_type, ty.Integer) +@_convert_expression.register(IntExpr) +def _(expression: IntExpr, target_type: ty.Type) -> Expr: + if target_type.is_compatible_strong(expression.type_) or not isinstance( + target_type, + ty.Integer, + ): + return expression + + return IntConversion( + target_type, + expression, + expression.origin, + ) + + +@_convert_expression.register(CaseExpr) +def _(expression: CaseExpr, target_type: ty.Type) -> Expr: + if target_type.is_compatible_strong(expression.type_) or not isinstance( + target_type, + ty.Integer, + ): + return expression - choices = [] + choices = [] - for k, v in expression.choices: - assert isinstance(v, IntExpr) - choices.append( + for k, v in expression.choices: + assert isinstance(v, IntExpr) + choices.append( + ( + k, ( - k, - ( - v - if target_type.is_compatible_strong(v.type_) - else IntConversion(target_type, v, v.origin) - ), + v + if target_type.is_compatible_strong(v.type_) + else IntConversion(target_type, v, v.origin) ), - ) - - result = expression.__class__( - expression.expression, - choices, - target_type, - origin=expression.origin, + ), ) - else: - assert False + return expression.__class__( + expression.expression, + choices, + target_type, + origin=expression.origin, + ) - return result + +@_convert_expression.register(MsgAgg) +@_convert_expression.register(DeltaMsgAgg) +def _( + expression: MsgAgg | DeltaMsgAgg, + target_type: ty.Type, # noqa: ARG001 +) -> Expr: + field_values: dict[ID, Expr] = { + f: _convert_expression(v, expression.type_.types[f]) + for f, v in expression.field_values.items() + } + + return expression.__class__( + expression.identifier, + field_values, + expression.type_, + expression.origin, + ) def add_checks(statements: Sequence[Stmt], variable_id: Generator[ID, None, None]) -> list[Stmt]: