diff --git a/pyteal/ast/router.py b/pyteal/ast/router.py index 27ec1d51b..994af7e07 100644 --- a/pyteal/ast/router.py +++ b/pyteal/ast/router.py @@ -29,7 +29,7 @@ from pyteal.ast.methodsig import MethodSignature from pyteal.ast.naryexpr import And, Or from pyteal.ast.txn import Txn -from pyteal.ast.return_ import Approve +from pyteal.ast.return_ import Approve, Reject class CallConfig(IntFlag): @@ -50,7 +50,7 @@ class CallConfig(IntFlag): CREATE = 2 ALL = 3 - def condition_under_config(self) -> Expr | int: + def approval_condition_under_config(self) -> Expr | int: match self: case CallConfig.NEVER: return 0 @@ -63,6 +63,19 @@ def condition_under_config(self) -> Expr | int: case _: raise TealInternalError(f"unexpected CallConfig {self}") + def clear_state_condition_under_config(self) -> int: + match self: + case CallConfig.NEVER: + return 0 + case CallConfig.CALL: + return 1 + case CallConfig.CREATE | CallConfig.ALL: + raise TealInputError( + "Only CallConfig.CALL or CallConfig.NEVER are valid for a clear state CallConfig, since clear state can never be invoked during creation" + ) + case _: + raise TealInputError(f"unexpected CallConfig {self}") + CallConfig.__module__ = "pyteal" @@ -101,7 +114,7 @@ def approval_cond(self) -> Expr | int: else: cond_list = [] for config, oc in config_oc_pairs: - config_cond = config.condition_under_config() + config_cond = config.approval_condition_under_config() match config_cond: case Expr(): cond_list.append(And(Txn.on_completion() == oc, config_cond)) @@ -116,7 +129,7 @@ def approval_cond(self) -> Expr | int: return Or(*cond_list) def clear_state_cond(self) -> Expr | int: - return self.clear_state.condition_under_config() + return self.clear_state.clear_state_condition_under_config() @dataclass(frozen=True) @@ -214,7 +227,11 @@ def approval_construction(self) -> Optional[Expr]: cond_body = wrapped_handler case CallConfig.CALL | CallConfig.CREATE: cond_body = Seq( - Assert(cast(Expr, oca.call_config.condition_under_config())), + Assert( + cast( + Expr, oca.call_config.approval_condition_under_config() + ) + ), wrapped_handler, ) case _: @@ -233,29 +250,16 @@ def clear_state_construction(self) -> Optional[Expr]: if self.clear_state.is_empty(): return None - wrapped_handler = ASTBuilder.wrap_handler( + # call this to make sure we error if the CallConfig is CREATE or ALL + self.clear_state.call_config.clear_state_condition_under_config() + + return ASTBuilder.wrap_handler( False, cast( Expr | SubroutineFnWrapper | ABIReturnSubroutine, self.clear_state.action, ), ) - match self.clear_state.call_config: - case CallConfig.ALL: - return wrapped_handler - case CallConfig.CALL | CallConfig.CREATE: - return Seq( - Assert( - cast( - Expr, self.clear_state.call_config.condition_under_config() - ) - ), - wrapped_handler, - ) - case _: - raise TealInternalError( - f"Unexpected CallConfig: {self.clear_state.call_config!r}" - ) BareCallActions.__module__ = "pyteal" @@ -475,7 +479,7 @@ def add_method_to_ast( def program_construction(self) -> Expr: if not self.conditions_n_branches: - raise TealInputError("ABIRouter: Cannot build program with an empty AST") + return Reject() return Cond(*[[n.condition, n.branch] for n in self.conditions_n_branches]) @@ -659,6 +663,9 @@ def build_program(self) -> tuple[Expr, Expr, sdk_abi.Contract]: Constructs ASTs for approval and clear-state programs from the registered methods in the router, also generates a JSON object of contract to allow client read and call the methods easily. + Note that if no methods or bare app call actions have been registered to either the approval + or clear state programs, then that program will reject all transactions. + Returns: approval_program: AST for approval program clear_state_program: AST for clear-state program @@ -681,6 +688,9 @@ def compile_program( Combining `build_program` and `compileTeal`, compiles built Approval and ClearState programs and returns Contract JSON object for off-chain calling. + Note that if no methods or bare app call actions have been registered to either the approval + or clear state programs, then that program will reject all transactions. + Returns: approval_program: compiled approval program clear_state_program: compiled clear-state program diff --git a/pyteal/ast/router_test.py b/pyteal/ast/router_test.py index 4390f0109..157e9c41c 100644 --- a/pyteal/ast/router_test.py +++ b/pyteal/ast/router_test.py @@ -266,8 +266,8 @@ def camel_to_snake(name: str) -> str: def test_call_config(): for cc in pt.CallConfig: - cond_on_cc: pt.Expr | int = cc.condition_under_config() - match cond_on_cc: + approval_cond_on_cc: pt.Expr | int = cc.approval_condition_under_config() + match approval_cond_on_cc: case pt.Expr(): expected_cc = ( (pt.Txn.application_id() == pt.Int(0)) @@ -275,11 +275,34 @@ def test_call_config(): else (pt.Txn.application_id() != pt.Int(0)) ) with pt.TealComponent.Context.ignoreExprEquality(): - assert assemble_helper(cond_on_cc) == assemble_helper(expected_cc) + assert assemble_helper(approval_cond_on_cc) == assemble_helper( + expected_cc + ) case int(): - assert cond_on_cc == int(cc) & 1 + assert approval_cond_on_cc == int(cc) & 1 + case _: + raise pt.TealInternalError( + f"unexpected approval_cond_on_cc {approval_cond_on_cc}" + ) + + if cc in (pt.CallConfig.CREATE, pt.CallConfig.ALL): + with pytest.raises( + pt.TealInputError, + match=r"Only CallConfig.CALL or CallConfig.NEVER are valid for a clear state CallConfig, since clear state can never be invoked during creation$", + ): + cc.clear_state_condition_under_config() + continue + + clear_state_cond_on_cc: int = cc.clear_state_condition_under_config() + match clear_state_cond_on_cc: + case 0: + assert cc == pt.CallConfig.NEVER + case 1: + assert cc == pt.CallConfig.CALL case _: - raise pt.TealInternalError(f"unexpected cond_on_cc {cond_on_cc}") + raise pt.TealInternalError( + f"unexpected clear_state_cond_on_cc {clear_state_cond_on_cc}" + ) def test_method_config(): @@ -304,18 +327,14 @@ def test_method_config(): match mc.clear_state: case pt.CallConfig.NEVER: assert mc.clear_state_cond() == 0 - case pt.CallConfig.ALL: - assert mc.clear_state_cond() == 1 case pt.CallConfig.CALL: - with pt.TealComponent.Context.ignoreExprEquality(): - assert assemble_helper( - mc.clear_state_cond() - ) == assemble_helper(pt.Txn.application_id() != pt.Int(0)) - case pt.CallConfig.CREATE: - with pt.TealComponent.Context.ignoreExprEquality(): - assert assemble_helper( - mc.clear_state_cond() - ) == assemble_helper(pt.Txn.application_id() == pt.Int(0)) + assert mc.clear_state_cond() == 1 + case pt.CallConfig.CREATE | pt.CallConfig.ALL: + with pytest.raises( + pt.TealInputError, + match=r"Only CallConfig.CALL or CallConfig.NEVER are valid for a clear state CallConfig, since clear state can never be invoked during creation$", + ): + mc.clear_state_cond() if mc.is_never() or all( getattr(mc, i) == pt.CallConfig.NEVER for i, _ in approval_check_names_n_ocs @@ -330,7 +349,9 @@ def test_method_config(): continue list_of_cc = [ ( - typing.cast(pt.CallConfig, getattr(mc, i)).condition_under_config(), + typing.cast( + pt.CallConfig, getattr(mc, i) + ).approval_condition_under_config(), oc, ) for i, oc in approval_check_names_n_ocs @@ -624,3 +645,146 @@ def test_contract_json_obj(): sdk_contract = sdk_abi.Contract(contract_name, method_list) contract = router.contract_construct() assert contract == sdk_contract + + +def test_build_program_all_empty(): + router = pt.Router("test") + + approval, clear_state, contract = router.build_program() + + expected_empty_program = pt.TealSimpleBlock( + [ + pt.TealOp(None, pt.Op.int, 0), + pt.TealOp(None, pt.Op.return_), + ] + ) + + with pt.TealComponent.Context.ignoreExprEquality(): + assert assemble_helper(approval) == expected_empty_program + assert assemble_helper(clear_state) == expected_empty_program + + expected_contract = sdk_abi.Contract("test", []) + assert contract == expected_contract + + +def test_build_program_approval_empty(): + router = pt.Router( + "test", + pt.BareCallActions(clear_state=pt.OnCompleteAction.call_only(pt.Approve())), + ) + + approval, clear_state, contract = router.build_program() + + expected_empty_program = pt.TealSimpleBlock( + [ + pt.TealOp(None, pt.Op.int, 0), + pt.TealOp(None, pt.Op.return_), + ] + ) + + with pt.TealComponent.Context.ignoreExprEquality(): + assert assemble_helper(approval) == expected_empty_program + assert assemble_helper(clear_state) != expected_empty_program + + expected_contract = sdk_abi.Contract("test", []) + assert contract == expected_contract + + +def test_build_program_clear_state_empty(): + router = pt.Router( + "test", pt.BareCallActions(no_op=pt.OnCompleteAction.always(pt.Approve())) + ) + + approval, clear_state, contract = router.build_program() + + expected_empty_program = pt.TealSimpleBlock( + [ + pt.TealOp(None, pt.Op.int, 0), + pt.TealOp(None, pt.Op.return_), + ] + ) + + with pt.TealComponent.Context.ignoreExprEquality(): + assert assemble_helper(approval) != expected_empty_program + assert assemble_helper(clear_state) == expected_empty_program + + expected_contract = sdk_abi.Contract("test", []) + assert contract == expected_contract + + +def test_build_program_clear_state_invalid_config(): + for config in (pt.CallConfig.CREATE, pt.CallConfig.ALL): + bareCalls = pt.BareCallActions( + clear_state=pt.OnCompleteAction(action=pt.Approve(), call_config=config) + ) + with pytest.raises( + pt.TealInputError, + match=r"Only CallConfig.CALL or CallConfig.NEVER are valid for a clear state CallConfig, since clear state can never be invoked during creation$", + ): + pt.Router("test", bareCalls) + + router = pt.Router("test") + + @pt.ABIReturnSubroutine + def clear_state_method(): + return pt.Approve() + + with pytest.raises( + pt.TealInputError, + match=r"Only CallConfig.CALL or CallConfig.NEVER are valid for a clear state CallConfig, since clear state can never be invoked during creation$", + ): + router.add_method_handler( + clear_state_method, + method_config=pt.MethodConfig(clear_state=config), + ) + + +def test_build_program_clear_state_valid_config(): + action = pt.If(pt.Txn.fee() == pt.Int(4)).Then(pt.Approve()).Else(pt.Reject()) + config = pt.CallConfig.CALL + + router_with_bare_call = pt.Router( + "test", + pt.BareCallActions( + clear_state=pt.OnCompleteAction(action=action, call_config=config) + ), + ) + _, actual_clear_state_with_bare_call, _ = router_with_bare_call.build_program() + + expected_clear_state_with_bare_call = assemble_helper( + pt.Cond([pt.Txn.application_args.length() == pt.Int(0), action]) + ) + + with pt.TealComponent.Context.ignoreExprEquality(): + assert ( + assemble_helper(actual_clear_state_with_bare_call) + == expected_clear_state_with_bare_call + ) + + router_with_method = pt.Router("test") + + @pt.ABIReturnSubroutine + def clear_state_method(): + return action + + router_with_method.add_method_handler( + clear_state_method, method_config=pt.MethodConfig(clear_state=config) + ) + + _, actual_clear_state_with_method, _ = router_with_method.build_program() + + expected_clear_state_with_method = assemble_helper( + pt.Cond( + [ + pt.Txn.application_args[0] + == pt.MethodSignature("clear_state_method()void"), + pt.Seq(clear_state_method(), pt.Approve()), + ] + ) + ) + + with pt.TealComponent.Context.ignoreExprEquality(): + assert ( + assemble_helper(actual_clear_state_with_method) + == expected_clear_state_with_method + ) diff --git a/pyteal/compiler/compiler_test.py b/pyteal/compiler/compiler_test.py index 1aa88a6cc..fb1172a2d 100644 --- a/pyteal/compiler/compiler_test.py +++ b/pyteal/compiler/compiler_test.py @@ -2830,10 +2830,6 @@ def approve_if_odd(condition_encoding: pt.abi.Uint32) -> pt.Expr: bnz main_l5 err main_l5: -txn ApplicationID -int 0 -!= -assert txna ApplicationArgs 1 int 0 extract_uint32 @@ -2843,10 +2839,6 @@ def approve_if_odd(condition_encoding: pt.abi.Uint32) -> pt.Expr: int 1 return main_l6: -txn ApplicationID -int 0 -!= -assert callsub log1_1 store 1 byte 0x151f7c75 @@ -2857,18 +2849,10 @@ def approve_if_odd(condition_encoding: pt.abi.Uint32) -> pt.Expr: int 1 return main_l7: -txn ApplicationID -int 0 -!= -assert callsub emptyreturnsubroutine_0 int 1 return main_l8: -txn ApplicationID -int 0 -!= -assert int 1 return @@ -3369,10 +3353,6 @@ def approve_if_odd(condition_encoding: pt.abi.Uint32) -> pt.Expr: bnz main_l4 err main_l4: -txn ApplicationID -int 0 -!= -assert txna ApplicationArgs 1 int 0 extract_uint32 @@ -3382,10 +3362,6 @@ def approve_if_odd(condition_encoding: pt.abi.Uint32) -> pt.Expr: int 1 return main_l5: -txn ApplicationID -int 0 -!= -assert callsub log1_1 store 1 byte 0x151f7c75 @@ -3396,10 +3372,6 @@ def approve_if_odd(condition_encoding: pt.abi.Uint32) -> pt.Expr: int 1 return main_l6: -txn ApplicationID -int 0 -!= -assert callsub emptyreturnsubroutine_0 int 1 return