diff --git a/pyteal/__init__.pyi b/pyteal/__init__.pyi index 409e1aa53..420f674aa 100644 --- a/pyteal/__init__.pyi +++ b/pyteal/__init__.pyi @@ -114,6 +114,7 @@ __all__ = [ "SubroutineDefinition", "SubroutineDeclaration", "SubroutineCall", + "SubroutineFnWrapper", "ScratchSlot", "ScratchLoad", "ScratchStore", diff --git a/pyteal/ast/__init__.py b/pyteal/ast/__init__.py index 24b5240e0..03cb4daeb 100644 --- a/pyteal/ast/__init__.py +++ b/pyteal/ast/__init__.py @@ -111,6 +111,7 @@ SubroutineDefinition, SubroutineDeclaration, SubroutineCall, + SubroutineFnWrapper, ) from .while_ import While from .for_ import For @@ -221,6 +222,7 @@ "SubroutineDefinition", "SubroutineDeclaration", "SubroutineCall", + "SubroutineFnWrapper", "ScratchSlot", "ScratchLoad", "ScratchStore", diff --git a/pyteal/ast/subroutine.py b/pyteal/ast/subroutine.py index 69dd54e8c..136be53c9 100644 --- a/pyteal/ast/subroutine.py +++ b/pyteal/ast/subroutine.py @@ -1,8 +1,7 @@ -from typing import Callable, Tuple, List, Optional, cast, TYPE_CHECKING +from typing import Callable, List, Optional, TYPE_CHECKING from inspect import Parameter, signature -from functools import wraps -from ..types import TealType, require_type +from ..types import TealType from ..ir import TealOp, Op, TealBlock from ..errors import TealInputError, verifyTealVersion from .expr import Expr @@ -167,6 +166,39 @@ def has_return(self): SubroutineCall.__module__ = "pyteal" +class SubroutineFnWrapper: + def __init__( + self, + fnImplementation: Callable[..., Expr], + returnType: TealType, + name: str = None, + ) -> None: + self.subroutine = SubroutineDefinition( + fnImplementation, returnType=returnType, nameStr=name + ) + + def __call__(self, *args: Expr, **kwargs) -> Expr: + if len(kwargs) != 0: + raise TealInputError( + "Subroutine cannot be called with keyword arguments. Received keyword arguments: {}".format( + ",".join(kwargs.keys()) + ) + ) + return self.subroutine.invoke(list(args)) + + def name(self) -> str: + return self.subroutine.name() + + def type_of(self): + return self.subroutine.getDeclaration().type_of() + + def has_return(self): + return self.subroutine.getDeclaration().has_return() + + +SubroutineFnWrapper.__module__ = "pyteal" + + class Subroutine: """Used to create a PyTeal subroutine from a Python function. @@ -194,20 +226,12 @@ def __init__(self, returnType: TealType, name: str = None) -> None: self.returnType = returnType self.name = name - def __call__(self, fnImplementation: Callable[..., Expr]) -> Callable[..., Expr]: - subroutine = SubroutineDefinition(fnImplementation, self.returnType, self.name) - - @wraps(fnImplementation) - def subroutineCall(*args: Expr, **kwargs) -> Expr: - if len(kwargs) != 0: - raise TealInputError( - "Subroutine cannot be called with keyword arguments. Received keyword arguments: {}".format( - ",".join(kwargs.keys()) - ) - ) - return subroutine.invoke(list(args)) - - return subroutineCall + def __call__(self, fnImplementation: Callable[..., Expr]) -> SubroutineFnWrapper: + return SubroutineFnWrapper( + fnImplementation=fnImplementation, + returnType=self.returnType, + name=self.name, + ) Subroutine.__module__ = "pyteal" diff --git a/pyteal/ast/subroutine_test.py b/pyteal/ast/subroutine_test.py index b7157ab16..a283b3a55 100644 --- a/pyteal/ast/subroutine_test.py +++ b/pyteal/ast/subroutine_test.py @@ -148,7 +148,7 @@ def test_decorator(): def mySubroutine(a): return Return() - assert callable(mySubroutine) + assert isinstance(mySubroutine, SubroutineFnWrapper) invocation = mySubroutine(Int(1)) assert isinstance(invocation, SubroutineCall)