Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a check in method_signature to disallow reference return types #368

Merged
merged 8 commits into from
May 31, 2022
Merged
2 changes: 2 additions & 0 deletions pyteal/ast/abi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
make,
size_of,
type_spec_from_annotation,
contains_type_spec,
)

__all__ = [
Expand Down Expand Up @@ -103,4 +104,5 @@
"size_of",
"algosdk_from_annotation",
"algosdk_from_type_spec",
"contains_type_spec",
]
25 changes: 24 additions & 1 deletion pyteal/ast/abi/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TypeVar, Any, Literal, get_origin, get_args, cast
from typing import Sequence, TypeVar, Any, Literal, get_origin, get_args, cast

import algosdk.abi

Expand Down Expand Up @@ -235,6 +235,29 @@ def type_spec_from_annotation(annotation: Any) -> TypeSpec:
T = TypeVar("T", bound=BaseType)


def contains_type_spec(ts: TypeSpec, targets: Sequence[TypeSpec]) -> bool:
from pyteal.ast.abi.array_dynamic import DynamicArrayTypeSpec
from pyteal.ast.abi.array_static import StaticArrayTypeSpec
from pyteal.ast.abi.tuple import TupleTypeSpec

stack: list[TypeSpec] = [ts]

while stack:
current = stack.pop()
if current in targets:
return True

match current:
case TupleTypeSpec():
stack.extend(current.value_type_specs())
case DynamicArrayTypeSpec():
stack.append(current.value_type_spec())
case StaticArrayTypeSpec():
stack.append(current.value_type_spec())

return False


def size_of(t: type[T]) -> int:
"""Get the size in bytes of an ABI type. Must be a static type"""

Expand Down
11 changes: 11 additions & 0 deletions pyteal/ast/subroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,17 @@ def method_signature(self, overriding_name: str = None) -> str:
"Only registrable methods may return a method signature"
)

ret_type = self.type_of()
if isinstance(ret_type, abi.TypeSpec) and abi.contains_type_spec(
ret_type,
[
abi.AccountTypeSpec(),
abi.AssetTypeSpec(),
abi.ApplicationTypeSpec(),
],
):
raise TealInputError("Reference types may not be used as return values")

args = [str(v) for v in self.subroutine.abi_args.values()]
if overriding_name is None:
overriding_name = self.name()
Expand Down
33 changes: 32 additions & 1 deletion pyteal/ast/subroutine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass

import pyteal as pt
from pyteal.ast.subroutine import evaluate_subroutine
from pyteal.ast.subroutine import ABIReturnSubroutine, evaluate_subroutine

options = pt.CompileOptions(version=5)

Expand Down Expand Up @@ -237,6 +237,37 @@ def fn_w_tuple1arg(
case.definition.method_signature()


def test_subroutine_return_reference():
@ABIReturnSubroutine
def invalid_ret_type(*, output: pt.abi.Account):
return output.set(0)

with pytest.raises(pt.TealInputError):
invalid_ret_type.method_signature()

@ABIReturnSubroutine
def invalid_ret_type_collection(
*, output: pt.abi.Tuple2[pt.abi.Account, pt.abi.Uint64]
):
return output.set(pt.abi.Account(), pt.abi.Uint64())

with pytest.raises(pt.TealInputError):
invalid_ret_type_collection.method_signature()

@ABIReturnSubroutine
def invalid_ret_type_collection_nested(
*, output: pt.abi.DynamicArray[pt.abi.Tuple2[pt.abi.Account, pt.abi.Uint64]]
):
return output.set(
pt.abi.make(
pt.abi.DynamicArray[pt.abi.Tuple2[pt.abi.Account, pt.abi.Uint64]]
)
)

with pytest.raises(pt.TealInputError):
invalid_ret_type_collection_nested.method_signature()


def test_subroutine_definition_validate():
"""
DFS through SubroutineDefinition.validate()'s logic
Expand Down