diff --git a/secrets_env/collect.py b/secrets_env/collect.py index fad4214b..da033787 100644 --- a/secrets_env/collect.py +++ b/secrets_env/collect.py @@ -5,7 +5,7 @@ if typing.TYPE_CHECKING: from secrets_env.config0.parser import Config - from secrets_env.provider import ProviderBase, RequestSpec + from secrets_env.provider import Provider, RequestSpec logger = logging.getLogger(__name__) @@ -37,17 +37,17 @@ def read_values(config: Config) -> dict[str, str]: return output_values -def read1(provider: ProviderBase, name: str, spec: RequestSpec) -> str | None: +def read1(provider: Provider, name: str, spec: RequestSpec) -> str | None: """Read single value. This function wraps :py:meth:`secrets_env.provider.ProviderBase.get` and captures all exceptions. """ import secrets_env.exceptions - from secrets_env.provider import ProviderBase + from secrets_env.provider import Provider # type checking - if not isinstance(provider, ProviderBase): + if not isinstance(provider, Provider): raise TypeError( f'Expected "provider" to be a credential provider class, ' f"got {type(provider).__name__}" diff --git a/secrets_env/config0/parser.py b/secrets_env/config0/parser.py index 774bc12e..92fdf5c0 100644 --- a/secrets_env/config0/parser.py +++ b/secrets_env/config0/parser.py @@ -1,16 +1,17 @@ +from __future__ import annotations + import itertools import logging import re import typing -from typing import Any, Dict, Iterator, List, Optional, Tuple, TypedDict +from typing import Any, Iterator, TypedDict import secrets_env.exceptions import secrets_env.providers -from secrets_env.provider import RequestSpec from secrets_env.utils import ensure_dict, ensure_str if typing.TYPE_CHECKING: - from secrets_env.provider import ProviderBase + from secrets_env.provider import Provider, RequestSpec DEFAULT_PROVIDER_NAME = "main" @@ -29,11 +30,11 @@ class Request(TypedDict): class Config(TypedDict): """The parsed configurations.""" - providers: Dict[str, "ProviderBase"] - requests: List[Request] + providers: dict[str, Provider] + requests: list[Request] -def parse_config(data: dict) -> Optional[Config]: +def parse_config(data: dict) -> Config | None: """Parse and validate configs, build it into structured object.""" requests = get_requests(data) if not requests: @@ -48,11 +49,11 @@ def parse_config(data: dict) -> Optional[Config]: return Config(providers=providers, requests=requests) -def get_providers(data: dict) -> Dict[str, "ProviderBase"]: +def get_providers(data: dict) -> dict[str, Provider]: sections = list(extract_sources(data)) logger.debug("%d raw provider configs extracted", len(sections)) - providers: Dict[str, "ProviderBase"] = {} + providers: dict[str, Provider] = {} for data in sections: result = parse_source_item(data) if not result: @@ -75,7 +76,7 @@ def get_providers(data: dict) -> Dict[str, "ProviderBase"]: return providers -def extract_sources(data: dict) -> Iterator[Dict[str, Any]]: +def extract_sources(data: dict) -> Iterator[dict[str, Any]]: """Extracts both "source(s)" section and ensure the output is list of dict""" for item in itertools.chain( get_list(data, "source"), @@ -97,7 +98,7 @@ def get_list(data: dict, key: str) -> Iterator[dict]: logger.warning("Found invalid value in field %s", key) -def parse_source_item(config: dict) -> Optional[Tuple[str, "ProviderBase"]]: +def parse_source_item(config: dict) -> tuple[str, Provider] | None: # check name name = config.get("name") or DEFAULT_PROVIDER_NAME name, ok = ensure_str("source.name", name) @@ -117,7 +118,7 @@ def parse_source_item(config: dict) -> Optional[Tuple[str, "ProviderBase"]]: return name, provider -def get_requests(data: dict) -> List[Request]: +def get_requests(data: dict) -> list[Request]: # accept both keyword `secret(s)` raw = {} diff --git a/secrets_env/provider.py b/secrets_env/provider.py index 4f66b02e..fadee477 100644 --- a/secrets_env/provider.py +++ b/secrets_env/provider.py @@ -6,53 +6,46 @@ from __future__ import annotations import abc -import sys -import typing -from typing import Dict, Union +from typing import ClassVar, Union -if typing.TYPE_CHECKING and sys.version_info >= (3, 10): - from typing import TypeAlias +from pydantic import BaseModel +RequestSpec = Union[dict[str, str], str] +""":py:class:`RequestSpec` represents a path spec to read the value. -RequestSpec: TypeAlias = Union[Dict[str, str], str] -""":py:class:`RequestSpec` represents a secret spec (name/path) to be loaded. +It should be a :py:class:`dict` in most cases; or :py:class:`str` if this +provider accepts shortcut. """ -class ProviderBase(abc.ABC): - """Abstract base class for secret provider. All secret provider must implement +class Provider(BaseModel, abc.ABC): + """Abstract base class for secret provider. All provider must implement this interface. """ - @property - @abc.abstractmethod - def type(self) -> str: - """Provider name.""" + type: ClassVar[str] @abc.abstractmethod def get(self, spec: RequestSpec) -> str: - """Get secret value. + """Get secret. Parameters ---------- - spec : dict | str - Raw input from config file. - - It should be :py:class:`dict` in most cases; or :py:class:`str` if - this provider accepts shortcut. + path : dict | str + Raw input from config file for reading the secret value. Return ------ - The secret value. + The value Raises ------ - ConfigError - The path dict is malformed. - ValueNotFound - The path dict is correct but the secret not exists. - - Note - ---- - Key ``source`` is preserved in ``spec`` dictionary. + ValidationError + If the input format is invalid. + UnsupportedError + When this operation is not supported. + AuthenticationError + Failed during authentication. + LookupError + If the secret is not found. """ diff --git a/secrets_env/providers/__init__.py b/secrets_env/providers/__init__.py index 112af6f3..709b5641 100644 --- a/secrets_env/providers/__init__.py +++ b/secrets_env/providers/__init__.py @@ -1,33 +1,50 @@ -import typing +from __future__ import annotations -import secrets_env.exceptions +import logging +import typing if typing.TYPE_CHECKING: - from secrets_env.provider import ProviderBase + from secrets_env.provider import Provider DEFAULT_PROVIDER = "vault" +logger = logging.getLogger(__name__) + + +def get_provider(config: dict) -> Provider: + """ + Returns a provider instance based on the configuration. + + Raises + ------ + ValueError + If the provider type is not recognized. + ValidationError + If the provider configuration is invalid. + """ + type_ = config.get("type") + if not type_: + type_ = DEFAULT_PROVIDER + logger.warning("Provider type unspecified, using default: %s", type_) -def get_provider(data: dict) -> "ProviderBase": - type_ = data.get("type", DEFAULT_PROVIDER) - type_lower = type_.lower() + itype = type_.lower() # fmt: off - if type_lower == "null": - from . import null - return null.get_provider(type_, data) - if type_lower == "plain": - from . import plain - return plain.get_provider(type_, data) - if type_lower == "teleport": - from . import teleport - return teleport.get_provider(type_, data) - if type_lower == "vault": - from . import vault - return vault.get_provider(type_, data) - if type_lower.startswith("teleport+"): - from . import teleport - return teleport.get_adapted_provider(type_, data) + if itype == "null": + from secrets_env.providers.null import NullProvider + return NullProvider.model_validate(config) + if itype == "plain": + from secrets_env.providers.plain import PlainTextProvider + return PlainTextProvider.model_validate(config) + if itype == "teleport": + from secrets_env.providers.teleport import TeleportProvider + return TeleportProvider.model_validate(config) + if itype == "teleport+vault": + logger.error('"teleport+vault provider is not yet implemented') + raise NotImplementedError + if itype == "vault": + from secrets_env.providers.vault import VaultKvProvider + return VaultKvProvider.model_validate(config) # fmt: on - raise secrets_env.exceptions.ConfigError("Unknown provider type {}", type_) + raise ValueError(f"Unknown provider type {type_}") diff --git a/secrets_env/providers/null.py b/secrets_env/providers/null.py index 2c1dc342..5d439550 100644 --- a/secrets_env/providers/null.py +++ b/secrets_env/providers/null.py @@ -1,22 +1,18 @@ +from __future__ import annotations + import typing -from secrets_env.provider import ProviderBase +from secrets_env.provider import Provider if typing.TYPE_CHECKING: from secrets_env.provider import RequestSpec -class NullProvider(ProviderBase): +class NullProvider(Provider): """A provider that always returns empty string. This provider is preserved for debugging.""" - @property - def type(self) -> str: - return "null" + type = "null" - def get(self, spec: "RequestSpec") -> str: + def get(self, spec: RequestSpec) -> str: return "" - - -def get_provider(type_: str, data: dict) -> NullProvider: - return NullProvider() diff --git a/secrets_env/providers/plain.py b/secrets_env/providers/plain.py index 601dcbb2..50b8819b 100644 --- a/secrets_env/providers/plain.py +++ b/secrets_env/providers/plain.py @@ -1,30 +1,21 @@ +from __future__ import annotations + import typing -from secrets_env.provider import ProviderBase +from secrets_env.provider import Provider if typing.TYPE_CHECKING: from secrets_env.provider import RequestSpec -class PlainTextProvider(ProviderBase): +class PlainTextProvider(Provider): """Plain text provider returns text that is copied directly from the configuration file.""" - @property - def type(self) -> str: - return "plain" + type = "plain" - def get(self, spec: "RequestSpec") -> str: + def get(self, spec: RequestSpec) -> str: if isinstance(spec, str): - value = spec + return spec elif isinstance(spec, dict): - value = spec.get("value") - else: - raise TypeError( - f'Expected "spec" to be a string or dict, got {type(spec).__name__}' - ) - return value or "" - - -def get_provider(type_: str, data: dict) -> PlainTextProvider: - return PlainTextProvider() + return spec.get("value") or "" diff --git a/secrets_env/providers/teleport/__init__.py b/secrets_env/providers/teleport/__init__.py index 614a8777..ead2ae53 100644 --- a/secrets_env/providers/teleport/__init__.py +++ b/secrets_env/providers/teleport/__init__.py @@ -1,45 +1,70 @@ from __future__ import annotations +import logging import typing +from typing import Literal -import pydantic +from pydantic import BaseModel, model_validator -from secrets_env.exceptions import ConfigError +from secrets_env.provider import Provider +from secrets_env.providers.teleport.config import TeleportUserConfig if typing.TYPE_CHECKING: - from secrets_env.provider import ProviderBase - from secrets_env.providers.teleport.provider import TeleportProvider - -ADAPTER_PREFIX = "teleport+" - - -def get_provider(type_: str, data: dict) -> TeleportProvider: - from .config import TeleportUserConfig - from .provider import TeleportProvider - - cfg = TeleportUserConfig.model_validate(data) - return TeleportProvider(config=cfg) - - -def get_adapted_provider(type_: str, data: dict) -> ProviderBase: - from .adapters import get_adapter - from .config import TeleportUserConfig # noqa: TCH001 - - class TeleportAdapterConfig(pydantic.BaseModel): - """Config layout for using Teleport as an adapter.""" - - teleport: TeleportUserConfig - - iname = type_.lower() - if not iname.startswith(ADAPTER_PREFIX): - raise ConfigError("Not a Teleport compatible provider: {}", type_) - - subtype = type_[len(ADAPTER_PREFIX) :] - factory = get_adapter(subtype) - - # get connection parameter - app_param = TeleportAdapterConfig.model_validate(data) - conn_param = app_param.teleport.get_connection_param() - - # forward parameters to corresponding provider - return factory(subtype, data, conn_param) + from typing import Self + + from secrets_env.provider import RequestSpec + from secrets_env.providers.teleport.config import TeleportConnectionParameter + +logger = logging.getLogger(__name__) + + +class TeleportRequestSpec(BaseModel): + field: Literal["uri", "ca", "cert", "key", "cert+key"] + format: Literal["path", "pem"] = "path" + + @model_validator(mode="before") + @classmethod + def _accept_shortcut(cls, data: RequestSpec | Self) -> dict[str, str] | Self: + if isinstance(data, str): + return {"field": data} + return data + + +class TeleportProvider(Provider, TeleportUserConfig): + """Read certificates from Teleport.""" + + type = "teleport" + + def get(self, spec: RequestSpec) -> str: + ps = TeleportRequestSpec.model_validate(spec) + + if ps.field == "uri": + return self.connection_param.uri + elif ps.field == "ca": + return get_ca(self.connection_param, ps.format) + elif ps.field == "cert": + if ps.format == "path": + return str(self.connection_param.path_cert) + elif ps.format == "pem": + return self.connection_param.cert.decode() + elif ps.field == "key": + if ps.format == "path": + return str(self.connection_param.path_key) + elif ps.format == "pem": + return self.connection_param.key.decode() + elif ps.field == "cert+key": + if ps.format == "path": + return str(self.connection_param.path_cert_and_key) + elif ps.format == "pem": + return self.connection_param.cert_and_key.decode() + + raise RuntimeError + + +def get_ca(param: TeleportConnectionParameter, fmt: Literal["path", "pem"]) -> str: + if param.ca is None: + raise LookupError("CA is not available") + if fmt == "path": + return str(param.path_ca) + elif fmt == "pem": + return param.ca.decode() diff --git a/secrets_env/providers/teleport/adapters.py b/secrets_env/providers/teleport/adapters.py deleted file mode 100644 index 01ca070b..00000000 --- a/secrets_env/providers/teleport/adapters.py +++ /dev/null @@ -1,54 +0,0 @@ -from __future__ import annotations - -import logging -import typing - -from secrets_env.exceptions import ConfigError - -if typing.TYPE_CHECKING: - from secrets_env.provider import ProviderBase - from secrets_env.providers.teleport.config import TeleportConnectionParameter - - AdapterType = typing.Callable[ - [str, dict, TeleportConnectionParameter], ProviderBase - ] - -logger = logging.getLogger(__name__) - - -def get_adapter(name: str) -> AdapterType: - iname = name.lower() - if iname == "vault": - return adapt_vault_provider - - raise ConfigError("Unknown provider type {}", name) - - -def adapt_vault_provider( - type_: str, data: dict, param: TeleportConnectionParameter -) -> ProviderBase: - assert isinstance(data, dict) - from secrets_env.providers import vault - - # url - if (url := data.get("url")) and url != param.uri: - logger.warning("Overwrite source.url to %s", param.uri) - - data["url"] = param.uri - logger.debug("Set Vault URL to %s", param.uri) - - # ca - tls: dict = data.setdefault("tls", {}) - if param.path_ca: - tls["ca_cert"] = param.path_ca - logger.debug("Set Vault CA to %s", param.path_ca) - - # cert - tls["client_cert"] = param.path_cert - logger.debug("Set Vault client cert to %s", param.path_cert) - - # key - tls["client_key"] = param.path_key - logger.debug("Set Vault client key to %s", param.path_key) - - return vault.get_provider(type_, data) diff --git a/secrets_env/providers/teleport/config.py b/secrets_env/providers/teleport/config.py index c86c1a03..32892545 100644 --- a/secrets_env/providers/teleport/config.py +++ b/secrets_env/providers/teleport/config.py @@ -42,7 +42,8 @@ def _use_shortcut(cls, data): return {"app": data} return data - def get_connection_param(self) -> TeleportConnectionParameter: + @cached_property + def connection_param(self) -> TeleportConnectionParameter: """Get app connection parameter from Teleport CLI. Raises diff --git a/secrets_env/providers/teleport/provider.py b/secrets_env/providers/teleport/provider.py deleted file mode 100644 index 5e7c1f8d..00000000 --- a/secrets_env/providers/teleport/provider.py +++ /dev/null @@ -1,115 +0,0 @@ -from __future__ import annotations - -import logging -import typing -from functools import cached_property -from typing import Literal - -from secrets_env.exceptions import ConfigError, ValueNotFound -from secrets_env.provider import ProviderBase - -if typing.TYPE_CHECKING: - from secrets_env.provider import RequestSpec - from secrets_env.providers.teleport.config import ( - TeleportConnectionParameter, - TeleportUserConfig, - ) - -DEFAULT_OUTPUT_FORMAT = "path" - - -class OutputSpec(typing.NamedTuple): - field: Literal["uri", "ca", "cert", "key", "cert+key"] - format: Literal["path", "pem"] - - -logger = logging.getLogger(__name__) - - -class TeleportProvider(ProviderBase): - """Read certificates from Teleport.""" - - @property - def type(self) -> str: - return "teleport" - - def __init__(self, *, config: TeleportUserConfig) -> None: - self._config = config - - @cached_property - def tsh(self) -> TeleportConnectionParameter: - """Return teleport app connection information.""" - return self._config.get_connection_param() - - def get(self, spec: RequestSpec) -> str: - parsed = parse_spec(spec) - - if parsed.field == "uri": - return self.tsh.uri - - elif parsed.field == "ca": - # bypass cognitive complexity check - return get_ca(self.tsh, parsed.format) - - elif parsed.field == "cert": - if parsed.format == "path": - return str(self.tsh.path_cert) - elif parsed.format == "pem": - return self.tsh.cert.decode() - - elif parsed.field == "key": - if parsed.format == "path": - return str(self.tsh.path_key) - elif parsed.format == "pem": - return self.tsh.key.decode() - - elif parsed.field == "cert+key": - if parsed.format == "path": - return str(self.tsh.path_cert_and_key) - elif parsed.format == "pem": - return self.tsh.cert_and_key.decode() - - raise ConfigError("Invalid value spec: {}", spec) - - -def parse_spec(spec: RequestSpec) -> OutputSpec: - # extract - if isinstance(spec, str): - output_field = spec - output_format = DEFAULT_OUTPUT_FORMAT - elif isinstance(spec, dict): - output_field = spec.get("field") - output_format = spec.get("format", DEFAULT_OUTPUT_FORMAT) - else: - raise ConfigError( - "Expect dict for secrets path spec, got {}", type(spec).__name__ - ) - - # validate - if ( - False - or not isinstance(output_field, str) - or output_field.lower() not in ("uri", "ca", "cert", "key", "cert+key") - ): - raise ConfigError("Invalid field (secrets.VAR.field): {}", output_field) - - if ( - False - or not isinstance(output_format, str) - or output_format.lower() not in ("path", "pem") - ): - raise ConfigError("Invalid format (secrets.VAR.format): {}", output_format) - - return OutputSpec(output_field.lower(), output_format.lower()) # type: ignore[reportGeneralTypeIssues] - - -def get_ca( - conn_info: TeleportConnectionParameter, format_: Literal["path", "pem"] -) -> str: - if not conn_info.ca: - raise ValueNotFound("CA is not avaliable") - if format_ == "path": - return str(conn_info.path_ca) - elif format_ == "pem": - return conn_info.ca.decode() - raise RuntimeError diff --git a/secrets_env/providers/vault/__init__.py b/secrets_env/providers/vault/__init__.py index d78cb686..190df8f9 100644 --- a/secrets_env/providers/vault/__init__.py +++ b/secrets_env/providers/vault/__init__.py @@ -1,10 +1,3 @@ -from secrets_env.exceptions import ConfigError +__all__ = ["VaultKvProvider"] -from .config import get_connection_info -from .provider import KvProvider - - -def get_provider(type_: str, data: dict) -> KvProvider: - if not (cfg := get_connection_info(data)): - raise ConfigError("Invalid config for vault provider") - return KvProvider(**cfg) +from secrets_env.providers.vault.provider import VaultKvProvider diff --git a/secrets_env/providers/vault/auth/__init__.py b/secrets_env/providers/vault/auth/__init__.py index 4d4173a9..97dbad03 100644 --- a/secrets_env/providers/vault/auth/__init__.py +++ b/secrets_env/providers/vault/auth/__init__.py @@ -1,13 +1,15 @@ from __future__ import annotations +__all__ = ["Auth", "create_auth_by_name"] + import logging import typing +from secrets_env.providers.vault.auth.base import Auth + if typing.TYPE_CHECKING: from pydantic_core import Url - from secrets_env.providers.vault.auth.base import Auth - logger = logging.getLogger(__name__) diff --git a/secrets_env/providers/vault/auth/base.py b/secrets_env/providers/vault/auth/base.py index f8838cf2..fe55ff6b 100644 --- a/secrets_env/providers/vault/auth/base.py +++ b/secrets_env/providers/vault/auth/base.py @@ -31,8 +31,15 @@ def create(cls, url: Url, config: dict[str, Any]) -> Self: """ @abstractmethod - def login(self, client: httpx.Client) -> str | None: - """Login and get token.""" + def login(self, client: httpx.Client) -> str: + """ + Login and get Vault token. + + Raises + ------ + AuthenticationError + If the login fails. + """ class NoAuth(Auth): diff --git a/secrets_env/providers/vault/auth/userpass.py b/secrets_env/providers/vault/auth/userpass.py index 40205580..1086846f 100644 --- a/secrets_env/providers/vault/auth/userpass.py +++ b/secrets_env/providers/vault/auth/userpass.py @@ -6,6 +6,7 @@ from pydantic import PrivateAttr, SecretStr +from secrets_env.exceptions import AuthenticationError from secrets_env.providers.vault.auth.base import Auth from secrets_env.utils import ( create_keyring_login_key, @@ -76,7 +77,7 @@ def _get_password(cls, url: Url, username: str) -> str | None: return prompt(f"Password for {username}", hide_input=True) - def login(self, client: httpx.Client) -> str | None: + def login(self, client: httpx.Client) -> str: username = urllib.parse.quote(self.username) resp = client.post( f"/v1/auth/{self.vault_name}/login/{username}", @@ -88,14 +89,13 @@ def login(self, client: httpx.Client) -> str | None: ) if not resp.is_success: - logger.error("Failed to login with %s method", self.method) logger.debug( "Login failed. URL= %s, Code= %d. Msg= %s", resp.url, resp.status_code, resp.text, ) - return + raise AuthenticationError("Failed to login with %s method", self.method) return resp.json()["auth"]["client_token"] diff --git a/secrets_env/providers/vault/provider.py b/secrets_env/providers/vault/provider.py index 6e56278c..bb730057 100644 --- a/secrets_env/providers/vault/provider.py +++ b/secrets_env/providers/vault/provider.py @@ -2,70 +2,121 @@ import enum import logging -import re import typing from functools import cached_property from http import HTTPStatus -from typing import Dict, Literal, Union +from typing import Literal import httpx -from pydantic import BaseModel, Field, model_validator +from pydantic import ( + BaseModel, + Field, + InstanceOf, + PrivateAttr, + field_validator, + model_validator, + validate_call, +) import secrets_env.version -from secrets_env.exceptions import AuthenticationError, ValueNotFound -from secrets_env.provider import ProviderBase, RequestSpec +from secrets_env.exceptions import AuthenticationError +from secrets_env.provider import Provider +from secrets_env.providers.vault.config import VaultUserConfig from secrets_env.utils import LruDict, get_httpx_error_reason, log_httpx_response if typing.TYPE_CHECKING: - from pathlib import Path - from typing import Any, Self - - from secrets_env.providers.vault.auth.base import Auth - from secrets_env.providers.vault.config import CertTypes + from typing import Iterable, Iterator, Self, Sequence + from secrets_env.provider import RequestSpec + from secrets_env.providers.vault.auth import Auth logger = logging.getLogger(__name__) class Marker(enum.Enum): - NoMatch = enum.auto() - SecretNotExist = enum.auto() + """Internal marker for cache handling.""" + NoCache = enum.auto() + NotFound = enum.auto() -class SecretSource(typing.NamedTuple): - path: str - field: str +class VaultPath(BaseModel): + """Represents a path to a value in Vault.""" -if typing.TYPE_CHECKING: - KVVersion = Literal[1, 2] - VaultSecret = Dict[str, str] - VaultSecretQueryResult = Union[VaultSecret, Literal[Marker.SecretNotExist]] - - -class KvProvider(ProviderBase): - """Read secrets from Vault KV engine.""" - - def __init__( - self, - url: str, - auth: Auth, - *, - proxy: str | None = None, - ca_cert: Path | None = None, - client_cert: CertTypes | None = None, - ) -> None: - self.url = url - self.auth = auth - self.proxy = proxy - self.ca_cert = ca_cert - self.client_cert = client_cert - - self._secrets: LruDict[str, VaultSecretQueryResult] = LruDict() + path: str = Field(min_length=1) + field: tuple[str, ...] + + def __str__(self) -> str: + return f"{self.path}#{self.field_str}" @property - def type(self) -> str: - return "vault" + def field_str(self) -> str: + seq = [] + for f in self.field: + if "." in f: + seq.append(f'"{f}"') + else: + seq.append(f) + return ".".join(seq) + + @model_validator(mode="before") + @classmethod + def _create_from_str(cls, value: str | dict | Self) -> dict | Self: + if not isinstance(value, str): + return value + if value.count("#") != 1: + raise ValueError("Invalid format. Expected 'path#field'") + path, field = value.rsplit("#", 1) + return { + "path": path, + "field": field, + } + + @field_validator("field", mode="before") + @classmethod + def _accept_str_for_field(cls, value) -> Iterable[str]: + if isinstance(value, str): + return _split_field_str(value) + return value + + @field_validator("field", mode="after") + @classmethod + def _validate_field(cls, field: Sequence[str]) -> Sequence[str]: + if not field: + raise ValueError("Field cannot be empty") + if any(not f for f in field): + raise ValueError("Field cannot contain empty subpath") + return field + + +def _split_field_str(f: str) -> Iterator[str]: + """Split a field name into subsequences. By default, this function splits + the name by dots, with supportting of preserving the quoted subpaths. + """ + pos = 0 + while pos < len(f): + if f[pos] == '"': + # quoted + end = f.find('"', pos + 1) + if end == -1: + raise ValueError(f"Failed to parse field: {f}") + yield f[pos + 1 : end] + pos = end + 2 + else: + # simple + end = f.find(".", pos) + if end == -1: + end = len(f) + yield f[pos:end] + pos = end + 1 + + +class VaultKvProvider(Provider, VaultUserConfig): + """Read secrets from Hashicorp Vault KV engine.""" + + type = "vault" + + _cache: dict[str, dict | Marker] = PrivateAttr(default_factory=LruDict) @cached_property def client(self) -> httpx.Client: @@ -76,102 +127,101 @@ def client(self) -> httpx.Client: self.auth.method, ) - # initialize client - client_params: dict[str, Any] = {"base_url": self.url} - - if self.proxy: - logger.debug("Use proxy: %s", self.proxy) - client_params["proxies"] = self.proxy - if self.ca_cert: - logger.debug("CA installed: %s", self.ca_cert) - client_params["verify"] = self.ca_cert - if self.client_cert: - logger.debug("Client side certificate file installed: %s", self.client_cert) - client_params["cert"] = self.client_cert - - client = httpx.Client( - **client_params, - headers={ - "Accept": "application/json", - "User-Agent": ( - f"secrets.env/{secrets_env.version.__version__} " - f"python-httpx/{httpx.__version__}" - ), - }, - ) - - # install token + client = create_http_client(self) client.headers["X-Vault-Token"] = get_token(client, self.auth) return client def get(self, spec: RequestSpec) -> str: path = VaultPath.model_validate(spec) - return self.read_field(path.path, path.field) - - def read_secret(self, path: str) -> VaultSecret: - """Read secret from Vault. + secret = self._read_secret(path) + + for f in path.field: + try: + secret = secret[f] + except (KeyError, TypeError): + raise LookupError( + f'Field "{path.field_str}" not found in "{path.path}"' + ) from None + + if not isinstance(secret, str): + raise LookupError( + f'Field "{path.field_str}" in "{path.path}" is not point to a string value' + ) - Parameters - ---------- - path : str - Secret path + return secret - Returns - ------- - secret : dict - Secret data. Or 'SecretNotExist' marker when not found. + def _read_secret(self, path: VaultPath) -> dict: """ - if not isinstance(path, str): - raise TypeError( - f'Expected "path" to be a string, got {type(path).__name__}' - ) - - # try cache - result = self._secrets.get(path, Marker.NoMatch) + Get a secret from the Vault. A Vault "secret" is a object that contains + key-value pairs. - if result == Marker.NoMatch: - # not found in cache - start query - if secret := read_secret(self.client, path): - result = secret - else: - result = Marker.SecretNotExist - self._secrets[path] = result + This method wraps the `read_secret` method and cache the result. - # returns value - if result == Marker.SecretNotExist: - raise ValueNotFound("Secret {} not found", path) - return result + Raises + ------ + LookupError + If the secret is not found. + """ + result = self._cache.get(path.path, Marker.NoCache) - def read_field(self, path: str, field: str) -> str: - """Read only one field from Vault. + if result == Marker.NoCache: + result = read_secret(self.client, path.path) + if result is None: + result = Marker.NotFound + self._cache[path.path] = result - Parameters - ---------- - path : str - Secret path - field : str - Field name + if result == Marker.NotFound: + raise LookupError(f'Secret "{path}" not found') - Returns - ------- - value : str - The secret value if matched - """ - if not isinstance(field, str): - raise TypeError( - f'Expected "field" to be a string, got {type(field).__name__}' - ) + return result - secret = self.read_secret(path) - value = get_field(secret, field) - if value is None: - raise ValueNotFound("Secret {}#{} not found", path, field) - return value +@validate_call +def create_http_client(config: VaultUserConfig) -> httpx.Client: + logger.debug( + "Vault client initialization requested. URL= %s, Auth type= %s", + config.url, + config.auth.method, + ) + client_params = { + "base_url": str(config.url), + "headers": { + "Accept": "application/json", + "User-Agent": ( + f"secrets.env/{secrets_env.version.__version__} " + f"python-httpx/{httpx.__version__}" + ), + }, + } + + if config.proxy: + logger.debug("Proxy is set: %s", config.proxy) + client_params["proxy"] = str(config.proxy) + if config.tls.ca_cert: + logger.debug("CA cert is set: %s", config.tls.ca_cert) + client_params["verify"] = config.tls.ca_cert + if config.tls.client_cert and config.tls.client_key: + cert_pair = (config.tls.client_cert, config.tls.client_key) + logger.debug("Client cert pair is set: %s ", cert_pair) + client_params["cert"] = cert_pair + elif config.tls.client_cert: + logger.debug("Client cert is set: %s", config.tls.client_cert) + client_params["cert"] = config.tls.client_cert + + return httpx.Client(**client_params) + + +def get_token(client: InstanceOf[httpx.Client], auth: Auth) -> str: + """ + Request a token from the Vault server and verify it. -def get_token(client: httpx.Client, auth: Auth) -> str: + Raises + ------ + AuthenticationError + If the token cannot be retrieved or is invalid. + """ # login try: token = auth.login(client) @@ -180,9 +230,6 @@ def get_token(client: httpx.Client, auth: Auth) -> str: raise raise AuthenticationError("Encounter {} while retrieving token", reason) from e - if not token: - raise AuthenticationError("Absence of token information") - # verify if not is_authenticated(client, token): raise AuthenticationError("Invalid token") @@ -190,20 +237,14 @@ def get_token(client: httpx.Client, auth: Auth) -> str: return token -def is_authenticated(client: httpx.Client, token: str) -> bool: +@validate_call +def is_authenticated(client: InstanceOf[httpx.Client], token: str) -> bool: """Check is a token is authenticated. See also -------- https://developer.hashicorp.com/vault/api-docs/auth/token """ - if not isinstance(client, httpx.Client): - raise TypeError( - f'Expected "client" to be a httpx client, got {type(client).__name__}' - ) - if not isinstance(token, str): - raise TypeError(f'Expected "token" to be a string, got {type(token).__name__}') - logger.debug("Validate token for %s", client.base_url) resp = client.get("/v1/auth/token/lookup-self", headers={"X-Vault-Token": token}) @@ -211,180 +252,119 @@ def is_authenticated(client: httpx.Client, token: str) -> bool: return True logger.debug( - "Token verification failed. Code= %d. Msg= %s", + "Token verification failed. Code= %d (%s). Msg= %s", resp.status_code, + resp.reason_phrase, resp.json(), ) return False -def get_mount_point( - client: httpx.Client, path: str -) -> tuple[str | None, KVVersion | None]: - """Get mount point and KV engine version to a secret. - - Returns - ------- - mount_point : str - The path the secret engine mounted on. - version : int - The secret engine version - - See also - -------- - Vault HTTP API - https://developer.hashicorp.com/vault/api-docs/system/internal-ui-mounts - consul-template - https://github.com/hashicorp/consul-template/blob/v0.29.1/dependency/vault_common.go#L294-L357 - """ - if not isinstance(client, httpx.Client): - raise TypeError( - f'Expected "client" to be a httpx client, got {type(client).__name__}' - ) - if not isinstance(path, str): - raise TypeError(f'Expected "path" to be a string, got {type(path).__name__}') - - try: - resp = client.get(f"/v1/sys/internal/ui/mounts/{path}") - except httpx.HTTPError as e: - if not (reason := get_httpx_error_reason(e)): - raise - logger.error("Error occurred during checking metadata for %s: %s", path, reason) - return None, None - - if resp.is_success: - data = resp.json().get("data", {}) - - mount_point = data.get("path") - version = data.get("options", {}).get("version") - - if version == "2" and data.get("type") == "kv": - return mount_point, 2 - elif version == "1": - return mount_point, 1 - - logging.error("Unknown version %s for path %s", version, path) - logging.debug("Raw response: %s", resp) - return None, None - - elif resp.status_code == HTTPStatus.NOT_FOUND: - # 404 is expected on an older version of vault, default to version 1 - # https://github.com/hashicorp/consul-template/blob/v0.29.1/dependency/vault_common.go#L310-L311 - return "", 1 - - logger.error("Error occurred during checking metadata for %s", path) - log_httpx_response(logger, resp) - return None, None - - -def read_secret(client: httpx.Client, path: str) -> VaultSecret | None: +@validate_call +def read_secret(client: InstanceOf[httpx.Client], path: str) -> dict | None: """Read secret from Vault. See also -------- https://developer.hashicorp.com/vault/api-docs/secret/kv """ - if not isinstance(client, httpx.Client): - raise TypeError( - f'Expected "client" to be a httpx client, got {type(client).__name__}' - ) - if not isinstance(path, str): - raise TypeError(f'Expected "path" to be a string, got {type(path).__name__}') + mount = get_mount(client, path) + if not mount: + return - mount_point, version = get_mount_point(client, path) - if not mount_point: - return None + logger.debug("Secret %s is mounted at %s (kv%d)", path, mount.path, mount.version) - logger.debug("Secret %s is mounted at %s (kv%d)", path, mount_point, version) - - if version == 1: - url = f"/v1/{path}" + if mount.version == 2: + subpath = path.removeprefix(mount.path) + request_path = f"/v1/{mount.path}data/{subpath}" else: - subpath = path.removeprefix(mount_point) - url = f"/v1/{mount_point}data/{subpath}" + request_path = f"/v1/{path}" try: - resp = client.get(url) + resp = client.get(request_path) except httpx.HTTPError as e: if not (reason := get_httpx_error_reason(e)): raise logger.error("Error occurred during query secret %s: %s", path, reason) - return None + return if resp.is_success: data = resp.json() - if version == 1: - return data["data"] - elif version == 2: + if mount.version == 2: return data["data"]["data"] + else: + return data["data"] elif resp.status_code == HTTPStatus.NOT_FOUND: logger.error("Secret %s not found", path) - return None + return logger.error("Error occurred during query secret %s", path) log_httpx_response(logger, resp) - return None + return -def get_field(secret: dict, name: str) -> str | None: - """Traverse the secret data to get the field along with the given name.""" - for n in split_field(name): - if not isinstance(secret, dict): - return None - secret = typing.cast(dict, secret.get(n)) +class _RawMountMetadata(BaseModel): + """ + { + "data": { + "options": {"version": "1"}, + "path": "secrets/", + "type": "kv", + } + } + """ - if not isinstance(secret, str): - return None + data: _DataBlock - return secret + class _DataBlock(BaseModel): + options: _OptionBlock + path: str + type: str -def split_field(name: str) -> list[str]: - """Split a field name into subsequences. By default, this function splits - the name by dots, with supportting of preserving the quoted subpaths. - """ - pattern_quoted = re.compile(r'"([^"]+)"') - pattern_simple = re.compile(r"([\w-]+)") + class _OptionBlock(BaseModel): + version: str - seq = [] - pos = 0 - while pos < len(name): - # try match pattern - if m := pattern_simple.match(name, pos): - pass - elif m := pattern_quoted.match(name, pos): - pass - else: - break - seq.append(m.group(1)) +class MountMetadata(BaseModel): + """Represents a mount point and KV engine version to a secret.""" - # check remaining part - # +1 for skipping the dot (if exists) - pos = m.end() + 1 + path: str + version: Literal[1, 2] - if pos <= len(name): - raise ValueError(f"Failed to parse name: {name}") - return seq +@validate_call +def get_mount(client: InstanceOf[httpx.Client], path: str) -> MountMetadata | None: + """Get mount point and KV engine version to a secret. + See also + -------- + Vault HTTP API + https://developer.hashicorp.com/vault/api-docs/system/internal-ui-mounts + consul-template + https://github.com/hashicorp/consul-template/blob/v0.29.1/dependency/vault_common.go#L294-L357 + """ + try: + resp = client.get(f"/v1/sys/internal/ui/mounts/{path}") + except httpx.HTTPError as e: + if not (reason := get_httpx_error_reason(e)): + raise + logger.error("Error occurred during checking metadata for %s: %s", path, reason) + return -class VaultPath(BaseModel): - """Represents a path to a value in Vault.""" + if resp.is_success: + parsed = _RawMountMetadata.model_validate_json(resp.read()) + return MountMetadata( + path=parsed.data.path, + version=int(parsed.data.options.version), # type: ignore[reportArgumentType] + ) - path: str = Field(min_length=1) - field: str = Field(min_length=1) + elif resp.status_code == HTTPStatus.NOT_FOUND: + # 404 is expected on an older version of vault, default to version 1 + # https://github.com/hashicorp/consul-template/blob/v0.29.1/dependency/vault_common.go#L310-L311 + return MountMetadata(path="", version=1) - @model_validator(mode="before") - @classmethod - def _from_str(cls, value: str | dict | Self) -> dict | Self: - if not isinstance(value, str): - return value - if value.count("#") != 1: - raise ValueError("Invalid format. Expected 'path#field'") - path, field = value.rsplit("#", 1) - return { - "path": path, - "field": field, - } + logger.error("Error occurred during checking metadata for %s", path) + log_httpx_response(logger, resp) + return diff --git a/tests/config0/test_config__init__.py b/tests/config0/test_config__init__.py index 4d64f4eb..2e2b6ad7 100644 --- a/tests/config0/test_config__init__.py +++ b/tests/config0/test_config__init__.py @@ -4,7 +4,7 @@ import pytest import secrets_env.config0 as t -from secrets_env.provider import ProviderBase +from secrets_env.provider import Provider class TestLoadConfig: @@ -32,7 +32,7 @@ def assert_config_format(self, cfg: dict): for name, provider in cfg["providers"].items(): assert isinstance(name, str) - assert isinstance(provider, ProviderBase) + assert isinstance(provider, Provider) for request in cfg["requests"]: assert isinstance(request["name"], str) diff --git a/tests/config0/test_parser.py b/tests/config0/test_parser.py index 1d81834f..add1b1b4 100644 --- a/tests/config0/test_parser.py +++ b/tests/config0/test_parser.py @@ -4,14 +4,14 @@ import secrets_env.config0.parser as t from secrets_env.exceptions import AuthenticationError, ConfigError -from secrets_env.provider import ProviderBase +from secrets_env.provider import Provider @pytest.fixture() def _patch_get_provider(monkeypatch: pytest.MonkeyPatch): def mock_parser(data: dict): assert isinstance(data, dict) - return Mock(spec=ProviderBase) + return Mock(spec=Provider) monkeypatch.setattr("secrets_env.providers.get_provider", mock_parser) @@ -35,7 +35,7 @@ def test_success(self): } ) assert isinstance(cfg, dict) - assert isinstance(cfg["providers"]["main"], ProviderBase) + assert isinstance(cfg["providers"]["main"], Provider) assert cfg["requests"][0]["spec"] == "foobar" def test_no_request(self): @@ -73,15 +73,15 @@ def test_success(self): } ) assert len(result) == 3 - assert isinstance(result["main"], ProviderBase) - assert isinstance(result["provider 1"], ProviderBase) - assert isinstance(result["provider 2"], ProviderBase) + assert isinstance(result["main"], Provider) + assert isinstance(result["provider 1"], Provider) + assert isinstance(result["provider 2"], Provider) @pytest.mark.usefixtures("_patch_get_provider") def test_duplicated_name(self, caplog: pytest.LogCaptureFixture): result = t.get_providers({"source": [{"data": "dummy"}] * 2}) assert len(result) == 1 - assert isinstance(result["main"], ProviderBase) + assert isinstance(result["main"], Provider) assert "Duplicated source name main" in caplog.text @@ -122,13 +122,13 @@ class TestParseSourceItem: def test_success_1(self): name, provider = t.parse_source_item({}) assert name == "main" - assert isinstance(provider, ProviderBase) + assert isinstance(provider, Provider) @pytest.mark.usefixtures("_patch_get_provider") def test_success_2(self): name, provider = t.parse_source_item({"name": "test"}) assert name == "test" - assert isinstance(provider, ProviderBase) + assert isinstance(provider, Provider) def test_name_error(self): assert t.parse_source_item({"name": object()}) is None diff --git a/tests/providers/teleport/test_teleport.py b/tests/providers/teleport/test_teleport.py index 678cb8c8..fd9dab0d 100644 --- a/tests/providers/teleport/test_teleport.py +++ b/tests/providers/teleport/test_teleport.py @@ -1,52 +1,82 @@ -from unittest.mock import Mock +from pathlib import Path +from unittest.mock import Mock, PropertyMock import pytest -import secrets_env.providers.teleport as t -from secrets_env.exceptions import ConfigError -from secrets_env.provider import ProviderBase -from secrets_env.providers.teleport.config import TeleportUserConfig -from secrets_env.providers.teleport.provider import TeleportProvider +from secrets_env.providers.teleport import TeleportProvider, TeleportRequestSpec, get_ca +from secrets_env.providers.teleport.config import TeleportConnectionParameter -def test_get_provider(): - provider = t.get_provider("teleport", {"app": "test"}) - assert isinstance(provider, TeleportProvider) +class TestTeleportRequestSpec: + def test_success(self): + spec = TeleportRequestSpec.model_validate({"field": "ca", "format": "pem"}) + assert spec == TeleportRequestSpec(field="ca", format="pem") + def test_shortcut(self): + spec = TeleportRequestSpec.model_validate("uri") + assert spec == TeleportRequestSpec(field="uri", format="path") -class TestGetAdaptedProvider: - def test_success(self, monkeypatch: pytest.MonkeyPatch): - def mock_factory(subtype, data, param): - assert subtype == "Test" - assert isinstance(data, dict) - assert isinstance(param, TeleportUserConfig) - return Mock(spec=ProviderBase) - - def mock_get_adapter(subtype): - assert subtype == "Test" - return mock_factory +class TestTeleportProvider: + @pytest.fixture() + def provider( + self, monkeypatch: pytest.MonkeyPatch, conn_param: TeleportConnectionParameter + ): monkeypatch.setattr( - "secrets_env.providers.teleport.adapters.get_adapter", - mock_get_adapter, + TeleportProvider, "connection_param", PropertyMock(return_value=conn_param) ) + return TeleportProvider(app="test") + + def test_get_uri(self, provider: TeleportProvider): + assert provider.get("uri") == "https://example.com" + + def test_get_ca(self, provider: TeleportProvider): + expect = "subject=/C=XX/L=Default City/O=Test\n-----MOCK CERTIFICATE-----" + assert provider.get({"field": "ca", "format": "pem"}) == expect + with open(provider.get("ca")) as fd: + assert fd.read() == expect + + def test_get_cert(self, provider: TeleportProvider): + expect = "-----MOCK CERTIFICATE-----" + assert provider.get({"field": "cert", "format": "pem"}) == expect + with open(provider.get("cert")) as fd: + assert fd.read() == expect + + def test_get_key(self, provider: TeleportProvider): + expect = "-----MOCK PRIVATE KEY-----" + assert provider.get({"field": "key", "format": "pem"}) == expect + with open(provider.get("key")) as fd: + assert fd.read() == expect + + def test_get_cert_and_key(self, provider: TeleportProvider): + expect = "-----MOCK CERTIFICATE-----\n-----MOCK PRIVATE KEY-----" + assert provider.get({"field": "cert+key", "format": "pem"}) == expect + with open(provider.get("cert+key")) as fd: + assert fd.read() == expect + + def test_get_invalid(self, monkeypatch: pytest.MonkeyPatch): + spec = Mock(TeleportRequestSpec) + spec.field = "unknown" monkeypatch.setattr( - TeleportUserConfig, - "get_connection_param", - lambda _: Mock(spec=TeleportUserConfig), + TeleportRequestSpec, "model_validate", Mock(return_value=spec) ) - config = { - "teleport": { - "app": "test", - } - } - provider = t.get_adapted_provider("teleport+Test", config) + with pytest.raises(RuntimeError): + TeleportProvider(app="test").get("unknown") + - assert isinstance(provider, ProviderBase) +class TestGetCa: + def test_success(self): + param = Mock(TeleportConnectionParameter) + param.ca = b"-----MOCK CERTIFICATE-----" + param.path_ca = Path("path/to/ca") + + assert get_ca(param, "path") == "path/to/ca" + assert get_ca(param, "pem") == "-----MOCK CERTIFICATE-----" def test_fail(self): - with pytest.raises(ConfigError): - t.get_adapted_provider("not-teleport+other", {}) - with pytest.raises(ConfigError): # raise by get_adapter - t.get_adapted_provider("teleport+no-this-type", {}) + param = Mock(TeleportConnectionParameter) + param.ca = None + + with pytest.raises(LookupError, match="CA is not available"): + get_ca(param, "pem") diff --git a/tests/providers/teleport/test_teleport_adapters.py b/tests/providers/teleport/test_teleport_adapters.py deleted file mode 100644 index 0cc930f3..00000000 --- a/tests/providers/teleport/test_teleport_adapters.py +++ /dev/null @@ -1,39 +0,0 @@ -from pathlib import Path -from unittest.mock import Mock - -import pytest - -import secrets_env.providers.teleport.adapters as t -from secrets_env.exceptions import ConfigError -from secrets_env.provider import ProviderBase -from secrets_env.providers.teleport.config import TeleportConnectionParameter - - -def test_get_adapter(): - assert callable(t.get_adapter("Vault")) - - with pytest.raises(ConfigError): - t.get_adapter("no-this-type") - - -def test_adapt_vault_provider(monkeypatch: pytest.MonkeyPatch): - def mock_load(type_, data): - assert data["url"] == "https://example.com" - assert len(data["tls"]) == 2 - assert isinstance(data["tls"]["client_cert"], Path) - assert isinstance(data["tls"]["client_key"], Path) - return Mock(spec=ProviderBase) - - monkeypatch.setattr("secrets_env.providers.vault.get_provider", mock_load) - - provider = t.adapt_vault_provider( - type_="vault", - data={"url": "http://invalid.example.com", "auth": "oidc"}, - param=TeleportConnectionParameter( - uri="https://example.com", - ca=None, - cert=b"cert", - key=b"key", - ), - ) - assert isinstance(provider, ProviderBase) diff --git a/tests/providers/teleport/test_teleport_config.py b/tests/providers/teleport/test_teleport_config.py index d0d509a5..9008f13b 100644 --- a/tests/providers/teleport/test_teleport_config.py +++ b/tests/providers/teleport/test_teleport_config.py @@ -49,7 +49,7 @@ def _patch_which(self, monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr("shutil.which", lambda _: "/mock/tsh") @pytest.mark.usefixtures("_patch_which") - def test_get_connection_param_1(self, monkeypatch: pytest.MonkeyPatch): + def test_connection_param_1(self, monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr( "secrets_env.providers.teleport.config.call_version", lambda: True, @@ -60,10 +60,10 @@ def test_get_connection_param_1(self, monkeypatch: pytest.MonkeyPatch): ) cfg = TeleportUserConfig(app="test") - assert isinstance(cfg.get_connection_param(), TeleportConnectionParameter) + assert isinstance(cfg.connection_param, TeleportConnectionParameter) @pytest.mark.usefixtures("_patch_which") - def test_get_connection_param_2(self, monkeypatch: pytest.MonkeyPatch): + def test_connection_param_2(self, monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr( "secrets_env.providers.teleport.config.call_version", lambda: True, @@ -82,28 +82,26 @@ def test_get_connection_param_2(self, monkeypatch: pytest.MonkeyPatch): ) cfg = TeleportUserConfig(app="test") - assert isinstance(cfg.get_connection_param(), TeleportConnectionParameter) + assert isinstance(cfg.connection_param, TeleportConnectionParameter) - def test_get_connection_param_missing_dependency( - self, monkeypatch: pytest.MonkeyPatch - ): + def test_connection_param_missing_dependency(self, monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr("shutil.which", lambda _: None) cfg = TeleportUserConfig(app="test") with pytest.raises(UnsupportedError): - cfg.get_connection_param() + cfg.connection_param # noqa: B018 @pytest.mark.usefixtures("_patch_which") - def test_get_connection_param_version_error(self, monkeypatch: pytest.MonkeyPatch): + def test_connection_param_version_error(self, monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr( "secrets_env.providers.teleport.config.call_version", lambda: False, ) cfg = TeleportUserConfig(app="test") with pytest.raises(RuntimeError): - cfg.get_connection_param() + cfg.connection_param # noqa: B018 @pytest.mark.usefixtures("_patch_which") - def test_get_connection_param_no_config(self, monkeypatch: pytest.MonkeyPatch): + def test_connection_param_no_config(self, monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr( "secrets_env.providers.teleport.config.call_version", lambda: True, @@ -123,7 +121,7 @@ def test_get_connection_param_no_config(self, monkeypatch: pytest.MonkeyPatch): cfg = TeleportUserConfig(app="test") with pytest.raises(AuthenticationError): - cfg.get_connection_param() + cfg.connection_param # noqa: B018 class TestTeleportConnectionParameter: diff --git a/tests/providers/teleport/test_teleport_provider.py b/tests/providers/teleport/test_teleport_provider.py deleted file mode 100644 index fcbdd889..00000000 --- a/tests/providers/teleport/test_teleport_provider.py +++ /dev/null @@ -1,123 +0,0 @@ -import re -from pathlib import Path -from unittest.mock import Mock - -import pytest - -import secrets_env.providers.teleport.provider as t -from secrets_env.exceptions import ConfigError, ValueNotFound -from secrets_env.providers.teleport.config import ( - TeleportConnectionParameter, - TeleportUserConfig, -) - - -class TestTeleportProvider: - @pytest.fixture() - def provider(self): - return t.TeleportProvider(config=TeleportUserConfig(app="test")) - - def test_type(self, provider): - assert provider.type == "teleport" - - @pytest.mark.parametrize( - ("field", "expect"), - [ - ("CA", b"subject=/C=XX/L=Default City/O=Test\n-----MOCK CERTIFICATE-----"), - ("CERT", b"-----MOCK CERTIFICATE-----"), - ("KEY", b"-----MOCK PRIVATE KEY-----"), - ("cert+KEY", b"-----MOCK CERTIFICATE-----\n-----MOCK PRIVATE KEY-----"), - ], - ) - def test_get_path( - self, monkeypatch: pytest.MonkeyPatch, provider, conn_param, field, expect - ): - monkeypatch.setattr( - TeleportUserConfig, "get_connection_param", lambda _: conn_param - ) - path = provider.get({"field": field, "format": "path"}) - assert isinstance(path, str) - assert Path(path).is_file() - assert Path(path).read_bytes() == expect - - @pytest.mark.parametrize( - ("field", "expect"), - [ - ("URI", "https://example.com"), - ("CA", "subject=/C=XX/L=Default City/O=Test\n-----MOCK CERTIFICATE-----"), - ("CERT", "-----MOCK CERTIFICATE-----"), - ("KEY", "-----MOCK PRIVATE KEY-----"), - ("cert+KEY", "-----MOCK CERTIFICATE-----\n-----MOCK PRIVATE KEY-----"), - ], - ) - def test_get_pem( - self, monkeypatch: pytest.MonkeyPatch, provider, conn_param, field, expect - ): - monkeypatch.setattr( - TeleportUserConfig, "get_connection_param", lambda _: conn_param - ) - data = provider.get({"field": field, "format": "pem"}) - assert isinstance(data, str) - assert data == expect - - def test_get_error(self, monkeypatch: pytest.MonkeyPatch, provider): - monkeypatch.setattr( - t, "parse_spec", lambda _: Mock(spec=t.OutputSpec, field="unknown") - ) - with pytest.raises( - ConfigError, match=re.escape("Invalid value spec: {'mock': 'mocked'}") - ): - provider.get({"mock": "mocked"}) - - -class TestParseSpec: - def test_success(self): - assert t.parse_spec("ca") == t.OutputSpec("ca", "path") - assert t.parse_spec({"field": "ca"}) == t.OutputSpec("ca", "path") - assert t.parse_spec({"field": "ca", "format": "pem"}) == t.OutputSpec( - "ca", "pem" - ) - - def test_failed(self): - with pytest.raises( - ConfigError, match=re.escape("Invalid field (secrets.VAR.field): invalid") - ): - t.parse_spec({"field": "invalid"}) - - with pytest.raises( - ConfigError, match=re.escape("Invalid format (secrets.VAR.format): invalid") - ): - t.parse_spec({"field": "ca", "format": "invalid"}) - - def test_type_error(self): - with pytest.raises( - ConfigError, - match=re.escape("Expect dict for secrets path spec, got int"), - ): - t.parse_spec(1234) - - -def test_get_ca(conn_param): - # path - p = t.get_ca(conn_param, "path") - assert isinstance(p, str) - assert Path(p).is_file() - - # pem - d = t.get_ca(conn_param, "pem") - assert isinstance(d, str) - assert d == "subject=/C=XX/L=Default City/O=Test\n-----MOCK CERTIFICATE-----" - - # no CA - conn_param_no_ca = TeleportConnectionParameter( - uri="https://example.com", - ca=None, - cert=b"-----MOCK CERTIFICATE-----", - key=b"-----MOCK PRIVATE KEY-----", - ) - with pytest.raises(ValueNotFound): - t.get_ca(conn_param_no_ca, "path") - - # make linter happy - with pytest.raises(RuntimeError): - t.get_ca(conn_param, "no-this-format") diff --git a/tests/providers/test_plain.py b/tests/providers/test_plain.py deleted file mode 100644 index 454a3e52..00000000 --- a/tests/providers/test_plain.py +++ /dev/null @@ -1,16 +0,0 @@ -import pytest - -import secrets_env.providers.plain as t - - -def test(): - provider = t.get_provider("plain", {}) - assert provider.type == "plain" - - assert provider.get("foo") == "foo" - assert provider.get("") == "" - assert provider.get({"value": "foo"}) == "foo" - assert provider.get({"value": None}) == "" - - with pytest.raises(TypeError): - provider.get(None) diff --git a/tests/providers/test_providers.py b/tests/providers/test_providers.py new file mode 100644 index 00000000..d83164c5 --- /dev/null +++ b/tests/providers/test_providers.py @@ -0,0 +1,53 @@ +import pytest + +from secrets_env.providers import get_provider +from secrets_env.providers.null import NullProvider +from secrets_env.providers.plain import PlainTextProvider +from secrets_env.providers.teleport import TeleportProvider +from secrets_env.providers.vault import VaultKvProvider + + +class TestGetProvider: + def test_null(self): + provider = get_provider({"type": "null"}) + assert isinstance(provider, NullProvider) + + def test_plain(self): + provider = get_provider({"type": "plain"}) + assert isinstance(provider, PlainTextProvider) + + def test_teleport(self): + provider = get_provider({"type": "teleport", "app": "test"}) + assert isinstance(provider, TeleportProvider) + + def test_teleport_adapter(self): + with pytest.raises(NotImplementedError): + get_provider({"type": "teleport+vault"}) + + def test_vault(self): + provider = get_provider({"url": "https://example.com/", "auth": "null"}) + assert isinstance(provider, VaultKvProvider) + + def test_invalid(self): + with pytest.raises(ValueError, match="Unknown provider type invalid"): + get_provider({"type": "invalid"}) + + +class TestNullProvider: + def test_get(self): + provider = NullProvider() + assert provider.get("test") == "" + assert provider.get({"value": "test"}) == "" + + +class TestPlainTextProvider: + def test(self): + provider = PlainTextProvider() + assert provider.get("test") == "test" + assert provider.get("") == "" + + assert provider.get({"value": "test"}) == "test" + assert provider.get({"value": None}) == "" + assert provider.get({"value": ""}) == "" + + assert provider.get({"invalid": "foo"}) == "" diff --git a/tests/providers/test_providers___init__.py b/tests/providers/test_providers___init__.py deleted file mode 100644 index e1748d16..00000000 --- a/tests/providers/test_providers___init__.py +++ /dev/null @@ -1,45 +0,0 @@ -from unittest.mock import Mock - -import pytest - -import secrets_env.providers as t -from secrets_env.exceptions import ConfigError -from secrets_env.provider import ProviderBase - - -def mock_get_provider(type_, data): - return Mock(spec=ProviderBase) - - -class TestGetProvider: - def test_null(self): - p = t.get_provider({"type": "NULL"}) - assert p.get({}) == "" - - def test_vault(self, monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr( - "secrets_env.providers.vault.get_provider", mock_get_provider - ) - assert isinstance(t.get_provider({"type": "Vault"}), ProviderBase) - - def test_teleport(self, monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr( - "secrets_env.providers.teleport.get_provider", mock_get_provider - ) - assert isinstance(t.get_provider({"type": "Teleport"}), ProviderBase) - - def test_teleport_adapter(self, monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr( - "secrets_env.providers.teleport.get_adapted_provider", mock_get_provider - ) - assert isinstance(t.get_provider({"type": "Teleport+Test"}), ProviderBase) - - def test_plain(self, monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr( - "secrets_env.providers.plain.get_provider", mock_get_provider - ) - assert isinstance(t.get_provider({"type": "plain"}), ProviderBase) - - def test_not_found(self): - with pytest.raises(ConfigError): - t.get_provider({"type": "no-this-type"}) diff --git a/tests/providers/vault/auth/test_userpass.py b/tests/providers/vault/auth/test_userpass.py index 1089949f..3fd3ce25 100644 --- a/tests/providers/vault/auth/test_userpass.py +++ b/tests/providers/vault/auth/test_userpass.py @@ -8,6 +8,7 @@ from pydantic_core import Url import secrets_env.providers.vault.auth.userpass as t +from secrets_env.exceptions import AuthenticationError from secrets_env.providers.vault.auth.userpass import UserPasswordAuth @@ -122,10 +123,7 @@ class MockAuth(UserPasswordAuth): assert auth_obj.login(unittest_client) == "client-token" def test_login_fail( - self, - unittest_respx: respx.MockRouter, - unittest_client: httpx.Client, - caplog: pytest.LogCaptureFixture, + self, unittest_respx: respx.MockRouter, unittest_client: httpx.Client ): unittest_respx.post("/v1/auth/mock/login/user%40example.com").mock( return_value=httpx.Response(400) @@ -136,8 +134,9 @@ class MockAuth(UserPasswordAuth): vault_name = "mock" auth_obj = MockAuth(username="user@example.com", password="password") - assert auth_obj.login(unittest_client) is None - assert "Failed to login with MOCK method" in caplog.text + + with pytest.raises(AuthenticationError): + assert auth_obj.login(unittest_client) is None @pytest.mark.parametrize( diff --git a/tests/providers/vault/test_vault___init__.py b/tests/providers/vault/test_vault___init__.py deleted file mode 100644 index 7b1561d1..00000000 --- a/tests/providers/vault/test_vault___init__.py +++ /dev/null @@ -1,25 +0,0 @@ -import os - -import pytest - -import secrets_env.providers.vault as t -from secrets_env.exceptions import ConfigError -from secrets_env.providers.vault.provider import KvProvider - - -class TestGetProvider: - def test_success(self): - out = t.get_provider("vault", {"url": "https://example.com", "auth": "null"}) - assert isinstance(out, KvProvider) - - def test_fail(self): - if "VAULT_ADDR" in os.environ: - pytest.skip("VAULT_ADDR is set. Skipping test.") - with pytest.raises(ConfigError): - t.get_provider("vault", {}) - - def test_not_related(self): - if "VAULT_ADDR" in os.environ: - pytest.skip("VAULT_ADDR is set. Skipping test.") - with pytest.raises(ConfigError): - t.get_provider("something-else", {}) diff --git a/tests/providers/vault/test_vault_provider.py b/tests/providers/vault/test_vault_provider.py index 117d6c2a..3b1cdf5b 100644 --- a/tests/providers/vault/test_vault_provider.py +++ b/tests/providers/vault/test_vault_provider.py @@ -1,261 +1,338 @@ import os from pathlib import Path -from unittest.mock import Mock, PropertyMock, patch +from unittest.mock import Mock, PropertyMock import httpx -import httpx._config import pytest import respx -from pydantic import ValidationError - -import secrets_env.providers.vault.provider as t -from secrets_env.exceptions import AuthenticationError, ValueNotFound -from secrets_env.providers.vault.auth.base import Auth -from secrets_env.providers.vault.auth.token import TokenAuth -from secrets_env.providers.vault.provider import VaultPath +from pydantic_core import ValidationError + +from secrets_env.exceptions import AuthenticationError +from secrets_env.providers.vault.auth.base import Auth, NoAuth +from secrets_env.providers.vault.config import TlsConfig, VaultUserConfig +from secrets_env.providers.vault.provider import ( + MountMetadata, + VaultKvProvider, + VaultPath, + _split_field_str, + create_http_client, + get_mount, + get_token, + is_authenticated, + read_secret, +) @pytest.fixture() -def mock_client() -> httpx.Client: - client = Mock(spec=httpx.Client) - client.headers = {} - return client +def intl_provider() -> VaultKvProvider: + if "VAULT_ADDR" not in os.environ: + raise pytest.skip("VAULT_ADDR is not set") + if "VAULT_TOKEN" not in os.environ: + raise pytest.skip("VAULT_TOKEN is not set") + return VaultKvProvider(auth="token") @pytest.fixture() -def mock_auth(): - auth = Mock(spec=Auth) - auth.method = "mocked" - return auth - +def intl_client(intl_provider: VaultKvProvider) -> httpx.Client: + return intl_provider.client -class TestKvProvider: - """Unit tests for KvProvider""" - @pytest.fixture() - def provider(self, mock_auth: Auth) -> t.KvProvider: - return t.KvProvider("https://example.com/", mock_auth) - - def test_type(self, provider: t.KvProvider): - assert provider.type == "vault" +class TestVaultPath: + def test_success(self): + path = VaultPath.model_validate('foo#"bar.baz".qux') + assert path == VaultPath(path="foo", field=("bar.baz", "qux")) + assert str(path) == 'foo#"bar.baz".qux' - def test_client_success( - self, - monkeypatch: pytest.MonkeyPatch, - provider: t.KvProvider, - mock_client: httpx.Client, - ): - # setup - monkeypatch.setattr(t, "get_token", lambda c, a: "token") + def test_invalid(self): + # missing path + with pytest.raises(ValidationError): + VaultPath(path="", field=("b")) - patch_client = Mock(return_value=mock_client) - monkeypatch.setattr("httpx.Client", patch_client) + # missing path/field separator + with pytest.raises(ValidationError): + VaultPath.model_validate("foobar") - provider.proxy = "proxy" - provider.ca_cert = Mock(spec=Path) - provider.client_cert = Mock(spec=Path) + # too many path/field separator + with pytest.raises(ValidationError): + VaultPath.model_validate("foo#bar#baz") - # run twice for testing cache - assert provider.client is mock_client - assert provider.client is mock_client + # empty field subpath + with pytest.raises(ValidationError): + VaultPath(path="a", field=()) + with pytest.raises(ValidationError): + VaultPath(path="a", field=("b", "", "c")) + with pytest.raises(ValidationError): + VaultPath(path="a", field=("b", "")) - # test - assert mock_client.headers["X-Vault-Token"] == "token" - _, kwargs = patch_client.call_args - assert kwargs["base_url"] == "https://example.com/" - assert kwargs["proxies"] == "proxy" - assert isinstance(kwargs["verify"], Path) - assert isinstance(kwargs["cert"], Path) +class TestSplitFieldStr: + def test_success(self): + assert list(_split_field_str("foo")) == ["foo"] + assert list(_split_field_str("foo.bar.baz")) == ["foo", "bar", "baz"] + assert list(_split_field_str('foo."bar.baz"')) == ["foo", "bar.baz"] + assert list(_split_field_str('"foo.bar".baz')) == ["foo.bar", "baz"] + assert list(_split_field_str("")) == [] - def test_get_success(self, monkeypatch: pytest.MonkeyPatch, provider: t.KvProvider): - def mock_read_field(path, field): - assert path == "foo" - assert field == "bar" - return "secret" + def test_invalid(self): + with pytest.raises(ValueError, match=r"Failed to parse field:"): + list(_split_field_str('foo."bar.baz')) - monkeypatch.setattr(provider, "read_field", mock_read_field) - assert provider.get("foo#bar") == "secret" +class TestVaultKvProvider: + def test_client(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr( + "secrets_env.providers.vault.provider.create_http_client", + lambda _: Mock(httpx.Client, headers={}), + ) + provider = VaultKvProvider(url="https://vault.example.com", auth="null") + assert isinstance(provider.client, httpx.Client) - def test_get_fail(self, provider: t.KvProvider): - with pytest.raises(ValidationError): - provider.get({}) - with pytest.raises(ValidationError): - provider.get(1234) + @pytest.fixture() + def unittest_provider(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr( + VaultKvProvider, "client", PropertyMock(return_value=Mock(httpx.Client)) + ) + return VaultKvProvider(url="https://vault.example.com", auth="null") - def test_read_secret_success( - self, - monkeypatch: pytest.MonkeyPatch, - provider: t.KvProvider, - unittest_client: httpx.Client, + def test_get__success( + self, monkeypatch: pytest.MonkeyPatch, unittest_provider: VaultKvProvider ): - monkeypatch.setattr(provider, "_secrets", {}) monkeypatch.setattr( - t.KvProvider, "client", PropertyMock(return_value=unittest_client) + VaultKvProvider, "_read_secret", Mock(return_value={"bar": "test"}) ) + assert unittest_provider.get({"path": "foo", "field": "bar"}) == "test" - monkeypatch.setattr(t, "read_secret", lambda _1, _2: {"bar": "secret"}) - assert provider.read_secret("test-path") == {"bar": "secret"} - - def test_read_secret_cache( - self, monkeypatch: pytest.MonkeyPatch, provider: t.KvProvider + def test_get__too_depth( + self, monkeypatch: pytest.MonkeyPatch, unittest_provider: VaultKvProvider ): - monkeypatch.setattr(provider, "_secrets", {"test-path": {"bar": "secret"}}) - assert provider.read_secret("test-path") == {"bar": "secret"} + monkeypatch.setattr( + VaultKvProvider, "_read_secret", Mock(return_value={"bar": "test"}) + ) + with pytest.raises(LookupError, match='Field "bar.baz" not found in "foo"'): + unittest_provider.get({"path": "foo", "field": "bar.baz"}) - def test_read_secret_not_found( - self, - monkeypatch: pytest.MonkeyPatch, - provider: t.KvProvider, - unittest_client: httpx.Client, + def test_get__too_shallow( + self, monkeypatch: pytest.MonkeyPatch, unittest_provider: VaultKvProvider ): - monkeypatch.setattr(provider, "_secrets", {}) monkeypatch.setattr( - t.KvProvider, "client", PropertyMock(return_value=unittest_client) + VaultKvProvider, "_read_secret", Mock(return_value={"bar": {"baz": "test"}}) ) - monkeypatch.setattr(t, "read_secret", lambda _1, _2: None) - with pytest.raises(ValueNotFound): - provider.read_secret("test-secret") - - def test_read_secret_error(self, provider: t.KvProvider): - with pytest.raises(TypeError): - provider.read_secret(1234) + with pytest.raises( + LookupError, match='Field "bar" in "foo" is not point to a string value' + ): + unittest_provider.get({"path": "foo", "field": "bar"}) - def test_read_field_success( - self, monkeypatch: pytest.MonkeyPatch, provider: t.KvProvider + def test_read_secret__success( + self, monkeypatch: pytest.MonkeyPatch, unittest_provider: VaultKvProvider ): - monkeypatch.setattr(provider, "read_secret", lambda _: {"bar": "secret"}) - assert provider.read_field("foo", "bar") == "secret" + func = Mock(return_value={"foo": "bar"}) + monkeypatch.setattr("secrets_env.providers.vault.provider.read_secret", func) + + path = VaultPath(path="foo", field="bar") + assert unittest_provider._read_secret(path) == {"foo": "bar"} + assert unittest_provider._read_secret(path) == {"foo": "bar"} + + assert func.call_count == 1 + + client, path = func.call_args[0] + assert isinstance(client, httpx.Client) + assert path == "foo" - def test_read_field_fail( - self, monkeypatch: pytest.MonkeyPatch, provider: t.KvProvider + def test_read_secret__not_found( + self, monkeypatch: pytest.MonkeyPatch, unittest_provider: VaultKvProvider ): - with pytest.raises(TypeError): - provider.read_field(1234, "bar") + func = Mock(return_value=None) + monkeypatch.setattr("secrets_env.providers.vault.provider.read_secret", func) - monkeypatch.setattr(provider, "read_secret", lambda _: {}) - with pytest.raises(ValueNotFound): - provider.read_field("foo", "bar") + path = VaultPath(path="foo", field="bar") + with pytest.raises(LookupError): + unittest_provider._read_secret(path) + with pytest.raises(LookupError): + unittest_provider._read_secret(path) + assert func.call_count == 1 -class TestKvProviderUsingVaultConnection: - @pytest.fixture(scope="class") - def provider(self) -> t.KvProvider: - url = os.getenv("VAULT_ADDR") - token = os.getenv("VAULT_TOKEN") - if not url or not token: - pytest.skip("VAULT_ADDR or VAULT_TOKEN are not set") - return t.KvProvider(url, TokenAuth(token=token)) + def test_integration(self, intl_provider: VaultKvProvider): + assert intl_provider.get({"path": "kv2/test", "field": "foo"}) == "hello, world" + assert intl_provider.get('kv2/test#test."name.with-dot"') == "sample-value" - def test_client_success(self, provider: t.KvProvider): - with patch.object(t, "is_authenticated", return_value=True): - assert isinstance(provider.client, httpx.Client) - assert isinstance(provider.client, httpx.Client) # from cache - def test_get(self, provider: t.KvProvider): - assert provider.get("kv1/test#foo") == "hello" - assert provider.get({"path": "kv2/test", "field": "foo"}) == "hello, world" +class TestCreateHttpClient: + @pytest.mark.skipif("VAULT_ADDR" in os.environ, reason="VAULT_ADDR is set") + def test_basic(self): + config = VaultUserConfig( + url="https://vault.example.com", + auth="null", + ) - def test_read_secret_v1(self, provider: t.KvProvider): - secret_1 = provider.read_secret("kv1/test") - assert isinstance(secret_1, dict) - assert secret_1["foo"] == "hello" + client = create_http_client(config) - secret_2 = provider.read_secret("kv1/test") - assert secret_1 is secret_2 + assert isinstance(client, httpx.Client) + assert client.base_url == httpx.URL("https://vault.example.com/") - def test_read_secret_v2(self, provider: t.KvProvider): - secret = provider.read_secret("kv2/test") - assert isinstance(secret, dict) - assert secret["foo"] == "hello, world" + def test_proxy(self): + config = VaultUserConfig( + url="https://vault.example.com", + auth="null", + proxy="http://proxy.example.com", + ) - def test_read_field(self, provider: t.KvProvider): - assert provider.read_field("kv1/test", "foo") == "hello" - assert provider.read_field("kv2/test", 'test."name.with-dot"') == "sample-value" + client = create_http_client(config) + assert isinstance(client, httpx.Client) - with pytest.raises(ValueNotFound): - provider.read_field("kv2/test", "foo.no-extra-level") - with pytest.raises(ValueNotFound): - provider.read_field("kv2/test", "test.no-this-key") - with pytest.raises(ValueNotFound): - provider.read_field("secret/no-this-secret", "test") + @pytest.fixture() + def mock_httpx_client(self, monkeypatch: pytest.MonkeyPatch): + client = Mock(httpx.Client) + monkeypatch.setattr("httpx.Client", client) + return client + + def test_ca(self, mock_httpx_client: Mock): + config = VaultUserConfig( + url="https://vault.example.com", + auth="null", + tls=Mock( + TlsConfig, + ca_cert=Path("/mock/ca.pem"), + client_cert=None, + client_key=None, + ), + ) + create_http_client(config) + + _, kwargs = mock_httpx_client.call_args + assert kwargs["verify"] == Path("/mock/ca.pem") + + def test_client_cert(self, mock_httpx_client: Mock): + config = VaultUserConfig( + url="https://vault.example.com", + auth="null", + tls=Mock( + TlsConfig, + ca_cert=None, + client_cert=Path("/mock/client.pem"), + client_key=None, + ), + ) -class TestGetToken: - def test_success( - self, - mock_client: httpx.Client, - mock_auth: Auth, - monkeypatch: pytest.MonkeyPatch, - ): - mock_auth.login.return_value = "t0ken" - monkeypatch.setattr(t, "is_authenticated", lambda c, t: True) - assert t.get_token(mock_client, mock_auth) == "t0ken" + create_http_client(config) + + _, kwargs = mock_httpx_client.call_args + assert kwargs["cert"] == Path("/mock/client.pem") + + def test_client_cert_pair(self, mock_httpx_client: Mock): + config = VaultUserConfig( + url="https://vault.example.com", + auth="null", + tls=Mock( + TlsConfig, + ca_cert=None, + client_cert=Path("/mock/client.pem"), + client_key=Path("/mock/client.key"), + ), + ) - def test_no_token(self, mock_client: httpx.Client, mock_auth: Auth): - mock_auth.login.return_value = None - with pytest.raises(AuthenticationError, match="Absence of token information"): - t.get_token(mock_client, mock_auth) + create_http_client(config) - def test_not_authenticated( - self, - mock_client: httpx.Client, - mock_auth: Auth, - monkeypatch: pytest.MonkeyPatch, - ): - mock_auth.login.return_value = "t0ken" - monkeypatch.setattr(t, "is_authenticated", lambda c, t: False) + _, kwargs = mock_httpx_client.call_args + assert kwargs["cert"] == (Path("/mock/client.pem"), Path("/mock/client.key")) + + +class TestGetToken: + def test_success(self, monkeypatch: pytest.MonkeyPatch): + client = Mock(httpx.Client) + auth = NoAuth(token="t0ken") + monkeypatch.setattr( + "secrets_env.providers.vault.provider.is_authenticated", lambda c, t: True + ) + assert get_token(client, auth) == "t0ken" + + def test_authenticate_fail(self, monkeypatch: pytest.MonkeyPatch): + client = Mock(httpx.Client) + auth = NoAuth(token="t0ken") + monkeypatch.setattr( + "secrets_env.providers.vault.provider.is_authenticated", lambda c, t: False + ) with pytest.raises(AuthenticationError, match="Invalid token"): - t.get_token(mock_client, mock_auth) + get_token(client, auth) - def test_login_connection_error(self, mock_client: httpx.Client, mock_auth: Auth): - mock_auth.login.side_effect = httpx.ProxyError("test") + def test_login_connection_error(self): + client = Mock(httpx.Client) + auth = Mock(Auth) + auth.login.side_effect = httpx.ProxyError("test") with pytest.raises( AuthenticationError, match="Encounter proxy error while retrieving token" ): - t.get_token(mock_client, mock_auth) + get_token(client, auth) - def test_login_exception(self, mock_client: httpx.Client, mock_auth: Auth): - mock_auth.login.side_effect = httpx.HTTPError("test") + def test_login_exception(self): + client = Mock(httpx.Client) + auth = Mock(Auth) + auth.login.side_effect = httpx.HTTPError("test") with pytest.raises(httpx.HTTPError): - t.get_token(mock_client, mock_auth) + get_token(client, auth) + +class TestIsAuthenticated: -def test_is_authenticated(): - url = os.getenv("VAULT_ADDR") - token = os.getenv("VAULT_TOKEN") - if not url or not token: - pytest.skip("VAULT_ADDR or VAULT_TOKEN are not set") + def test_success(self, respx_mock: respx.MockRouter): + respx_mock.get("https://vault.example.com/v1/auth/token/lookup-self") - # success: use real client - client = httpx.Client(base_url=os.getenv("VAULT_ADDR")) - assert t.is_authenticated(client, os.getenv("VAULT_TOKEN")) - assert not t.is_authenticated(client, "invalid-token") + client = httpx.Client(base_url="https://vault.example.com") + assert is_authenticated(client, "test-token") is True - # type error - with pytest.raises(TypeError): - t.is_authenticated("http://example.com", "token") - with pytest.raises(TypeError): - t.is_authenticated(client, 1234) + def test_fail(self, respx_mock: respx.MockRouter): + respx_mock.get("https://vault.example.com/v1/auth/token/lookup-self").respond( + status_code=403, + json={"errors": ["mock permission denied"]}, + ) + + client = httpx.Client(base_url="https://vault.example.com") + assert is_authenticated(client, "test-token") is False + + def test_integration(self): + if "VAULT_ADDR" not in os.environ: + raise pytest.skip("VAULT_ADDR is not set") + if "VAULT_TOKEN" not in os.environ: + raise pytest.skip("VAULT_TOKEN is not set") + + client = httpx.Client(base_url=os.getenv("VAULT_ADDR")) + + assert is_authenticated(client, os.getenv("VAULT_TOKEN")) is True + assert is_authenticated(client, "invalid-token") is False -class TestGetMountPoint: +class TestReadSecret: @pytest.fixture() - def route(self, respx_mock: respx.MockRouter): - return respx_mock.get( - "https://example.com/v1/sys/internal/ui/mounts/secrets/test" + def _set_mount_kv2(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr( + "secrets_env.providers.vault.provider.get_mount", + lambda c, p: MountMetadata(path="secrets/", version=2), ) - def test_success_kv1(self, route: respx.Route, unittest_client: httpx.Client): - route.mock( + @pytest.mark.usefixtures("_set_mount_kv2") + def test_kv2( + self, + respx_mock: respx.MockRouter, + unittest_client: httpx.Client, + ): + respx_mock.get("https://example.com/v1/secrets/data/test").mock( httpx.Response( 200, json={ + "request_id": "9ababbb6-3749-cf2c-5a5b-85660e917e8e", + "lease_id": "", + "renewable": False, + "lease_duration": 0, "data": { - "options": {"version": "1"}, - "path": "secrets/", - "type": "kv", + "data": {"test": "mock"}, + "metadata": { + "created_time": "2022-09-20T15:57:45.143053836Z", + "custom_metadata": None, + "deletion_time": "", + "destroyed": False, + "version": 1, + }, }, "wrap_info": None, "warnings": None, @@ -263,131 +340,145 @@ def test_success_kv1(self, route: respx.Route, unittest_client: httpx.Client): }, ) ) - assert t.get_mount_point(unittest_client, "secrets/test") == ("secrets/", 1) - def test_success_kv2(self, route: respx.Route, unittest_client: httpx.Client): - route.mock( - httpx.Response( - 200, - json={ - "data": { - "options": {"version": "2"}, - "path": "secrets/", - "type": "kv", - }, - }, - ) - ) - assert t.get_mount_point(unittest_client, "secrets/test") == ("secrets/", 2) + assert read_secret(unittest_client, "secrets/test") == {"test": "mock"} - def test_success_legacy(self, route: respx.Route, unittest_client: httpx.Client): - route.mock(httpx.Response(404)) - assert t.get_mount_point(unittest_client, "secrets/test") == ("", 1) + def test_kv2_integration(self, intl_client: httpx.Client): + assert read_secret(intl_client, "kv2/test") == { + "foo": "hello, world", + "test": {"name.with-dot": "sample-value"}, + } - def test_not_ported_version( - self, route: respx.Route, unittest_client: httpx.Client + def test_kv1( + self, + monkeypatch: pytest.MonkeyPatch, + respx_mock: respx.MockRouter, + unittest_client: httpx.Client, ): - route.mock( + monkeypatch.setattr( + "secrets_env.providers.vault.provider.get_mount", + lambda c, p: MountMetadata(path="secrets/", version=1), + ) + respx_mock.get("https://example.com/v1/secrets/test").mock( httpx.Response( 200, json={ - "data": { - "path": "mock/", - "type": "kv", - "options": {"version": "99"}, - } + "request_id": "a8f28d97-8a9d-c9dd-4d86-e815083b33ad", + "lease_id": "", + "renewable": False, + "lease_duration": 2764800, + "data": {"test": "mock"}, + "wrap_info": None, + "warnings": None, + "auth": None, }, ) ) - assert t.get_mount_point(unittest_client, "secrets/test") == (None, None) - def test_bad_request( + assert read_secret(unittest_client, "secrets/test") == {"test": "mock"} + + def test_kv1_integration(self, intl_client: httpx.Client): + assert read_secret(intl_client, "kv1/test") == {"foo": "hello"} + + @pytest.mark.usefixtures("_set_mount_kv2") + def test_not_found( self, - route: respx.Route, + respx_mock: respx.MockRouter, unittest_client: httpx.Client, caplog: pytest.LogCaptureFixture, ): - route.mock(httpx.Response(400)) - assert t.get_mount_point(unittest_client, "secrets/test") == (None, None) - assert "Error occurred during checking metadata for secrets/test" in caplog.text + respx_mock.get("https://example.com/v1/secrets/data/test").mock( + httpx.Response(404) + ) + assert read_secret(unittest_client, "secrets/test") is None + assert "Secret secrets/test not found" in caplog.text + + def test_get_mount_error( + self, monkeypatch: pytest.MonkeyPatch, unittest_client: httpx.Client + ): + monkeypatch.setattr( + "secrets_env.providers.vault.provider.get_mount", lambda c, p: None + ) + assert read_secret(unittest_client, "secrets/test") is None + @pytest.mark.usefixtures("_set_mount_kv2") def test_connection_error( self, - route: respx.Route, + respx_mock: respx.MockRouter, unittest_client: httpx.Client, caplog: pytest.LogCaptureFixture, ): - route.mock(side_effect=httpx.ConnectError) - assert t.get_mount_point(unittest_client, "secrets/test") == (None, None) + respx_mock.get("https://example.com/v1/secrets/data/test").mock( + side_effect=httpx.ProxyError + ) + assert read_secret(unittest_client, "secrets/test") is None assert ( - "Error occurred during checking metadata for secrets/test: connection error" + "Error occurred during query secret secrets/test: proxy error" in caplog.text ) - def test_unhandled_exception( - self, route: respx.Route, unittest_client: httpx.Client + @pytest.mark.usefixtures("_set_mount_kv2") + def test_http_exception( + self, respx_mock: respx.MockRouter, unittest_client: httpx.Client ): - route.mock(side_effect=httpx.DecodingError) + respx_mock.get("https://example.com/v1/secrets/data/test").mock( + side_effect=httpx.DecodingError + ) with pytest.raises(httpx.DecodingError): - t.get_mount_point(unittest_client, "secrets/test") + read_secret(unittest_client, "secrets/test") - def test_type_error(self): - with pytest.raises(TypeError): - t.get_mount_point(1234, "secrets/test") - with pytest.raises(TypeError): - t.get_mount_point(Mock(spec=httpx.Client), 1234) + @pytest.mark.usefixtures("_set_mount_kv2") + def test_bad_request( + self, + respx_mock: respx.MockRouter, + unittest_client: httpx.Client, + caplog: pytest.LogCaptureFixture, + ): + respx_mock.get("https://example.com/v1/secrets/data/test").mock( + httpx.Response(499) + ) + assert read_secret(unittest_client, "secrets/test") is None + assert "Error occurred during query secret secrets/test" in caplog.text -class TestReadSecret: +class TestGetMount: @pytest.fixture() - def patch_get_mount_point(self): - with patch.object(t, "get_mount_point", return_value=("secrets/", 1)) as p: - yield p + def route(self, respx_mock: respx.MockRouter): + return respx_mock.get( + "https://example.com/v1/sys/internal/ui/mounts/secrets/test" + ) - @pytest.mark.usefixtures("patch_get_mount_point") - def test_kv1(self, respx_mock: respx.MockRouter, unittest_client: httpx.Client): - respx_mock.get("https://example.com/v1/secrets/test").mock( + def test_success_kv2(self, route: respx.Route, unittest_client: httpx.Client): + route.mock( httpx.Response( 200, json={ - "request_id": "a8f28d97-8a9d-c9dd-4d86-e815083b33ad", - "lease_id": "", - "renewable": False, - "lease_duration": 2764800, - "data": {"test": "mock"}, - "wrap_info": None, - "warnings": None, - "auth": None, + "data": { + "options": {"version": "2"}, + "path": "secrets/", + "type": "kv", + }, }, ) ) + assert get_mount(unittest_client, "secrets/test") == MountMetadata( + path="secrets/", version=2 + ) - with patch.object(t, "get_mount_point", return_value=("secrets/", 1)): - assert t.read_secret(unittest_client, "secrets/test") == {"test": "mock"} + def test_success_kv2_integration(self, intl_client: httpx.Client): + assert get_mount(intl_client, "kv2/test") == MountMetadata( + path="kv2/", version=2 + ) - def test_kv2( - self, - respx_mock: respx.MockRouter, - unittest_client: httpx.Client, - patch_get_mount_point: Mock, - ): - respx_mock.get("https://example.com/v1/secrets/data/test").mock( + def test_success_kv1(self, route: respx.Route, unittest_client: httpx.Client): + route.mock( httpx.Response( 200, json={ - "request_id": "9ababbb6-3749-cf2c-5a5b-85660e917e8e", - "lease_id": "", - "renewable": False, - "lease_duration": 0, "data": { - "data": {"test": "mock"}, - "metadata": { - "created_time": "2022-09-20T15:57:45.143053836Z", - "custom_metadata": None, - "deletion_time": "", - "destroyed": False, - "version": 1, - }, + "options": {"version": "1"}, + "path": "secrets/", + "type": "kv", }, "wrap_info": None, "warnings": None, @@ -395,102 +486,64 @@ def test_kv2( }, ) ) + assert get_mount(unittest_client, "secrets/test") == MountMetadata( + path="secrets/", version=1 + ) - patch_get_mount_point.return_value = ("secrets/", 2) - assert t.read_secret(unittest_client, "secrets/test") == {"test": "mock"} - - def test_get_mount_point_error( - self, unittest_client: httpx.Client, patch_get_mount_point: Mock - ): - patch_get_mount_point.return_value = (None, None) - assert t.read_secret(unittest_client, "secrets/test") is None - - @pytest.mark.usefixtures("patch_get_mount_point") - def test_connection_error( - self, - respx_mock: respx.MockRouter, - unittest_client: httpx.Client, - caplog: pytest.LogCaptureFixture, - ): - respx_mock.get("https://example.com/v1/secrets/test").mock( - side_effect=httpx.ProxyError + def test_success_kv1_integration(self, intl_client: httpx.Client): + assert get_mount(intl_client, "kv1/test") == MountMetadata( + path="kv1/", version=1 ) - assert t.read_secret(unittest_client, "secrets/test") is None - assert ( - "Error occurred during query secret secrets/test: proxy error" - in caplog.text + def test_success_legacy(self, route: respx.Route, unittest_client: httpx.Client): + route.mock(httpx.Response(404)) + assert get_mount(unittest_client, "secrets/test") == MountMetadata( + path="", version=1 ) - @pytest.mark.usefixtures("patch_get_mount_point") - def test_unhandled_exception( - self, respx_mock: respx.MockRouter, unittest_client: httpx.Client + def test_not_ported_version( + self, route: respx.Route, unittest_client: httpx.Client ): - respx_mock.get("https://example.com/v1/secrets/test").mock( - side_effect=httpx.DecodingError + route.mock( + httpx.Response( + 200, + json={ + "data": { + "path": "mock/", + "type": "kv", + "options": {"version": "99"}, + } + }, + ) ) - with pytest.raises(httpx.DecodingError): - t.read_secret(unittest_client, "secrets/test") - @pytest.mark.usefixtures("patch_get_mount_point") - def test_not_found( + with pytest.raises(ValidationError): + get_mount(unittest_client, "secrets/test") + + def test_bad_request( self, - respx_mock: respx.MockRouter, + route: respx.Route, unittest_client: httpx.Client, caplog: pytest.LogCaptureFixture, ): - respx_mock.get("https://example.com/v1/secrets/test").mock(httpx.Response(404)) - assert t.read_secret(unittest_client, "secrets/test") is None - assert "Secret secrets/test not found" in caplog.text + route.mock(httpx.Response(400)) + assert get_mount(unittest_client, "secrets/test") is None + assert "Error occurred during checking metadata for secrets/test" in caplog.text - @pytest.mark.usefixtures("patch_get_mount_point") - def test_bad_request( + def test_connection_error( self, - respx_mock: respx.MockRouter, + route: respx.Route, unittest_client: httpx.Client, caplog: pytest.LogCaptureFixture, ): - respx_mock.get("https://example.com/v1/secrets/test").mock(httpx.Response(499)) - assert t.read_secret(unittest_client, "secrets/test") is None - assert "Error occurred during query secret secrets/test" in caplog.text - - def test_type_error(self): - with pytest.raises(TypeError): - t.read_secret(1234, "secrets/test") - with pytest.raises(TypeError): - t.read_secret(Mock(spec=httpx.Client), 1234) - - -def test_split_field(): - assert t.split_field("aa") == ["aa"] - assert t.split_field("aa.bb") == ["aa", "bb"] - assert t.split_field('aa."bb.cc"') == ["aa", "bb.cc"] - assert t.split_field('"aa.bb".cc') == ["aa.bb", "cc"] - assert t.split_field('"aa.bb"') == ["aa.bb"] - - with pytest.raises(ValueError, match=r"Failed to parse name: "): - t.split_field("") - with pytest.raises(ValueError, match=r"Failed to parse name: .+"): - t.split_field(".") - with pytest.raises(ValueError, match=r"Failed to parse name: .+"): - t.split_field("aa.") - with pytest.raises(ValueError, match=r"Failed to parse name: .+"): - t.split_field(".aa") - - -class TestVaultPath: - def test_success(self): - path = VaultPath.model_validate("foo#bar") - assert path == VaultPath(path="foo", field="bar") - - def test_empty(self): - with pytest.raises(ValidationError): - VaultPath.model_validate("a#") - with pytest.raises(ValidationError): - VaultPath.model_validate("#b") + route.mock(side_effect=httpx.ConnectError) + assert get_mount(unittest_client, "secrets/test") is None + assert ( + "Error occurred during checking metadata for secrets/test: connection error" + in caplog.text + ) - def test_invalid(self): - with pytest.raises(ValidationError): - VaultPath.model_validate("foobar") - with pytest.raises(ValidationError): - VaultPath.model_validate("foo#bar#baz") + def test_http_exception(self, route: respx.Route, unittest_client: httpx.Client): + route.mock(side_effect=httpx.DecodingError) + with pytest.raises(httpx.DecodingError): + get_mount(unittest_client, "secrets/test") diff --git a/tests/test_collect.py b/tests/test_collect.py index a6768e0c..a1527c51 100644 --- a/tests/test_collect.py +++ b/tests/test_collect.py @@ -10,7 +10,7 @@ def test_read_values(caplog: pytest.LogCaptureFixture): def create_provider(return_value: str): - provider = Mock(spec=secrets_env.provider.ProviderBase) + provider = Mock(spec=secrets_env.provider.Provider) provider.get.return_value = return_value return provider @@ -50,7 +50,7 @@ def create_provider(return_value: str): class TestRead1: def setup_method(self): - self.provider = Mock(spec=secrets_env.provider.ProviderBase) + self.provider = Mock(spec=secrets_env.provider.Provider) type(self.provider).name = PropertyMock(return_value="mock") def test_success(self):