Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[codegen]: deallocate variables after last use #4219

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
3 changes: 1 addition & 2 deletions examples/tokens/ERC1155ownable.vy
Original file line number Diff line number Diff line change
Expand Up @@ -404,5 +404,4 @@ def supportsInterface(interfaceId: bytes4) -> bool:
ERC165_INTERFACE_ID,
ERC1155_INTERFACE_ID,
ERC1155_INTERFACE_ID_METADATA,
]

]
7 changes: 6 additions & 1 deletion vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2128,6 +2128,7 @@ def build_IR(self, expr, args, kwargs, context):

arg = args[0]
# TODO: reify decimal and integer sqrt paths (see isqrt)
# TODO: rewrite this in IR, or move it to pure vyper stdlib
with arg.cache_when_complex("x") as (b1, arg):
sqrt_code = """
assert x >= 0.0
Expand Down Expand Up @@ -2156,7 +2157,11 @@ def build_IR(self, expr, args, kwargs, context):
new_var_pos = context.new_internal_variable(x_type)
placeholder_copy = ["mstore", new_var_pos, arg]
# Create input variables.
variables = {"x": VariableRecord(name="x", pos=new_var_pos, typ=x_type, mutable=False)}
variables = {
"x": VariableRecord(
name="x", pos=new_var_pos, typ=x_type, system=True, mutable=False
)
}
# Dictionary to update new (i.e. typecheck) namespace
variables_2 = {"x": VarInfo(DecimalT())}
# Generate inline IR.
Expand Down
28 changes: 27 additions & 1 deletion vyper/codegen/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class VariableRecord:
defined_at: Any = None
is_internal: bool = False
alloca: Optional[Alloca] = None
system: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we rather modify the analysis than create a new type of variable given it's used only at one place?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, it's a bit of a kludge, until we refactor sqrt to be either pure IR or a library function


# the following members are probably dead
is_immutable: bool = False
Expand Down Expand Up @@ -121,6 +122,8 @@ def __init__(

self.settings = get_global_settings()

self._to_deallocate = set()

def is_constant(self):
return self.constancy is Constancy.Constant or self.in_range_expr

Expand Down Expand Up @@ -201,7 +204,7 @@ def deallocate_variable(self, varname, var):

# sanity check the type's size hasn't changed since allocation.
n = var.typ.memory_bytes_required
assert n == var.size
assert n == var.size, var

if self.settings.experimental_codegen:
# do not deallocate at this stage because this will break
Expand All @@ -212,6 +215,29 @@ def deallocate_variable(self, varname, var):

del self.vars[var.name]

def mark_for_deallocation(self, varname):
# for variables get deallocated anyway
if varname in self.forvars:
return
if self.vars[varname].system:
return
self._to_deallocate.add(varname)

# "mark-and-sweep", haha
def sweep(self):
tmp = set()
for varname in self._to_deallocate:
var = self.vars[varname]
for s in self._scopes:
if s not in var.blockscopes:
# defer deallocation until we hit the end of its scope
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this accurate? we defer the dealoc until we hit its scope again (not necessarily its end)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or if "it" refers to the scope then still we might have to wait until multiple scopes finish

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"it" refers to the scope of the variable

Copy link
Collaborator

@cyberthirst cyberthirst Sep 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but we sweep after statements, right?

a: uint256 = 0
for i: uint256 in range(1, 3):
  a += i # <- can't dealoc, defer
b: uint256 = 0 # a already deallocated, swept after the for
b += 1
return b  # actual scope end but a is already deallocated

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, this is the idea here

tmp.add(varname)
break
else:
self.deallocate_variable(varname, self.vars[varname])

self._to_deallocate = tmp

def _new_variable(
self,
name: str,
Expand Down
8 changes: 7 additions & 1 deletion vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,14 @@ def parse_Name(self):

# local variable
if varname in self.context.vars:
ret = self.context.lookup_var(varname).as_ir_node()
var = self.context.lookup_var(varname)
ret = var.as_ir_node()
ret._referenced_variables = {varinfo}

last_use = self.expr._metadata.get("last_use", False)
if last_use and var.location == MEMORY:
self.context.mark_for_deallocation(varname)

return ret

if varinfo.is_constant:
Expand Down
26 changes: 25 additions & 1 deletion vyper/codegen/function_definitions/common.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,37 @@
from collections import defaultdict
from dataclasses import dataclass
from functools import cached_property
from typing import Optional
from typing import TYPE_CHECKING, Optional

import vyper.ast as vy_ast
from vyper.codegen.context import Constancy, Context
from vyper.codegen.ir_node import IRnode
from vyper.codegen.memory_allocator import MemoryAllocator
from vyper.evm.opcodes import version_check
from vyper.exceptions import CompilerPanic
from vyper.semantics.types import VyperType
from vyper.semantics.types.function import ContractFunctionT, StateMutability
from vyper.semantics.types.module import ModuleT
from vyper.utils import MemoryPositions

if TYPE_CHECKING:
from vyper.semantics.analysis.base import VarInfo


def analyse_last_use(fn_ast: vy_ast.FunctionDef):
counts: dict[VarInfo, int] = defaultdict(lambda: 0)
for stmt in fn_ast.body:
for expr in stmt.get_descendants(vy_ast.ExprNode):
info = expr._expr_info
if info is None:
continue
for r in info._reads:
counts[r.variable] += 1
if r.variable._use_count < counts[r.variable]: # pragma: nocover
raise CompilerPanic("unreachable!")
if r.variable._use_count == counts[r.variable]:
expr._metadata["last_use"] = True


@dataclass
class FrameInfo:
Expand Down Expand Up @@ -114,6 +135,9 @@ def initialize_context(
):
init_ir_info(func_t)

assert isinstance(func_t.ast_def, vy_ast.FunctionDef) # help mypy
analyse_last_use(func_t.ast_def)

# calculate starting frame
callees = func_t.called_functions
# we start our function frame from the largest callee frame
Expand Down
2 changes: 2 additions & 0 deletions vyper/codegen/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def __init__(self, node: vy_ast.VyperNode, context: Context) -> None:
self.ir_node.annotation = self.stmt.get("node_source_code")
self.ir_node.ast_source = self.stmt

context.sweep()

def parse_Expr(self):
return Expr(self.stmt.value, self.context, is_stmt=True).ir_node

Expand Down
1 change: 1 addition & 0 deletions vyper/semantics/analysis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def __hash__(self):
def __post_init__(self):
self.position = None
self._modification_count = 0
self._use_count = 0

@property
def getter_ast(self) -> Optional[vy_ast.VyperNode]:
Expand Down
3 changes: 3 additions & 0 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,9 @@ def visit(self, node, typ):
if var_access is not None:
info._reads.add(var_access)

for r in info._reads:
r.variable._use_count += 1

if self.function_analyzer:
for s in self.function_analyzer.loop_variables:
if s is None:
Expand Down
Loading