diff --git a/pyteal/ast/abi/__init__.py b/pyteal/ast/abi/__init__.py index 983cf7685..c789c3556 100644 --- a/pyteal/ast/abi/__init__.py +++ b/pyteal/ast/abi/__init__.py @@ -49,6 +49,7 @@ make, size_of, type_spec_from_annotation, + contains_type_spec, ) __all__ = [ @@ -103,4 +104,5 @@ "size_of", "algosdk_from_annotation", "algosdk_from_type_spec", + "contains_type_spec", ] diff --git a/pyteal/ast/abi/util.py b/pyteal/ast/abi/util.py index 5b39509df..b01675b6d 100644 --- a/pyteal/ast/abi/util.py +++ b/pyteal/ast/abi/util.py @@ -1,4 +1,4 @@ -from typing import TypeVar, Any, Literal, get_origin, get_args, cast +from typing import Sequence, TypeVar, Any, Literal, get_origin, get_args, cast import algosdk.abi @@ -235,6 +235,29 @@ def type_spec_from_annotation(annotation: Any) -> TypeSpec: T = TypeVar("T", bound=BaseType) +def contains_type_spec(ts: TypeSpec, targets: Sequence[TypeSpec]) -> bool: + from pyteal.ast.abi.array_dynamic import DynamicArrayTypeSpec + from pyteal.ast.abi.array_static import StaticArrayTypeSpec + from pyteal.ast.abi.tuple import TupleTypeSpec + + stack: list[TypeSpec] = [ts] + + while stack: + current = stack.pop() + if current in targets: + return True + + match current: + case TupleTypeSpec(): + stack.extend(current.value_type_specs()) + case DynamicArrayTypeSpec(): + stack.append(current.value_type_spec()) + case StaticArrayTypeSpec(): + stack.append(current.value_type_spec()) + + return False + + def size_of(t: type[T]) -> int: """Get the size in bytes of an ABI type. Must be a static type""" diff --git a/pyteal/ast/subroutine.py b/pyteal/ast/subroutine.py index 75c98ae76..221de6306 100644 --- a/pyteal/ast/subroutine.py +++ b/pyteal/ast/subroutine.py @@ -595,6 +595,17 @@ def method_signature(self, overriding_name: str = None) -> str: "Only registrable methods may return a method signature" ) + ret_type = self.type_of() + if isinstance(ret_type, abi.TypeSpec) and abi.contains_type_spec( + ret_type, + [ + abi.AccountTypeSpec(), + abi.AssetTypeSpec(), + abi.ApplicationTypeSpec(), + ], + ): + raise TealInputError("Reference types may not be used as return values") + args = [str(v) for v in self.subroutine.abi_args.values()] if overriding_name is None: overriding_name = self.name() diff --git a/pyteal/ast/subroutine_test.py b/pyteal/ast/subroutine_test.py index 3d1934cdf..a58df675f 100644 --- a/pyteal/ast/subroutine_test.py +++ b/pyteal/ast/subroutine_test.py @@ -5,7 +5,7 @@ from dataclasses import dataclass import pyteal as pt -from pyteal.ast.subroutine import evaluate_subroutine +from pyteal.ast.subroutine import ABIReturnSubroutine, evaluate_subroutine options = pt.CompileOptions(version=5) @@ -237,6 +237,37 @@ def fn_w_tuple1arg( case.definition.method_signature() +def test_subroutine_return_reference(): + @ABIReturnSubroutine + def invalid_ret_type(*, output: pt.abi.Account): + return output.set(0) + + with pytest.raises(pt.TealInputError): + invalid_ret_type.method_signature() + + @ABIReturnSubroutine + def invalid_ret_type_collection( + *, output: pt.abi.Tuple2[pt.abi.Account, pt.abi.Uint64] + ): + return output.set(pt.abi.Account(), pt.abi.Uint64()) + + with pytest.raises(pt.TealInputError): + invalid_ret_type_collection.method_signature() + + @ABIReturnSubroutine + def invalid_ret_type_collection_nested( + *, output: pt.abi.DynamicArray[pt.abi.Tuple2[pt.abi.Account, pt.abi.Uint64]] + ): + return output.set( + pt.abi.make( + pt.abi.DynamicArray[pt.abi.Tuple2[pt.abi.Account, pt.abi.Uint64]] + ) + ) + + with pytest.raises(pt.TealInputError): + invalid_ret_type_collection_nested.method_signature() + + def test_subroutine_definition_validate(): """ DFS through SubroutineDefinition.validate()'s logic