Skip to content

Commit

Permalink
Improve simplification of if expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
treiher committed Aug 16, 2022
1 parent ac00822 commit 836462a
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 20 deletions.
37 changes: 37 additions & 0 deletions rflx/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,8 @@ def simplified(self) -> Expr:
if total != self.neutral_element():
terms.append(Number(total))

terms = self._simplified_if_expressions(terms)

if len(terms) == 1:
return terms[0]

Expand Down Expand Up @@ -513,6 +515,41 @@ def _simplified_boolean_expressions(self, terms: Sequence[Expr], total: int) ->

return terms

def _simplified_if_expressions(self, terms: list[Expr]) -> list[Expr]:
"""Merge if expressions which differ only in the condition."""

if not terms:
return []

t = terms[0]

if isinstance(t, IfExpr):
for i, u in enumerate(terms[1:]):
if (
isinstance(u, IfExpr)
and len(u.condition_expressions) == 1
and u.condition_expressions[0][1] == t.condition_expressions[0][1]
and u.else_expression == t.else_expression
):
return [
IfExpr(
[
(
Or(
t.condition_expressions[0][0],
u.condition_expressions[0][0],
).simplified(),
t.condition_expressions[0][1],
)
],
t.else_expression,
),
*self._simplified_if_expressions(terms[1 : i + 1] + terms[i + 2 :]),
]
return [t, *self._simplified_if_expressions(terms[1:])]

return terms

@abstractmethod
def operation(self, left: int, right: int) -> int:
raise NotImplementedError
Expand Down
72 changes: 72 additions & 0 deletions tests/unit/expression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,19 @@ def test_bin_expr_substituted_location() -> None:
assert expr.location


def test_ass_expr_str() -> None:
assert (
str(
Add(
Number(1),
IfExpr([(Variable("A"), Variable("B"))], Variable("C")),
IfExpr([(Variable("X"), Variable("Y"))], Variable("Z")),
)
)
== "1 + (if A then B else C) + (if X then Y else Z)"
)


def test_ass_expr_findall() -> None:
assert_equal(
And(Equal(Variable("X"), Number(1)), Less(Variable("Y"), Number(2))).findall(
Expand All @@ -254,6 +267,65 @@ def test_ass_expr_findall() -> None:
)


def test_ass_expr_simplified() -> None:
assert_equal(
Add(
Number(8),
IfExpr(
[
(
And(
Variable("A"),
Or(Variable("B"), Variable("C")),
Equal(Variable("D"), TRUE),
),
Variable("X"),
)
],
Variable("Y"),
),
Number(16),
IfExpr(
[
(
And(
Variable("A"),
Or(Variable("B"), Variable("C")),
Equal(Variable("D"), FALSE),
),
Variable("X"),
)
],
Variable("Y"),
),
Number(24),
).simplified(),
Add(
IfExpr(
[
(
Or(
And(
Variable("A"),
Or(Variable("B"), Variable("C")),
Equal(Variable("D"), TRUE),
),
And(
Variable("A"),
Or(Variable("B"), Variable("C")),
Equal(Variable("D"), FALSE),
),
),
Variable("X"),
)
],
Variable("Y"),
),
Number(48),
),
)


def test_ass_expr_substituted() -> None:
assert_equal(
And(Equal(Variable("X"), Number(1)), Variable("Y")).substituted(
Expand Down
51 changes: 31 additions & 20 deletions tests/unit/model/message_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2843,8 +2843,15 @@ def test_size() -> None:
Field("Data"): Selected(Variable("M"), "F"),
}
) == Add(
IfExpr([(Equal(Variable("X"), FALSE), Size(Selected(Variable("M"), "F")))], Number(0)),
IfExpr([(Equal(Variable("X"), TRUE), Size(Selected(Variable("M"), "F")))], Number(0)),
IfExpr(
[
(
Or(Equal(Variable("X"), FALSE), Equal(Variable("X"), TRUE)),
Size(Selected(Variable("M"), "F")),
)
],
Number(0),
),
Number(16),
)
assert variable_field_value.size(
Expand All @@ -2854,8 +2861,10 @@ def test_size() -> None:
Field("Data"): Variable("Z"),
}
) == Add(
IfExpr([(Equal(Variable("X"), FALSE), Size(Variable("Z")))], Number(0)),
IfExpr([(Equal(Variable("X"), TRUE), Size(Variable("Z")))], Number(0)),
IfExpr(
[(Or(Equal(Variable("X"), FALSE), Equal(Variable("X"), TRUE)), Size(Variable("Z")))],
Number(0),
),
Number(16),
)

Expand Down Expand Up @@ -2942,8 +2951,10 @@ def test_size() -> None:
32
)
assert path_dependent_fields.size({Field("A"): Variable("X")}) == Add(
IfExpr([(Equal(Variable("X"), Number(0)), Number(16))], Number(0)),
IfExpr([(Greater(Variable("X"), Number(0)), Number(16))], Number(0)),
IfExpr(
[(Or(Equal(Variable("X"), Number(0)), Greater(Variable("X"), Number(0))), Number(16))],
Number(0),
),
Number(16),
)

Expand Down Expand Up @@ -3115,16 +3126,10 @@ def test_size_subpath() -> None:
IfExpr(
[
(
Equal(Selected(Variable("X"), "Has_Data"), FALSE),
Size(Selected(Variable("M"), "F")),
)
],
Number(0),
),
IfExpr(
[
(
Equal(Selected(Variable("X"), "Has_Data"), TRUE),
Or(
Equal(Selected(Variable("X"), "Has_Data"), FALSE),
Equal(Selected(Variable("X"), "Has_Data"), TRUE),
),
Size(Selected(Variable("M"), "F")),
)
],
Expand All @@ -3141,10 +3146,16 @@ def test_size_subpath() -> None:
subpath=True,
) == Add(
IfExpr(
[(Equal(Selected(Variable("X"), "Has_Data"), FALSE), Size(Variable("Z")))], Number(0)
),
IfExpr(
[(Equal(Selected(Variable("X"), "Has_Data"), TRUE), Size(Variable("Z")))], Number(0)
[
(
Or(
Equal(Selected(Variable("X"), "Has_Data"), FALSE),
Equal(Selected(Variable("X"), "Has_Data"), TRUE),
),
Size(Variable("Z")),
)
],
Number(0),
),
Number(16),
)
Expand Down

0 comments on commit 836462a

Please sign in to comment.