Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(anta): Refactor runner again #656

Merged
merged 3 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 41 additions & 23 deletions anta/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,21 @@
import os
import resource
from collections import defaultdict
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from anta import GITHUB_SUGGESTION
from anta.logger import anta_log_exception, exc_to_str
from anta.models import AntaTest
from anta.tools import Catchtime

if TYPE_CHECKING:
from collections.abc import Coroutine

from anta.catalog import AntaCatalog, AntaTestDefinition
from anta.device import AntaDevice
from anta.inventory import AntaInventory
from anta.result_manager import ResultManager
from anta.result_manager.models import TestResult

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -108,7 +111,7 @@ async def setup_inventory(inventory: AntaInventory, tags: set[str] | None, devic
return selected_inventory


async def prepare_tests(
def prepare_tests(
inventory: AntaInventory, catalog: AntaCatalog, tests: set[str] | None, tags: set[str] | None
) -> defaultdict[AntaDevice, set[AntaTestDefinition]] | None:
"""Prepare the tests to run.
Expand Down Expand Up @@ -154,7 +157,37 @@ async def prepare_tests(
return device_to_tests


async def main( # noqa: PLR0913,C901
def get_coroutines(selected_tests: defaultdict[AntaDevice, set[AntaTestDefinition]]) -> list[Coroutine[Any, Any, TestResult]]:
"""Get the coroutines for the ANTA run.

Args:
----
selected_tests: A mapping of devices to the tests to run. The selected tests are generated by the `prepare_tests` function.

Returns
-------
The list of coroutines to run.
"""
coros = []
for device, test_definitions in selected_tests.items():
for test in test_definitions:
try:
test_instance = test.test(device=device, inputs=test.inputs)
coros.append(test_instance.test())
except Exception as e: # noqa: PERF203, pylint: disable=broad-exception-caught
# An AntaTest instance is potentially user-defined code.
# We need to catch everything and exit gracefully with an error message.
message = "\n".join(
[
f"There is an error when creating test {test.test.__module__}.{test.test.__name__}.",
f"If this is not a custom test implementation: {GITHUB_SUGGESTION}",
],
)
anta_log_exception(e, message, logger)
return coros


async def main( # noqa: PLR0913
manager: ResultManager,
inventory: AntaInventory,
catalog: AntaCatalog,
Expand Down Expand Up @@ -196,7 +229,7 @@ async def main( # noqa: PLR0913,C901
return

with Catchtime(logger=logger, message="Preparing the tests"):
selected_tests = await prepare_tests(selected_inventory, catalog, tests, tags)
selected_tests = prepare_tests(selected_inventory, catalog, tests, tags)
if selected_tests is None:
return

Expand All @@ -217,34 +250,19 @@ async def main( # noqa: PLR0913,C901
"Please consult the ANTA FAQ."
)

coros = []
for device, test_definitions in selected_tests.items():
for test in test_definitions:
try:
test_instance = test.test(device=device, inputs=test.inputs)
coros.append(test_instance.test())
except Exception as e: # noqa: PERF203, pylint: disable=broad-exception-caught
# An AntaTest instance is potentially user-defined code.
# We need to catch everything and exit gracefully with an error message.
message = "\n".join(
[
f"There is an error when creating test {test.test.__module__}.{test.test.__name__}.",
f"If this is not a custom test implementation: {GITHUB_SUGGESTION}",
],
)
anta_log_exception(e, message, logger)
coroutines = get_coroutines(selected_tests)

if dry_run:
logger.info("Dry-run mode, exiting before running the tests.")
for coro in coros:
for coro in coroutines:
coro.close()
return

if AntaTest.progress is not None:
AntaTest.nrfu_task = AntaTest.progress.add_task("Running NRFU Tests...", total=len(coros))
AntaTest.nrfu_task = AntaTest.progress.add_task("Running NRFU Tests...", total=len(coroutines))

with Catchtime(logger=logger, message="Running ANTA tests"):
test_results = await asyncio.gather(*coros)
test_results = await asyncio.gather(*coroutines)
for r in test_results:
manager.add(r)

Expand Down
4 changes: 2 additions & 2 deletions tests/units/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ async def test_prepare_tests(
caplog.set_level(logging.INFO)

catalog: AntaCatalog = AntaCatalog.parse(str(DATA_DIR / "test_catalog_with_tags.yml"))
selected_tests = await prepare_tests(inventory=test_inventory, catalog=catalog, tags=tags, tests=None)
selected_tests = prepare_tests(inventory=test_inventory, catalog=catalog, tags=tags, tests=None)

if selected_tests is None:
assert expected_tests_count == 0
Expand All @@ -180,7 +180,7 @@ async def test_prepare_tests_with_specific_tests(caplog: pytest.LogCaptureFixtur
caplog.set_level(logging.INFO)

catalog: AntaCatalog = AntaCatalog.parse(str(DATA_DIR / "test_catalog_with_tags.yml"))
selected_tests = await prepare_tests(inventory=test_inventory, catalog=catalog, tags=None, tests={"VerifyMlagStatus", "VerifyUptime"})
selected_tests = prepare_tests(inventory=test_inventory, catalog=catalog, tags=None, tests={"VerifyMlagStatus", "VerifyUptime"})

assert selected_tests is not None
assert len(selected_tests) == 3
Expand Down
Loading