Skip to content

Commit

Permalink
Merge branch 'main' into fix_interface_module_failure_msg
Browse files Browse the repository at this point in the history
  • Loading branch information
geetanjalimanegslab committed Feb 27, 2025
2 parents a3cd02f + 9ba0284 commit 963e3f3
Show file tree
Hide file tree
Showing 42 changed files with 1,843 additions and 529 deletions.
33 changes: 21 additions & 12 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,16 @@ repos:
- --allow-past-years
- --fuzzy-match-generates-todo
- --comment-style
- '<!--| ~| -->'
- "<!--| ~| -->"

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.6
rev: v0.9.7
hooks:
- id: ruff
name: Run Ruff linter
args: [ --fix ]
- id: ruff-format
name: Run Ruff formatter
- id: ruff
name: Run Ruff linter
args: [--fix]
- id: ruff-format
name: Run Ruff formatter

- repo: https://github.com/pycqa/pylint
rev: "v3.3.4"
Expand All @@ -62,9 +62,9 @@ repos:
description: This hook runs pylint.
types: [python]
args:
- -rn # Only display messages
- -sn # Don't display the score
- --rcfile=pyproject.toml # Link to config file
- -rn # Only display messages
- -sn # Don't display the score
- --rcfile=pyproject.toml # Link to config file
additional_dependencies:
- anta[cli]
- types-PyYAML
Expand Down Expand Up @@ -123,5 +123,14 @@ repos:
pass_filenames: false
additional_dependencies:
- anta[cli]
# TODO: next can go once we have it added to anta properly
- numpydoc
- id: doc-snippets
name: Generate doc snippets
entry: >-
sh -c "docs/scripts/generate_doc_snippets.py"
language: python
types: [python]
files: anta/cli/
verbose: true
pass_filenames: false
additional_dependencies:
- anta[cli]
5 changes: 3 additions & 2 deletions anta/cli/nrfu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ def parse_args(self, ctx: click.Context, args: list[str]) -> list[str]:
if "--help" not in args:
raise

# remove the required params so that help can display
# Fake presence of the required params so that help can display
for param in self.params:
param.required = False
if param.required:
param.value_is_missing = lambda value: False # type: ignore[method-assign] # noqa: ARG005

return super().parse_args(ctx, args)

Expand Down
9 changes: 6 additions & 3 deletions anta/cli/nrfu/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def table(ctx: click.Context, group_by: Literal["device", "test"] | None) -> Non
help="Path to save report as a JSON file",
)
def json(ctx: click.Context, output: pathlib.Path | None) -> None:
"""ANTA command to check network state with JSON results."""
"""ANTA command to check network state with JSON results.
If no `--output` is specified, the output is printed to stdout.
"""
run_tests(ctx)
print_json(ctx, output=output)
exit_with_code(ctx)
Expand All @@ -72,11 +75,11 @@ def text(ctx: click.Context) -> None:
path_type=pathlib.Path,
),
show_envvar=True,
required=False,
required=True,
help="Path to save report as a CSV file",
)
def csv(ctx: click.Context, csv_output: pathlib.Path) -> None:
"""ANTA command to check network states with CSV result."""
"""ANTA command to check network state with CSV report."""
run_tests(ctx)
save_to_csv(ctx, csv_file=csv_output)
exit_with_code(ctx)
Expand Down
124 changes: 74 additions & 50 deletions anta/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
# https://github.com/pyca/cryptography/issues/7236#issuecomment-1131908472
CLIENT_KEYS = asyncssh.public_key.load_default_keypairs()

# Limit concurrency to 100 requests (HTTPX default) to avoid high-concurrency performance issues
# See: https://github.com/encode/httpx/issues/3215
MAX_CONCURRENT_REQUESTS = 100


class AntaCache:
"""Class to be used as cache.
Expand Down Expand Up @@ -296,6 +300,7 @@ async def copy(self, sources: list[Path], destination: Path, direction: Literal[
raise NotImplementedError(msg)


# pylint: disable=too-many-instance-attributes
class AsyncEOSDevice(AntaDevice):
"""Implementation of AntaDevice for EOS using aio-eapi.
Expand Down Expand Up @@ -388,6 +393,10 @@ def __init__( # noqa: PLR0913
host=host, port=ssh_port, username=username, password=password, client_keys=CLIENT_KEYS, **ssh_params
)

# In Python 3.9, Semaphore must be created within a running event loop
# TODO: Once we drop Python 3.9 support, initialize the semaphore here
self._command_semaphore: asyncio.Semaphore | None = None

def __rich_repr__(self) -> Iterator[tuple[str, Any]]:
"""Implement Rich Repr Protocol.
Expand Down Expand Up @@ -431,6 +440,15 @@ def _keys(self) -> tuple[Any, ...]:
"""
return (self._session.host, self._session.port)

async def _get_semaphore(self) -> asyncio.Semaphore:
"""Return the semaphore, initializing it if needed.
TODO: Remove this method once we drop Python 3.9 support.
"""
if self._command_semaphore is None:
self._command_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
return self._command_semaphore

async def _collect(self, command: AntaCommand, *, collection_id: str | None = None) -> None:
"""Collect device command output from EOS using aio-eapi.
Expand All @@ -445,57 +463,63 @@ async def _collect(self, command: AntaCommand, *, collection_id: str | None = No
collection_id
An identifier used to build the eAPI request ID.
"""
commands: list[dict[str, str | int]] = []
if self.enable and self._enable_password is not None:
commands.append(
{
"cmd": "enable",
"input": str(self._enable_password),
},
)
elif self.enable:
# No password
commands.append({"cmd": "enable"})
commands += [{"cmd": command.command, "revision": command.revision}] if command.revision else [{"cmd": command.command}]
try:
response: list[dict[str, Any] | str] = await self._session.cli(
commands=commands,
ofmt=command.ofmt,
version=command.version,
req_id=f"ANTA-{collection_id}-{id(command)}" if collection_id else f"ANTA-{id(command)}",
) # type: ignore[assignment] # multiple commands returns a list
# Do not keep response of 'enable' command
command.output = response[-1]
except asynceapi.EapiCommandError as e:
# This block catches exceptions related to EOS issuing an error.
self._log_eapi_command_error(command, e)
except TimeoutException as e:
# This block catches Timeout exceptions.
command.errors = [exc_to_str(e)]
timeouts = self._session.timeout.as_dict()
logger.error(
"%s occurred while sending a command to %s. Consider increasing the timeout.\nCurrent timeouts: Connect: %s | Read: %s | Write: %s | Pool: %s",
exc_to_str(e),
self.name,
timeouts["connect"],
timeouts["read"],
timeouts["write"],
timeouts["pool"],
)
except (ConnectError, OSError) as e:
# This block catches OSError and socket issues related exceptions.
command.errors = [exc_to_str(e)]
if (isinstance(exc := e.__cause__, httpcore.ConnectError) and isinstance(os_error := exc.__context__, OSError)) or isinstance(os_error := e, OSError): # pylint: disable=no-member
if isinstance(os_error.__cause__, OSError):
os_error = os_error.__cause__
logger.error("A local OS error occurred while connecting to %s: %s.", self.name, os_error)
else:
semaphore = await self._get_semaphore()

async with semaphore:
commands: list[dict[str, str | int]] = []
if self.enable and self._enable_password is not None:
commands.append(
{
"cmd": "enable",
"input": str(self._enable_password),
},
)
elif self.enable:
# No password
commands.append({"cmd": "enable"})
commands += [{"cmd": command.command, "revision": command.revision}] if command.revision else [{"cmd": command.command}]
try:
response: list[dict[str, Any] | str] = await self._session.cli(
commands=commands,
ofmt=command.ofmt,
version=command.version,
req_id=f"ANTA-{collection_id}-{id(command)}" if collection_id else f"ANTA-{id(command)}",
) # type: ignore[assignment] # multiple commands returns a list
# Do not keep response of 'enable' command
command.output = response[-1]
except asynceapi.EapiCommandError as e:
# This block catches exceptions related to EOS issuing an error.
self._log_eapi_command_error(command, e)
except TimeoutException as e:
# This block catches Timeout exceptions.
command.errors = [exc_to_str(e)]
timeouts = self._session.timeout.as_dict()
logger.error(
"%s occurred while sending a command to %s. Consider increasing the timeout.\nCurrent timeouts: Connect: %s | Read: %s | Write: %s | Pool: %s",
exc_to_str(e),
self.name,
timeouts["connect"],
timeouts["read"],
timeouts["write"],
timeouts["pool"],
)
except (ConnectError, OSError) as e:
# This block catches OSError and socket issues related exceptions.
command.errors = [exc_to_str(e)]
# pylint: disable=no-member
if (isinstance(exc := e.__cause__, httpcore.ConnectError) and isinstance(os_error := exc.__context__, OSError)) or isinstance(
os_error := e, OSError
):
if isinstance(os_error.__cause__, OSError):
os_error = os_error.__cause__
logger.error("A local OS error occurred while connecting to %s: %s.", self.name, os_error)
else:
anta_log_exception(e, f"An error occurred while issuing an eAPI request to {self.name}", logger)
except HTTPError as e:
# This block catches most of the httpx Exceptions and logs a general message.
command.errors = [exc_to_str(e)]
anta_log_exception(e, f"An error occurred while issuing an eAPI request to {self.name}", logger)
except HTTPError as e:
# This block catches most of the httpx Exceptions and logs a general message.
command.errors = [exc_to_str(e)]
anta_log_exception(e, f"An error occurred while issuing an eAPI request to {self.name}", logger)
logger.debug("%s: %s", self.name, command)
logger.debug("%s: %s", self.name, command)

def _log_eapi_command_error(self, command: AntaCommand, e: asynceapi.EapiCommandError) -> None:
"""Appropriately log the eapi command error."""
Expand Down
10 changes: 6 additions & 4 deletions anta/input_models/connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ class Host(BaseModel):
source: IPv4Address | IPv6Address | Interface
"""Source address IP or egress interface to use."""
vrf: str = "default"
"""VRF context. Defaults to `default`."""
"""VRF context."""
repeat: int = 2
"""Number of ping repetition. Defaults to 2."""
"""Number of ping repetition."""
size: int = 100
"""Specify datagram size. Defaults to 100."""
"""Specify datagram size."""
df_bit: bool = False
"""Enable do not fragment bit in IP header. Defaults to False."""
"""Enable do not fragment bit in IP header."""
reachable: bool = True
"""Indicates whether the destination should be reachable."""

def __str__(self) -> str:
"""Return a human-readable string representation of the Host for reporting.
Expand Down
31 changes: 29 additions & 2 deletions anta/input_models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,20 @@

from __future__ import annotations

from typing import Literal
from ipaddress import IPv4Interface
from typing import Any, Literal
from warnings import warn

from pydantic import BaseModel, ConfigDict

from anta.custom_types import Interface, PortChannelInterface


class InterfaceState(BaseModel):
"""Model for an interface state."""
"""Model for an interface state.
TODO: Need to review this class name in ANTA v2.0.0.
"""

model_config = ConfigDict(extra="forbid")
name: Interface
Expand All @@ -33,6 +38,10 @@ class InterfaceState(BaseModel):
Can be enabled in the `VerifyLACPInterfacesStatus` tests.
"""
primary_ip: IPv4Interface | None = None
"""Primary IPv4 address in CIDR notation. Required field in the `VerifyInterfaceIPv4` test."""
secondary_ips: list[IPv4Interface] | None = None
"""List of secondary IPv4 addresses in CIDR notation. Can be provided in the `VerifyInterfaceIPv4` test."""

def __str__(self) -> str:
"""Return a human-readable string representation of the InterfaceState for reporting.
Expand All @@ -46,3 +55,21 @@ def __str__(self) -> str:
if self.portchannel is not None:
base_string += f" Port-Channel: {self.portchannel}"
return base_string


class InterfaceDetail(InterfaceState): # pragma: no cover
"""Alias for the InterfaceState model to maintain backward compatibility.
When initialized, it will emit a deprecation warning and call the InterfaceState model.
TODO: Remove this class in ANTA v2.0.0.
"""

def __init__(self, **data: Any) -> None: # noqa: ANN401
"""Initialize the InterfaceState class, emitting a depreciation warning."""
warn(
message="InterfaceDetail model is deprecated and will be removed in ANTA v2.0.0. Use the InterfaceState model instead.",
category=DeprecationWarning,
stacklevel=2,
)
super().__init__(**data)
2 changes: 1 addition & 1 deletion anta/input_models/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ class NTPServer(BaseModel):

def __str__(self) -> str:
"""Representation of the NTPServer model."""
return f"{self.server_address} (Preferred: {self.preferred}, Stratum: {self.stratum})"
return f"NTP Server: {self.server_address} Preferred: {self.preferred} Stratum: {self.stratum}"
9 changes: 8 additions & 1 deletion anta/tests/connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class VerifyReachability(AntaTest):
vrf: MGMT
df_bit: True
size: 100
reachable: true
- source: Management0
destination: 8.8.8.8
vrf: MGMT
Expand All @@ -47,6 +48,7 @@ class VerifyReachability(AntaTest):
vrf: default
df_bit: True
size: 100
reachable: false
```
"""

Expand Down Expand Up @@ -89,9 +91,14 @@ def test(self) -> None:
self.result.is_success()

for command, host in zip(self.instance_commands, self.inputs.hosts):
if f"{host.repeat} received" not in command.json_output["messages"][0]:
# Verifies the network is reachable
if host.reachable and f"{host.repeat} received" not in command.json_output["messages"][0]:
self.result.is_failure(f"{host} - Unreachable")

# Verifies the network is unreachable.
if not host.reachable and f"{host.repeat} received" in command.json_output["messages"][0]:
self.result.is_failure(f"{host} - Destination is expected to be unreachable but found reachable.")


class VerifyLLDPNeighbors(AntaTest):
"""Verifies the connection status of the specified LLDP (Link Layer Discovery Protocol) neighbors.
Expand Down
Loading

0 comments on commit 963e3f3

Please sign in to comment.