From 836462a9645606bfc505bc704d5eaf45436486e1 Mon Sep 17 00:00:00 2001 From: Tobias Reiher Date: Thu, 11 Aug 2022 16:40:33 +0200 Subject: [PATCH] Improve simplification of if expressions --- rflx/expression.py | 37 ++++++++++++++++ tests/unit/expression_test.py | 72 ++++++++++++++++++++++++++++++++ tests/unit/model/message_test.py | 51 +++++++++++++--------- 3 files changed, 140 insertions(+), 20 deletions(-) diff --git a/rflx/expression.py b/rflx/expression.py index 010ecc1de..7e616d856 100644 --- a/rflx/expression.py +++ b/rflx/expression.py @@ -484,6 +484,8 @@ def simplified(self) -> Expr: if total != self.neutral_element(): terms.append(Number(total)) + terms = self._simplified_if_expressions(terms) + if len(terms) == 1: return terms[0] @@ -513,6 +515,41 @@ def _simplified_boolean_expressions(self, terms: Sequence[Expr], total: int) -> return terms + def _simplified_if_expressions(self, terms: list[Expr]) -> list[Expr]: + """Merge if expressions which differ only in the condition.""" + + if not terms: + return [] + + t = terms[0] + + if isinstance(t, IfExpr): + for i, u in enumerate(terms[1:]): + if ( + isinstance(u, IfExpr) + and len(u.condition_expressions) == 1 + and u.condition_expressions[0][1] == t.condition_expressions[0][1] + and u.else_expression == t.else_expression + ): + return [ + IfExpr( + [ + ( + Or( + t.condition_expressions[0][0], + u.condition_expressions[0][0], + ).simplified(), + t.condition_expressions[0][1], + ) + ], + t.else_expression, + ), + *self._simplified_if_expressions(terms[1 : i + 1] + terms[i + 2 :]), + ] + return [t, *self._simplified_if_expressions(terms[1:])] + + return terms + @abstractmethod def operation(self, left: int, right: int) -> int: raise NotImplementedError diff --git a/tests/unit/expression_test.py b/tests/unit/expression_test.py index 8bd772f4c..c586c6ca7 100644 --- a/tests/unit/expression_test.py +++ b/tests/unit/expression_test.py @@ -245,6 +245,19 @@ def test_bin_expr_substituted_location() -> None: assert expr.location +def test_ass_expr_str() -> None: + assert ( + str( + Add( + Number(1), + IfExpr([(Variable("A"), Variable("B"))], Variable("C")), + IfExpr([(Variable("X"), Variable("Y"))], Variable("Z")), + ) + ) + == "1 + (if A then B else C) + (if X then Y else Z)" + ) + + def test_ass_expr_findall() -> None: assert_equal( And(Equal(Variable("X"), Number(1)), Less(Variable("Y"), Number(2))).findall( @@ -254,6 +267,65 @@ def test_ass_expr_findall() -> None: ) +def test_ass_expr_simplified() -> None: + assert_equal( + Add( + Number(8), + IfExpr( + [ + ( + And( + Variable("A"), + Or(Variable("B"), Variable("C")), + Equal(Variable("D"), TRUE), + ), + Variable("X"), + ) + ], + Variable("Y"), + ), + Number(16), + IfExpr( + [ + ( + And( + Variable("A"), + Or(Variable("B"), Variable("C")), + Equal(Variable("D"), FALSE), + ), + Variable("X"), + ) + ], + Variable("Y"), + ), + Number(24), + ).simplified(), + Add( + IfExpr( + [ + ( + Or( + And( + Variable("A"), + Or(Variable("B"), Variable("C")), + Equal(Variable("D"), TRUE), + ), + And( + Variable("A"), + Or(Variable("B"), Variable("C")), + Equal(Variable("D"), FALSE), + ), + ), + Variable("X"), + ) + ], + Variable("Y"), + ), + Number(48), + ), + ) + + def test_ass_expr_substituted() -> None: assert_equal( And(Equal(Variable("X"), Number(1)), Variable("Y")).substituted( diff --git a/tests/unit/model/message_test.py b/tests/unit/model/message_test.py index 6b027a1ad..0c37d43df 100644 --- a/tests/unit/model/message_test.py +++ b/tests/unit/model/message_test.py @@ -2843,8 +2843,15 @@ def test_size() -> None: Field("Data"): Selected(Variable("M"), "F"), } ) == Add( - IfExpr([(Equal(Variable("X"), FALSE), Size(Selected(Variable("M"), "F")))], Number(0)), - IfExpr([(Equal(Variable("X"), TRUE), Size(Selected(Variable("M"), "F")))], Number(0)), + IfExpr( + [ + ( + Or(Equal(Variable("X"), FALSE), Equal(Variable("X"), TRUE)), + Size(Selected(Variable("M"), "F")), + ) + ], + Number(0), + ), Number(16), ) assert variable_field_value.size( @@ -2854,8 +2861,10 @@ def test_size() -> None: Field("Data"): Variable("Z"), } ) == Add( - IfExpr([(Equal(Variable("X"), FALSE), Size(Variable("Z")))], Number(0)), - IfExpr([(Equal(Variable("X"), TRUE), Size(Variable("Z")))], Number(0)), + IfExpr( + [(Or(Equal(Variable("X"), FALSE), Equal(Variable("X"), TRUE)), Size(Variable("Z")))], + Number(0), + ), Number(16), ) @@ -2942,8 +2951,10 @@ def test_size() -> None: 32 ) assert path_dependent_fields.size({Field("A"): Variable("X")}) == Add( - IfExpr([(Equal(Variable("X"), Number(0)), Number(16))], Number(0)), - IfExpr([(Greater(Variable("X"), Number(0)), Number(16))], Number(0)), + IfExpr( + [(Or(Equal(Variable("X"), Number(0)), Greater(Variable("X"), Number(0))), Number(16))], + Number(0), + ), Number(16), ) @@ -3115,16 +3126,10 @@ def test_size_subpath() -> None: IfExpr( [ ( - Equal(Selected(Variable("X"), "Has_Data"), FALSE), - Size(Selected(Variable("M"), "F")), - ) - ], - Number(0), - ), - IfExpr( - [ - ( - Equal(Selected(Variable("X"), "Has_Data"), TRUE), + Or( + Equal(Selected(Variable("X"), "Has_Data"), FALSE), + Equal(Selected(Variable("X"), "Has_Data"), TRUE), + ), Size(Selected(Variable("M"), "F")), ) ], @@ -3141,10 +3146,16 @@ def test_size_subpath() -> None: subpath=True, ) == Add( IfExpr( - [(Equal(Selected(Variable("X"), "Has_Data"), FALSE), Size(Variable("Z")))], Number(0) - ), - IfExpr( - [(Equal(Selected(Variable("X"), "Has_Data"), TRUE), Size(Variable("Z")))], Number(0) + [ + ( + Or( + Equal(Selected(Variable("X"), "Has_Data"), FALSE), + Equal(Selected(Variable("X"), "Has_Data"), TRUE), + ), + Size(Variable("Z")), + ) + ], + Number(0), ), Number(16), )