diff --git a/secrets_env/collect.py b/secrets_env/collect.py
index fad4214b..da033787 100644
--- a/secrets_env/collect.py
+++ b/secrets_env/collect.py
@@ -5,7 +5,7 @@
if typing.TYPE_CHECKING:
from secrets_env.config0.parser import Config
- from secrets_env.provider import ProviderBase, RequestSpec
+ from secrets_env.provider import Provider, RequestSpec
logger = logging.getLogger(__name__)
@@ -37,17 +37,17 @@ def read_values(config: Config) -> dict[str, str]:
return output_values
-def read1(provider: ProviderBase, name: str, spec: RequestSpec) -> str | None:
+def read1(provider: Provider, name: str, spec: RequestSpec) -> str | None:
"""Read single value.
This function wraps :py:meth:`secrets_env.provider.ProviderBase.get` and
captures all exceptions.
"""
import secrets_env.exceptions
- from secrets_env.provider import ProviderBase
+ from secrets_env.provider import Provider
# type checking
- if not isinstance(provider, ProviderBase):
+ if not isinstance(provider, Provider):
raise TypeError(
f'Expected "provider" to be a credential provider class, '
f"got {type(provider).__name__}"
diff --git a/secrets_env/config0/parser.py b/secrets_env/config0/parser.py
index 774bc12e..92fdf5c0 100644
--- a/secrets_env/config0/parser.py
+++ b/secrets_env/config0/parser.py
@@ -1,16 +1,17 @@
+from __future__ import annotations
+
import itertools
import logging
import re
import typing
-from typing import Any, Dict, Iterator, List, Optional, Tuple, TypedDict
+from typing import Any, Iterator, TypedDict
import secrets_env.exceptions
import secrets_env.providers
-from secrets_env.provider import RequestSpec
from secrets_env.utils import ensure_dict, ensure_str
if typing.TYPE_CHECKING:
- from secrets_env.provider import ProviderBase
+ from secrets_env.provider import Provider, RequestSpec
DEFAULT_PROVIDER_NAME = "main"
@@ -29,11 +30,11 @@ class Request(TypedDict):
class Config(TypedDict):
"""The parsed configurations."""
- providers: Dict[str, "ProviderBase"]
- requests: List[Request]
+ providers: dict[str, Provider]
+ requests: list[Request]
-def parse_config(data: dict) -> Optional[Config]:
+def parse_config(data: dict) -> Config | None:
"""Parse and validate configs, build it into structured object."""
requests = get_requests(data)
if not requests:
@@ -48,11 +49,11 @@ def parse_config(data: dict) -> Optional[Config]:
return Config(providers=providers, requests=requests)
-def get_providers(data: dict) -> Dict[str, "ProviderBase"]:
+def get_providers(data: dict) -> dict[str, Provider]:
sections = list(extract_sources(data))
logger.debug("%d raw provider configs extracted", len(sections))
- providers: Dict[str, "ProviderBase"] = {}
+ providers: dict[str, Provider] = {}
for data in sections:
result = parse_source_item(data)
if not result:
@@ -75,7 +76,7 @@ def get_providers(data: dict) -> Dict[str, "ProviderBase"]:
return providers
-def extract_sources(data: dict) -> Iterator[Dict[str, Any]]:
+def extract_sources(data: dict) -> Iterator[dict[str, Any]]:
"""Extracts both "source(s)" section and ensure the output is list of dict"""
for item in itertools.chain(
get_list(data, "source"),
@@ -97,7 +98,7 @@ def get_list(data: dict, key: str) -> Iterator[dict]:
logger.warning("Found invalid value in field %s", key)
-def parse_source_item(config: dict) -> Optional[Tuple[str, "ProviderBase"]]:
+def parse_source_item(config: dict) -> tuple[str, Provider] | None:
# check name
name = config.get("name") or DEFAULT_PROVIDER_NAME
name, ok = ensure_str("source.name", name)
@@ -117,7 +118,7 @@ def parse_source_item(config: dict) -> Optional[Tuple[str, "ProviderBase"]]:
return name, provider
-def get_requests(data: dict) -> List[Request]:
+def get_requests(data: dict) -> list[Request]:
# accept both keyword `secret(s)`
raw = {}
diff --git a/secrets_env/provider.py b/secrets_env/provider.py
index 4f66b02e..fadee477 100644
--- a/secrets_env/provider.py
+++ b/secrets_env/provider.py
@@ -6,53 +6,46 @@
from __future__ import annotations
import abc
-import sys
-import typing
-from typing import Dict, Union
+from typing import ClassVar, Union
-if typing.TYPE_CHECKING and sys.version_info >= (3, 10):
- from typing import TypeAlias
+from pydantic import BaseModel
+RequestSpec = Union[dict[str, str], str]
+""":py:class:`RequestSpec` represents a path spec to read the value.
-RequestSpec: TypeAlias = Union[Dict[str, str], str]
-""":py:class:`RequestSpec` represents a secret spec (name/path) to be loaded.
+It should be a :py:class:`dict` in most cases; or :py:class:`str` if this
+provider accepts shortcut.
"""
-class ProviderBase(abc.ABC):
- """Abstract base class for secret provider. All secret provider must implement
+class Provider(BaseModel, abc.ABC):
+ """Abstract base class for secret provider. All provider must implement
this interface.
"""
- @property
- @abc.abstractmethod
- def type(self) -> str:
- """Provider name."""
+ type: ClassVar[str]
@abc.abstractmethod
def get(self, spec: RequestSpec) -> str:
- """Get secret value.
+ """Get secret.
Parameters
----------
- spec : dict | str
- Raw input from config file.
-
- It should be :py:class:`dict` in most cases; or :py:class:`str` if
- this provider accepts shortcut.
+ path : dict | str
+ Raw input from config file for reading the secret value.
Return
------
- The secret value.
+ The value
Raises
------
- ConfigError
- The path dict is malformed.
- ValueNotFound
- The path dict is correct but the secret not exists.
-
- Note
- ----
- Key ``source`` is preserved in ``spec`` dictionary.
+ ValidationError
+ If the input format is invalid.
+ UnsupportedError
+ When this operation is not supported.
+ AuthenticationError
+ Failed during authentication.
+ LookupError
+ If the secret is not found.
"""
diff --git a/secrets_env/providers/__init__.py b/secrets_env/providers/__init__.py
index 112af6f3..709b5641 100644
--- a/secrets_env/providers/__init__.py
+++ b/secrets_env/providers/__init__.py
@@ -1,33 +1,50 @@
-import typing
+from __future__ import annotations
-import secrets_env.exceptions
+import logging
+import typing
if typing.TYPE_CHECKING:
- from secrets_env.provider import ProviderBase
+ from secrets_env.provider import Provider
DEFAULT_PROVIDER = "vault"
+logger = logging.getLogger(__name__)
+
+
+def get_provider(config: dict) -> Provider:
+ """
+ Returns a provider instance based on the configuration.
+
+ Raises
+ ------
+ ValueError
+ If the provider type is not recognized.
+ ValidationError
+ If the provider configuration is invalid.
+ """
+ type_ = config.get("type")
+ if not type_:
+ type_ = DEFAULT_PROVIDER
+ logger.warning("Provider type unspecified, using default: %s", type_)
-def get_provider(data: dict) -> "ProviderBase":
- type_ = data.get("type", DEFAULT_PROVIDER)
- type_lower = type_.lower()
+ itype = type_.lower()
# fmt: off
- if type_lower == "null":
- from . import null
- return null.get_provider(type_, data)
- if type_lower == "plain":
- from . import plain
- return plain.get_provider(type_, data)
- if type_lower == "teleport":
- from . import teleport
- return teleport.get_provider(type_, data)
- if type_lower == "vault":
- from . import vault
- return vault.get_provider(type_, data)
- if type_lower.startswith("teleport+"):
- from . import teleport
- return teleport.get_adapted_provider(type_, data)
+ if itype == "null":
+ from secrets_env.providers.null import NullProvider
+ return NullProvider.model_validate(config)
+ if itype == "plain":
+ from secrets_env.providers.plain import PlainTextProvider
+ return PlainTextProvider.model_validate(config)
+ if itype == "teleport":
+ from secrets_env.providers.teleport import TeleportProvider
+ return TeleportProvider.model_validate(config)
+ if itype == "teleport+vault":
+ logger.error('"teleport+vault provider is not yet implemented')
+ raise NotImplementedError
+ if itype == "vault":
+ from secrets_env.providers.vault import VaultKvProvider
+ return VaultKvProvider.model_validate(config)
# fmt: on
- raise secrets_env.exceptions.ConfigError("Unknown provider type {}", type_)
+ raise ValueError(f"Unknown provider type {type_}")
diff --git a/secrets_env/providers/null.py b/secrets_env/providers/null.py
index 2c1dc342..5d439550 100644
--- a/secrets_env/providers/null.py
+++ b/secrets_env/providers/null.py
@@ -1,22 +1,18 @@
+from __future__ import annotations
+
import typing
-from secrets_env.provider import ProviderBase
+from secrets_env.provider import Provider
if typing.TYPE_CHECKING:
from secrets_env.provider import RequestSpec
-class NullProvider(ProviderBase):
+class NullProvider(Provider):
"""A provider that always returns empty string. This provider is preserved
for debugging."""
- @property
- def type(self) -> str:
- return "null"
+ type = "null"
- def get(self, spec: "RequestSpec") -> str:
+ def get(self, spec: RequestSpec) -> str:
return ""
-
-
-def get_provider(type_: str, data: dict) -> NullProvider:
- return NullProvider()
diff --git a/secrets_env/providers/plain.py b/secrets_env/providers/plain.py
index 601dcbb2..50b8819b 100644
--- a/secrets_env/providers/plain.py
+++ b/secrets_env/providers/plain.py
@@ -1,30 +1,21 @@
+from __future__ import annotations
+
import typing
-from secrets_env.provider import ProviderBase
+from secrets_env.provider import Provider
if typing.TYPE_CHECKING:
from secrets_env.provider import RequestSpec
-class PlainTextProvider(ProviderBase):
+class PlainTextProvider(Provider):
"""Plain text provider returns text that is copied directly from the
configuration file."""
- @property
- def type(self) -> str:
- return "plain"
+ type = "plain"
- def get(self, spec: "RequestSpec") -> str:
+ def get(self, spec: RequestSpec) -> str:
if isinstance(spec, str):
- value = spec
+ return spec
elif isinstance(spec, dict):
- value = spec.get("value")
- else:
- raise TypeError(
- f'Expected "spec" to be a string or dict, got {type(spec).__name__}'
- )
- return value or ""
-
-
-def get_provider(type_: str, data: dict) -> PlainTextProvider:
- return PlainTextProvider()
+ return spec.get("value") or ""
diff --git a/secrets_env/providers/teleport/__init__.py b/secrets_env/providers/teleport/__init__.py
index 614a8777..ead2ae53 100644
--- a/secrets_env/providers/teleport/__init__.py
+++ b/secrets_env/providers/teleport/__init__.py
@@ -1,45 +1,70 @@
from __future__ import annotations
+import logging
import typing
+from typing import Literal
-import pydantic
+from pydantic import BaseModel, model_validator
-from secrets_env.exceptions import ConfigError
+from secrets_env.provider import Provider
+from secrets_env.providers.teleport.config import TeleportUserConfig
if typing.TYPE_CHECKING:
- from secrets_env.provider import ProviderBase
- from secrets_env.providers.teleport.provider import TeleportProvider
-
-ADAPTER_PREFIX = "teleport+"
-
-
-def get_provider(type_: str, data: dict) -> TeleportProvider:
- from .config import TeleportUserConfig
- from .provider import TeleportProvider
-
- cfg = TeleportUserConfig.model_validate(data)
- return TeleportProvider(config=cfg)
-
-
-def get_adapted_provider(type_: str, data: dict) -> ProviderBase:
- from .adapters import get_adapter
- from .config import TeleportUserConfig # noqa: TCH001
-
- class TeleportAdapterConfig(pydantic.BaseModel):
- """Config layout for using Teleport as an adapter."""
-
- teleport: TeleportUserConfig
-
- iname = type_.lower()
- if not iname.startswith(ADAPTER_PREFIX):
- raise ConfigError("Not a Teleport compatible provider: {}", type_)
-
- subtype = type_[len(ADAPTER_PREFIX) :]
- factory = get_adapter(subtype)
-
- # get connection parameter
- app_param = TeleportAdapterConfig.model_validate(data)
- conn_param = app_param.teleport.get_connection_param()
-
- # forward parameters to corresponding provider
- return factory(subtype, data, conn_param)
+ from typing import Self
+
+ from secrets_env.provider import RequestSpec
+ from secrets_env.providers.teleport.config import TeleportConnectionParameter
+
+logger = logging.getLogger(__name__)
+
+
+class TeleportRequestSpec(BaseModel):
+ field: Literal["uri", "ca", "cert", "key", "cert+key"]
+ format: Literal["path", "pem"] = "path"
+
+ @model_validator(mode="before")
+ @classmethod
+ def _accept_shortcut(cls, data: RequestSpec | Self) -> dict[str, str] | Self:
+ if isinstance(data, str):
+ return {"field": data}
+ return data
+
+
+class TeleportProvider(Provider, TeleportUserConfig):
+ """Read certificates from Teleport."""
+
+ type = "teleport"
+
+ def get(self, spec: RequestSpec) -> str:
+ ps = TeleportRequestSpec.model_validate(spec)
+
+ if ps.field == "uri":
+ return self.connection_param.uri
+ elif ps.field == "ca":
+ return get_ca(self.connection_param, ps.format)
+ elif ps.field == "cert":
+ if ps.format == "path":
+ return str(self.connection_param.path_cert)
+ elif ps.format == "pem":
+ return self.connection_param.cert.decode()
+ elif ps.field == "key":
+ if ps.format == "path":
+ return str(self.connection_param.path_key)
+ elif ps.format == "pem":
+ return self.connection_param.key.decode()
+ elif ps.field == "cert+key":
+ if ps.format == "path":
+ return str(self.connection_param.path_cert_and_key)
+ elif ps.format == "pem":
+ return self.connection_param.cert_and_key.decode()
+
+ raise RuntimeError
+
+
+def get_ca(param: TeleportConnectionParameter, fmt: Literal["path", "pem"]) -> str:
+ if param.ca is None:
+ raise LookupError("CA is not available")
+ if fmt == "path":
+ return str(param.path_ca)
+ elif fmt == "pem":
+ return param.ca.decode()
diff --git a/secrets_env/providers/teleport/adapters.py b/secrets_env/providers/teleport/adapters.py
deleted file mode 100644
index 01ca070b..00000000
--- a/secrets_env/providers/teleport/adapters.py
+++ /dev/null
@@ -1,54 +0,0 @@
-from __future__ import annotations
-
-import logging
-import typing
-
-from secrets_env.exceptions import ConfigError
-
-if typing.TYPE_CHECKING:
- from secrets_env.provider import ProviderBase
- from secrets_env.providers.teleport.config import TeleportConnectionParameter
-
- AdapterType = typing.Callable[
- [str, dict, TeleportConnectionParameter], ProviderBase
- ]
-
-logger = logging.getLogger(__name__)
-
-
-def get_adapter(name: str) -> AdapterType:
- iname = name.lower()
- if iname == "vault":
- return adapt_vault_provider
-
- raise ConfigError("Unknown provider type {}", name)
-
-
-def adapt_vault_provider(
- type_: str, data: dict, param: TeleportConnectionParameter
-) -> ProviderBase:
- assert isinstance(data, dict)
- from secrets_env.providers import vault
-
- # url
- if (url := data.get("url")) and url != param.uri:
- logger.warning("Overwrite source.url to %s", param.uri)
-
- data["url"] = param.uri
- logger.debug("Set Vault URL to %s", param.uri)
-
- # ca
- tls: dict = data.setdefault("tls", {})
- if param.path_ca:
- tls["ca_cert"] = param.path_ca
- logger.debug("Set Vault CA to %s", param.path_ca)
-
- # cert
- tls["client_cert"] = param.path_cert
- logger.debug("Set Vault client cert to %s", param.path_cert)
-
- # key
- tls["client_key"] = param.path_key
- logger.debug("Set Vault client key to %s", param.path_key)
-
- return vault.get_provider(type_, data)
diff --git a/secrets_env/providers/teleport/config.py b/secrets_env/providers/teleport/config.py
index c86c1a03..32892545 100644
--- a/secrets_env/providers/teleport/config.py
+++ b/secrets_env/providers/teleport/config.py
@@ -42,7 +42,8 @@ def _use_shortcut(cls, data):
return {"app": data}
return data
- def get_connection_param(self) -> TeleportConnectionParameter:
+ @cached_property
+ def connection_param(self) -> TeleportConnectionParameter:
"""Get app connection parameter from Teleport CLI.
Raises
diff --git a/secrets_env/providers/teleport/provider.py b/secrets_env/providers/teleport/provider.py
deleted file mode 100644
index 5e7c1f8d..00000000
--- a/secrets_env/providers/teleport/provider.py
+++ /dev/null
@@ -1,115 +0,0 @@
-from __future__ import annotations
-
-import logging
-import typing
-from functools import cached_property
-from typing import Literal
-
-from secrets_env.exceptions import ConfigError, ValueNotFound
-from secrets_env.provider import ProviderBase
-
-if typing.TYPE_CHECKING:
- from secrets_env.provider import RequestSpec
- from secrets_env.providers.teleport.config import (
- TeleportConnectionParameter,
- TeleportUserConfig,
- )
-
-DEFAULT_OUTPUT_FORMAT = "path"
-
-
-class OutputSpec(typing.NamedTuple):
- field: Literal["uri", "ca", "cert", "key", "cert+key"]
- format: Literal["path", "pem"]
-
-
-logger = logging.getLogger(__name__)
-
-
-class TeleportProvider(ProviderBase):
- """Read certificates from Teleport."""
-
- @property
- def type(self) -> str:
- return "teleport"
-
- def __init__(self, *, config: TeleportUserConfig) -> None:
- self._config = config
-
- @cached_property
- def tsh(self) -> TeleportConnectionParameter:
- """Return teleport app connection information."""
- return self._config.get_connection_param()
-
- def get(self, spec: RequestSpec) -> str:
- parsed = parse_spec(spec)
-
- if parsed.field == "uri":
- return self.tsh.uri
-
- elif parsed.field == "ca":
- # bypass cognitive complexity check
- return get_ca(self.tsh, parsed.format)
-
- elif parsed.field == "cert":
- if parsed.format == "path":
- return str(self.tsh.path_cert)
- elif parsed.format == "pem":
- return self.tsh.cert.decode()
-
- elif parsed.field == "key":
- if parsed.format == "path":
- return str(self.tsh.path_key)
- elif parsed.format == "pem":
- return self.tsh.key.decode()
-
- elif parsed.field == "cert+key":
- if parsed.format == "path":
- return str(self.tsh.path_cert_and_key)
- elif parsed.format == "pem":
- return self.tsh.cert_and_key.decode()
-
- raise ConfigError("Invalid value spec: {}", spec)
-
-
-def parse_spec(spec: RequestSpec) -> OutputSpec:
- # extract
- if isinstance(spec, str):
- output_field = spec
- output_format = DEFAULT_OUTPUT_FORMAT
- elif isinstance(spec, dict):
- output_field = spec.get("field")
- output_format = spec.get("format", DEFAULT_OUTPUT_FORMAT)
- else:
- raise ConfigError(
- "Expect dict for secrets path spec, got {}", type(spec).__name__
- )
-
- # validate
- if (
- False
- or not isinstance(output_field, str)
- or output_field.lower() not in ("uri", "ca", "cert", "key", "cert+key")
- ):
- raise ConfigError("Invalid field (secrets.VAR.field): {}", output_field)
-
- if (
- False
- or not isinstance(output_format, str)
- or output_format.lower() not in ("path", "pem")
- ):
- raise ConfigError("Invalid format (secrets.VAR.format): {}", output_format)
-
- return OutputSpec(output_field.lower(), output_format.lower()) # type: ignore[reportGeneralTypeIssues]
-
-
-def get_ca(
- conn_info: TeleportConnectionParameter, format_: Literal["path", "pem"]
-) -> str:
- if not conn_info.ca:
- raise ValueNotFound("CA is not avaliable")
- if format_ == "path":
- return str(conn_info.path_ca)
- elif format_ == "pem":
- return conn_info.ca.decode()
- raise RuntimeError
diff --git a/secrets_env/providers/vault/__init__.py b/secrets_env/providers/vault/__init__.py
index d78cb686..190df8f9 100644
--- a/secrets_env/providers/vault/__init__.py
+++ b/secrets_env/providers/vault/__init__.py
@@ -1,10 +1,3 @@
-from secrets_env.exceptions import ConfigError
+__all__ = ["VaultKvProvider"]
-from .config import get_connection_info
-from .provider import KvProvider
-
-
-def get_provider(type_: str, data: dict) -> KvProvider:
- if not (cfg := get_connection_info(data)):
- raise ConfigError("Invalid config for vault provider")
- return KvProvider(**cfg)
+from secrets_env.providers.vault.provider import VaultKvProvider
diff --git a/secrets_env/providers/vault/auth/__init__.py b/secrets_env/providers/vault/auth/__init__.py
index 4d4173a9..97dbad03 100644
--- a/secrets_env/providers/vault/auth/__init__.py
+++ b/secrets_env/providers/vault/auth/__init__.py
@@ -1,13 +1,15 @@
from __future__ import annotations
+__all__ = ["Auth", "create_auth_by_name"]
+
import logging
import typing
+from secrets_env.providers.vault.auth.base import Auth
+
if typing.TYPE_CHECKING:
from pydantic_core import Url
- from secrets_env.providers.vault.auth.base import Auth
-
logger = logging.getLogger(__name__)
diff --git a/secrets_env/providers/vault/auth/base.py b/secrets_env/providers/vault/auth/base.py
index f8838cf2..fe55ff6b 100644
--- a/secrets_env/providers/vault/auth/base.py
+++ b/secrets_env/providers/vault/auth/base.py
@@ -31,8 +31,15 @@ def create(cls, url: Url, config: dict[str, Any]) -> Self:
"""
@abstractmethod
- def login(self, client: httpx.Client) -> str | None:
- """Login and get token."""
+ def login(self, client: httpx.Client) -> str:
+ """
+ Login and get Vault token.
+
+ Raises
+ ------
+ AuthenticationError
+ If the login fails.
+ """
class NoAuth(Auth):
diff --git a/secrets_env/providers/vault/auth/userpass.py b/secrets_env/providers/vault/auth/userpass.py
index 40205580..1086846f 100644
--- a/secrets_env/providers/vault/auth/userpass.py
+++ b/secrets_env/providers/vault/auth/userpass.py
@@ -6,6 +6,7 @@
from pydantic import PrivateAttr, SecretStr
+from secrets_env.exceptions import AuthenticationError
from secrets_env.providers.vault.auth.base import Auth
from secrets_env.utils import (
create_keyring_login_key,
@@ -76,7 +77,7 @@ def _get_password(cls, url: Url, username: str) -> str | None:
return prompt(f"Password for {username}", hide_input=True)
- def login(self, client: httpx.Client) -> str | None:
+ def login(self, client: httpx.Client) -> str:
username = urllib.parse.quote(self.username)
resp = client.post(
f"/v1/auth/{self.vault_name}/login/{username}",
@@ -88,14 +89,13 @@ def login(self, client: httpx.Client) -> str | None:
)
if not resp.is_success:
- logger.error("Failed to login with %s method", self.method)
logger.debug(
"Login failed. URL= %s, Code= %d. Msg= %s",
resp.url,
resp.status_code,
resp.text,
)
- return
+ raise AuthenticationError("Failed to login with %s method", self.method)
return resp.json()["auth"]["client_token"]
diff --git a/secrets_env/providers/vault/provider.py b/secrets_env/providers/vault/provider.py
index 6e56278c..bb730057 100644
--- a/secrets_env/providers/vault/provider.py
+++ b/secrets_env/providers/vault/provider.py
@@ -2,70 +2,121 @@
import enum
import logging
-import re
import typing
from functools import cached_property
from http import HTTPStatus
-from typing import Dict, Literal, Union
+from typing import Literal
import httpx
-from pydantic import BaseModel, Field, model_validator
+from pydantic import (
+ BaseModel,
+ Field,
+ InstanceOf,
+ PrivateAttr,
+ field_validator,
+ model_validator,
+ validate_call,
+)
import secrets_env.version
-from secrets_env.exceptions import AuthenticationError, ValueNotFound
-from secrets_env.provider import ProviderBase, RequestSpec
+from secrets_env.exceptions import AuthenticationError
+from secrets_env.provider import Provider
+from secrets_env.providers.vault.config import VaultUserConfig
from secrets_env.utils import LruDict, get_httpx_error_reason, log_httpx_response
if typing.TYPE_CHECKING:
- from pathlib import Path
- from typing import Any, Self
-
- from secrets_env.providers.vault.auth.base import Auth
- from secrets_env.providers.vault.config import CertTypes
+ from typing import Iterable, Iterator, Self, Sequence
+ from secrets_env.provider import RequestSpec
+ from secrets_env.providers.vault.auth import Auth
logger = logging.getLogger(__name__)
class Marker(enum.Enum):
- NoMatch = enum.auto()
- SecretNotExist = enum.auto()
+ """Internal marker for cache handling."""
+ NoCache = enum.auto()
+ NotFound = enum.auto()
-class SecretSource(typing.NamedTuple):
- path: str
- field: str
+class VaultPath(BaseModel):
+ """Represents a path to a value in Vault."""
-if typing.TYPE_CHECKING:
- KVVersion = Literal[1, 2]
- VaultSecret = Dict[str, str]
- VaultSecretQueryResult = Union[VaultSecret, Literal[Marker.SecretNotExist]]
-
-
-class KvProvider(ProviderBase):
- """Read secrets from Vault KV engine."""
-
- def __init__(
- self,
- url: str,
- auth: Auth,
- *,
- proxy: str | None = None,
- ca_cert: Path | None = None,
- client_cert: CertTypes | None = None,
- ) -> None:
- self.url = url
- self.auth = auth
- self.proxy = proxy
- self.ca_cert = ca_cert
- self.client_cert = client_cert
-
- self._secrets: LruDict[str, VaultSecretQueryResult] = LruDict()
+ path: str = Field(min_length=1)
+ field: tuple[str, ...]
+
+ def __str__(self) -> str:
+ return f"{self.path}#{self.field_str}"
@property
- def type(self) -> str:
- return "vault"
+ def field_str(self) -> str:
+ seq = []
+ for f in self.field:
+ if "." in f:
+ seq.append(f'"{f}"')
+ else:
+ seq.append(f)
+ return ".".join(seq)
+
+ @model_validator(mode="before")
+ @classmethod
+ def _create_from_str(cls, value: str | dict | Self) -> dict | Self:
+ if not isinstance(value, str):
+ return value
+ if value.count("#") != 1:
+ raise ValueError("Invalid format. Expected 'path#field'")
+ path, field = value.rsplit("#", 1)
+ return {
+ "path": path,
+ "field": field,
+ }
+
+ @field_validator("field", mode="before")
+ @classmethod
+ def _accept_str_for_field(cls, value) -> Iterable[str]:
+ if isinstance(value, str):
+ return _split_field_str(value)
+ return value
+
+ @field_validator("field", mode="after")
+ @classmethod
+ def _validate_field(cls, field: Sequence[str]) -> Sequence[str]:
+ if not field:
+ raise ValueError("Field cannot be empty")
+ if any(not f for f in field):
+ raise ValueError("Field cannot contain empty subpath")
+ return field
+
+
+def _split_field_str(f: str) -> Iterator[str]:
+ """Split a field name into subsequences. By default, this function splits
+ the name by dots, with supportting of preserving the quoted subpaths.
+ """
+ pos = 0
+ while pos < len(f):
+ if f[pos] == '"':
+ # quoted
+ end = f.find('"', pos + 1)
+ if end == -1:
+ raise ValueError(f"Failed to parse field: {f}")
+ yield f[pos + 1 : end]
+ pos = end + 2
+ else:
+ # simple
+ end = f.find(".", pos)
+ if end == -1:
+ end = len(f)
+ yield f[pos:end]
+ pos = end + 1
+
+
+class VaultKvProvider(Provider, VaultUserConfig):
+ """Read secrets from Hashicorp Vault KV engine."""
+
+ type = "vault"
+
+ _cache: dict[str, dict | Marker] = PrivateAttr(default_factory=LruDict)
@cached_property
def client(self) -> httpx.Client:
@@ -76,102 +127,101 @@ def client(self) -> httpx.Client:
self.auth.method,
)
- # initialize client
- client_params: dict[str, Any] = {"base_url": self.url}
-
- if self.proxy:
- logger.debug("Use proxy: %s", self.proxy)
- client_params["proxies"] = self.proxy
- if self.ca_cert:
- logger.debug("CA installed: %s", self.ca_cert)
- client_params["verify"] = self.ca_cert
- if self.client_cert:
- logger.debug("Client side certificate file installed: %s", self.client_cert)
- client_params["cert"] = self.client_cert
-
- client = httpx.Client(
- **client_params,
- headers={
- "Accept": "application/json",
- "User-Agent": (
- f"secrets.env/{secrets_env.version.__version__} "
- f"python-httpx/{httpx.__version__}"
- ),
- },
- )
-
- # install token
+ client = create_http_client(self)
client.headers["X-Vault-Token"] = get_token(client, self.auth)
return client
def get(self, spec: RequestSpec) -> str:
path = VaultPath.model_validate(spec)
- return self.read_field(path.path, path.field)
-
- def read_secret(self, path: str) -> VaultSecret:
- """Read secret from Vault.
+ secret = self._read_secret(path)
+
+ for f in path.field:
+ try:
+ secret = secret[f]
+ except (KeyError, TypeError):
+ raise LookupError(
+ f'Field "{path.field_str}" not found in "{path.path}"'
+ ) from None
+
+ if not isinstance(secret, str):
+ raise LookupError(
+ f'Field "{path.field_str}" in "{path.path}" is not point to a string value'
+ )
- Parameters
- ----------
- path : str
- Secret path
+ return secret
- Returns
- -------
- secret : dict
- Secret data. Or 'SecretNotExist' marker when not found.
+ def _read_secret(self, path: VaultPath) -> dict:
"""
- if not isinstance(path, str):
- raise TypeError(
- f'Expected "path" to be a string, got {type(path).__name__}'
- )
-
- # try cache
- result = self._secrets.get(path, Marker.NoMatch)
+ Get a secret from the Vault. A Vault "secret" is a object that contains
+ key-value pairs.
- if result == Marker.NoMatch:
- # not found in cache - start query
- if secret := read_secret(self.client, path):
- result = secret
- else:
- result = Marker.SecretNotExist
- self._secrets[path] = result
+ This method wraps the `read_secret` method and cache the result.
- # returns value
- if result == Marker.SecretNotExist:
- raise ValueNotFound("Secret {} not found", path)
- return result
+ Raises
+ ------
+ LookupError
+ If the secret is not found.
+ """
+ result = self._cache.get(path.path, Marker.NoCache)
- def read_field(self, path: str, field: str) -> str:
- """Read only one field from Vault.
+ if result == Marker.NoCache:
+ result = read_secret(self.client, path.path)
+ if result is None:
+ result = Marker.NotFound
+ self._cache[path.path] = result
- Parameters
- ----------
- path : str
- Secret path
- field : str
- Field name
+ if result == Marker.NotFound:
+ raise LookupError(f'Secret "{path}" not found')
- Returns
- -------
- value : str
- The secret value if matched
- """
- if not isinstance(field, str):
- raise TypeError(
- f'Expected "field" to be a string, got {type(field).__name__}'
- )
+ return result
- secret = self.read_secret(path)
- value = get_field(secret, field)
- if value is None:
- raise ValueNotFound("Secret {}#{} not found", path, field)
- return value
+@validate_call
+def create_http_client(config: VaultUserConfig) -> httpx.Client:
+ logger.debug(
+ "Vault client initialization requested. URL= %s, Auth type= %s",
+ config.url,
+ config.auth.method,
+ )
+ client_params = {
+ "base_url": str(config.url),
+ "headers": {
+ "Accept": "application/json",
+ "User-Agent": (
+ f"secrets.env/{secrets_env.version.__version__} "
+ f"python-httpx/{httpx.__version__}"
+ ),
+ },
+ }
+
+ if config.proxy:
+ logger.debug("Proxy is set: %s", config.proxy)
+ client_params["proxy"] = str(config.proxy)
+ if config.tls.ca_cert:
+ logger.debug("CA cert is set: %s", config.tls.ca_cert)
+ client_params["verify"] = config.tls.ca_cert
+ if config.tls.client_cert and config.tls.client_key:
+ cert_pair = (config.tls.client_cert, config.tls.client_key)
+ logger.debug("Client cert pair is set: %s ", cert_pair)
+ client_params["cert"] = cert_pair
+ elif config.tls.client_cert:
+ logger.debug("Client cert is set: %s", config.tls.client_cert)
+ client_params["cert"] = config.tls.client_cert
+
+ return httpx.Client(**client_params)
+
+
+def get_token(client: InstanceOf[httpx.Client], auth: Auth) -> str:
+ """
+ Request a token from the Vault server and verify it.
-def get_token(client: httpx.Client, auth: Auth) -> str:
+ Raises
+ ------
+ AuthenticationError
+ If the token cannot be retrieved or is invalid.
+ """
# login
try:
token = auth.login(client)
@@ -180,9 +230,6 @@ def get_token(client: httpx.Client, auth: Auth) -> str:
raise
raise AuthenticationError("Encounter {} while retrieving token", reason) from e
- if not token:
- raise AuthenticationError("Absence of token information")
-
# verify
if not is_authenticated(client, token):
raise AuthenticationError("Invalid token")
@@ -190,20 +237,14 @@ def get_token(client: httpx.Client, auth: Auth) -> str:
return token
-def is_authenticated(client: httpx.Client, token: str) -> bool:
+@validate_call
+def is_authenticated(client: InstanceOf[httpx.Client], token: str) -> bool:
"""Check is a token is authenticated.
See also
--------
https://developer.hashicorp.com/vault/api-docs/auth/token
"""
- if not isinstance(client, httpx.Client):
- raise TypeError(
- f'Expected "client" to be a httpx client, got {type(client).__name__}'
- )
- if not isinstance(token, str):
- raise TypeError(f'Expected "token" to be a string, got {type(token).__name__}')
-
logger.debug("Validate token for %s", client.base_url)
resp = client.get("/v1/auth/token/lookup-self", headers={"X-Vault-Token": token})
@@ -211,180 +252,119 @@ def is_authenticated(client: httpx.Client, token: str) -> bool:
return True
logger.debug(
- "Token verification failed. Code= %d. Msg= %s",
+ "Token verification failed. Code= %d (%s). Msg= %s",
resp.status_code,
+ resp.reason_phrase,
resp.json(),
)
return False
-def get_mount_point(
- client: httpx.Client, path: str
-) -> tuple[str | None, KVVersion | None]:
- """Get mount point and KV engine version to a secret.
-
- Returns
- -------
- mount_point : str
- The path the secret engine mounted on.
- version : int
- The secret engine version
-
- See also
- --------
- Vault HTTP API
- https://developer.hashicorp.com/vault/api-docs/system/internal-ui-mounts
- consul-template
- https://github.com/hashicorp/consul-template/blob/v0.29.1/dependency/vault_common.go#L294-L357
- """
- if not isinstance(client, httpx.Client):
- raise TypeError(
- f'Expected "client" to be a httpx client, got {type(client).__name__}'
- )
- if not isinstance(path, str):
- raise TypeError(f'Expected "path" to be a string, got {type(path).__name__}')
-
- try:
- resp = client.get(f"/v1/sys/internal/ui/mounts/{path}")
- except httpx.HTTPError as e:
- if not (reason := get_httpx_error_reason(e)):
- raise
- logger.error("Error occurred during checking metadata for %s: %s", path, reason)
- return None, None
-
- if resp.is_success:
- data = resp.json().get("data", {})
-
- mount_point = data.get("path")
- version = data.get("options", {}).get("version")
-
- if version == "2" and data.get("type") == "kv":
- return mount_point, 2
- elif version == "1":
- return mount_point, 1
-
- logging.error("Unknown version %s for path %s", version, path)
- logging.debug("Raw response: %s", resp)
- return None, None
-
- elif resp.status_code == HTTPStatus.NOT_FOUND:
- # 404 is expected on an older version of vault, default to version 1
- # https://github.com/hashicorp/consul-template/blob/v0.29.1/dependency/vault_common.go#L310-L311
- return "", 1
-
- logger.error("Error occurred during checking metadata for %s", path)
- log_httpx_response(logger, resp)
- return None, None
-
-
-def read_secret(client: httpx.Client, path: str) -> VaultSecret | None:
+@validate_call
+def read_secret(client: InstanceOf[httpx.Client], path: str) -> dict | None:
"""Read secret from Vault.
See also
--------
https://developer.hashicorp.com/vault/api-docs/secret/kv
"""
- if not isinstance(client, httpx.Client):
- raise TypeError(
- f'Expected "client" to be a httpx client, got {type(client).__name__}'
- )
- if not isinstance(path, str):
- raise TypeError(f'Expected "path" to be a string, got {type(path).__name__}')
+ mount = get_mount(client, path)
+ if not mount:
+ return
- mount_point, version = get_mount_point(client, path)
- if not mount_point:
- return None
+ logger.debug("Secret %s is mounted at %s (kv%d)", path, mount.path, mount.version)
- logger.debug("Secret %s is mounted at %s (kv%d)", path, mount_point, version)
-
- if version == 1:
- url = f"/v1/{path}"
+ if mount.version == 2:
+ subpath = path.removeprefix(mount.path)
+ request_path = f"/v1/{mount.path}data/{subpath}"
else:
- subpath = path.removeprefix(mount_point)
- url = f"/v1/{mount_point}data/{subpath}"
+ request_path = f"/v1/{path}"
try:
- resp = client.get(url)
+ resp = client.get(request_path)
except httpx.HTTPError as e:
if not (reason := get_httpx_error_reason(e)):
raise
logger.error("Error occurred during query secret %s: %s", path, reason)
- return None
+ return
if resp.is_success:
data = resp.json()
- if version == 1:
- return data["data"]
- elif version == 2:
+ if mount.version == 2:
return data["data"]["data"]
+ else:
+ return data["data"]
elif resp.status_code == HTTPStatus.NOT_FOUND:
logger.error("Secret %s not found", path)
- return None
+ return
logger.error("Error occurred during query secret %s", path)
log_httpx_response(logger, resp)
- return None
+ return
-def get_field(secret: dict, name: str) -> str | None:
- """Traverse the secret data to get the field along with the given name."""
- for n in split_field(name):
- if not isinstance(secret, dict):
- return None
- secret = typing.cast(dict, secret.get(n))
+class _RawMountMetadata(BaseModel):
+ """
+ {
+ "data": {
+ "options": {"version": "1"},
+ "path": "secrets/",
+ "type": "kv",
+ }
+ }
+ """
- if not isinstance(secret, str):
- return None
+ data: _DataBlock
- return secret
+ class _DataBlock(BaseModel):
+ options: _OptionBlock
+ path: str
+ type: str
-def split_field(name: str) -> list[str]:
- """Split a field name into subsequences. By default, this function splits
- the name by dots, with supportting of preserving the quoted subpaths.
- """
- pattern_quoted = re.compile(r'"([^"]+)"')
- pattern_simple = re.compile(r"([\w-]+)")
+ class _OptionBlock(BaseModel):
+ version: str
- seq = []
- pos = 0
- while pos < len(name):
- # try match pattern
- if m := pattern_simple.match(name, pos):
- pass
- elif m := pattern_quoted.match(name, pos):
- pass
- else:
- break
- seq.append(m.group(1))
+class MountMetadata(BaseModel):
+ """Represents a mount point and KV engine version to a secret."""
- # check remaining part
- # +1 for skipping the dot (if exists)
- pos = m.end() + 1
+ path: str
+ version: Literal[1, 2]
- if pos <= len(name):
- raise ValueError(f"Failed to parse name: {name}")
- return seq
+@validate_call
+def get_mount(client: InstanceOf[httpx.Client], path: str) -> MountMetadata | None:
+ """Get mount point and KV engine version to a secret.
+ See also
+ --------
+ Vault HTTP API
+ https://developer.hashicorp.com/vault/api-docs/system/internal-ui-mounts
+ consul-template
+ https://github.com/hashicorp/consul-template/blob/v0.29.1/dependency/vault_common.go#L294-L357
+ """
+ try:
+ resp = client.get(f"/v1/sys/internal/ui/mounts/{path}")
+ except httpx.HTTPError as e:
+ if not (reason := get_httpx_error_reason(e)):
+ raise
+ logger.error("Error occurred during checking metadata for %s: %s", path, reason)
+ return
-class VaultPath(BaseModel):
- """Represents a path to a value in Vault."""
+ if resp.is_success:
+ parsed = _RawMountMetadata.model_validate_json(resp.read())
+ return MountMetadata(
+ path=parsed.data.path,
+ version=int(parsed.data.options.version), # type: ignore[reportArgumentType]
+ )
- path: str = Field(min_length=1)
- field: str = Field(min_length=1)
+ elif resp.status_code == HTTPStatus.NOT_FOUND:
+ # 404 is expected on an older version of vault, default to version 1
+ # https://github.com/hashicorp/consul-template/blob/v0.29.1/dependency/vault_common.go#L310-L311
+ return MountMetadata(path="", version=1)
- @model_validator(mode="before")
- @classmethod
- def _from_str(cls, value: str | dict | Self) -> dict | Self:
- if not isinstance(value, str):
- return value
- if value.count("#") != 1:
- raise ValueError("Invalid format. Expected 'path#field'")
- path, field = value.rsplit("#", 1)
- return {
- "path": path,
- "field": field,
- }
+ logger.error("Error occurred during checking metadata for %s", path)
+ log_httpx_response(logger, resp)
+ return
diff --git a/tests/config0/test_config__init__.py b/tests/config0/test_config__init__.py
index 4d64f4eb..2e2b6ad7 100644
--- a/tests/config0/test_config__init__.py
+++ b/tests/config0/test_config__init__.py
@@ -4,7 +4,7 @@
import pytest
import secrets_env.config0 as t
-from secrets_env.provider import ProviderBase
+from secrets_env.provider import Provider
class TestLoadConfig:
@@ -32,7 +32,7 @@ def assert_config_format(self, cfg: dict):
for name, provider in cfg["providers"].items():
assert isinstance(name, str)
- assert isinstance(provider, ProviderBase)
+ assert isinstance(provider, Provider)
for request in cfg["requests"]:
assert isinstance(request["name"], str)
diff --git a/tests/config0/test_parser.py b/tests/config0/test_parser.py
index 1d81834f..add1b1b4 100644
--- a/tests/config0/test_parser.py
+++ b/tests/config0/test_parser.py
@@ -4,14 +4,14 @@
import secrets_env.config0.parser as t
from secrets_env.exceptions import AuthenticationError, ConfigError
-from secrets_env.provider import ProviderBase
+from secrets_env.provider import Provider
@pytest.fixture()
def _patch_get_provider(monkeypatch: pytest.MonkeyPatch):
def mock_parser(data: dict):
assert isinstance(data, dict)
- return Mock(spec=ProviderBase)
+ return Mock(spec=Provider)
monkeypatch.setattr("secrets_env.providers.get_provider", mock_parser)
@@ -35,7 +35,7 @@ def test_success(self):
}
)
assert isinstance(cfg, dict)
- assert isinstance(cfg["providers"]["main"], ProviderBase)
+ assert isinstance(cfg["providers"]["main"], Provider)
assert cfg["requests"][0]["spec"] == "foobar"
def test_no_request(self):
@@ -73,15 +73,15 @@ def test_success(self):
}
)
assert len(result) == 3
- assert isinstance(result["main"], ProviderBase)
- assert isinstance(result["provider 1"], ProviderBase)
- assert isinstance(result["provider 2"], ProviderBase)
+ assert isinstance(result["main"], Provider)
+ assert isinstance(result["provider 1"], Provider)
+ assert isinstance(result["provider 2"], Provider)
@pytest.mark.usefixtures("_patch_get_provider")
def test_duplicated_name(self, caplog: pytest.LogCaptureFixture):
result = t.get_providers({"source": [{"data": "dummy"}] * 2})
assert len(result) == 1
- assert isinstance(result["main"], ProviderBase)
+ assert isinstance(result["main"], Provider)
assert "Duplicated source name main" in caplog.text
@@ -122,13 +122,13 @@ class TestParseSourceItem:
def test_success_1(self):
name, provider = t.parse_source_item({})
assert name == "main"
- assert isinstance(provider, ProviderBase)
+ assert isinstance(provider, Provider)
@pytest.mark.usefixtures("_patch_get_provider")
def test_success_2(self):
name, provider = t.parse_source_item({"name": "test"})
assert name == "test"
- assert isinstance(provider, ProviderBase)
+ assert isinstance(provider, Provider)
def test_name_error(self):
assert t.parse_source_item({"name": object()}) is None
diff --git a/tests/providers/teleport/test_teleport.py b/tests/providers/teleport/test_teleport.py
index 678cb8c8..fd9dab0d 100644
--- a/tests/providers/teleport/test_teleport.py
+++ b/tests/providers/teleport/test_teleport.py
@@ -1,52 +1,82 @@
-from unittest.mock import Mock
+from pathlib import Path
+from unittest.mock import Mock, PropertyMock
import pytest
-import secrets_env.providers.teleport as t
-from secrets_env.exceptions import ConfigError
-from secrets_env.provider import ProviderBase
-from secrets_env.providers.teleport.config import TeleportUserConfig
-from secrets_env.providers.teleport.provider import TeleportProvider
+from secrets_env.providers.teleport import TeleportProvider, TeleportRequestSpec, get_ca
+from secrets_env.providers.teleport.config import TeleportConnectionParameter
-def test_get_provider():
- provider = t.get_provider("teleport", {"app": "test"})
- assert isinstance(provider, TeleportProvider)
+class TestTeleportRequestSpec:
+ def test_success(self):
+ spec = TeleportRequestSpec.model_validate({"field": "ca", "format": "pem"})
+ assert spec == TeleportRequestSpec(field="ca", format="pem")
+ def test_shortcut(self):
+ spec = TeleportRequestSpec.model_validate("uri")
+ assert spec == TeleportRequestSpec(field="uri", format="path")
-class TestGetAdaptedProvider:
- def test_success(self, monkeypatch: pytest.MonkeyPatch):
- def mock_factory(subtype, data, param):
- assert subtype == "Test"
- assert isinstance(data, dict)
- assert isinstance(param, TeleportUserConfig)
- return Mock(spec=ProviderBase)
-
- def mock_get_adapter(subtype):
- assert subtype == "Test"
- return mock_factory
+class TestTeleportProvider:
+ @pytest.fixture()
+ def provider(
+ self, monkeypatch: pytest.MonkeyPatch, conn_param: TeleportConnectionParameter
+ ):
monkeypatch.setattr(
- "secrets_env.providers.teleport.adapters.get_adapter",
- mock_get_adapter,
+ TeleportProvider, "connection_param", PropertyMock(return_value=conn_param)
)
+ return TeleportProvider(app="test")
+
+ def test_get_uri(self, provider: TeleportProvider):
+ assert provider.get("uri") == "https://example.com"
+
+ def test_get_ca(self, provider: TeleportProvider):
+ expect = "subject=/C=XX/L=Default City/O=Test\n-----MOCK CERTIFICATE-----"
+ assert provider.get({"field": "ca", "format": "pem"}) == expect
+ with open(provider.get("ca")) as fd:
+ assert fd.read() == expect
+
+ def test_get_cert(self, provider: TeleportProvider):
+ expect = "-----MOCK CERTIFICATE-----"
+ assert provider.get({"field": "cert", "format": "pem"}) == expect
+ with open(provider.get("cert")) as fd:
+ assert fd.read() == expect
+
+ def test_get_key(self, provider: TeleportProvider):
+ expect = "-----MOCK PRIVATE KEY-----"
+ assert provider.get({"field": "key", "format": "pem"}) == expect
+ with open(provider.get("key")) as fd:
+ assert fd.read() == expect
+
+ def test_get_cert_and_key(self, provider: TeleportProvider):
+ expect = "-----MOCK CERTIFICATE-----\n-----MOCK PRIVATE KEY-----"
+ assert provider.get({"field": "cert+key", "format": "pem"}) == expect
+ with open(provider.get("cert+key")) as fd:
+ assert fd.read() == expect
+
+ def test_get_invalid(self, monkeypatch: pytest.MonkeyPatch):
+ spec = Mock(TeleportRequestSpec)
+ spec.field = "unknown"
monkeypatch.setattr(
- TeleportUserConfig,
- "get_connection_param",
- lambda _: Mock(spec=TeleportUserConfig),
+ TeleportRequestSpec, "model_validate", Mock(return_value=spec)
)
- config = {
- "teleport": {
- "app": "test",
- }
- }
- provider = t.get_adapted_provider("teleport+Test", config)
+ with pytest.raises(RuntimeError):
+ TeleportProvider(app="test").get("unknown")
+
- assert isinstance(provider, ProviderBase)
+class TestGetCa:
+ def test_success(self):
+ param = Mock(TeleportConnectionParameter)
+ param.ca = b"-----MOCK CERTIFICATE-----"
+ param.path_ca = Path("path/to/ca")
+
+ assert get_ca(param, "path") == "path/to/ca"
+ assert get_ca(param, "pem") == "-----MOCK CERTIFICATE-----"
def test_fail(self):
- with pytest.raises(ConfigError):
- t.get_adapted_provider("not-teleport+other", {})
- with pytest.raises(ConfigError): # raise by get_adapter
- t.get_adapted_provider("teleport+no-this-type", {})
+ param = Mock(TeleportConnectionParameter)
+ param.ca = None
+
+ with pytest.raises(LookupError, match="CA is not available"):
+ get_ca(param, "pem")
diff --git a/tests/providers/teleport/test_teleport_adapters.py b/tests/providers/teleport/test_teleport_adapters.py
deleted file mode 100644
index 0cc930f3..00000000
--- a/tests/providers/teleport/test_teleport_adapters.py
+++ /dev/null
@@ -1,39 +0,0 @@
-from pathlib import Path
-from unittest.mock import Mock
-
-import pytest
-
-import secrets_env.providers.teleport.adapters as t
-from secrets_env.exceptions import ConfigError
-from secrets_env.provider import ProviderBase
-from secrets_env.providers.teleport.config import TeleportConnectionParameter
-
-
-def test_get_adapter():
- assert callable(t.get_adapter("Vault"))
-
- with pytest.raises(ConfigError):
- t.get_adapter("no-this-type")
-
-
-def test_adapt_vault_provider(monkeypatch: pytest.MonkeyPatch):
- def mock_load(type_, data):
- assert data["url"] == "https://example.com"
- assert len(data["tls"]) == 2
- assert isinstance(data["tls"]["client_cert"], Path)
- assert isinstance(data["tls"]["client_key"], Path)
- return Mock(spec=ProviderBase)
-
- monkeypatch.setattr("secrets_env.providers.vault.get_provider", mock_load)
-
- provider = t.adapt_vault_provider(
- type_="vault",
- data={"url": "http://invalid.example.com", "auth": "oidc"},
- param=TeleportConnectionParameter(
- uri="https://example.com",
- ca=None,
- cert=b"cert",
- key=b"key",
- ),
- )
- assert isinstance(provider, ProviderBase)
diff --git a/tests/providers/teleport/test_teleport_config.py b/tests/providers/teleport/test_teleport_config.py
index d0d509a5..9008f13b 100644
--- a/tests/providers/teleport/test_teleport_config.py
+++ b/tests/providers/teleport/test_teleport_config.py
@@ -49,7 +49,7 @@ def _patch_which(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr("shutil.which", lambda _: "/mock/tsh")
@pytest.mark.usefixtures("_patch_which")
- def test_get_connection_param_1(self, monkeypatch: pytest.MonkeyPatch):
+ def test_connection_param_1(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(
"secrets_env.providers.teleport.config.call_version",
lambda: True,
@@ -60,10 +60,10 @@ def test_get_connection_param_1(self, monkeypatch: pytest.MonkeyPatch):
)
cfg = TeleportUserConfig(app="test")
- assert isinstance(cfg.get_connection_param(), TeleportConnectionParameter)
+ assert isinstance(cfg.connection_param, TeleportConnectionParameter)
@pytest.mark.usefixtures("_patch_which")
- def test_get_connection_param_2(self, monkeypatch: pytest.MonkeyPatch):
+ def test_connection_param_2(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(
"secrets_env.providers.teleport.config.call_version",
lambda: True,
@@ -82,28 +82,26 @@ def test_get_connection_param_2(self, monkeypatch: pytest.MonkeyPatch):
)
cfg = TeleportUserConfig(app="test")
- assert isinstance(cfg.get_connection_param(), TeleportConnectionParameter)
+ assert isinstance(cfg.connection_param, TeleportConnectionParameter)
- def test_get_connection_param_missing_dependency(
- self, monkeypatch: pytest.MonkeyPatch
- ):
+ def test_connection_param_missing_dependency(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr("shutil.which", lambda _: None)
cfg = TeleportUserConfig(app="test")
with pytest.raises(UnsupportedError):
- cfg.get_connection_param()
+ cfg.connection_param # noqa: B018
@pytest.mark.usefixtures("_patch_which")
- def test_get_connection_param_version_error(self, monkeypatch: pytest.MonkeyPatch):
+ def test_connection_param_version_error(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(
"secrets_env.providers.teleport.config.call_version",
lambda: False,
)
cfg = TeleportUserConfig(app="test")
with pytest.raises(RuntimeError):
- cfg.get_connection_param()
+ cfg.connection_param # noqa: B018
@pytest.mark.usefixtures("_patch_which")
- def test_get_connection_param_no_config(self, monkeypatch: pytest.MonkeyPatch):
+ def test_connection_param_no_config(self, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(
"secrets_env.providers.teleport.config.call_version",
lambda: True,
@@ -123,7 +121,7 @@ def test_get_connection_param_no_config(self, monkeypatch: pytest.MonkeyPatch):
cfg = TeleportUserConfig(app="test")
with pytest.raises(AuthenticationError):
- cfg.get_connection_param()
+ cfg.connection_param # noqa: B018
class TestTeleportConnectionParameter:
diff --git a/tests/providers/teleport/test_teleport_provider.py b/tests/providers/teleport/test_teleport_provider.py
deleted file mode 100644
index fcbdd889..00000000
--- a/tests/providers/teleport/test_teleport_provider.py
+++ /dev/null
@@ -1,123 +0,0 @@
-import re
-from pathlib import Path
-from unittest.mock import Mock
-
-import pytest
-
-import secrets_env.providers.teleport.provider as t
-from secrets_env.exceptions import ConfigError, ValueNotFound
-from secrets_env.providers.teleport.config import (
- TeleportConnectionParameter,
- TeleportUserConfig,
-)
-
-
-class TestTeleportProvider:
- @pytest.fixture()
- def provider(self):
- return t.TeleportProvider(config=TeleportUserConfig(app="test"))
-
- def test_type(self, provider):
- assert provider.type == "teleport"
-
- @pytest.mark.parametrize(
- ("field", "expect"),
- [
- ("CA", b"subject=/C=XX/L=Default City/O=Test\n-----MOCK CERTIFICATE-----"),
- ("CERT", b"-----MOCK CERTIFICATE-----"),
- ("KEY", b"-----MOCK PRIVATE KEY-----"),
- ("cert+KEY", b"-----MOCK CERTIFICATE-----\n-----MOCK PRIVATE KEY-----"),
- ],
- )
- def test_get_path(
- self, monkeypatch: pytest.MonkeyPatch, provider, conn_param, field, expect
- ):
- monkeypatch.setattr(
- TeleportUserConfig, "get_connection_param", lambda _: conn_param
- )
- path = provider.get({"field": field, "format": "path"})
- assert isinstance(path, str)
- assert Path(path).is_file()
- assert Path(path).read_bytes() == expect
-
- @pytest.mark.parametrize(
- ("field", "expect"),
- [
- ("URI", "https://example.com"),
- ("CA", "subject=/C=XX/L=Default City/O=Test\n-----MOCK CERTIFICATE-----"),
- ("CERT", "-----MOCK CERTIFICATE-----"),
- ("KEY", "-----MOCK PRIVATE KEY-----"),
- ("cert+KEY", "-----MOCK CERTIFICATE-----\n-----MOCK PRIVATE KEY-----"),
- ],
- )
- def test_get_pem(
- self, monkeypatch: pytest.MonkeyPatch, provider, conn_param, field, expect
- ):
- monkeypatch.setattr(
- TeleportUserConfig, "get_connection_param", lambda _: conn_param
- )
- data = provider.get({"field": field, "format": "pem"})
- assert isinstance(data, str)
- assert data == expect
-
- def test_get_error(self, monkeypatch: pytest.MonkeyPatch, provider):
- monkeypatch.setattr(
- t, "parse_spec", lambda _: Mock(spec=t.OutputSpec, field="unknown")
- )
- with pytest.raises(
- ConfigError, match=re.escape("Invalid value spec: {'mock': 'mocked'}")
- ):
- provider.get({"mock": "mocked"})
-
-
-class TestParseSpec:
- def test_success(self):
- assert t.parse_spec("ca") == t.OutputSpec("ca", "path")
- assert t.parse_spec({"field": "ca"}) == t.OutputSpec("ca", "path")
- assert t.parse_spec({"field": "ca", "format": "pem"}) == t.OutputSpec(
- "ca", "pem"
- )
-
- def test_failed(self):
- with pytest.raises(
- ConfigError, match=re.escape("Invalid field (secrets.VAR.field): invalid")
- ):
- t.parse_spec({"field": "invalid"})
-
- with pytest.raises(
- ConfigError, match=re.escape("Invalid format (secrets.VAR.format): invalid")
- ):
- t.parse_spec({"field": "ca", "format": "invalid"})
-
- def test_type_error(self):
- with pytest.raises(
- ConfigError,
- match=re.escape("Expect dict for secrets path spec, got int"),
- ):
- t.parse_spec(1234)
-
-
-def test_get_ca(conn_param):
- # path
- p = t.get_ca(conn_param, "path")
- assert isinstance(p, str)
- assert Path(p).is_file()
-
- # pem
- d = t.get_ca(conn_param, "pem")
- assert isinstance(d, str)
- assert d == "subject=/C=XX/L=Default City/O=Test\n-----MOCK CERTIFICATE-----"
-
- # no CA
- conn_param_no_ca = TeleportConnectionParameter(
- uri="https://example.com",
- ca=None,
- cert=b"-----MOCK CERTIFICATE-----",
- key=b"-----MOCK PRIVATE KEY-----",
- )
- with pytest.raises(ValueNotFound):
- t.get_ca(conn_param_no_ca, "path")
-
- # make linter happy
- with pytest.raises(RuntimeError):
- t.get_ca(conn_param, "no-this-format")
diff --git a/tests/providers/test_plain.py b/tests/providers/test_plain.py
deleted file mode 100644
index 454a3e52..00000000
--- a/tests/providers/test_plain.py
+++ /dev/null
@@ -1,16 +0,0 @@
-import pytest
-
-import secrets_env.providers.plain as t
-
-
-def test():
- provider = t.get_provider("plain", {})
- assert provider.type == "plain"
-
- assert provider.get("foo") == "foo"
- assert provider.get("") == ""
- assert provider.get({"value": "foo"}) == "foo"
- assert provider.get({"value": None}) == ""
-
- with pytest.raises(TypeError):
- provider.get(None)
diff --git a/tests/providers/test_providers.py b/tests/providers/test_providers.py
new file mode 100644
index 00000000..d83164c5
--- /dev/null
+++ b/tests/providers/test_providers.py
@@ -0,0 +1,53 @@
+import pytest
+
+from secrets_env.providers import get_provider
+from secrets_env.providers.null import NullProvider
+from secrets_env.providers.plain import PlainTextProvider
+from secrets_env.providers.teleport import TeleportProvider
+from secrets_env.providers.vault import VaultKvProvider
+
+
+class TestGetProvider:
+ def test_null(self):
+ provider = get_provider({"type": "null"})
+ assert isinstance(provider, NullProvider)
+
+ def test_plain(self):
+ provider = get_provider({"type": "plain"})
+ assert isinstance(provider, PlainTextProvider)
+
+ def test_teleport(self):
+ provider = get_provider({"type": "teleport", "app": "test"})
+ assert isinstance(provider, TeleportProvider)
+
+ def test_teleport_adapter(self):
+ with pytest.raises(NotImplementedError):
+ get_provider({"type": "teleport+vault"})
+
+ def test_vault(self):
+ provider = get_provider({"url": "https://example.com/", "auth": "null"})
+ assert isinstance(provider, VaultKvProvider)
+
+ def test_invalid(self):
+ with pytest.raises(ValueError, match="Unknown provider type invalid"):
+ get_provider({"type": "invalid"})
+
+
+class TestNullProvider:
+ def test_get(self):
+ provider = NullProvider()
+ assert provider.get("test") == ""
+ assert provider.get({"value": "test"}) == ""
+
+
+class TestPlainTextProvider:
+ def test(self):
+ provider = PlainTextProvider()
+ assert provider.get("test") == "test"
+ assert provider.get("") == ""
+
+ assert provider.get({"value": "test"}) == "test"
+ assert provider.get({"value": None}) == ""
+ assert provider.get({"value": ""}) == ""
+
+ assert provider.get({"invalid": "foo"}) == ""
diff --git a/tests/providers/test_providers___init__.py b/tests/providers/test_providers___init__.py
deleted file mode 100644
index e1748d16..00000000
--- a/tests/providers/test_providers___init__.py
+++ /dev/null
@@ -1,45 +0,0 @@
-from unittest.mock import Mock
-
-import pytest
-
-import secrets_env.providers as t
-from secrets_env.exceptions import ConfigError
-from secrets_env.provider import ProviderBase
-
-
-def mock_get_provider(type_, data):
- return Mock(spec=ProviderBase)
-
-
-class TestGetProvider:
- def test_null(self):
- p = t.get_provider({"type": "NULL"})
- assert p.get({}) == ""
-
- def test_vault(self, monkeypatch: pytest.MonkeyPatch):
- monkeypatch.setattr(
- "secrets_env.providers.vault.get_provider", mock_get_provider
- )
- assert isinstance(t.get_provider({"type": "Vault"}), ProviderBase)
-
- def test_teleport(self, monkeypatch: pytest.MonkeyPatch):
- monkeypatch.setattr(
- "secrets_env.providers.teleport.get_provider", mock_get_provider
- )
- assert isinstance(t.get_provider({"type": "Teleport"}), ProviderBase)
-
- def test_teleport_adapter(self, monkeypatch: pytest.MonkeyPatch):
- monkeypatch.setattr(
- "secrets_env.providers.teleport.get_adapted_provider", mock_get_provider
- )
- assert isinstance(t.get_provider({"type": "Teleport+Test"}), ProviderBase)
-
- def test_plain(self, monkeypatch: pytest.MonkeyPatch):
- monkeypatch.setattr(
- "secrets_env.providers.plain.get_provider", mock_get_provider
- )
- assert isinstance(t.get_provider({"type": "plain"}), ProviderBase)
-
- def test_not_found(self):
- with pytest.raises(ConfigError):
- t.get_provider({"type": "no-this-type"})
diff --git a/tests/providers/vault/auth/test_userpass.py b/tests/providers/vault/auth/test_userpass.py
index 1089949f..3fd3ce25 100644
--- a/tests/providers/vault/auth/test_userpass.py
+++ b/tests/providers/vault/auth/test_userpass.py
@@ -8,6 +8,7 @@
from pydantic_core import Url
import secrets_env.providers.vault.auth.userpass as t
+from secrets_env.exceptions import AuthenticationError
from secrets_env.providers.vault.auth.userpass import UserPasswordAuth
@@ -122,10 +123,7 @@ class MockAuth(UserPasswordAuth):
assert auth_obj.login(unittest_client) == "client-token"
def test_login_fail(
- self,
- unittest_respx: respx.MockRouter,
- unittest_client: httpx.Client,
- caplog: pytest.LogCaptureFixture,
+ self, unittest_respx: respx.MockRouter, unittest_client: httpx.Client
):
unittest_respx.post("/v1/auth/mock/login/user%40example.com").mock(
return_value=httpx.Response(400)
@@ -136,8 +134,9 @@ class MockAuth(UserPasswordAuth):
vault_name = "mock"
auth_obj = MockAuth(username="user@example.com", password="password")
- assert auth_obj.login(unittest_client) is None
- assert "Failed to login with MOCK method" in caplog.text
+
+ with pytest.raises(AuthenticationError):
+ assert auth_obj.login(unittest_client) is None
@pytest.mark.parametrize(
diff --git a/tests/providers/vault/test_vault___init__.py b/tests/providers/vault/test_vault___init__.py
deleted file mode 100644
index 7b1561d1..00000000
--- a/tests/providers/vault/test_vault___init__.py
+++ /dev/null
@@ -1,25 +0,0 @@
-import os
-
-import pytest
-
-import secrets_env.providers.vault as t
-from secrets_env.exceptions import ConfigError
-from secrets_env.providers.vault.provider import KvProvider
-
-
-class TestGetProvider:
- def test_success(self):
- out = t.get_provider("vault", {"url": "https://example.com", "auth": "null"})
- assert isinstance(out, KvProvider)
-
- def test_fail(self):
- if "VAULT_ADDR" in os.environ:
- pytest.skip("VAULT_ADDR is set. Skipping test.")
- with pytest.raises(ConfigError):
- t.get_provider("vault", {})
-
- def test_not_related(self):
- if "VAULT_ADDR" in os.environ:
- pytest.skip("VAULT_ADDR is set. Skipping test.")
- with pytest.raises(ConfigError):
- t.get_provider("something-else", {})
diff --git a/tests/providers/vault/test_vault_provider.py b/tests/providers/vault/test_vault_provider.py
index 117d6c2a..3b1cdf5b 100644
--- a/tests/providers/vault/test_vault_provider.py
+++ b/tests/providers/vault/test_vault_provider.py
@@ -1,261 +1,338 @@
import os
from pathlib import Path
-from unittest.mock import Mock, PropertyMock, patch
+from unittest.mock import Mock, PropertyMock
import httpx
-import httpx._config
import pytest
import respx
-from pydantic import ValidationError
-
-import secrets_env.providers.vault.provider as t
-from secrets_env.exceptions import AuthenticationError, ValueNotFound
-from secrets_env.providers.vault.auth.base import Auth
-from secrets_env.providers.vault.auth.token import TokenAuth
-from secrets_env.providers.vault.provider import VaultPath
+from pydantic_core import ValidationError
+
+from secrets_env.exceptions import AuthenticationError
+from secrets_env.providers.vault.auth.base import Auth, NoAuth
+from secrets_env.providers.vault.config import TlsConfig, VaultUserConfig
+from secrets_env.providers.vault.provider import (
+ MountMetadata,
+ VaultKvProvider,
+ VaultPath,
+ _split_field_str,
+ create_http_client,
+ get_mount,
+ get_token,
+ is_authenticated,
+ read_secret,
+)
@pytest.fixture()
-def mock_client() -> httpx.Client:
- client = Mock(spec=httpx.Client)
- client.headers = {}
- return client
+def intl_provider() -> VaultKvProvider:
+ if "VAULT_ADDR" not in os.environ:
+ raise pytest.skip("VAULT_ADDR is not set")
+ if "VAULT_TOKEN" not in os.environ:
+ raise pytest.skip("VAULT_TOKEN is not set")
+ return VaultKvProvider(auth="token")
@pytest.fixture()
-def mock_auth():
- auth = Mock(spec=Auth)
- auth.method = "mocked"
- return auth
-
+def intl_client(intl_provider: VaultKvProvider) -> httpx.Client:
+ return intl_provider.client
-class TestKvProvider:
- """Unit tests for KvProvider"""
- @pytest.fixture()
- def provider(self, mock_auth: Auth) -> t.KvProvider:
- return t.KvProvider("https://example.com/", mock_auth)
-
- def test_type(self, provider: t.KvProvider):
- assert provider.type == "vault"
+class TestVaultPath:
+ def test_success(self):
+ path = VaultPath.model_validate('foo#"bar.baz".qux')
+ assert path == VaultPath(path="foo", field=("bar.baz", "qux"))
+ assert str(path) == 'foo#"bar.baz".qux'
- def test_client_success(
- self,
- monkeypatch: pytest.MonkeyPatch,
- provider: t.KvProvider,
- mock_client: httpx.Client,
- ):
- # setup
- monkeypatch.setattr(t, "get_token", lambda c, a: "token")
+ def test_invalid(self):
+ # missing path
+ with pytest.raises(ValidationError):
+ VaultPath(path="", field=("b"))
- patch_client = Mock(return_value=mock_client)
- monkeypatch.setattr("httpx.Client", patch_client)
+ # missing path/field separator
+ with pytest.raises(ValidationError):
+ VaultPath.model_validate("foobar")
- provider.proxy = "proxy"
- provider.ca_cert = Mock(spec=Path)
- provider.client_cert = Mock(spec=Path)
+ # too many path/field separator
+ with pytest.raises(ValidationError):
+ VaultPath.model_validate("foo#bar#baz")
- # run twice for testing cache
- assert provider.client is mock_client
- assert provider.client is mock_client
+ # empty field subpath
+ with pytest.raises(ValidationError):
+ VaultPath(path="a", field=())
+ with pytest.raises(ValidationError):
+ VaultPath(path="a", field=("b", "", "c"))
+ with pytest.raises(ValidationError):
+ VaultPath(path="a", field=("b", ""))
- # test
- assert mock_client.headers["X-Vault-Token"] == "token"
- _, kwargs = patch_client.call_args
- assert kwargs["base_url"] == "https://example.com/"
- assert kwargs["proxies"] == "proxy"
- assert isinstance(kwargs["verify"], Path)
- assert isinstance(kwargs["cert"], Path)
+class TestSplitFieldStr:
+ def test_success(self):
+ assert list(_split_field_str("foo")) == ["foo"]
+ assert list(_split_field_str("foo.bar.baz")) == ["foo", "bar", "baz"]
+ assert list(_split_field_str('foo."bar.baz"')) == ["foo", "bar.baz"]
+ assert list(_split_field_str('"foo.bar".baz')) == ["foo.bar", "baz"]
+ assert list(_split_field_str("")) == []
- def test_get_success(self, monkeypatch: pytest.MonkeyPatch, provider: t.KvProvider):
- def mock_read_field(path, field):
- assert path == "foo"
- assert field == "bar"
- return "secret"
+ def test_invalid(self):
+ with pytest.raises(ValueError, match=r"Failed to parse field:"):
+ list(_split_field_str('foo."bar.baz'))
- monkeypatch.setattr(provider, "read_field", mock_read_field)
- assert provider.get("foo#bar") == "secret"
+class TestVaultKvProvider:
+ def test_client(self, monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.setattr(
+ "secrets_env.providers.vault.provider.create_http_client",
+ lambda _: Mock(httpx.Client, headers={}),
+ )
+ provider = VaultKvProvider(url="https://vault.example.com", auth="null")
+ assert isinstance(provider.client, httpx.Client)
- def test_get_fail(self, provider: t.KvProvider):
- with pytest.raises(ValidationError):
- provider.get({})
- with pytest.raises(ValidationError):
- provider.get(1234)
+ @pytest.fixture()
+ def unittest_provider(self, monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.setattr(
+ VaultKvProvider, "client", PropertyMock(return_value=Mock(httpx.Client))
+ )
+ return VaultKvProvider(url="https://vault.example.com", auth="null")
- def test_read_secret_success(
- self,
- monkeypatch: pytest.MonkeyPatch,
- provider: t.KvProvider,
- unittest_client: httpx.Client,
+ def test_get__success(
+ self, monkeypatch: pytest.MonkeyPatch, unittest_provider: VaultKvProvider
):
- monkeypatch.setattr(provider, "_secrets", {})
monkeypatch.setattr(
- t.KvProvider, "client", PropertyMock(return_value=unittest_client)
+ VaultKvProvider, "_read_secret", Mock(return_value={"bar": "test"})
)
+ assert unittest_provider.get({"path": "foo", "field": "bar"}) == "test"
- monkeypatch.setattr(t, "read_secret", lambda _1, _2: {"bar": "secret"})
- assert provider.read_secret("test-path") == {"bar": "secret"}
-
- def test_read_secret_cache(
- self, monkeypatch: pytest.MonkeyPatch, provider: t.KvProvider
+ def test_get__too_depth(
+ self, monkeypatch: pytest.MonkeyPatch, unittest_provider: VaultKvProvider
):
- monkeypatch.setattr(provider, "_secrets", {"test-path": {"bar": "secret"}})
- assert provider.read_secret("test-path") == {"bar": "secret"}
+ monkeypatch.setattr(
+ VaultKvProvider, "_read_secret", Mock(return_value={"bar": "test"})
+ )
+ with pytest.raises(LookupError, match='Field "bar.baz" not found in "foo"'):
+ unittest_provider.get({"path": "foo", "field": "bar.baz"})
- def test_read_secret_not_found(
- self,
- monkeypatch: pytest.MonkeyPatch,
- provider: t.KvProvider,
- unittest_client: httpx.Client,
+ def test_get__too_shallow(
+ self, monkeypatch: pytest.MonkeyPatch, unittest_provider: VaultKvProvider
):
- monkeypatch.setattr(provider, "_secrets", {})
monkeypatch.setattr(
- t.KvProvider, "client", PropertyMock(return_value=unittest_client)
+ VaultKvProvider, "_read_secret", Mock(return_value={"bar": {"baz": "test"}})
)
- monkeypatch.setattr(t, "read_secret", lambda _1, _2: None)
- with pytest.raises(ValueNotFound):
- provider.read_secret("test-secret")
-
- def test_read_secret_error(self, provider: t.KvProvider):
- with pytest.raises(TypeError):
- provider.read_secret(1234)
+ with pytest.raises(
+ LookupError, match='Field "bar" in "foo" is not point to a string value'
+ ):
+ unittest_provider.get({"path": "foo", "field": "bar"})
- def test_read_field_success(
- self, monkeypatch: pytest.MonkeyPatch, provider: t.KvProvider
+ def test_read_secret__success(
+ self, monkeypatch: pytest.MonkeyPatch, unittest_provider: VaultKvProvider
):
- monkeypatch.setattr(provider, "read_secret", lambda _: {"bar": "secret"})
- assert provider.read_field("foo", "bar") == "secret"
+ func = Mock(return_value={"foo": "bar"})
+ monkeypatch.setattr("secrets_env.providers.vault.provider.read_secret", func)
+
+ path = VaultPath(path="foo", field="bar")
+ assert unittest_provider._read_secret(path) == {"foo": "bar"}
+ assert unittest_provider._read_secret(path) == {"foo": "bar"}
+
+ assert func.call_count == 1
+
+ client, path = func.call_args[0]
+ assert isinstance(client, httpx.Client)
+ assert path == "foo"
- def test_read_field_fail(
- self, monkeypatch: pytest.MonkeyPatch, provider: t.KvProvider
+ def test_read_secret__not_found(
+ self, monkeypatch: pytest.MonkeyPatch, unittest_provider: VaultKvProvider
):
- with pytest.raises(TypeError):
- provider.read_field(1234, "bar")
+ func = Mock(return_value=None)
+ monkeypatch.setattr("secrets_env.providers.vault.provider.read_secret", func)
- monkeypatch.setattr(provider, "read_secret", lambda _: {})
- with pytest.raises(ValueNotFound):
- provider.read_field("foo", "bar")
+ path = VaultPath(path="foo", field="bar")
+ with pytest.raises(LookupError):
+ unittest_provider._read_secret(path)
+ with pytest.raises(LookupError):
+ unittest_provider._read_secret(path)
+ assert func.call_count == 1
-class TestKvProviderUsingVaultConnection:
- @pytest.fixture(scope="class")
- def provider(self) -> t.KvProvider:
- url = os.getenv("VAULT_ADDR")
- token = os.getenv("VAULT_TOKEN")
- if not url or not token:
- pytest.skip("VAULT_ADDR or VAULT_TOKEN are not set")
- return t.KvProvider(url, TokenAuth(token=token))
+ def test_integration(self, intl_provider: VaultKvProvider):
+ assert intl_provider.get({"path": "kv2/test", "field": "foo"}) == "hello, world"
+ assert intl_provider.get('kv2/test#test."name.with-dot"') == "sample-value"
- def test_client_success(self, provider: t.KvProvider):
- with patch.object(t, "is_authenticated", return_value=True):
- assert isinstance(provider.client, httpx.Client)
- assert isinstance(provider.client, httpx.Client) # from cache
- def test_get(self, provider: t.KvProvider):
- assert provider.get("kv1/test#foo") == "hello"
- assert provider.get({"path": "kv2/test", "field": "foo"}) == "hello, world"
+class TestCreateHttpClient:
+ @pytest.mark.skipif("VAULT_ADDR" in os.environ, reason="VAULT_ADDR is set")
+ def test_basic(self):
+ config = VaultUserConfig(
+ url="https://vault.example.com",
+ auth="null",
+ )
- def test_read_secret_v1(self, provider: t.KvProvider):
- secret_1 = provider.read_secret("kv1/test")
- assert isinstance(secret_1, dict)
- assert secret_1["foo"] == "hello"
+ client = create_http_client(config)
- secret_2 = provider.read_secret("kv1/test")
- assert secret_1 is secret_2
+ assert isinstance(client, httpx.Client)
+ assert client.base_url == httpx.URL("https://vault.example.com/")
- def test_read_secret_v2(self, provider: t.KvProvider):
- secret = provider.read_secret("kv2/test")
- assert isinstance(secret, dict)
- assert secret["foo"] == "hello, world"
+ def test_proxy(self):
+ config = VaultUserConfig(
+ url="https://vault.example.com",
+ auth="null",
+ proxy="http://proxy.example.com",
+ )
- def test_read_field(self, provider: t.KvProvider):
- assert provider.read_field("kv1/test", "foo") == "hello"
- assert provider.read_field("kv2/test", 'test."name.with-dot"') == "sample-value"
+ client = create_http_client(config)
+ assert isinstance(client, httpx.Client)
- with pytest.raises(ValueNotFound):
- provider.read_field("kv2/test", "foo.no-extra-level")
- with pytest.raises(ValueNotFound):
- provider.read_field("kv2/test", "test.no-this-key")
- with pytest.raises(ValueNotFound):
- provider.read_field("secret/no-this-secret", "test")
+ @pytest.fixture()
+ def mock_httpx_client(self, monkeypatch: pytest.MonkeyPatch):
+ client = Mock(httpx.Client)
+ monkeypatch.setattr("httpx.Client", client)
+ return client
+
+ def test_ca(self, mock_httpx_client: Mock):
+ config = VaultUserConfig(
+ url="https://vault.example.com",
+ auth="null",
+ tls=Mock(
+ TlsConfig,
+ ca_cert=Path("/mock/ca.pem"),
+ client_cert=None,
+ client_key=None,
+ ),
+ )
+ create_http_client(config)
+
+ _, kwargs = mock_httpx_client.call_args
+ assert kwargs["verify"] == Path("/mock/ca.pem")
+
+ def test_client_cert(self, mock_httpx_client: Mock):
+ config = VaultUserConfig(
+ url="https://vault.example.com",
+ auth="null",
+ tls=Mock(
+ TlsConfig,
+ ca_cert=None,
+ client_cert=Path("/mock/client.pem"),
+ client_key=None,
+ ),
+ )
-class TestGetToken:
- def test_success(
- self,
- mock_client: httpx.Client,
- mock_auth: Auth,
- monkeypatch: pytest.MonkeyPatch,
- ):
- mock_auth.login.return_value = "t0ken"
- monkeypatch.setattr(t, "is_authenticated", lambda c, t: True)
- assert t.get_token(mock_client, mock_auth) == "t0ken"
+ create_http_client(config)
+
+ _, kwargs = mock_httpx_client.call_args
+ assert kwargs["cert"] == Path("/mock/client.pem")
+
+ def test_client_cert_pair(self, mock_httpx_client: Mock):
+ config = VaultUserConfig(
+ url="https://vault.example.com",
+ auth="null",
+ tls=Mock(
+ TlsConfig,
+ ca_cert=None,
+ client_cert=Path("/mock/client.pem"),
+ client_key=Path("/mock/client.key"),
+ ),
+ )
- def test_no_token(self, mock_client: httpx.Client, mock_auth: Auth):
- mock_auth.login.return_value = None
- with pytest.raises(AuthenticationError, match="Absence of token information"):
- t.get_token(mock_client, mock_auth)
+ create_http_client(config)
- def test_not_authenticated(
- self,
- mock_client: httpx.Client,
- mock_auth: Auth,
- monkeypatch: pytest.MonkeyPatch,
- ):
- mock_auth.login.return_value = "t0ken"
- monkeypatch.setattr(t, "is_authenticated", lambda c, t: False)
+ _, kwargs = mock_httpx_client.call_args
+ assert kwargs["cert"] == (Path("/mock/client.pem"), Path("/mock/client.key"))
+
+
+class TestGetToken:
+ def test_success(self, monkeypatch: pytest.MonkeyPatch):
+ client = Mock(httpx.Client)
+ auth = NoAuth(token="t0ken")
+ monkeypatch.setattr(
+ "secrets_env.providers.vault.provider.is_authenticated", lambda c, t: True
+ )
+ assert get_token(client, auth) == "t0ken"
+
+ def test_authenticate_fail(self, monkeypatch: pytest.MonkeyPatch):
+ client = Mock(httpx.Client)
+ auth = NoAuth(token="t0ken")
+ monkeypatch.setattr(
+ "secrets_env.providers.vault.provider.is_authenticated", lambda c, t: False
+ )
with pytest.raises(AuthenticationError, match="Invalid token"):
- t.get_token(mock_client, mock_auth)
+ get_token(client, auth)
- def test_login_connection_error(self, mock_client: httpx.Client, mock_auth: Auth):
- mock_auth.login.side_effect = httpx.ProxyError("test")
+ def test_login_connection_error(self):
+ client = Mock(httpx.Client)
+ auth = Mock(Auth)
+ auth.login.side_effect = httpx.ProxyError("test")
with pytest.raises(
AuthenticationError, match="Encounter proxy error while retrieving token"
):
- t.get_token(mock_client, mock_auth)
+ get_token(client, auth)
- def test_login_exception(self, mock_client: httpx.Client, mock_auth: Auth):
- mock_auth.login.side_effect = httpx.HTTPError("test")
+ def test_login_exception(self):
+ client = Mock(httpx.Client)
+ auth = Mock(Auth)
+ auth.login.side_effect = httpx.HTTPError("test")
with pytest.raises(httpx.HTTPError):
- t.get_token(mock_client, mock_auth)
+ get_token(client, auth)
+
+class TestIsAuthenticated:
-def test_is_authenticated():
- url = os.getenv("VAULT_ADDR")
- token = os.getenv("VAULT_TOKEN")
- if not url or not token:
- pytest.skip("VAULT_ADDR or VAULT_TOKEN are not set")
+ def test_success(self, respx_mock: respx.MockRouter):
+ respx_mock.get("https://vault.example.com/v1/auth/token/lookup-self")
- # success: use real client
- client = httpx.Client(base_url=os.getenv("VAULT_ADDR"))
- assert t.is_authenticated(client, os.getenv("VAULT_TOKEN"))
- assert not t.is_authenticated(client, "invalid-token")
+ client = httpx.Client(base_url="https://vault.example.com")
+ assert is_authenticated(client, "test-token") is True
- # type error
- with pytest.raises(TypeError):
- t.is_authenticated("http://example.com", "token")
- with pytest.raises(TypeError):
- t.is_authenticated(client, 1234)
+ def test_fail(self, respx_mock: respx.MockRouter):
+ respx_mock.get("https://vault.example.com/v1/auth/token/lookup-self").respond(
+ status_code=403,
+ json={"errors": ["mock permission denied"]},
+ )
+
+ client = httpx.Client(base_url="https://vault.example.com")
+ assert is_authenticated(client, "test-token") is False
+
+ def test_integration(self):
+ if "VAULT_ADDR" not in os.environ:
+ raise pytest.skip("VAULT_ADDR is not set")
+ if "VAULT_TOKEN" not in os.environ:
+ raise pytest.skip("VAULT_TOKEN is not set")
+
+ client = httpx.Client(base_url=os.getenv("VAULT_ADDR"))
+
+ assert is_authenticated(client, os.getenv("VAULT_TOKEN")) is True
+ assert is_authenticated(client, "invalid-token") is False
-class TestGetMountPoint:
+class TestReadSecret:
@pytest.fixture()
- def route(self, respx_mock: respx.MockRouter):
- return respx_mock.get(
- "https://example.com/v1/sys/internal/ui/mounts/secrets/test"
+ def _set_mount_kv2(self, monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.setattr(
+ "secrets_env.providers.vault.provider.get_mount",
+ lambda c, p: MountMetadata(path="secrets/", version=2),
)
- def test_success_kv1(self, route: respx.Route, unittest_client: httpx.Client):
- route.mock(
+ @pytest.mark.usefixtures("_set_mount_kv2")
+ def test_kv2(
+ self,
+ respx_mock: respx.MockRouter,
+ unittest_client: httpx.Client,
+ ):
+ respx_mock.get("https://example.com/v1/secrets/data/test").mock(
httpx.Response(
200,
json={
+ "request_id": "9ababbb6-3749-cf2c-5a5b-85660e917e8e",
+ "lease_id": "",
+ "renewable": False,
+ "lease_duration": 0,
"data": {
- "options": {"version": "1"},
- "path": "secrets/",
- "type": "kv",
+ "data": {"test": "mock"},
+ "metadata": {
+ "created_time": "2022-09-20T15:57:45.143053836Z",
+ "custom_metadata": None,
+ "deletion_time": "",
+ "destroyed": False,
+ "version": 1,
+ },
},
"wrap_info": None,
"warnings": None,
@@ -263,131 +340,145 @@ def test_success_kv1(self, route: respx.Route, unittest_client: httpx.Client):
},
)
)
- assert t.get_mount_point(unittest_client, "secrets/test") == ("secrets/", 1)
- def test_success_kv2(self, route: respx.Route, unittest_client: httpx.Client):
- route.mock(
- httpx.Response(
- 200,
- json={
- "data": {
- "options": {"version": "2"},
- "path": "secrets/",
- "type": "kv",
- },
- },
- )
- )
- assert t.get_mount_point(unittest_client, "secrets/test") == ("secrets/", 2)
+ assert read_secret(unittest_client, "secrets/test") == {"test": "mock"}
- def test_success_legacy(self, route: respx.Route, unittest_client: httpx.Client):
- route.mock(httpx.Response(404))
- assert t.get_mount_point(unittest_client, "secrets/test") == ("", 1)
+ def test_kv2_integration(self, intl_client: httpx.Client):
+ assert read_secret(intl_client, "kv2/test") == {
+ "foo": "hello, world",
+ "test": {"name.with-dot": "sample-value"},
+ }
- def test_not_ported_version(
- self, route: respx.Route, unittest_client: httpx.Client
+ def test_kv1(
+ self,
+ monkeypatch: pytest.MonkeyPatch,
+ respx_mock: respx.MockRouter,
+ unittest_client: httpx.Client,
):
- route.mock(
+ monkeypatch.setattr(
+ "secrets_env.providers.vault.provider.get_mount",
+ lambda c, p: MountMetadata(path="secrets/", version=1),
+ )
+ respx_mock.get("https://example.com/v1/secrets/test").mock(
httpx.Response(
200,
json={
- "data": {
- "path": "mock/",
- "type": "kv",
- "options": {"version": "99"},
- }
+ "request_id": "a8f28d97-8a9d-c9dd-4d86-e815083b33ad",
+ "lease_id": "",
+ "renewable": False,
+ "lease_duration": 2764800,
+ "data": {"test": "mock"},
+ "wrap_info": None,
+ "warnings": None,
+ "auth": None,
},
)
)
- assert t.get_mount_point(unittest_client, "secrets/test") == (None, None)
- def test_bad_request(
+ assert read_secret(unittest_client, "secrets/test") == {"test": "mock"}
+
+ def test_kv1_integration(self, intl_client: httpx.Client):
+ assert read_secret(intl_client, "kv1/test") == {"foo": "hello"}
+
+ @pytest.mark.usefixtures("_set_mount_kv2")
+ def test_not_found(
self,
- route: respx.Route,
+ respx_mock: respx.MockRouter,
unittest_client: httpx.Client,
caplog: pytest.LogCaptureFixture,
):
- route.mock(httpx.Response(400))
- assert t.get_mount_point(unittest_client, "secrets/test") == (None, None)
- assert "Error occurred during checking metadata for secrets/test" in caplog.text
+ respx_mock.get("https://example.com/v1/secrets/data/test").mock(
+ httpx.Response(404)
+ )
+ assert read_secret(unittest_client, "secrets/test") is None
+ assert "Secret secrets/test not found" in caplog.text
+
+ def test_get_mount_error(
+ self, monkeypatch: pytest.MonkeyPatch, unittest_client: httpx.Client
+ ):
+ monkeypatch.setattr(
+ "secrets_env.providers.vault.provider.get_mount", lambda c, p: None
+ )
+ assert read_secret(unittest_client, "secrets/test") is None
+ @pytest.mark.usefixtures("_set_mount_kv2")
def test_connection_error(
self,
- route: respx.Route,
+ respx_mock: respx.MockRouter,
unittest_client: httpx.Client,
caplog: pytest.LogCaptureFixture,
):
- route.mock(side_effect=httpx.ConnectError)
- assert t.get_mount_point(unittest_client, "secrets/test") == (None, None)
+ respx_mock.get("https://example.com/v1/secrets/data/test").mock(
+ side_effect=httpx.ProxyError
+ )
+ assert read_secret(unittest_client, "secrets/test") is None
assert (
- "Error occurred during checking metadata for secrets/test: connection error"
+ "Error occurred during query secret secrets/test: proxy error"
in caplog.text
)
- def test_unhandled_exception(
- self, route: respx.Route, unittest_client: httpx.Client
+ @pytest.mark.usefixtures("_set_mount_kv2")
+ def test_http_exception(
+ self, respx_mock: respx.MockRouter, unittest_client: httpx.Client
):
- route.mock(side_effect=httpx.DecodingError)
+ respx_mock.get("https://example.com/v1/secrets/data/test").mock(
+ side_effect=httpx.DecodingError
+ )
with pytest.raises(httpx.DecodingError):
- t.get_mount_point(unittest_client, "secrets/test")
+ read_secret(unittest_client, "secrets/test")
- def test_type_error(self):
- with pytest.raises(TypeError):
- t.get_mount_point(1234, "secrets/test")
- with pytest.raises(TypeError):
- t.get_mount_point(Mock(spec=httpx.Client), 1234)
+ @pytest.mark.usefixtures("_set_mount_kv2")
+ def test_bad_request(
+ self,
+ respx_mock: respx.MockRouter,
+ unittest_client: httpx.Client,
+ caplog: pytest.LogCaptureFixture,
+ ):
+ respx_mock.get("https://example.com/v1/secrets/data/test").mock(
+ httpx.Response(499)
+ )
+ assert read_secret(unittest_client, "secrets/test") is None
+ assert "Error occurred during query secret secrets/test" in caplog.text
-class TestReadSecret:
+class TestGetMount:
@pytest.fixture()
- def patch_get_mount_point(self):
- with patch.object(t, "get_mount_point", return_value=("secrets/", 1)) as p:
- yield p
+ def route(self, respx_mock: respx.MockRouter):
+ return respx_mock.get(
+ "https://example.com/v1/sys/internal/ui/mounts/secrets/test"
+ )
- @pytest.mark.usefixtures("patch_get_mount_point")
- def test_kv1(self, respx_mock: respx.MockRouter, unittest_client: httpx.Client):
- respx_mock.get("https://example.com/v1/secrets/test").mock(
+ def test_success_kv2(self, route: respx.Route, unittest_client: httpx.Client):
+ route.mock(
httpx.Response(
200,
json={
- "request_id": "a8f28d97-8a9d-c9dd-4d86-e815083b33ad",
- "lease_id": "",
- "renewable": False,
- "lease_duration": 2764800,
- "data": {"test": "mock"},
- "wrap_info": None,
- "warnings": None,
- "auth": None,
+ "data": {
+ "options": {"version": "2"},
+ "path": "secrets/",
+ "type": "kv",
+ },
},
)
)
+ assert get_mount(unittest_client, "secrets/test") == MountMetadata(
+ path="secrets/", version=2
+ )
- with patch.object(t, "get_mount_point", return_value=("secrets/", 1)):
- assert t.read_secret(unittest_client, "secrets/test") == {"test": "mock"}
+ def test_success_kv2_integration(self, intl_client: httpx.Client):
+ assert get_mount(intl_client, "kv2/test") == MountMetadata(
+ path="kv2/", version=2
+ )
- def test_kv2(
- self,
- respx_mock: respx.MockRouter,
- unittest_client: httpx.Client,
- patch_get_mount_point: Mock,
- ):
- respx_mock.get("https://example.com/v1/secrets/data/test").mock(
+ def test_success_kv1(self, route: respx.Route, unittest_client: httpx.Client):
+ route.mock(
httpx.Response(
200,
json={
- "request_id": "9ababbb6-3749-cf2c-5a5b-85660e917e8e",
- "lease_id": "",
- "renewable": False,
- "lease_duration": 0,
"data": {
- "data": {"test": "mock"},
- "metadata": {
- "created_time": "2022-09-20T15:57:45.143053836Z",
- "custom_metadata": None,
- "deletion_time": "",
- "destroyed": False,
- "version": 1,
- },
+ "options": {"version": "1"},
+ "path": "secrets/",
+ "type": "kv",
},
"wrap_info": None,
"warnings": None,
@@ -395,102 +486,64 @@ def test_kv2(
},
)
)
+ assert get_mount(unittest_client, "secrets/test") == MountMetadata(
+ path="secrets/", version=1
+ )
- patch_get_mount_point.return_value = ("secrets/", 2)
- assert t.read_secret(unittest_client, "secrets/test") == {"test": "mock"}
-
- def test_get_mount_point_error(
- self, unittest_client: httpx.Client, patch_get_mount_point: Mock
- ):
- patch_get_mount_point.return_value = (None, None)
- assert t.read_secret(unittest_client, "secrets/test") is None
-
- @pytest.mark.usefixtures("patch_get_mount_point")
- def test_connection_error(
- self,
- respx_mock: respx.MockRouter,
- unittest_client: httpx.Client,
- caplog: pytest.LogCaptureFixture,
- ):
- respx_mock.get("https://example.com/v1/secrets/test").mock(
- side_effect=httpx.ProxyError
+ def test_success_kv1_integration(self, intl_client: httpx.Client):
+ assert get_mount(intl_client, "kv1/test") == MountMetadata(
+ path="kv1/", version=1
)
- assert t.read_secret(unittest_client, "secrets/test") is None
- assert (
- "Error occurred during query secret secrets/test: proxy error"
- in caplog.text
+ def test_success_legacy(self, route: respx.Route, unittest_client: httpx.Client):
+ route.mock(httpx.Response(404))
+ assert get_mount(unittest_client, "secrets/test") == MountMetadata(
+ path="", version=1
)
- @pytest.mark.usefixtures("patch_get_mount_point")
- def test_unhandled_exception(
- self, respx_mock: respx.MockRouter, unittest_client: httpx.Client
+ def test_not_ported_version(
+ self, route: respx.Route, unittest_client: httpx.Client
):
- respx_mock.get("https://example.com/v1/secrets/test").mock(
- side_effect=httpx.DecodingError
+ route.mock(
+ httpx.Response(
+ 200,
+ json={
+ "data": {
+ "path": "mock/",
+ "type": "kv",
+ "options": {"version": "99"},
+ }
+ },
+ )
)
- with pytest.raises(httpx.DecodingError):
- t.read_secret(unittest_client, "secrets/test")
- @pytest.mark.usefixtures("patch_get_mount_point")
- def test_not_found(
+ with pytest.raises(ValidationError):
+ get_mount(unittest_client, "secrets/test")
+
+ def test_bad_request(
self,
- respx_mock: respx.MockRouter,
+ route: respx.Route,
unittest_client: httpx.Client,
caplog: pytest.LogCaptureFixture,
):
- respx_mock.get("https://example.com/v1/secrets/test").mock(httpx.Response(404))
- assert t.read_secret(unittest_client, "secrets/test") is None
- assert "Secret secrets/test not found" in caplog.text
+ route.mock(httpx.Response(400))
+ assert get_mount(unittest_client, "secrets/test") is None
+ assert "Error occurred during checking metadata for secrets/test" in caplog.text
- @pytest.mark.usefixtures("patch_get_mount_point")
- def test_bad_request(
+ def test_connection_error(
self,
- respx_mock: respx.MockRouter,
+ route: respx.Route,
unittest_client: httpx.Client,
caplog: pytest.LogCaptureFixture,
):
- respx_mock.get("https://example.com/v1/secrets/test").mock(httpx.Response(499))
- assert t.read_secret(unittest_client, "secrets/test") is None
- assert "Error occurred during query secret secrets/test" in caplog.text
-
- def test_type_error(self):
- with pytest.raises(TypeError):
- t.read_secret(1234, "secrets/test")
- with pytest.raises(TypeError):
- t.read_secret(Mock(spec=httpx.Client), 1234)
-
-
-def test_split_field():
- assert t.split_field("aa") == ["aa"]
- assert t.split_field("aa.bb") == ["aa", "bb"]
- assert t.split_field('aa."bb.cc"') == ["aa", "bb.cc"]
- assert t.split_field('"aa.bb".cc') == ["aa.bb", "cc"]
- assert t.split_field('"aa.bb"') == ["aa.bb"]
-
- with pytest.raises(ValueError, match=r"Failed to parse name: "):
- t.split_field("")
- with pytest.raises(ValueError, match=r"Failed to parse name: .+"):
- t.split_field(".")
- with pytest.raises(ValueError, match=r"Failed to parse name: .+"):
- t.split_field("aa.")
- with pytest.raises(ValueError, match=r"Failed to parse name: .+"):
- t.split_field(".aa")
-
-
-class TestVaultPath:
- def test_success(self):
- path = VaultPath.model_validate("foo#bar")
- assert path == VaultPath(path="foo", field="bar")
-
- def test_empty(self):
- with pytest.raises(ValidationError):
- VaultPath.model_validate("a#")
- with pytest.raises(ValidationError):
- VaultPath.model_validate("#b")
+ route.mock(side_effect=httpx.ConnectError)
+ assert get_mount(unittest_client, "secrets/test") is None
+ assert (
+ "Error occurred during checking metadata for secrets/test: connection error"
+ in caplog.text
+ )
- def test_invalid(self):
- with pytest.raises(ValidationError):
- VaultPath.model_validate("foobar")
- with pytest.raises(ValidationError):
- VaultPath.model_validate("foo#bar#baz")
+ def test_http_exception(self, route: respx.Route, unittest_client: httpx.Client):
+ route.mock(side_effect=httpx.DecodingError)
+ with pytest.raises(httpx.DecodingError):
+ get_mount(unittest_client, "secrets/test")
diff --git a/tests/test_collect.py b/tests/test_collect.py
index a6768e0c..a1527c51 100644
--- a/tests/test_collect.py
+++ b/tests/test_collect.py
@@ -10,7 +10,7 @@
def test_read_values(caplog: pytest.LogCaptureFixture):
def create_provider(return_value: str):
- provider = Mock(spec=secrets_env.provider.ProviderBase)
+ provider = Mock(spec=secrets_env.provider.Provider)
provider.get.return_value = return_value
return provider
@@ -50,7 +50,7 @@ def create_provider(return_value: str):
class TestRead1:
def setup_method(self):
- self.provider = Mock(spec=secrets_env.provider.ProviderBase)
+ self.provider = Mock(spec=secrets_env.provider.Provider)
type(self.provider).name = PropertyMock(return_value="mock")
def test_success(self):