Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check transaction type arguments #427

Merged
merged 2 commits into from
Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions pyteal/ast/abi/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from pyteal.ast.abi.type import BaseType, ComputedValue, TypeSpec
from pyteal.ast.expr import Expr
from pyteal.ast.int import Int
from pyteal.ast.txn import TxnObject
from pyteal.ast.txn import TxnObject, TxnType
from pyteal.ast.gtxn import Gtxn
from pyteal.types import TealType
from pyteal.errors import TealInputError
from pyteal.errors import TealInputError, TealInternalError

T = TypeVar("T", bound=BaseType)

Expand Down Expand Up @@ -43,6 +43,17 @@ def byte_length_static(self) -> int:
def storage_type(self) -> TealType:
return TealType.uint64

def txn_type_enum(self) -> Expr:
"""Get the integer transaction type value this TransactionTypeSpec represents.

See :any:`TxnType` for the complete list.

If this is a generic TransactionTypeSpec, i.e. type :code:`txn`, this method will raise an error, since this type does not represent a single transaction type.
"""
raise TealInternalError(
"abi.TransactionTypeSpec does not represent a specific transaction type"
)

def __eq__(self, other: object) -> bool:
return type(self) is type(other)

Expand All @@ -66,7 +77,9 @@ def type_spec(self) -> TransactionTypeSpec:
def get(self) -> TxnObject:
return Gtxn[self.index()]

def set(self: T, value: Union[int, Expr, "Transaction", ComputedValue[T]]) -> Expr:
def _set_index(
self: T, value: Union[int, Expr, "Transaction", ComputedValue[T]]
) -> Expr:
match value:
case ComputedValue():
return self._set_with_computed_type(value)
Expand Down Expand Up @@ -106,6 +119,9 @@ def new_instance(self) -> "PaymentTransaction":
def annotation_type(self) -> "type[PaymentTransaction]":
return PaymentTransaction

def txn_type_enum(self) -> Expr:
return TxnType.Payment

def __str__(self) -> str:
return TransactionType.Payment.value

Expand All @@ -128,6 +144,9 @@ def new_instance(self) -> "KeyRegisterTransaction":
def annotation_type(self) -> "type[KeyRegisterTransaction]":
return KeyRegisterTransaction

def txn_type_enum(self) -> Expr:
return TxnType.KeyRegistration

def __str__(self) -> str:
return TransactionType.KeyRegistration.value

Expand All @@ -150,6 +169,9 @@ def new_instance(self) -> "AssetConfigTransaction":
def annotation_type(self) -> "type[AssetConfigTransaction]":
return AssetConfigTransaction

def txn_type_enum(self) -> Expr:
return TxnType.AssetConfig

def __str__(self) -> str:
return TransactionType.AssetConfig.value

Expand All @@ -172,6 +194,9 @@ def new_instance(self) -> "AssetFreezeTransaction":
def annotation_type(self) -> "type[AssetFreezeTransaction]":
return AssetFreezeTransaction

def txn_type_enum(self) -> Expr:
return TxnType.AssetFreeze

def __str__(self) -> str:
return TransactionType.AssetFreeze.value

Expand All @@ -194,6 +219,9 @@ def new_instance(self) -> "AssetTransferTransaction":
def annotation_type(self) -> "type[AssetTransferTransaction]":
return AssetTransferTransaction

def txn_type_enum(self) -> Expr:
return TxnType.AssetTransfer

def __str__(self) -> str:
return TransactionType.AssetTransfer.value

Expand All @@ -216,6 +244,9 @@ def new_instance(self) -> "ApplicationCallTransaction":
def annotation_type(self) -> "type[ApplicationCallTransaction]":
return ApplicationCallTransaction

def txn_type_enum(self) -> Expr:
return TxnType.ApplicationCall

def __str__(self) -> str:
return TransactionType.ApplicationCall.value

Expand Down
45 changes: 37 additions & 8 deletions pyteal/ast/abi/transaction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,46 @@ class TransactionTypeTest:
ts: abi.TransactionTypeSpec
t: abi.Transaction
s: str
txn_type_enum: pt.Expr | None


TransactionValues: List[TransactionTypeTest] = [
TransactionTypeTest(abi.TransactionTypeSpec(), abi.Transaction(), "txn"),
TransactionTypeTest(abi.TransactionTypeSpec(), abi.Transaction(), "txn", None),
TransactionTypeTest(
abi.KeyRegisterTransactionTypeSpec(), abi.KeyRegisterTransaction(), "keyreg"
abi.KeyRegisterTransactionTypeSpec(),
abi.KeyRegisterTransaction(),
"keyreg",
pt.TxnType.KeyRegistration,
),
TransactionTypeTest(
abi.PaymentTransactionTypeSpec(), abi.PaymentTransaction(), "pay"
abi.PaymentTransactionTypeSpec(),
abi.PaymentTransaction(),
"pay",
pt.TxnType.Payment,
),
TransactionTypeTest(
abi.AssetConfigTransactionTypeSpec(), abi.AssetConfigTransaction(), "acfg"
abi.AssetConfigTransactionTypeSpec(),
abi.AssetConfigTransaction(),
"acfg",
pt.TxnType.AssetConfig,
),
TransactionTypeTest(
abi.AssetFreezeTransactionTypeSpec(), abi.AssetFreezeTransaction(), "afrz"
abi.AssetFreezeTransactionTypeSpec(),
abi.AssetFreezeTransaction(),
"afrz",
pt.TxnType.AssetFreeze,
),
TransactionTypeTest(
abi.AssetTransferTransactionTypeSpec(), abi.AssetTransferTransaction(), "axfer"
abi.AssetTransferTransactionTypeSpec(),
abi.AssetTransferTransaction(),
"axfer",
pt.TxnType.AssetTransfer,
),
TransactionTypeTest(
abi.ApplicationCallTransactionTypeSpec(),
abi.ApplicationCallTransaction(),
"appl",
pt.TxnType.ApplicationCall,
),
]

Expand All @@ -56,6 +73,18 @@ def test_TransactionTypeSpec_new_instance():
assert isinstance(tv.ts.new_instance(), abi.Transaction)


def test_TransactionTypeSpec_txn_type_enum():
for tv in TransactionValues:
if tv.txn_type_enum is None:
with pytest.raises(
pt.TealInternalError,
match=r"abi.TransactionTypeSpec does not represent a specific transaction type$",
):
tv.ts.txn_type_enum()
else:
assert tv.ts.txn_type_enum() is tv.txn_type_enum


def test_TransactionTypeSpec_eq():
for tv in TransactionValues:
assert tv.ts == tv.ts
Expand Down Expand Up @@ -91,10 +120,10 @@ def test_Transaction_get():
assert isinstance(expr, pt.TxnObject)


def test_Transaction_set():
def test_Transaction__set_index():
for tv in TransactionValues:
val_to_set = 2
expr = tv.t.set(val_to_set)
expr = tv.t._set_index(val_to_set)

assert expr.type_of() == pt.TealType.none
assert expr.has_return() is False
Expand Down
22 changes: 14 additions & 8 deletions pyteal/ast/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from pyteal.ast.assert_ import Assert
from pyteal.ast.cond import Cond
from pyteal.ast.expr import Expr
from pyteal.ast.app import OnComplete, EnumInt
from pyteal.ast.int import Int
from pyteal.ast.app import OnComplete
from pyteal.ast.int import Int, EnumInt
from pyteal.ast.seq import Seq
from pyteal.ast.methodsig import MethodSignature
from pyteal.ast.naryexpr import And, Or
Expand Down Expand Up @@ -374,7 +374,7 @@ def wrap_handler(
tuplify = len(app_arg_vals) > METHOD_ARG_NUM_CUTOFF

# only transaction args (these are omitted from app args)
txn_arg_vals: list[abi.BaseType] = [
txn_arg_vals: list[abi.Transaction] = [
ats for ats in arg_vals if isinstance(ats, abi.Transaction)
]

Expand Down Expand Up @@ -413,12 +413,18 @@ def wrap_handler(
# and subtract that from the current index to get the absolute position
# in the group

txn_decode_instructions: list[Expr] = [
cast(abi.Transaction, arg_val).set(
Txn.group_index() - Int(txn_arg_len - idx)
txn_decode_instructions: list[Expr] = []

for idx, arg_val in enumerate(txn_arg_vals):
txn_decode_instructions.append(
arg_val._set_index(Txn.group_index() - Int(txn_arg_len - idx))
)
for idx, arg_val in enumerate(txn_arg_vals)
]
spec = arg_val.type_spec()
if type(spec) is not abi.TransactionTypeSpec:
# this is a specific transaction type
txn_decode_instructions.append(
Assert(arg_val.get().type_enum() == spec.txn_type_enum())
)

decode_instructions += txn_decode_instructions

Expand Down
35 changes: 24 additions & 11 deletions pyteal/ast/router_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,13 @@ def multiple_txn(
appl: pt.abi.ApplicationCallTransaction,
axfer: pt.abi.AssetTransferTransaction,
pay: pt.abi.PaymentTransaction,
any_txn: pt.abi.Transaction,
*,
output: pt.abi.Uint64,
):
return output.set(appl.get().fee() + axfer.get().fee() + pay.get().fee())
return output.set(
appl.get().fee() + axfer.get().fee() + pay.get().fee() + any_txn.get().fee()
)


GOOD_SUBROUTINE_CASES: list[pt.ABIReturnSubroutine | pt.SubroutineFnWrapper] = [
Expand Down Expand Up @@ -470,7 +473,7 @@ def test_wrap_handler_method_call():

app_arg_cnt = len(app_args)

txn_args = [
txn_args: list[pt.abi.Transaction] = [
arg for arg in args if arg.type_spec() in pt.abi.TransactionTypeSpecs
]

Expand All @@ -496,14 +499,19 @@ def test_wrap_handler_method_call():
]

if len(txn_args) > 0:
loading.extend(
[
typing.cast(pt.abi.Transaction, txn_arg).set(
for idx, txn_arg in enumerate(txn_args):
loading.append(
txn_arg._set_index(
pt.Txn.group_index() - pt.Int(len(txn_args) - idx)
)
for idx, txn_arg in enumerate(txn_args)
]
)
)
if str(txn_arg.type_spec()) != "txn":
loading.append(
pt.Assert(
txn_arg.get().type_enum()
== txn_arg.type_spec().txn_type_enum()
)
)

if app_arg_cnt > pt.METHOD_ARG_NUM_CUTOFF:
loading.extend(
Expand Down Expand Up @@ -541,12 +549,17 @@ def test_wrap_handler_method_txn_types():
pt.abi.ApplicationCallTransaction(),
pt.abi.AssetTransferTransaction(),
pt.abi.PaymentTransaction(),
pt.abi.Transaction(),
]
output_temp = pt.abi.Uint64()
expected_ast = pt.Seq(
args[0].set(pt.Txn.group_index() - pt.Int(3)),
args[1].set(pt.Txn.group_index() - pt.Int(2)),
args[2].set(pt.Txn.group_index() - pt.Int(1)),
args[0]._set_index(pt.Txn.group_index() - pt.Int(4)),
pt.Assert(args[0].get().type_enum() == pt.TxnType.ApplicationCall),
args[1]._set_index(pt.Txn.group_index() - pt.Int(3)),
pt.Assert(args[1].get().type_enum() == pt.TxnType.AssetTransfer),
args[2]._set_index(pt.Txn.group_index() - pt.Int(2)),
pt.Assert(args[2].get().type_enum() == pt.TxnType.Payment),
args[3]._set_index(pt.Txn.group_index() - pt.Int(1)),
multiple_txn(*args).store_into(output_temp),
pt.abi.MethodReturn(output_temp),
pt.Approve(),
Expand Down