Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[venom]: add load elimination #4265

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions tests/unit/compiler/venom/test_load_elimination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from vyper.venom.analysis.analysis import IRAnalysesCache
from vyper.venom.basicblock import IRLiteral, IRVariable
from vyper.venom.context import IRContext
from vyper.venom.passes.load_elimination import LoadElimination


def test_simple_load_elimination():
ctx = IRContext()
fn = ctx.create_function("test")

bb = fn.get_basic_block()

ptr = IRLiteral(11)
bb.append_instruction("mload", ptr)
bb.append_instruction("mload", ptr)
bb.append_instruction("stop")

ac = IRAnalysesCache(fn)
LoadElimination(ac, fn).run_pass()

assert len([inst for inst in bb.instructions if inst.opcode == "mload"]) == 1

inst0, inst1, inst2 = bb.instructions

assert inst0.opcode == "mload"
assert inst1.opcode == "store"
assert inst1.operands[0] == inst0.output
assert inst2.opcode == "stop"


def test_equivalent_var_elimination():
ctx = IRContext()
fn = ctx.create_function("test")

bb = fn.get_basic_block()

ptr1 = bb.append_instruction("store", IRLiteral(11))
ptr2 = bb.append_instruction("store", ptr1)
bb.append_instruction("mload", ptr1)
bb.append_instruction("mload", ptr2)
bb.append_instruction("stop")

ac = IRAnalysesCache(fn)
LoadElimination(ac, fn).run_pass()

assert len([inst for inst in bb.instructions if inst.opcode == "mload"]) == 1

inst0, inst1, inst2, inst3, inst4 = bb.instructions

assert inst0.opcode == "store"
assert inst1.opcode == "store"
assert inst2.opcode == "mload"
assert inst2.operands[0] == inst0.output
assert inst3.opcode == "store"
assert inst3.operands[0] == inst2.output
assert inst4.opcode == "stop"


def test_elimination_barrier():
ctx = IRContext()
fn = ctx.create_function("test")

bb = fn.get_basic_block()

ptr = IRLiteral(11)
bb.append_instruction("mload", ptr)

arbitrary = IRVariable("%100")
# fence, writes to memory
bb.append_instruction("staticcall", arbitrary, arbitrary, arbitrary, arbitrary)

bb.append_instruction("mload", ptr)
bb.append_instruction("stop")

ac = IRAnalysesCache(fn)

instructions = bb.instructions.copy()
LoadElimination(ac, fn).run_pass()

assert instructions == bb.instructions # no change


def test_store_load_elimination():
ctx = IRContext()
fn = ctx.create_function("test")

bb = fn.get_basic_block()

val = IRLiteral(55)
ptr1 = bb.append_instruction("store", IRLiteral(11))
ptr2 = bb.append_instruction("store", ptr1)
bb.append_instruction("mstore", val, ptr1)
bb.append_instruction("mload", ptr2)
bb.append_instruction("stop")

ac = IRAnalysesCache(fn)
LoadElimination(ac, fn).run_pass()

assert len([inst for inst in bb.instructions if inst.opcode == "mload"]) == 0

inst0, inst1, inst2, inst3, inst4 = bb.instructions

assert inst0.opcode == "store"
assert inst1.opcode == "store"
assert inst2.opcode == "mstore"
assert inst3.opcode == "store"
assert inst3.operands[0] == inst2.operands[0]
assert inst4.opcode == "stop"


def test_store_load_barrier():
ctx = IRContext()
fn = ctx.create_function("test")

bb = fn.get_basic_block()

val = IRLiteral(55)
ptr1 = bb.append_instruction("store", IRLiteral(11))
ptr2 = bb.append_instruction("store", ptr1)
bb.append_instruction("mstore", val, ptr1)

arbitrary = IRVariable("%100")
# fence, writes to memory
bb.append_instruction("staticcall", arbitrary, arbitrary, arbitrary, arbitrary)

bb.append_instruction("mload", ptr2)
bb.append_instruction("stop")

ac = IRAnalysesCache(fn)

instructions = bb.instructions.copy()
LoadElimination(ac, fn).run_pass()

assert instructions == bb.instructions
4 changes: 4 additions & 0 deletions vyper/venom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
AlgebraicOptimizationPass,
BranchOptimizationPass,
DFTPass,
LoadElimination,
MakeSSA,
Mem2Var,
RemoveUnusedVariablesPass,
Expand Down Expand Up @@ -52,8 +53,11 @@ def _run_passes(fn: IRFunction, optimize: OptimizationLevel) -> None:
Mem2Var(ac, fn).run_pass()
MakeSSA(ac, fn).run_pass()
SCCP(ac, fn).run_pass()

StoreElimination(ac, fn).run_pass()
SimplifyCFGPass(ac, fn).run_pass()
LoadElimination(ac, fn).run_pass()

AlgebraicOptimizationPass(ac, fn).run_pass()
# NOTE: MakeSSA is after algebraic optimization it currently produces
# smaller code by adding some redundant phi nodes. This is not a
Expand Down
1 change: 1 addition & 0 deletions vyper/venom/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .algebraic_optimization import AlgebraicOptimizationPass
from .branch_optimization import BranchOptimizationPass
from .dft import DFTPass
from .load_elimination import LoadElimination
from .make_ssa import MakeSSA
from .mem2var import Mem2Var
from .normalization import NormalizationPass
Expand Down
74 changes: 74 additions & 0 deletions vyper/venom/passes/load_elimination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from vyper.venom.effects import Effects
from vyper.venom.analysis import DFGAnalysis, LivenessAnalysis, VarEquivalenceAnalysis
from vyper.venom.passes.base_pass import IRPass


class LoadElimination(IRPass):
"""
Eliminate sloads, mloads and tloads
"""

# should this be renamed to EffectsElimination?

def run_pass(self):
self.equivalence = self.analyses_cache.request_analysis(VarEquivalenceAnalysis)

for bb in self.function.get_basic_blocks():
self._process_bb(bb)

self.analyses_cache.invalidate_analysis(LivenessAnalysis)
self.analyses_cache.invalidate_analysis(DFGAnalysis)

def equivalent(self, op1, op2):
return op1 == op2 or self.equivalence.equivalent(op1, op2)

def _process_bb(self, bb):
transient = ()
storage = ()
memory = ()

for inst in bb.instructions:
if Effects.MEMORY in inst.get_write_effects():
memory = ()
if Effects.STORAGE in inst.get_write_effects():
storage = ()
if Effects.TRANSIENT in inst.get_write_effects():
transient = ()

if inst.opcode == "mstore":
# mstore [val, ptr]
memory = (inst.operands[1], inst.operands[0])
if inst.opcode == "sstore":
storage = (inst.operands[1], inst.operands[0])
if inst.opcode == "tstore":
transient = (inst.operands[1], inst.operands[0])

if inst.opcode == "mload":
prev_memory = memory
memory = (inst.operands[0], inst.output)
if not prev_memory:
continue
if not self.equivalent(inst.operands[0], prev_memory[0]):
continue
inst.opcode = "store"
inst.operands = [prev_memory[1]]

if inst.opcode == "sload":
prev_storage = storage
storage = (inst.operands[0], inst.output)
if not prev_storage:
continue
if not self.equivalent(inst.operands[0], prev_storage[0]):
continue
inst.opcode = "store"
inst.operands = [prev_storage[1]]

if inst.opcode == "tload":
prev_transient = transient
transient = (inst.operands[0], inst.output)
if not prev_transient:
continue
if not self.equivalent(inst.operands[0], prev_transient[0]):
continue
inst.opcode = "store"
inst.operands = [prev_transient[1]]
Loading