diff --git a/src/ethereum_test_vm/bytecode.py b/src/ethereum_test_vm/bytecode.py index a4876c6857..a9907df601 100644 --- a/src/ethereum_test_vm/bytecode.py +++ b/src/ethereum_test_vm/bytecode.py @@ -148,12 +148,35 @@ def __add__(self, other: "Bytecode | int | None") -> "Bytecode": a_min, a_max = self.min_stack_height, self.max_stack_height b_pop, b_push = other.popped_stack_items, other.pushed_stack_items b_min, b_max = other.min_stack_height, other.max_stack_height - a_out = a_min - a_pop + a_push - c_pop = max(0, a_pop + (b_pop - a_push)) - c_push = max(0, a_push + b_push - b_pop) - c_min = a_min if a_out >= b_min else (b_min - a_out) + a_min - c_max = max(a_max + max(0, b_min - a_out), b_max + max(0, a_out - b_min)) + # NOTE: "_pop" is understood as the number of elements required by an instruction or + # bytecode to be popped off the stack before it starts returning (pushing). + + # Auxiliary variables representing "stages" of the execution of `c = a + b` bytecode: + # Assume starting point 0 as reference: + a_start = 0 + # A (potentially) pops some elements and reaches its "bottom", might be negative: + a_bottom = a_start - a_pop + # After this A pushes some elements, then B pops and reaches its "bottom": + b_bottom = a_bottom + a_push - b_pop + + # C's bottom is either at the bottom of A or B: + c_bottom = min(a_bottom, b_bottom) + if c_bottom == a_bottom: + # C pops the same as A to reach its bottom, then the rest of A and B are C's "push" + c_pop = a_pop + c_push = a_push - b_pop + b_push + else: + # A and B are C's "pop" to reach its bottom, then pushes the same as B + c_pop = a_pop - a_push + b_pop + c_push = b_push + + # C's minimum required stack is either A's or B's shifted by the net stack balance of A + c_min = max(a_min, b_min + a_pop - a_push) + + # C starts from c_min, then reaches max either in the spot where A reached a_max or in the + # spot where B reached b_max, after A had completed. + c_max = max(c_min + a_max - a_min, c_min - a_pop + a_push + b_max - b_min) return Bytecode( bytes(self) + bytes(other), diff --git a/src/ethereum_test_vm/tests/test_vm.py b/src/ethereum_test_vm/tests/test_vm.py index 2077e4b1e8..c999f06ca8 100644 --- a/src/ethereum_test_vm/tests/test_vm.py +++ b/src/ethereum_test_vm/tests/test_vm.py @@ -383,6 +383,21 @@ def test_macros(): pytest.param( Op.POP(Op.CALL(1, 2, 3, 4, 5, 6, 7)), 0, 0, 7, 0, id="POP(CALL(1, 2, 3, 4, 5, 6, 7))" ), + pytest.param( + Op.PUSH0 * 2 + Op.PUSH0 + Op.ADD + Op.PUSH0 + Op.POP * 2, 0, 1, 3, 0, id="parens1" + ), + pytest.param( + Op.PUSH0 * 2 + (Op.PUSH0 + Op.ADD + Op.PUSH0 + Op.POP * 2), 0, 1, 3, 0, id="parens2" + ), + pytest.param( + Op.PUSH0 * 2 + Op.PUSH0 + (Op.ADD + Op.PUSH0 + Op.POP * 2), 0, 1, 3, 0, id="parens3" + ), + pytest.param( + Op.PUSH0 * 2 + Op.PUSH0 + (Op.ADD + Op.PUSH0) + Op.POP * 2, 0, 1, 3, 0, id="parens4" + ), + pytest.param( + Op.PUSH0 * 2 + (Op.PUSH0 + Op.ADD + Op.PUSH0) + Op.POP * 2, 0, 1, 3, 0, id="parens5" + ), ], ) def test_bytecode_properties(