diff --git a/anta/runner.py b/anta/runner.py index 7fc7d347c..2d9a3ed2f 100644 --- a/anta/runner.py +++ b/anta/runner.py @@ -10,7 +10,7 @@ import os import resource from collections import defaultdict -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from anta import GITHUB_SUGGESTION from anta.logger import anta_log_exception, exc_to_str @@ -18,10 +18,13 @@ from anta.tools import Catchtime if TYPE_CHECKING: + from collections.abc import Coroutine + from anta.catalog import AntaCatalog, AntaTestDefinition from anta.device import AntaDevice from anta.inventory import AntaInventory from anta.result_manager import ResultManager + from anta.result_manager.models import TestResult logger = logging.getLogger(__name__) @@ -108,7 +111,7 @@ async def setup_inventory(inventory: AntaInventory, tags: set[str] | None, devic return selected_inventory -async def prepare_tests( +def prepare_tests( inventory: AntaInventory, catalog: AntaCatalog, tests: set[str] | None, tags: set[str] | None ) -> defaultdict[AntaDevice, set[AntaTestDefinition]] | None: """Prepare the tests to run. @@ -154,7 +157,37 @@ async def prepare_tests( return device_to_tests -async def main( # noqa: PLR0913,C901 +def get_coroutines(selected_tests: defaultdict[AntaDevice, set[AntaTestDefinition]]) -> list[Coroutine[Any, Any, TestResult]]: + """Get the coroutines for the ANTA run. + + Args: + ---- + selected_tests: A mapping of devices to the tests to run. The selected tests are generated by the `prepare_tests` function. + + Returns + ------- + The list of coroutines to run. + """ + coros = [] + for device, test_definitions in selected_tests.items(): + for test in test_definitions: + try: + test_instance = test.test(device=device, inputs=test.inputs) + coros.append(test_instance.test()) + except Exception as e: # noqa: PERF203, pylint: disable=broad-exception-caught + # An AntaTest instance is potentially user-defined code. + # We need to catch everything and exit gracefully with an error message. + message = "\n".join( + [ + f"There is an error when creating test {test.test.__module__}.{test.test.__name__}.", + f"If this is not a custom test implementation: {GITHUB_SUGGESTION}", + ], + ) + anta_log_exception(e, message, logger) + return coros + + +async def main( # noqa: PLR0913 manager: ResultManager, inventory: AntaInventory, catalog: AntaCatalog, @@ -196,7 +229,7 @@ async def main( # noqa: PLR0913,C901 return with Catchtime(logger=logger, message="Preparing the tests"): - selected_tests = await prepare_tests(selected_inventory, catalog, tests, tags) + selected_tests = prepare_tests(selected_inventory, catalog, tests, tags) if selected_tests is None: return @@ -217,34 +250,19 @@ async def main( # noqa: PLR0913,C901 "Please consult the ANTA FAQ." ) - coros = [] - for device, test_definitions in selected_tests.items(): - for test in test_definitions: - try: - test_instance = test.test(device=device, inputs=test.inputs) - coros.append(test_instance.test()) - except Exception as e: # noqa: PERF203, pylint: disable=broad-exception-caught - # An AntaTest instance is potentially user-defined code. - # We need to catch everything and exit gracefully with an error message. - message = "\n".join( - [ - f"There is an error when creating test {test.test.__module__}.{test.test.__name__}.", - f"If this is not a custom test implementation: {GITHUB_SUGGESTION}", - ], - ) - anta_log_exception(e, message, logger) + coroutines = get_coroutines(selected_tests) if dry_run: logger.info("Dry-run mode, exiting before running the tests.") - for coro in coros: + for coro in coroutines: coro.close() return if AntaTest.progress is not None: - AntaTest.nrfu_task = AntaTest.progress.add_task("Running NRFU Tests...", total=len(coros)) + AntaTest.nrfu_task = AntaTest.progress.add_task("Running NRFU Tests...", total=len(coroutines)) with Catchtime(logger=logger, message="Running ANTA tests"): - test_results = await asyncio.gather(*coros) + test_results = await asyncio.gather(*coroutines) for r in test_results: manager.add(r) diff --git a/tests/units/test_runner.py b/tests/units/test_runner.py index e3f9536da..955149d09 100644 --- a/tests/units/test_runner.py +++ b/tests/units/test_runner.py @@ -162,7 +162,7 @@ async def test_prepare_tests( caplog.set_level(logging.INFO) catalog: AntaCatalog = AntaCatalog.parse(str(DATA_DIR / "test_catalog_with_tags.yml")) - selected_tests = await prepare_tests(inventory=test_inventory, catalog=catalog, tags=tags, tests=None) + selected_tests = prepare_tests(inventory=test_inventory, catalog=catalog, tags=tags, tests=None) if selected_tests is None: assert expected_tests_count == 0 @@ -180,7 +180,7 @@ async def test_prepare_tests_with_specific_tests(caplog: pytest.LogCaptureFixtur caplog.set_level(logging.INFO) catalog: AntaCatalog = AntaCatalog.parse(str(DATA_DIR / "test_catalog_with_tags.yml")) - selected_tests = await prepare_tests(inventory=test_inventory, catalog=catalog, tags=None, tests={"VerifyMlagStatus", "VerifyUptime"}) + selected_tests = prepare_tests(inventory=test_inventory, catalog=catalog, tags=None, tests={"VerifyMlagStatus", "VerifyUptime"}) assert selected_tests is not None assert len(selected_tests) == 3