diff --git a/tests/functional/codegen/features/test_assignment.py b/tests/functional/codegen/features/test_assignment.py index 53d02dfbff..d74db6f5a7 100644 --- a/tests/functional/codegen/features/test_assignment.py +++ b/tests/functional/codegen/features/test_assignment.py @@ -1,5 +1,6 @@ import pytest +from vyper.evm.opcodes import version_check from vyper.exceptions import CodegenPanic, ImmutableViolation, InvalidType, TypeMismatch @@ -103,6 +104,71 @@ def test_augassign_oob(get_contract, tx_failed, source): "source", [ """ +@external +def entry() -> DynArray[uint256, 2]: + a: DynArray[uint256, 2] = [1, 1] + a[1] += a[1] + return a + """, + """ +@external +def entry() -> DynArray[uint256, 2]: + a: uint256 = 1 + a += a + b: DynArray[uint256, 2] = [a, a] + b[0] -= b[0] + b[0] += b[1] // 2 + return b + """, + """ +a: DynArray[uint256, 2] + +def read() -> uint256: + return self.a[1] + +@external +def entry() -> DynArray[uint256, 2]: + self.a = [1, 1] + self.a[1] += self.read() + return self.a + """, + """ +interface Foo: + def foo() -> uint256: nonpayable + +@external +def foo() -> uint256: + return 1 + +@external +def entry() -> DynArray[uint256, 2]: + # memory variable, can't be overwritten by extcall, so there + # is no panic + a: DynArray[uint256, 2] = [1, 1] + a[1] += extcall Foo(self).foo() + return a + """, + """ +interface Foo: + def foo() -> uint256: nonpayable + +def get_foo() -> uint256: + return extcall Foo(self).foo() + +@external +def foo() -> uint256: + return 1 + +@external +def entry() -> DynArray[uint256, 2]: + # memory variable, can't be overwritten by extcall, so there + # is no panic + a: DynArray[uint256, 2] = [1, 1] + # extcall hidden inside internal function + a[1] += self.get_foo() + return a + """, + """ a: public(DynArray[uint256, 2]) interface Foo: @@ -115,57 +181,78 @@ def foo() -> uint256: @external def entry() -> DynArray[uint256, 2]: self.a = [1, 1] - # panics due to staticcall self.a[1] += staticcall Foo(self).foo() return self.a - """ + """, ], ) -@pytest.mark.xfail(strict=True, raises=CodegenPanic) -def test_augassign_rhs_references_lhs(get_contract, tx_failed, source): - # xfail here (with panic): +def test_augassign_rhs_references_lhs2(get_contract, source): c = get_contract(source) - assert c.entry() == [1, 2] +@pytest.mark.requires_evm_version("cancun") +def test_augassign_rhs_references_lhs_transient(get_contract): + source = """ +x: transient(DynArray[uint256, 2]) + +def read() -> uint256: + return self.x[0] + +@external +def entry() -> DynArray[uint256, 2]: + self.x = [1, 1] + # test augassign with state read hidden behind function call + self.x[0] += self.read() + # augassign with direct state read + self.x[1] += self.x[0] + return self.x + """ + c = get_contract(source) + + assert c.entry() == [2, 3] + + @pytest.mark.parametrize( "source", [ """ +x: transient(DynArray[uint256, 2]) + +def write() -> uint256: + return self.x.pop() + @external def entry() -> DynArray[uint256, 2]: - a: DynArray[uint256, 2] = [1, 1] - a[1] += a[1] - return a - """, - """ -@external -def entry() -> DynArray[uint256, 2]: - a: uint256 = 1 - a += a - b: DynArray[uint256, 2] = [a, a] - b[0] -= b[0] - b[0] += b[1] // 2 - return b + self.x = [1, 1] + # hide state write behind function call + self.x[1] += self.write() + return self.x """, """ -a: DynArray[uint256, 2] - -def read() -> uint256: - return self.a[1] +x: transient(DynArray[uint256, 2]) @external def entry() -> DynArray[uint256, 2]: - self.a = [1, 1] - self.a[1] += self.read() - return self.a + self.x = [1, 1] + # direct state write + self.x[1] += self.x.pop() + return self.x """, ], ) -def test_augassign_rhs_references_lhs2(get_contract, source): +@pytest.mark.xfail(strict=True, raises=CodegenPanic) +def test_augassign_rhs_references_lhs_transient2(get_contract, tx_failed, source): + if not version_check(begin="cancun"): + # no transient available before cancun + pytest.skip() + + # xfail here (with panic): c = get_contract(source) - assert c.entry() == [1, 2] + + # not reached until the panic is fixed + with tx_failed(c): + c.entry() @pytest.mark.parametrize( diff --git a/vyper/codegen/ir_node.py b/vyper/codegen/ir_node.py index 6ed6345674..38cd41a84d 100644 --- a/vyper/codegen/ir_node.py +++ b/vyper/codegen/ir_node.py @@ -498,6 +498,18 @@ def contains_risky_call(self): return ret + @cached_property + def contains_writeable_call(self): + ret = self.value in ("call", "delegatecall", "create", "create2") + + for arg in self.args: + ret |= arg.contains_writeable_call + + if getattr(self, "is_self_call", False): + ret |= self.invoked_function_ir.func_ir.contains_writeable_call + + return ret + @cached_property def contains_self_call(self): return getattr(self, "is_self_call", False) or any(x.contains_self_call for x in self.args) diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 388dfba629..ffda836373 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -293,7 +293,9 @@ def parse_AugAssign(self): if var.typ._is_prim_word: continue # oob - GHSA-4w26-8p97-f4jp - if var in right.variable_writes or right.contains_risky_call: + if var in right.variable_writes or ( + var.is_state_variable() and right.contains_writeable_call + ): raise CodegenPanic("unreachable") with target.cache_when_complex("_loc") as (b, target): diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index adfc7540a0..c7bfbc11aa 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -205,6 +205,7 @@ def set_position(self, position: VarOffset) -> None: assert isinstance(position, VarOffset) # sanity check self.position = position + # TODO: convert to property def is_state_variable(self): non_state_locations = (DataLocation.UNSET, DataLocation.MEMORY, DataLocation.CALLDATA) # `self` gets a VarInfo, but it is not considered a state