Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[codegen]: relax the filter for augassign oob check #4497

Merged
merged 12 commits into from
Feb 27, 2025
139 changes: 111 additions & 28 deletions tests/functional/codegen/features/test_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,71 @@ def test_augassign_oob(get_contract, tx_failed, source):
"source",
[
"""
@external
def entry() -> DynArray[uint256, 2]:
a: DynArray[uint256, 2] = [1, 1]
a[1] += a[1]
return a
""",
"""
@external
def entry() -> DynArray[uint256, 2]:
a: uint256 = 1
a += a
b: DynArray[uint256, 2] = [a, a]
b[0] -= b[0]
b[0] += b[1] // 2
return b
""",
"""
a: DynArray[uint256, 2]

def read() -> uint256:
return self.a[1]

@external
def entry() -> DynArray[uint256, 2]:
self.a = [1, 1]
self.a[1] += self.read()
return self.a
""",
"""
interface Foo:
def foo() -> uint256: nonpayable

@external
def foo() -> uint256:
return 1

@external
def entry() -> DynArray[uint256, 2]:
# memory variable, can't be overwritten by extcall, so there
# is no panic
a: DynArray[uint256, 2] = [1, 1]
a[1] += extcall Foo(self).foo()
return a
""",
"""
interface Foo:
def foo() -> uint256: nonpayable

def get_foo() -> uint256:
return extcall Foo(self).foo()

@external
def foo() -> uint256:
return 1

@external
def entry() -> DynArray[uint256, 2]:
# memory variable, can't be overwritten by extcall, so there
# is no panic
a: DynArray[uint256, 2] = [1, 1]
# extcall hidden inside internal function
a[1] += self.get_foo()
return a
""",
"""
a: public(DynArray[uint256, 2])

interface Foo:
Expand All @@ -115,57 +180,75 @@ def foo() -> uint256:
@external
def entry() -> DynArray[uint256, 2]:
self.a = [1, 1]
# panics due to staticcall
self.a[1] += staticcall Foo(self).foo()
return self.a
"""
""",
],
)
@pytest.mark.xfail(strict=True, raises=CodegenPanic)
def test_augassign_rhs_references_lhs(get_contract, tx_failed, source):
# xfail here (with panic):
def test_augassign_rhs_references_lhs2(get_contract, source):
c = get_contract(source)

assert c.entry() == [1, 2]


@pytest.mark.requires_evm_version("cancun")
def test_augassign_rhs_references_lhs_transient(get_contract):
source = """
x: transient(DynArray[uint256, 2])

def read() -> uint256:
return self.x[0]

@external
def entry() -> DynArray[uint256, 2]:
self.x = [1, 1]
# test augassign with state read hidden behind function call
self.x[0] += self.read()
# augassign with direct state read
self.x[1] += self.x[0]
return self.x
"""
c = get_contract(source)

assert c.entry() == [3, 2]


@pytest.mark.parametrize(
"source",
[
"""
x: transient(DynArray[uint256, 2])

def write() -> uint256:
return self.x.pop()

@external
def entry() -> DynArray[uint256, 2]:
a: DynArray[uint256, 2] = [1, 1]
a[1] += a[1]
return a
""",
"""
@external
def entry() -> DynArray[uint256, 2]:
a: uint256 = 1
a += a
b: DynArray[uint256, 2] = [a, a]
b[0] -= b[0]
b[0] += b[1] // 2
return b
self.x = [1, 1]
# hide state write behind function call
self.x[1] += self.write()
return self.x
""",
"""
a: DynArray[uint256, 2]

def read() -> uint256:
return self.a[1]
x: transient(DynArray[uint256, 2])

@external
def entry() -> DynArray[uint256, 2]:
self.a = [1, 1]
self.a[1] += self.read()
return self.a
self.x = [1, 1]
# direct state write
self.x[1] += self.x.pop()
return self.x
""",
],
)
def test_augassign_rhs_references_lhs2(get_contract, source):
@pytest.mark.requires_evm_version("cancun")
@pytest.mark.xfail(strict=True, raises=CodegenPanic)
def test_augassign_rhs_references_lhs_transient2(get_contract, tx_failed, source):
# xfail here (with panic):
c = get_contract(source)
assert c.entry() == [1, 2]

# not reached until the panic is fixed
with tx_failed(c):
c.entry()


@pytest.mark.parametrize(
Expand Down
12 changes: 12 additions & 0 deletions vyper/codegen/ir_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,18 @@ def contains_risky_call(self):

return ret

@cached_property
def contains_writeable_call(self):
ret = self.value in ("call", "delegatecall", "create", "create2")

for arg in self.args:
ret |= arg.contains_writeable_call

if getattr(self, "is_self_call", False):
ret |= self.invoked_function_ir.func_ir.contains_writeable_call

return ret

@cached_property
def contains_self_call(self):
return getattr(self, "is_self_call", False) or any(x.contains_self_call for x in self.args)
Expand Down
4 changes: 3 additions & 1 deletion vyper/codegen/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,9 @@ def parse_AugAssign(self):
if var.typ._is_prim_word:
continue
# oob - GHSA-4w26-8p97-f4jp
if var in right.variable_writes or right.contains_risky_call:
if var in right.variable_writes or (
var.is_state_variable() and right.contains_writeable_call
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: immutables count as state variables

):
raise CodegenPanic("unreachable")

with target.cache_when_complex("_loc") as (b, target):
Expand Down
1 change: 1 addition & 0 deletions vyper/semantics/analysis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def set_position(self, position: VarOffset) -> None:
assert isinstance(position, VarOffset) # sanity check
self.position = position

# TODO: convert to property
def is_state_variable(self):
non_state_locations = (DataLocation.UNSET, DataLocation.MEMORY, DataLocation.CALLDATA)
# `self` gets a VarInfo, but it is not considered a state
Expand Down
Loading