diff --git a/pyteal/ast/abi/transaction.py b/pyteal/ast/abi/transaction.py index 8bf79e685..1f4b166a1 100644 --- a/pyteal/ast/abi/transaction.py +++ b/pyteal/ast/abi/transaction.py @@ -3,10 +3,10 @@ from pyteal.ast.abi.type import BaseType, ComputedValue, TypeSpec from pyteal.ast.expr import Expr from pyteal.ast.int import Int -from pyteal.ast.txn import TxnObject +from pyteal.ast.txn import TxnObject, TxnType from pyteal.ast.gtxn import Gtxn from pyteal.types import TealType -from pyteal.errors import TealInputError +from pyteal.errors import TealInputError, TealInternalError T = TypeVar("T", bound=BaseType) @@ -43,6 +43,17 @@ def byte_length_static(self) -> int: def storage_type(self) -> TealType: return TealType.uint64 + def txn_type_enum(self) -> Expr: + """Get the integer transaction type value this TransactionTypeSpec represents. + + See :any:`TxnType` for the complete list. + + If this is a generic TransactionTypeSpec, i.e. type :code:`txn`, this method will raise an error, since this type does not represent a single transaction type. + """ + raise TealInternalError( + "abi.TransactionTypeSpec does not represent a specific transaction type" + ) + def __eq__(self, other: object) -> bool: return type(self) is type(other) @@ -66,7 +77,9 @@ def type_spec(self) -> TransactionTypeSpec: def get(self) -> TxnObject: return Gtxn[self.index()] - def set(self: T, value: Union[int, Expr, "Transaction", ComputedValue[T]]) -> Expr: + def _set_index( + self: T, value: Union[int, Expr, "Transaction", ComputedValue[T]] + ) -> Expr: match value: case ComputedValue(): return self._set_with_computed_type(value) @@ -106,6 +119,9 @@ def new_instance(self) -> "PaymentTransaction": def annotation_type(self) -> "type[PaymentTransaction]": return PaymentTransaction + def txn_type_enum(self) -> Expr: + return TxnType.Payment + def __str__(self) -> str: return TransactionType.Payment.value @@ -128,6 +144,9 @@ def new_instance(self) -> "KeyRegisterTransaction": def annotation_type(self) -> "type[KeyRegisterTransaction]": return KeyRegisterTransaction + def txn_type_enum(self) -> Expr: + return TxnType.KeyRegistration + def __str__(self) -> str: return TransactionType.KeyRegistration.value @@ -150,6 +169,9 @@ def new_instance(self) -> "AssetConfigTransaction": def annotation_type(self) -> "type[AssetConfigTransaction]": return AssetConfigTransaction + def txn_type_enum(self) -> Expr: + return TxnType.AssetConfig + def __str__(self) -> str: return TransactionType.AssetConfig.value @@ -172,6 +194,9 @@ def new_instance(self) -> "AssetFreezeTransaction": def annotation_type(self) -> "type[AssetFreezeTransaction]": return AssetFreezeTransaction + def txn_type_enum(self) -> Expr: + return TxnType.AssetFreeze + def __str__(self) -> str: return TransactionType.AssetFreeze.value @@ -194,6 +219,9 @@ def new_instance(self) -> "AssetTransferTransaction": def annotation_type(self) -> "type[AssetTransferTransaction]": return AssetTransferTransaction + def txn_type_enum(self) -> Expr: + return TxnType.AssetTransfer + def __str__(self) -> str: return TransactionType.AssetTransfer.value @@ -216,6 +244,9 @@ def new_instance(self) -> "ApplicationCallTransaction": def annotation_type(self) -> "type[ApplicationCallTransaction]": return ApplicationCallTransaction + def txn_type_enum(self) -> Expr: + return TxnType.ApplicationCall + def __str__(self) -> str: return TransactionType.ApplicationCall.value diff --git a/pyteal/ast/abi/transaction_test.py b/pyteal/ast/abi/transaction_test.py index fffbe8f00..5803b8f32 100644 --- a/pyteal/ast/abi/transaction_test.py +++ b/pyteal/ast/abi/transaction_test.py @@ -14,29 +14,46 @@ class TransactionTypeTest: ts: abi.TransactionTypeSpec t: abi.Transaction s: str + txn_type_enum: pt.Expr | None TransactionValues: List[TransactionTypeTest] = [ - TransactionTypeTest(abi.TransactionTypeSpec(), abi.Transaction(), "txn"), + TransactionTypeTest(abi.TransactionTypeSpec(), abi.Transaction(), "txn", None), TransactionTypeTest( - abi.KeyRegisterTransactionTypeSpec(), abi.KeyRegisterTransaction(), "keyreg" + abi.KeyRegisterTransactionTypeSpec(), + abi.KeyRegisterTransaction(), + "keyreg", + pt.TxnType.KeyRegistration, ), TransactionTypeTest( - abi.PaymentTransactionTypeSpec(), abi.PaymentTransaction(), "pay" + abi.PaymentTransactionTypeSpec(), + abi.PaymentTransaction(), + "pay", + pt.TxnType.Payment, ), TransactionTypeTest( - abi.AssetConfigTransactionTypeSpec(), abi.AssetConfigTransaction(), "acfg" + abi.AssetConfigTransactionTypeSpec(), + abi.AssetConfigTransaction(), + "acfg", + pt.TxnType.AssetConfig, ), TransactionTypeTest( - abi.AssetFreezeTransactionTypeSpec(), abi.AssetFreezeTransaction(), "afrz" + abi.AssetFreezeTransactionTypeSpec(), + abi.AssetFreezeTransaction(), + "afrz", + pt.TxnType.AssetFreeze, ), TransactionTypeTest( - abi.AssetTransferTransactionTypeSpec(), abi.AssetTransferTransaction(), "axfer" + abi.AssetTransferTransactionTypeSpec(), + abi.AssetTransferTransaction(), + "axfer", + pt.TxnType.AssetTransfer, ), TransactionTypeTest( abi.ApplicationCallTransactionTypeSpec(), abi.ApplicationCallTransaction(), "appl", + pt.TxnType.ApplicationCall, ), ] @@ -56,6 +73,18 @@ def test_TransactionTypeSpec_new_instance(): assert isinstance(tv.ts.new_instance(), abi.Transaction) +def test_TransactionTypeSpec_txn_type_enum(): + for tv in TransactionValues: + if tv.txn_type_enum is None: + with pytest.raises( + pt.TealInternalError, + match=r"abi.TransactionTypeSpec does not represent a specific transaction type$", + ): + tv.ts.txn_type_enum() + else: + assert tv.ts.txn_type_enum() is tv.txn_type_enum + + def test_TransactionTypeSpec_eq(): for tv in TransactionValues: assert tv.ts == tv.ts @@ -91,10 +120,10 @@ def test_Transaction_get(): assert isinstance(expr, pt.TxnObject) -def test_Transaction_set(): +def test_Transaction__set_index(): for tv in TransactionValues: val_to_set = 2 - expr = tv.t.set(val_to_set) + expr = tv.t._set_index(val_to_set) assert expr.type_of() == pt.TealType.none assert expr.has_return() is False diff --git a/pyteal/ast/router.py b/pyteal/ast/router.py index 6c9f225f4..bc58a74d0 100644 --- a/pyteal/ast/router.py +++ b/pyteal/ast/router.py @@ -23,8 +23,8 @@ from pyteal.ast.assert_ import Assert from pyteal.ast.cond import Cond from pyteal.ast.expr import Expr -from pyteal.ast.app import OnComplete, EnumInt -from pyteal.ast.int import Int +from pyteal.ast.app import OnComplete +from pyteal.ast.int import Int, EnumInt from pyteal.ast.seq import Seq from pyteal.ast.methodsig import MethodSignature from pyteal.ast.naryexpr import And, Or @@ -374,7 +374,7 @@ def wrap_handler( tuplify = len(app_arg_vals) > METHOD_ARG_NUM_CUTOFF # only transaction args (these are omitted from app args) - txn_arg_vals: list[abi.BaseType] = [ + txn_arg_vals: list[abi.Transaction] = [ ats for ats in arg_vals if isinstance(ats, abi.Transaction) ] @@ -413,12 +413,18 @@ def wrap_handler( # and subtract that from the current index to get the absolute position # in the group - txn_decode_instructions: list[Expr] = [ - cast(abi.Transaction, arg_val).set( - Txn.group_index() - Int(txn_arg_len - idx) + txn_decode_instructions: list[Expr] = [] + + for idx, arg_val in enumerate(txn_arg_vals): + txn_decode_instructions.append( + arg_val._set_index(Txn.group_index() - Int(txn_arg_len - idx)) ) - for idx, arg_val in enumerate(txn_arg_vals) - ] + spec = arg_val.type_spec() + if type(spec) is not abi.TransactionTypeSpec: + # this is a specific transaction type + txn_decode_instructions.append( + Assert(arg_val.get().type_enum() == spec.txn_type_enum()) + ) decode_instructions += txn_decode_instructions diff --git a/pyteal/ast/router_test.py b/pyteal/ast/router_test.py index 157e9c41c..c7abd54fd 100644 --- a/pyteal/ast/router_test.py +++ b/pyteal/ast/router_test.py @@ -162,10 +162,13 @@ def multiple_txn( appl: pt.abi.ApplicationCallTransaction, axfer: pt.abi.AssetTransferTransaction, pay: pt.abi.PaymentTransaction, + any_txn: pt.abi.Transaction, *, output: pt.abi.Uint64, ): - return output.set(appl.get().fee() + axfer.get().fee() + pay.get().fee()) + return output.set( + appl.get().fee() + axfer.get().fee() + pay.get().fee() + any_txn.get().fee() + ) GOOD_SUBROUTINE_CASES: list[pt.ABIReturnSubroutine | pt.SubroutineFnWrapper] = [ @@ -470,7 +473,7 @@ def test_wrap_handler_method_call(): app_arg_cnt = len(app_args) - txn_args = [ + txn_args: list[pt.abi.Transaction] = [ arg for arg in args if arg.type_spec() in pt.abi.TransactionTypeSpecs ] @@ -496,14 +499,19 @@ def test_wrap_handler_method_call(): ] if len(txn_args) > 0: - loading.extend( - [ - typing.cast(pt.abi.Transaction, txn_arg).set( + for idx, txn_arg in enumerate(txn_args): + loading.append( + txn_arg._set_index( pt.Txn.group_index() - pt.Int(len(txn_args) - idx) ) - for idx, txn_arg in enumerate(txn_args) - ] - ) + ) + if str(txn_arg.type_spec()) != "txn": + loading.append( + pt.Assert( + txn_arg.get().type_enum() + == txn_arg.type_spec().txn_type_enum() + ) + ) if app_arg_cnt > pt.METHOD_ARG_NUM_CUTOFF: loading.extend( @@ -541,12 +549,17 @@ def test_wrap_handler_method_txn_types(): pt.abi.ApplicationCallTransaction(), pt.abi.AssetTransferTransaction(), pt.abi.PaymentTransaction(), + pt.abi.Transaction(), ] output_temp = pt.abi.Uint64() expected_ast = pt.Seq( - args[0].set(pt.Txn.group_index() - pt.Int(3)), - args[1].set(pt.Txn.group_index() - pt.Int(2)), - args[2].set(pt.Txn.group_index() - pt.Int(1)), + args[0]._set_index(pt.Txn.group_index() - pt.Int(4)), + pt.Assert(args[0].get().type_enum() == pt.TxnType.ApplicationCall), + args[1]._set_index(pt.Txn.group_index() - pt.Int(3)), + pt.Assert(args[1].get().type_enum() == pt.TxnType.AssetTransfer), + args[2]._set_index(pt.Txn.group_index() - pt.Int(2)), + pt.Assert(args[2].get().type_enum() == pt.TxnType.Payment), + args[3]._set_index(pt.Txn.group_index() - pt.Int(1)), multiple_txn(*args).store_into(output_temp), pt.abi.MethodReturn(output_temp), pt.Approve(),