Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[codegen]: fix transient codegen for slice and extract32 #3874

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions tests/functional/builtins/codegen/test_extract32.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,22 @@
def test_extract32_extraction(tx_failed, get_contract_with_gas_estimation):
extract32_code = """
y: Bytes[100]
import pytest

from vyper.evm.opcodes import version_check


@pytest.mark.parametrize("location", ["storage", "transient"])
def test_extract32_extraction(tx_failed, get_contract_with_gas_estimation, location):
if location == "transient" and not version_check(begin="cancun"):
pytest.skip(
"Skipping test as storage_location is 'transient' and EVM version is pre-Cancun"
)
if location == "storage":
decl = "y: Bytes[100]"
elif location == "transient":
decl = "y: transient(Bytes[100])"
else:
raise Exception("unreachable")
extract32_code = f"""
{decl}
@external
def extrakt32(inp: Bytes[100], index: uint256) -> bytes32:
return extract32(inp, index)
Expand Down Expand Up @@ -43,8 +59,6 @@ def extrakt32_storage(index: uint256, inp: Bytes[100]) -> bytes32:
with tx_failed():
c.extrakt32(S, i)

print("Passed bytes32 extraction test")


def test_extract32_code(tx_failed, get_contract_with_gas_estimation):
extract32_code = """
Expand Down Expand Up @@ -84,5 +98,3 @@ def foq(inp: Bytes[32]) -> address:

with tx_failed():
c.foq(b"crow" * 8)

print("Passed extract32 test")
34 changes: 30 additions & 4 deletions tests/functional/builtins/codegen/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from vyper.compiler import compile_code
from vyper.compiler.settings import OptimizationLevel, Settings
from vyper.evm.opcodes import version_check
from vyper.exceptions import ArgumentException, TypeMismatch

_fun_bytes32_bounds = [(0, 32), (3, 29), (27, 5), (0, 5), (5, 3), (30, 2)]
Expand Down Expand Up @@ -93,7 +94,9 @@ def _get_contract():
assert c.do_splice() == bytesdata[start : start + length]


@pytest.mark.parametrize("location", ("storage", "calldata", "memory", "literal", "code"))
@pytest.mark.parametrize(
"location", ["storage", "transient", "calldata", "memory", "literal", "code"]
)
@pytest.mark.parametrize("use_literal_start", (True, False))
@pytest.mark.parametrize("use_literal_length", (True, False))
@pytest.mark.parametrize("opt_level", list(OptimizationLevel))
Expand All @@ -112,13 +115,23 @@ def test_slice_bytes_fuzz(
use_literal_length,
length_bound,
):
if location == "transient" and not version_check(begin="cancun"):
pytest.skip(
"Skipping test as storage_location is 'transient' and EVM version is pre-Cancun"
)
preamble = ""
if location == "memory":
spliced_code = f"foo: Bytes[{length_bound}] = inp"
foo = "foo"
elif location == "storage":
preamble = f"""
foo: Bytes[{length_bound}]
"""
spliced_code = "self.foo = inp"
foo = "self.foo"
elif location == "transient":
preamble = f"""
foo: transient(Bytes[{length_bound}])
"""
spliced_code = "self.foo = inp"
foo = "self.foo"
Expand Down Expand Up @@ -194,10 +207,23 @@ def _get_contract():
assert c.do_slice(bytesdata, start, length) == bytesdata[start:end], code


def test_slice_private(get_contract):
@pytest.mark.parametrize("location", ["storage", "transient"])
def test_slice_private(get_contract, location):
if location == "transient" and not version_check(begin="cancun"):
pytest.skip(
"Skipping test as storage_location is 'transient' and EVM version is pre-Cancun"
)

# test there are no buffer overruns in the slice function
code = """
bytez: public(String[12])
if location == "storage":
decl = "bytez: public(String[12])"
elif location == "transient":
decl = "bytez: public(transient(String[12]))"
else:
raise Exception("unreachable")

code = f"""
{decl}

@internal
def _slice(start: uint256, length: uint256):
Expand Down
134 changes: 44 additions & 90 deletions vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from vyper.codegen.abi_encoder import abi_encode
from vyper.codegen.context import Context, VariableRecord
from vyper.codegen.core import (
LOAD,
STORE,
IRnode,
add_ofst,
Expand Down Expand Up @@ -36,7 +37,7 @@
from vyper.codegen.expr import Expr
from vyper.codegen.ir_node import Encoding, scope_multi
from vyper.codegen.keccak256_helper import keccak256_helper
from vyper.evm.address_space import MEMORY, STORAGE
from vyper.evm.address_space import MEMORY
from vyper.exceptions import (
ArgumentException,
CompilerPanic,
Expand Down Expand Up @@ -378,7 +379,7 @@ def build_IR(self, expr, args, kwargs, context):

# add 32 bytes to the buffer size bc word access might
# be unaligned (see below)
if src.location == STORAGE:
if src.location.word_addressable:
buflen += 32

# Get returntype string or bytes
Expand All @@ -405,8 +406,8 @@ def build_IR(self, expr, args, kwargs, context):
src_data = bytes_data_ptr(src)

# general case. byte-for-byte copy
if src.location == STORAGE:
# because slice uses byte-addressing but storage
if src.location.word_addressable:
# because slice uses byte-addressing but storage/tstorage
# is word-aligned, this algorithm starts at some number
# of bytes before the data section starts, and might copy
# an extra word. the pseudocode is:
Expand Down Expand Up @@ -838,19 +839,6 @@ class ECMul(_ECArith):
_precompile = 0x7


def _generic_element_getter(op):
def f(index):
return IRnode.from_list(
[op, ["add", "_sub", ["add", 32, ["mul", 32, index]]]], typ=INT128_T
)

return f


def _storage_element_getter(index):
return IRnode.from_list(["sload", ["add", "_sub", ["add", 1, index]]], typ=INT128_T)


class Extract32(BuiltinFunctionT):
_id = "extract32"
_inputs = [("b", BytesT.any()), ("start", IntegerT.unsigneds())]
Expand Down Expand Up @@ -882,81 +870,47 @@ def infer_kwarg_types(self, node):

@process_inputs
def build_IR(self, expr, args, kwargs, context):
sub, index = args
bytez, index = args
ret_type = kwargs["output_type"]

# Get length and specific element
if sub.location == STORAGE:
lengetter = IRnode.from_list(["sload", "_sub"], typ=INT128_T)
elementgetter = _storage_element_getter

else:
op = sub.location.load_op
lengetter = IRnode.from_list([op, "_sub"], typ=INT128_T)
elementgetter = _generic_element_getter(op)

# TODO rewrite all this with cache_when_complex and bitshifts

# Special case: index known to be a multiple of 32
if isinstance(index.value, int) and not index.value % 32:
o = IRnode.from_list(
[
"with",
"_sub",
sub,
elementgetter(
["div", clamp2(0, index, ["sub", lengetter, 32], signed=True), 32]
),
],
typ=ret_type,
annotation="extracting 32 bytes",
)
# General case
else:
o = IRnode.from_list(
[
"with",
"_sub",
sub,
[
"with",
"_len",
lengetter,
[
"with",
"_index",
clamp2(0, index, ["sub", "_len", 32], signed=True),
[
"with",
"_mi32",
["mod", "_index", 32],
[
"with",
"_di32",
["div", "_index", 32],
[
"if",
"_mi32",
[
"add",
["mul", elementgetter("_di32"), ["exp", 256, "_mi32"]],
[
"div",
elementgetter(["add", "_di32", 1]),
["exp", 256, ["sub", 32, "_mi32"]],
],
],
elementgetter("_di32"),
],
],
],
],
],
],
typ=ret_type,
annotation="extract32",
)
return IRnode.from_list(clamp_basetype(o), typ=ret_type)
def finalize(ret):
annotation = "extract32"
ret = IRnode.from_list(ret, typ=ret_type, annotation=annotation)
return clamp_basetype(ret)

with bytez.cache_when_complex("_sub") as (b1, bytez):
# merge
length = get_bytearray_length(bytez)
index = clamp2(0, index, ["sub", length, 32], signed=True)
with index.cache_when_complex("_index") as (b2, index):
assert not index.typ.is_signed

# "easy" case, byte- addressed locations:
if bytez.location.word_scale == 32:
word = LOAD(add_ofst(bytes_data_ptr(bytez), index))
return finalize(b1.resolve(b2.resolve(word)))

# storage and transient storage, word-addressed
assert bytez.location.word_scale == 1

slot = IRnode.from_list(["div", index, 32])
# byte offset within the slot
byte_ofst = IRnode.from_list(["mod", index, 32])

with byte_ofst.cache_when_complex("byte_ofst") as (
b3,
byte_ofst,
), slot.cache_when_complex("slot") as (b4, slot):
# perform two loads and merge
w1 = LOAD(add_ofst(bytes_data_ptr(bytez), slot))
w2 = LOAD(add_ofst(bytes_data_ptr(bytez), ["add", slot, 1]))

left_bytes = shl(["mul", 8, byte_ofst], w1)
right_bytes = shr(["mul", 8, ["sub", 32, byte_ofst]], w2)
merged = ["or", left_bytes, right_bytes]

ret = ["if", byte_ofst, merged, left_bytes]
return finalize(b1.resolve(b2.resolve(b3.resolve(b4.resolve(ret)))))


class AsWeiValue(BuiltinFunctionT):
Expand Down
4 changes: 4 additions & 0 deletions vyper/evm/address_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ class AddrSpace:
# TODO maybe make positional instead of defaulting to None
store_op: Optional[str] = None

@property
def word_addressable(self) -> bool:
return self.word_scale == 1


# alternative:
# class Memory(AddrSpace):
Expand Down
Loading