From 7083155fc8e82010acad98b8a1e739a23e69be69 Mon Sep 17 00:00:00 2001 From: Tobias Reiher Date: Fri, 28 Jun 2024 19:24:08 +0200 Subject: [PATCH] Add removal of unused temporary variables Ref. None --- librapidflux/src/ty.rs | 15 +- rflx/identifier.py | 6 +- rflx/ir.py | 205 +++++++++++++++++++-- tests/unit/generator/session_test.py | 8 + tests/unit/ir_test.py | 258 ++++++++++++++++++++++++++- 5 files changed, 475 insertions(+), 17 deletions(-) diff --git a/librapidflux/src/ty.rs b/librapidflux/src/ty.rs index c03b66e72..315e61801 100644 --- a/librapidflux/src/ty.rs +++ b/librapidflux/src/ty.rs @@ -3,7 +3,7 @@ use std::fmt::Display; use serde::{Deserialize, Serialize}; #[must_use] -#[derive(Clone, PartialEq, Serialize, Deserialize)] +#[derive(Clone, PartialEq, Serialize, Deserialize, Debug)] pub struct Bounds { lower: i128, upper: i128, @@ -113,4 +113,17 @@ mod tests { assert_eq!(result.lower, expected.lower); assert_eq!(result.upper, expected.upper); } + + #[test] + fn test_bounds_serde() { + let bounds = Bounds::new(1, 2); + let bytes = bincode::serialize(&bounds).expect("failed to serialize"); + let deserialized_bounds = bincode::deserialize(&bytes).expect("failed to deserialize"); + assert_eq!(bounds, deserialized_bounds); + } + + #[test] + fn test_bounds_display() { + assert_eq!(Bounds::new(1, 2).to_string(), "1 .. 2"); + } } diff --git a/rflx/identifier.py b/rflx/identifier.py index ff2707c8f..99b6d8827 100644 --- a/rflx/identifier.py +++ b/rflx/identifier.py @@ -1,15 +1,17 @@ from __future__ import annotations from collections.abc import Generator -from typing import Union +from typing import Final, Union from rflx.rapidflux import ID as ID StrID = Union[str, ID] +ID_PREFIX: Final = "T_" + def id_generator() -> Generator[ID, None, None]: i = 0 while True: - yield ID(f"T_{i}") + yield ID(f"{ID_PREFIX}{i}") i += 1 diff --git a/rflx/ir.py b/rflx/ir.py index 8c9ced845..3ef682b38 100644 --- a/rflx/ir.py +++ b/rflx/ir.py @@ -17,7 +17,7 @@ from rflx.common import Base from rflx.const import MAX_SCALAR_SIZE, MP_CONTEXT from rflx.error import info -from rflx.identifier import ID, StrID +from rflx.identifier import ID, ID_PREFIX, StrID from rflx.rapidflux import Location, ty if TYPE_CHECKING: @@ -142,6 +142,11 @@ def __str__(self) -> str: def location(self) -> Optional[Location]: return self.origin.location if self.origin else None + @property + @abstractmethod + def accessed_vars(self) -> list[ID]: + raise NotImplementedError + def substituted(self, mapping: Mapping[ID, ID]) -> Stmt: raise NotImplementedError @@ -168,6 +173,10 @@ class VarDecl(Stmt): expression: Optional[ComplexExpr] = None origin: Optional[Origin] = None + @property + def accessed_vars(self) -> list[ID]: + return [] + def preconditions(self, _variable_id: Generator[ID, None, None]) -> list[Cond]: return [] @@ -191,6 +200,10 @@ class Assign(Stmt): type_: rty.NamedType origin: Optional[Origin] = None + @property + def accessed_vars(self) -> list[ID]: + return self.expression.accessed_vars + def substituted(self, mapping: Mapping[ID, ID]) -> Assign: return Assign( mapping.get(self.target, self.target), @@ -237,6 +250,10 @@ class FieldAssign(Stmt): type_: rty.Message origin: Optional[Origin] = None + @property + def accessed_vars(self) -> list[ID]: + return [self.message, *self.expression.accessed_vars] + def substituted(self, mapping: Mapping[ID, ID]) -> FieldAssign: return FieldAssign( mapping.get(self.message, self.message), @@ -270,6 +287,10 @@ class Append(Stmt): type_: rty.Sequence origin: Optional[Origin] = None + @property + def accessed_vars(self) -> list[ID]: + return [self.sequence, *self.expression.accessed_vars] + def substituted(self, mapping: Mapping[ID, ID]) -> Append: return Append( mapping.get(self.sequence, self.sequence), @@ -295,6 +316,10 @@ class Extend(Stmt): type_: rty.Sequence origin: Optional[Origin] = None + @property + def accessed_vars(self) -> list[ID]: + return [self.sequence, *self.expression.accessed_vars] + def substituted(self, mapping: Mapping[ID, ID]) -> Extend: return Extend( mapping.get(self.sequence, self.sequence), @@ -320,6 +345,13 @@ class Reset(Stmt): type_: rty.Any origin: Optional[Origin] = None + @property + def accessed_vars(self) -> list[ID]: + return [ + self.identifier, + *[i for p in self.parameter_values.values() for i in p.accessed_vars], + ] + def substituted(self, mapping: Mapping[ID, ID]) -> Reset: return Reset( mapping.get(self.identifier, self.identifier), @@ -347,6 +379,10 @@ class ChannelStmt(Stmt): expression: BasicExpr origin: Optional[Origin] = None + @property + def accessed_vars(self) -> list[ID]: + return [*self.expression.accessed_vars] + def substituted(self, mapping: Mapping[ID, ID]) -> ChannelStmt: return self.__class__( self.channel, @@ -383,6 +419,10 @@ class Check(Stmt): def location(self) -> Optional[Location]: return self.expression.location + @property + def accessed_vars(self) -> list[ID]: + return self.expression.accessed_vars + def substituted(self, mapping: Mapping[ID, ID]) -> Check: return Check( self.expression.substituted(mapping), @@ -437,6 +477,11 @@ def origin_str(self) -> str: def location(self) -> Optional[Location]: return self.origin.location if self.origin else None + @property + @abstractmethod + def accessed_vars(self) -> list[ID]: + raise NotImplementedError + def substituted(self: Self, mapping: Mapping[ID, ID]) -> Self: raise NotImplementedError @@ -493,6 +538,10 @@ class Var(BasicExpr): identifier: ID = field(converter=ID) origin: Optional[Origin] = None + @property + def accessed_vars(self) -> list[ID]: + return [self.identifier] + def _update_str(self) -> None: self._str = intern(str(self.identifier)) @@ -550,7 +599,7 @@ def substituted(self, mapping: Mapping[ID, ID]) -> ObjVar: return self def to_z3_expr(self) -> z3.ExprRef: - raise NotImplementedError + return z3.BoolVal(val=True) @define(eq=False) @@ -563,6 +612,10 @@ class EnumLit(BasicExpr): def type_(self) -> rty.Enumeration: return self.enum_type + @property + def accessed_vars(self) -> list[ID]: + return [] + def substituted(self, _mapping: Mapping[ID, ID]) -> EnumLit: return self @@ -582,6 +635,10 @@ class IntVal(BasicIntExpr): def type_(self) -> rty.UniversalInteger: return rty.UniversalInteger(ty.Bounds(self.value, self.value)) + @property + def accessed_vars(self) -> list[ID]: + return [] + def substituted(self, _mapping: Mapping[ID, ID]) -> IntVal: return self @@ -597,6 +654,10 @@ class BoolVal(BasicBoolExpr): value: bool origin: Optional[Origin] = None + @property + def accessed_vars(self) -> list[ID]: + return [] + def substituted(self, _mapping: Mapping[ID, ID]) -> BoolVal: return self @@ -613,6 +674,10 @@ class Attr(Expr): prefix_type: rty.Any origin: Optional[Origin] = None + @property + def accessed_vars(self) -> list[ID]: + return [self.prefix] + def substituted(self, mapping: Mapping[ID, ID]) -> Attr: return self.__class__( mapping.get(self.prefix, self.prefix), @@ -750,6 +815,10 @@ class FieldAccessAttr(Expr): message_type: rty.Compound origin: Optional[Origin] = None + @property + def accessed_vars(self) -> list[ID]: + return [self.message] + @property def field_type(self) -> rty.Any: type_ = self.message_type.field_types[self.field] @@ -804,6 +873,10 @@ class UnaryExpr(Expr): expression: BasicExpr origin: Optional[Origin] = None + @property + def accessed_vars(self) -> list[ID]: + return self.expression.accessed_vars + def substituted(self, mapping: Mapping[ID, ID]) -> UnaryExpr: return self.__class__(self.expression.substituted(mapping), self.origin) @@ -830,6 +903,10 @@ class BinaryExpr(Expr): right: BasicExpr origin: Optional[Origin] = None + @property + def accessed_vars(self) -> list[ID]: + return [*self.left.accessed_vars, *self.right.accessed_vars] + def substituted(self, mapping: Mapping[ID, ID]) -> BinaryExpr: return self.__class__( self.left.substituted(mapping), @@ -1217,6 +1294,10 @@ class Call(Expr): origin: Optional[Origin] = None _preconditions: list[Cond] = field(init=False, factory=list) + @property + def accessed_vars(self) -> list[ID]: + return [i for a in self.arguments for i in a.accessed_vars] + def preconditions(self, _variable_id: Generator[ID, None, None]) -> list[Cond]: return self._preconditions @@ -1303,6 +1384,10 @@ class FieldAccess(Expr): message_type: rty.Compound origin: Optional[Origin] = None + @property + def accessed_vars(self) -> list[ID]: + return [self.message] + def substituted(self, mapping: Mapping[ID, ID]) -> FieldAccess: return self.__class__( mapping.get(self.message, self.message), @@ -1381,6 +1466,14 @@ class IfExpr(Expr): else_expr: ComplexExpr origin: Optional[Origin] = None + @property + def accessed_vars(self) -> list[ID]: + return [ + *self.condition.accessed_vars, + *self.then_expr.accessed_vars, + *self.else_expr.accessed_vars, + ] + def substituted(self, mapping: Mapping[ID, ID]) -> IfExpr: return self.__class__( self.condition.substituted(mapping), @@ -1447,6 +1540,10 @@ class Conversion(Expr): def type_(self) -> rty.Any: return self.target_type + @property + def accessed_vars(self) -> list[ID]: + return self.argument.accessed_vars + def substituted(self, mapping: Mapping[ID, ID]) -> Conversion: return self.__class__( self.identifier, @@ -1489,6 +1586,14 @@ class Comprehension(Expr): def type_(self) -> rty.Aggregate: return rty.Aggregate(self.selector.expr.type_) + @property + def accessed_vars(self) -> list[ID]: + return [ + *self.sequence.accessed_vars, + *self.selector.accessed_vars, + *self.condition.accessed_vars, + ] + def substituted(self, mapping: Mapping[ID, ID]) -> Comprehension: return self.__class__( mapping.get(self.iterator, self.iterator), @@ -1528,6 +1633,14 @@ class Find(Expr): def type_(self) -> rty.Any: return self.selector.expr.type_ + @property + def accessed_vars(self) -> list[ID]: + return [ + *self.sequence.accessed_vars, + *self.selector.accessed_vars, + *self.condition.accessed_vars, + ] + def substituted(self, mapping: Mapping[ID, ID]) -> Find: return self.__class__( mapping.get(self.iterator, self.iterator), @@ -1564,6 +1677,10 @@ class Agg(Expr): def type_(self) -> rty.Aggregate: return rty.Aggregate(rty.common_type([e.type_ for e in self.elements])) + @property + def accessed_vars(self) -> list[ID]: + return [i for e in self.elements for i in e.accessed_vars] + def substituted(self, mapping: Mapping[ID, ID]) -> Agg: return self.__class__( [e.substituted(mapping) for e in self.elements], @@ -1599,6 +1716,10 @@ class NamedAgg(Expr): def type_(self) -> rty.Any: raise NotImplementedError + @property + def accessed_vars(self) -> list[ID]: + return [i for (_, e) in self.elements for i in e.accessed_vars] + def substituted(self, mapping: Mapping[ID, ID]) -> NamedAgg: raise NotImplementedError @@ -1621,6 +1742,10 @@ class Str(Expr): def type_(self) -> rty.Sequence: return rty.OPAQUE + @property + def accessed_vars(self) -> list[ID]: + return [] + def substituted(self, _mapping: Mapping[ID, ID]) -> Str: return self @@ -1641,6 +1766,10 @@ class MsgAgg(Expr): type_: rty.Message origin: Optional[Origin] = None + @property + def accessed_vars(self) -> list[ID]: + return [i for v in self.field_values.values() for i in v.accessed_vars] + def substituted(self, mapping: Mapping[ID, ID]) -> MsgAgg: return self.__class__( self.identifier, @@ -1671,6 +1800,10 @@ class DeltaMsgAgg(Expr): type_: rty.Message origin: Optional[Origin] = None + @property + def accessed_vars(self) -> list[ID]: + return [i for v in self.field_values.values() for i in v.accessed_vars] + def substituted(self, mapping: Mapping[ID, ID]) -> DeltaMsgAgg: return self.__class__( self.identifier, @@ -1701,6 +1834,13 @@ class CaseExpr(Expr): type_: rty.Any origin: Optional[Origin] = None + @property + def accessed_vars(self) -> list[ID]: + return [ + *self.expression.accessed_vars, + *[i for (_, e) in self.choices for i in e.accessed_vars], + ] + def substituted(self, mapping: Mapping[ID, ID]) -> CaseExpr: return self.__class__( self.expression.substituted(mapping), @@ -1776,6 +1916,13 @@ def is_expr(self) -> bool: def is_basic_expr(self) -> bool: return self.is_expr() and isinstance(self.expr, BasicExpr) + @property + def accessed_vars(self) -> list[ID]: + return [ + *[i for s in self.stmts for i in s.accessed_vars], + *self.expr.accessed_vars, + ] + def substituted(self, mapping: Mapping[ID, ID]) -> ComplexExpr: return self.__class__( [s.substituted(mapping) for s in self.stmts], @@ -1848,7 +1995,12 @@ def __init__( # noqa: PLR0913 Transition( t.target, ComplexExpr( - add_required_checks(t.condition.stmts, manager, variable_id), + add_required_checks( + t.condition.stmts, + manager, + variable_id, + t.condition.expr.accessed_vars, + ), t.condition.expr, ), t.description, @@ -1857,7 +2009,7 @@ def __init__( # noqa: PLR0913 for t in s.transitions ], s.exception_transition, - add_required_checks(s.actions, manager, variable_id), + add_required_checks(s.actions, manager, variable_id, []), s.description, s.location, ) @@ -1894,7 +2046,11 @@ def add_checks(statements: Sequence[Stmt], variable_id: Generator[ID, None, None return result -def remove_unnecessary_checks(statements: Sequence[Stmt], manager: ProofManager) -> list[Stmt]: +def remove_unnecessary_checks( + statements: Sequence[Stmt], + manager: ProofManager, + accessed_vars: Sequence[ID], +) -> list[Stmt]: """Remove all checks that are always true.""" always_true: list[int] = [] @@ -1930,18 +2086,46 @@ def remove_unnecessary_checks(statements: Sequence[Stmt], manager: ProofManager) for i in reversed(always_true): result = [*result[:i], *result[i + 1 :]] - return remove_unused_assignments(result) + return remove_unused_temporary_variables(result, accessed_vars) + + +def remove_unused_temporary_variables( + statements: Sequence[Stmt], + accessed_vars: Sequence[ID], +) -> list[Stmt]: + """Remove all unused temporary variable declarations and assignments.""" + + used_vars = set(accessed_vars) + unused_statements = [] + for i, s in reversed(list(enumerate(statements))): + used_vars.update(s.accessed_vars) + if ( + isinstance(s, VarDecl) + and str(s.identifier).startswith(ID_PREFIX) + and s.identifier not in used_vars + ): + unused_statements.append(i) + if ( + isinstance(s, Assign) + and str(s.target).startswith(ID_PREFIX) + and s.target not in used_vars + ): + unused_statements.append(i) -def remove_unused_assignments(statements: Sequence[Stmt]) -> list[Stmt]: - # TODO(eng/recordflux/RecordFlux#1339): Add removal of unused assignments - return list(statements) + statements = list(statements) + + for i in unused_statements: + statements.pop(i) + + return statements def add_required_checks( statements: Sequence[Stmt], manager: ProofManager, variable_id: Generator[ID, None, None], + accessed_vars: Sequence[ID], ) -> list[Stmt]: """ Add check statements in places where preconditions are not always true. @@ -1950,7 +2134,8 @@ def add_required_checks( case, a check statement is added in front of the respective statement. The check statements in the resulting list mark the places where the code generator must insert explicit checks. """ - result = remove_unnecessary_checks(add_checks(statements, variable_id), manager) + + result = remove_unnecessary_checks(add_checks(statements, variable_id), manager, accessed_vars) for s in result: if isinstance(s, Check): diff --git a/tests/unit/generator/session_test.py b/tests/unit/generator/session_test.py index 049241785..270d42d1a 100644 --- a/tests/unit/generator/session_test.py +++ b/tests/unit/generator/session_test.py @@ -822,6 +822,10 @@ def test_session_declare_error( @define class UnknownStatement(ir.Stmt): + @property + def accessed_vars(self) -> list[ID]: + raise NotImplementedError + def preconditions(self, variable_id: typing.Generator[ID, None, None]) -> list[ir.Cond]: raise NotImplementedError @@ -1076,6 +1080,10 @@ class UnknownExpr(ir.Expr): def type_(self) -> rty.Any: return rty.Message("T") + @property + def accessed_vars(self) -> list[ID]: + raise NotImplementedError + def preconditions(self, variable_id: typing.Generator[ID, None, None]) -> list[ir.Cond]: raise NotImplementedError diff --git a/tests/unit/ir_test.py b/tests/unit/ir_test.py index 1c2c9ba1b..92b3f57bd 100644 --- a/tests/unit/ir_test.py +++ b/tests/unit/ir_test.py @@ -30,12 +30,27 @@ def test_stmt_location() -> None: ).location == Location((1, 2)) +def test_var_decl() -> None: + var_decl = ir.VarDecl("X", INT_TY) + assert var_decl.identifier == ID("X") + assert var_decl.type_ == INT_TY + assert var_decl.expression is None + + +def test_var_decl_accessed_vars() -> None: + assert ir.VarDecl("X", INT_TY).accessed_vars == [] + + def test_assign() -> None: assign = ir.Assign("X", ir.IntVar("Y", INT_TY), INT_TY) assert assign.target == ID("X") assert assign.expression == ir.IntVar("Y", INT_TY) +def test_assign_accessed_vars() -> None: + assert ir.Assign("X", ir.IntVar("Y", INT_TY), INT_TY).accessed_vars == [ID("Y")] + + def test_assign_str() -> None: assert str(ir.Assign("X", ir.IntVar("Y", INT_TY), INT_TY)) == "X := Y" @@ -73,6 +88,13 @@ def test_field_assign_type() -> None: assert ir.FieldAssign("X", "Y", ir.IntVar("Z", INT_TY), MSG_TY).type_ == MSG_TY +def test_field_assign_accessed_vars() -> None: + assert ir.FieldAssign("X", "Y", ir.IntVar("Z", INT_TY), MSG_TY).accessed_vars == [ + ID("X"), + ID("Z"), + ] + + def test_field_assign_substituted() -> None: assert ir.FieldAssign("X", "Y", ir.IntVar("Z", INT_TY), MSG_TY).substituted( { @@ -105,6 +127,13 @@ def test_append_str() -> None: assert str(ir.Append("X", ir.IntVar("Y", INT_TY), SEQ_TY)) == "X'Append (Y)" +def test_append_accessed_vars() -> None: + assert ir.Append("X", ir.IntVar("Y", INT_TY), SEQ_TY).accessed_vars == [ + ID("X"), + ID("Y"), + ] + + def test_append_substituted() -> None: assert ir.Append("X", ir.IntVar("Y", INT_TY), SEQ_TY).substituted( { @@ -126,6 +155,13 @@ def test_extend_str() -> None: assert str(ir.Extend("X", ir.IntVar("Y", INT_TY), SEQ_TY)) == "X'Extend (Y)" +def test_extend_accessed_vars() -> None: + assert ir.Extend("X", ir.IntVar("Y", INT_TY), SEQ_TY).accessed_vars == [ + ID("X"), + ID("Y"), + ] + + def test_extend_substituted() -> None: assert ir.Extend("X", ir.IntVar("Y", INT_TY), SEQ_TY).substituted( { @@ -148,6 +184,13 @@ def test_reset_str() -> None: assert str(ir.Reset("X", {ID("Y"): ir.IntVar("Z", INT_TY)}, MSG_TY)) == "X'Reset (Y => Z)" +def test_reset_accessed_vars() -> None: + assert ir.Reset("X", {ID("Y"): ir.IntVar("Z", INT_TY)}, MSG_TY).accessed_vars == [ + ID("X"), + ID("Z"), + ] + + def test_reset_substituted() -> None: assert ir.Reset("X", {ID("Y"): ir.IntVar("Z", INT_TY)}, MSG_TY).substituted( { @@ -176,6 +219,12 @@ def test_read_str() -> None: assert str(ir.Read("X", ir.IntVar("Y", INT_TY))) == "X'Read (Y)" +def test_read_accessed_vars() -> None: + assert ir.Read("X", ir.IntVar("Y", INT_TY)).accessed_vars == [ + ID("Y"), + ] + + def test_read_substituted() -> None: assert ir.Read("X", ir.IntVar("Y", INT_TY)).substituted( { @@ -197,6 +246,12 @@ def test_write_str() -> None: assert str(ir.Write("X", ir.IntVar("Y", INT_TY))) == "X'Write (Y)" +def test_write_accessed_vars() -> None: + assert ir.Read("X", ir.IntVar("Y", INT_TY)).accessed_vars == [ + ID("Y"), + ] + + def test_write_substituted() -> None: assert ir.Write("X", ir.IntVar("Y", INT_TY)).substituted( { @@ -214,6 +269,12 @@ def test_write_z3_expr() -> None: assert ir.Write("X", ir.ObjVar("Y", MSG_TY)).to_z3_expr() == z3.BoolVal(val=True) +def test_check_accessed_vars() -> None: + assert ir.Check(ir.BoolVar("X")).accessed_vars == [ + ID("X"), + ] + + def test_check_substituted() -> None: assert ir.Check(ir.BoolVar("X")).substituted( { @@ -257,6 +318,12 @@ def test_int_var_type() -> None: assert ir.IntVar("X", INT_TY).type_ == INT_TY +def test_int_var_accessed_vars() -> None: + assert ir.IntVar("X", INT_TY).accessed_vars == [ + ID("X"), + ] + + def test_int_var_substituted() -> None: assert ir.IntVar("X", INT_TY).substituted({}) == ir.IntVar("X", INT_TY) assert ir.IntVar("X", INT_TY).substituted( @@ -283,6 +350,12 @@ def test_bool_var_type() -> None: assert ir.BoolVar("X").type_ == rty.BOOLEAN +def test_bool_var_accessed_vars() -> None: + assert ir.BoolVar("X", INT_TY).accessed_vars == [ + ID("X"), + ] + + def test_bool_var_substituted() -> None: assert ir.BoolVar("X").substituted({}) == ir.BoolVar("X") assert ir.BoolVar("X").substituted( @@ -309,6 +382,12 @@ def test_obj_var_type() -> None: assert ir.ObjVar("X", ENUM_TY).type_ == ENUM_TY +def test_obj_var_accessed_vars() -> None: + assert ir.ObjVar("X", INT_TY).accessed_vars == [ + ID("X"), + ] + + def test_obj_var_substituted() -> None: assert ir.ObjVar("X", ENUM_TY).substituted({}) == ir.ObjVar("X", ENUM_TY) assert ir.ObjVar("X", ENUM_TY).substituted( @@ -320,8 +399,7 @@ def test_obj_var_substituted() -> None: def test_obj_var_to_z3_expr() -> None: - with pytest.raises(NotImplementedError): - ir.ObjVar("X", MSG_TY).to_z3_expr() + assert ir.ObjVar("X", MSG_TY).to_z3_expr() == z3.BoolVal(val=True) def test_enum_lit_str() -> None: @@ -332,6 +410,10 @@ def test_enum_lit_type() -> None: assert ir.EnumLit("Lit", ENUM_TY).type_ == ENUM_TY +def test_enum_lit_accessed_vars() -> None: + assert ir.EnumLit("Lit", ENUM_TY).accessed_vars == [] + + def test_enum_lit_substituted() -> None: assert ir.EnumLit("X", ENUM_TY).substituted( { @@ -349,6 +431,10 @@ def test_int_val_str() -> None: assert str(ir.IntVal(1)) == "1" +def test_int_val_accessed_vars() -> None: + assert ir.IntVal(1).accessed_vars == [] + + def test_int_val_to_z3_expr() -> None: assert ir.IntVal(1).to_z3_expr() == z3.IntVal(1) @@ -358,6 +444,10 @@ def test_bool_val_str() -> None: assert str(ir.BoolVal(value=False)) == "False" +def test_bool_val_accessed_vars() -> None: + assert ir.BoolVal(value=True).accessed_vars == [] + + def test_bool_val_to_z3_expr() -> None: assert ir.BoolVal(value=True).to_z3_expr() == z3.BoolVal(val=True) assert ir.BoolVal(value=False).to_z3_expr() == z3.BoolVal(val=False) @@ -402,6 +492,25 @@ def test_attr_type(attribute: ir.Attr, expected: rty.Type) -> None: assert attribute.type_ == expected +@pytest.mark.parametrize( + ("attribute"), + [ + (ir.Size("X", MSG_TY)), + (ir.Length("X", MSG_TY)), + (ir.First("X", MSG_TY)), + (ir.Last("X", MSG_TY)), + (ir.ValidChecksum("X", MSG_TY)), + (ir.Valid("X", MSG_TY)), + (ir.Present("X", MSG_TY)), + (ir.HasData("X", MSG_TY)), + (ir.Head("X", SEQ_TY)), + (ir.Opaque("X", MSG_TY)), + ], +) +def test_attr_accessed_vars(attribute: ir.Attr) -> None: + assert attribute.accessed_vars == [ID("X")] + + @pytest.mark.parametrize( ("attribute", "expected"), [ @@ -475,6 +584,18 @@ def test_field_access_attr_field_type() -> None: assert ir.FieldValid("X", "I", MSG_TY).field_type == INT_TY +@pytest.mark.parametrize( + ("attribute"), + [ + ir.FieldValid("X", "Y", MSG_TY), + ir.FieldPresent("X", "Y", MSG_TY), + ir.FieldSize("X", "Y", MSG_TY), + ], +) +def test_field_access_attr_accessed_vars(attribute: ir.FieldAccessAttr) -> None: + assert attribute.accessed_vars == [ID("X")] + + @pytest.mark.parametrize( ("attribute", "expected"), [ @@ -575,6 +696,40 @@ def test_binary_expr_origin_str(binary_expr: type[ir.BinaryExpr]) -> None: ) +@pytest.mark.parametrize( + "binary_expr", + [ + ir.Add, + ir.Sub, + ir.Mul, + ir.Div, + ir.Pow, + ir.Mod, + ir.And, + ir.Or, + ir.Less, + ir.LessEqual, + ir.Equal, + ir.GreaterEqual, + ir.Greater, + ir.NotEqual, + ], +) +def test_binary_expr_accessed_vars(binary_expr: type[ir.BinaryExpr]) -> None: + assert binary_expr(ir.IntVar("X", INT_TY), ir.IntVar("Y", INT_TY)).accessed_vars == [ + ID("X"), + ID("Y"), + ] + assert ( + binary_expr( + ir.IntVar("X", INT_TY), + ir.IntVal(1), + origin=ir.ConstructedOrigin("Z", None), + ).origin_str + == "Z" + ) + + @pytest.mark.parametrize( "binary_expr", [ @@ -612,6 +767,10 @@ def test_neg_type() -> None: assert ir.Neg(ir.IntVar("X", INT_TY)).type_ == INT_TY +def test_neg_accessed_vars() -> None: + assert ir.Neg(ir.IntVar("X", INT_TY)).accessed_vars == [ID("X")] + + def test_neg_to_z3_expr() -> None: assert ir.Neg(ir.IntVar("X", INT_TY)).to_z3_expr() == -z3.Int("X") @@ -894,6 +1053,15 @@ def test_obj_call_substituted() -> None: ) == ir.ObjCall("X", [ir.BoolVar("Z")], [], ENUM_TY) +def test_call_accessed_vars() -> None: + assert ir.IntCall( + "X", + [ir.IntVar("Y", INT_TY), ir.BoolVal(value=True)], + [], + INT_TY, + ).accessed_vars == [ID("Y")] + + def test_call_preconditions() -> None: call = ir.IntCall("X", [ir.IntVar("Y", INT_TY), ir.BoolVal(value=True)], [], INT_TY) assert not call.preconditions(id_generator()) @@ -911,6 +1079,10 @@ def test_int_field_access_type() -> None: assert ir.IntFieldAccess("M", "I", MSG_TY).type_ == INT_TY +def test_int_field_access_accessed_vars() -> None: + assert ir.IntFieldAccess("M", "I", MSG_TY).accessed_vars == [ID("M")] + + def test_int_field_access_substituted() -> None: assert ir.IntFieldAccess("M", "F", MSG_TY).substituted( {ID("M"): ID("X"), ID("F"): ID("Y")}, @@ -980,6 +1152,15 @@ def test_int_if_expr_type() -> None: ) +def test_int_if_expr_accessed_vars() -> None: + assert ir.IntIfExpr( + ir.BoolVar("X"), + ir.ComplexIntExpr([], ir.IntVar("Y", INT_TY)), + ir.ComplexIntExpr([], ir.IntVal(1)), + INT_TY, + ).accessed_vars == [ID("X"), ID("Y")] + + def test_int_if_expr_substituted() -> None: assert ir.IntIfExpr( ir.BoolVar("X"), @@ -1027,6 +1208,14 @@ def test_bool_if_expr_type() -> None: ) +def test_bool_if_expr_accessed_vars() -> None: + assert ir.BoolIfExpr( + ir.BoolVar("X"), + ir.ComplexBoolExpr([], ir.BoolVar("Y")), + ir.ComplexBoolExpr([], ir.BoolVal(value=False)), + ).accessed_vars == [ID("X"), ID("Y")] + + def test_bool_if_expr_substituted() -> None: assert ir.BoolIfExpr( ir.BoolVar("X"), @@ -1055,6 +1244,10 @@ def test_conversion_type() -> None: assert ir.Conversion("X", ir.IntVar("Y", INT_TY), INT_TY).type_ == INT_TY +def test_conversion_accessed_vars() -> None: + assert ir.Conversion("X", ir.IntVar("Y", INT_TY), INT_TY).accessed_vars == [ID("Y")] + + def test_conversion_substituted() -> None: assert ir.Conversion("X", ir.IntVar("Y", INT_TY), INT_TY).substituted( {ID("X"): ID("Y"), ID("Y"): ID("Z")}, @@ -1129,6 +1322,15 @@ def test_comprehension_type() -> None: ).type_ == rty.Aggregate(MSG_TY) +def test_comprehension_accessed_vars() -> None: + assert ir.Comprehension( + "X", + ir.ObjVar("Y", SEQ_TY), + ir.ComplexExpr([], ir.ObjVar("X", MSG_TY)), + ir.ComplexBoolExpr([], ir.BoolVal(value=True)), + ).accessed_vars == [ID("Y"), ID("X")] + + def test_comprehension_substituted() -> None: assert ir.Comprehension( "X", @@ -1213,6 +1415,15 @@ def test_find_type() -> None: ) +def test_find_accessed_vars() -> None: + assert ir.Find( + "X", + ir.ObjVar("Y", SEQ_TY), + ir.ComplexExpr([], ir.ObjVar("X", MSG_TY)), + ir.ComplexBoolExpr([], ir.BoolVal(value=True)), + ).accessed_vars == [ID("Y"), ID("X")] + + def test_find_substituted() -> None: assert ir.Find( "X", @@ -1256,6 +1467,10 @@ def test_agg_type() -> None: assert ir.Agg([ir.IntVar("X", INT_TY), ir.IntVal(10)]).type_ == rty.Aggregate(rty.BASE_INTEGER) +def test_agg_accessed_vars() -> None: + assert ir.Agg([ir.IntVar("X", INT_TY), ir.IntVal(10)]).accessed_vars == [ID("X")] + + def test_agg_substituted() -> None: assert ir.Agg([ir.IntVar("X", INT_TY), ir.IntVal(1)]).substituted({ID("X"): ID("Y")}) == ir.Agg( [ir.IntVar("Y", INT_TY), ir.IntVal(1)], @@ -1273,6 +1488,12 @@ def test_named_agg_str() -> None: ) +def test_named_agg_accessed_vars() -> None: + assert ir.NamedAgg( + [(ID("X"), ir.IntVar("Z", INT_TY)), (ID("Y"), ir.IntVal(1))], + ).accessed_vars == [ID("Z")] + + def test_str_str() -> None: assert str(ir.Str("X")) == '"X"' @@ -1281,6 +1502,10 @@ def test_str_type() -> None: assert ir.Str("X").type_ == rty.OPAQUE +def test_str_accessed_vars() -> None: + assert ir.Str("X").accessed_vars == [] + + def test_str_substituted() -> None: assert ir.Str("X").substituted({ID("X"): ID("Y")}) == ir.Str("X") @@ -1298,6 +1523,10 @@ def test_msg_agg_type() -> None: assert ir.MsgAgg("X", {}, MSG_TY).type_ == MSG_TY +def test_msg_agg_accessed_vars() -> None: + assert ir.MsgAgg("X", {ID("Y"): ir.IntVar("Z", INT_TY)}, MSG_TY).accessed_vars == [ID("Z")] + + def test_msg_agg_substituted() -> None: assert ir.MsgAgg("X", {ID("Y"): ir.IntVar("Z", INT_TY)}, MSG_TY).substituted( {ID("X"): ID("Y"), ID("Y"): ID("Z"), ID("Z"): ID("A")}, @@ -1317,6 +1546,10 @@ def test_delta_msg_agg_type() -> None: assert ir.DeltaMsgAgg("X", {}, MSG_TY).type_ == MSG_TY +def test_delta_msg_agg_accessed_vars() -> None: + assert ir.DeltaMsgAgg("X", {ID("Y"): ir.IntVar("Z", INT_TY)}, MSG_TY).accessed_vars == [ID("Z")] + + def test_delta_msg_agg_substituted() -> None: assert ir.DeltaMsgAgg("X", {ID("Y"): ir.IntVar("Z", INT_TY)}, MSG_TY).substituted( {ID("X"): ID("Y"), ID("Y"): ID("Z"), ID("Z"): ID("A")}, @@ -1366,6 +1599,23 @@ def test_case_expr_type() -> None: ) +def test_case_expr_accessed_vars() -> None: + assert ir.CaseExpr( + ir.IntVar("X", INT_TY), + [ + ( + [ir.IntVal(1), ir.IntVal(3)], + ir.IntVal(0), + ), + ( + [ir.IntVal(2)], + ir.IntVar("Y", INT_TY), + ), + ], + INT_TY, + ).accessed_vars == [ID("X"), ID("Y")] + + def test_case_expr_substituted() -> None: assert ir.CaseExpr( ir.IntVar("X", INT_TY), @@ -1433,6 +1683,7 @@ def test_add_required_checks() -> None: ], PROOF_MANAGER, id_generator(), + [], ) == [ ir.Check( ir.NotEqual( @@ -1461,6 +1712,7 @@ def test_add_required_checks() -> None: ], PROOF_MANAGER, id_generator(), + [], ) == [ ir.VarDecl("T_0", rty.BASE_INTEGER), ir.Assign("T_0", ir.Sub(ir.IntVal(100), ir.IntVal(1)), rty.BASE_INTEGER), @@ -1471,7 +1723,5 @@ def test_add_required_checks() -> None: ir.Check(ir.GreaterEqual(ir.IntVar("B", INT_TY), ir.IntVal(1))), ir.Assign("X", ir.Sub(ir.IntVar("B", INT_TY), ir.IntVal(1)), INT_TY), ir.Assign("Z", ir.IntVal(0), INT_TY), - ir.VarDecl("T_1", rty.BASE_INTEGER), - ir.Assign("T_1", ir.Sub(ir.IntVal(100), ir.IntVal(1)), rty.BASE_INTEGER), ir.Assign("C", ir.Add(ir.IntVar("Z", INT_TY), ir.IntVal(1)), INT_TY), ]