Skip to content

Commit

Permalink
Fixtures: make the entry_points fixture publicly available
Browse files Browse the repository at this point in the history
This fixture makes it easy to temporarily add or remove entry points
which comes in very handy when testing custom implementations of various
classes that are pluginnable. By moving them from `tests/conftest.py` to
the `aiida.manage.tests.pytest_fixtures` module, it can also be used by
plugin packages.
  • Loading branch information
sphuber committed Nov 4, 2022
1 parent f8ebf1e commit db0c254
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 120 deletions.
122 changes: 122 additions & 0 deletions aiida/manage/tests/pytest_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@
from __future__ import annotations

import asyncio
import copy
import pathlib
import shutil
import tempfile
import time
import warnings

import plumpy
import pytest
import wrapt

from aiida import plugins
from aiida.common.lang import type_check
from aiida.common.log import AIIDA_LOGGER
from aiida.common.warnings import warn_deprecation
from aiida.engine import Process, ProcessBuilder, submit
Expand Down Expand Up @@ -317,3 +322,120 @@ def _factory(
return node

return _factory


@wrapt.decorator
def suppress_deprecations(wrapped, _, args, kwargs):
"""Decorator that suppresses all ``DeprecationWarning``."""
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
return wrapped(*args, **kwargs)


class EntryPointManager:
"""Manager to temporarily add or remove entry points."""

@staticmethod
def eps():
return plugins.entry_point.eps()

@staticmethod
def _validate_entry_point(entry_point_string: str | None, group: str | None, name: str | None) -> tuple[str, str]:
"""Validate the definition of the entry point.
:param entry_point_string: Fully qualified entry point string.
:param name: Entry point name.
:param group: Entry point group.
:returns: The entry point group and name.
:raises TypeError: If `entry_point_string`, `group` or `name` are not a string, when defined.
:raises ValueError: If `entry_point_string` is not defined, nor a `group` and `name`.
:raises ValueError: If `entry_point_string` is not a complete entry point string with group and name.
"""
if entry_point_string is not None:
try:
group, name = plugins.entry_point.parse_entry_point_string(entry_point_string)
except TypeError as exception:
raise TypeError('`entry_point_string` should be a string when defined.') from exception
except ValueError as exception:
raise ValueError('invalid `entry_point_string` format, should `group:name`.') from exception

if name is None or group is None:
raise ValueError('neither `entry_point_string` is defined, nor `name` and `group`.')

type_check(group, str)
type_check(name, str)

return group, name

@suppress_deprecations
def add(
self,
value: type | str,
entry_point_string: str | None = None,
*,
name: str | None = None,
group: str | None = None
) -> None:
"""Add an entry point.
:param value: The class or function to register as entry point. The resource needs to be importable, so it can't
be inlined. Alternatively, the fully qualified name can be passed as a string.
:param entry_point_string: Fully qualified entry point string.
:param name: Entry point name.
:param group: Entry point group.
:returns: The entry point group and name.
:raises TypeError: If `entry_point_string`, `group` or `name` are not a string, when defined.
:raises ValueError: If `entry_point_string` is not defined, nor a `group` and `name`.
:raises ValueError: If `entry_point_string` is not a complete entry point string with group and name.
"""
if not isinstance(value, str):
value = f'{value.__module__}:{value.__name__}'

group, name = self._validate_entry_point(entry_point_string, group, name)
entry_point = plugins.entry_point.EntryPoint(name, value, group)
self.eps()[group].append(entry_point)

@suppress_deprecations
def remove(
self, entry_point_string: str | None = None, *, name: str | None = None, group: str | None = None
) -> None:
"""Remove an entry point.
:param value: Entry point value, fully qualified import path name.
:param entry_point_string: Fully qualified entry point string.
:param name: Entry point name.
:param group: Entry point group.
:returns: The entry point group and name.
:raises TypeError: If `entry_point_string`, `group` or `name` are not a string, when defined.
:raises ValueError: If `entry_point_string` is not defined, nor a `group` and `name`.
:raises ValueError: If `entry_point_string` is not a complete entry point string with group and name.
"""
group, name = self._validate_entry_point(entry_point_string, group, name)

for entry_point in self.eps()[group]:
if entry_point.name == name:
self.eps()[group].remove(entry_point)
break
else:
raise KeyError(f'entry point `{name}` does not exist in group `{group}`.')


@pytest.fixture
def entry_points(monkeypatch) -> EntryPointManager:
"""Return an instance of the ``EntryPointManager`` which allows to temporarily add or remove entry points.
This fixture creates a deep copy of the entry point cache returned by the :func:`aiida.plugins.entry_point.eps`
method and then monkey patches that function to return the deepcopy. This ensures that the changes on the entry
point cache performed during the test through the manager are undone at the end of the function scope.
.. note:: This fixture does not use the ``suppress_deprecations`` decorator on purpose, but instead adds it manually
inside the fixture's body. The reason is that otherwise all deprecations would be suppressed for the entire
scope of the fixture, including those raised by the code run in the test using the fixture, which is not
desirable.
"""
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
eps_copy = copy.deepcopy(plugins.entry_point.eps())
monkeypatch.setattr(plugins.entry_point, 'eps', lambda: eps_copy)
yield EntryPointManager()
23 changes: 23 additions & 0 deletions docs/source/topics/plugins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ The module provides the following fixtures:
* :ref:`started_daemon_client <topics:plugins:testfixtures:started-daemon-client>`: Same as ``daemon_client`` but the daemon is guaranteed to be running
* :ref:`stopped_daemon_client <topics:plugins:testfixtures:stopped-daemon-client>`: Same as ``daemon_client`` but the daemon is guaranteed to *not* be running
* :ref:`daemon_client <topics:plugins:testfixtures:daemon-client>`: Return a :class:`~aiida.engine.daemon.client.DaemonClient` instance to control the daemon
* :ref:`entry_points <topics:plugins:testfixtures:entry-points>`: Return a :class:`~aiida.manage.tests.pytest_fixtures.EntryPointManager` instance to add and remove entry points


.. _topics:plugins:testfixtures:aiida-profile:
Expand Down Expand Up @@ -509,6 +510,28 @@ Return a :class:`~aiida.engine.daemon.client.DaemonClient` instance that can be
daemon_client.stop_daemon(wait=True)
.. _topics:plugins:testfixtures:entry-points:

``entry_points``
----------------

Return a :class:`~aiida.manage.tests.pytest_fixtures.EntryPointManager` instance to add and remove entry points.

.. code-block:: python
def test_parser(entry_points):
"""Test a custom ``Parser`` implementation."""
from aiida.parsers import Parser
from aiida.plugins import ParserFactory
class CustomParser(Parser):
"""Parser implementation."""
entry_points.add(CustomParser, 'custom.parser')
assert ParserFactory('custom.parser', CustomParser)
Any entry points additions and removals are automatically undone at the end of the test.


.. _click: https://click.palletsprojects.com/
Expand Down
121 changes: 1 addition & 120 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@

import click
import pytest
import wrapt

from aiida import get_profile, plugins
from aiida.common.lang import type_check
from aiida import get_profile
from aiida.manage.configuration import Config, Profile, get_config, load_profile

pytest_plugins = ['aiida.manage.tests.pytest_fixtures', 'sphinx.testing.fixtures'] # pylint: disable=invalid-name
Expand Down Expand Up @@ -480,120 +478,3 @@ def reset_log_level():
finally:
log.CLI_LOG_LEVEL = None
log.configure_logging(with_orm=True)


@wrapt.decorator
def suppress_deprecations(wrapped, _, args, kwargs):
"""Decorator that suppresses all ``DeprecationWarning``s."""
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
return wrapped(*args, **kwargs)


class EntryPointManager:
"""Manager to temporarily add or remove entry points."""

@staticmethod
def eps():
return plugins.entry_point.eps()

@staticmethod
def _validate_entry_point(entry_point_string: str | None, group: str | None, name: str | None) -> tuple[str, str]:
"""Validate the definition of the entry point.
:param entry_point_string: Fully qualified entry point string.
:param name: Entry point name.
:param group: Entry point group.
:returns: The entry point group and name.
:raises TypeError: If `entry_point_string`, `group` or `name` are not a string, when defined.
:raises ValueError: If `entry_point_string` is not defined, nor a `group` and `name`.
:raises ValueError: If `entry_point_string` is not a complete entry point string with group and name.
"""
if entry_point_string is not None:
try:
group, name = plugins.entry_point.parse_entry_point_string(entry_point_string)
except TypeError as exception:
raise TypeError('`entry_point_string` should be a string when defined.') from exception
except ValueError as exception:
raise ValueError('invalid `entry_point_string` format, should `group:name`.') from exception

if name is None or group is None:
raise ValueError('neither `entry_point_string` is defined, nor `name` and `group`.')

type_check(group, str)
type_check(name, str)

return group, name

@suppress_deprecations
def add(
self,
value: type | str,
entry_point_string: str | None = None,
*,
name: str | None = None,
group: str | None = None
) -> None:
"""Add an entry point.
:param value: The class or function to register as entry point. The resource needs to be importable, so it can't
be inlined. Alternatively, the fully qualified name can be passed as a string.
:param entry_point_string: Fully qualified entry point string.
:param name: Entry point name.
:param group: Entry point group.
:returns: The entry point group and name.
:raises TypeError: If `entry_point_string`, `group` or `name` are not a string, when defined.
:raises ValueError: If `entry_point_string` is not defined, nor a `group` and `name`.
:raises ValueError: If `entry_point_string` is not a complete entry point string with group and name.
"""
if not isinstance(value, str):
value = f'{value.__module__}:{value.__name__}'

group, name = self._validate_entry_point(entry_point_string, group, name)
entry_point = plugins.entry_point.EntryPoint(name, value, group)
self.eps()[group].append(entry_point)

@suppress_deprecations
def remove(
self, entry_point_string: str | None = None, *, name: str | None = None, group: str | None = None
) -> None:
"""Remove an entry point.
:param value: Entry point value, fully qualified import path name.
:param entry_point_string: Fully qualified entry point string.
:param name: Entry point name.
:param group: Entry point group.
:returns: The entry point group and name.
:raises TypeError: If `entry_point_string`, `group` or `name` are not a string, when defined.
:raises ValueError: If `entry_point_string` is not defined, nor a `group` and `name`.
:raises ValueError: If `entry_point_string` is not a complete entry point string with group and name.
"""
group, name = self._validate_entry_point(entry_point_string, group, name)

for entry_point in self.eps()[group]:
if entry_point.name == name:
self.eps()[group].remove(entry_point)
break
else:
raise KeyError(f'entry point `{name}` does not exist in group `{group}`.')


@pytest.fixture
def entry_points(monkeypatch) -> EntryPointManager:
"""Return an instance of the ``EntryPointManager`` which allows to temporarily add or remove entry points.
This fixture creates a deep copy of the entry point cache returned by the :func:`aiida.plugins.entry_point.eps`
method and then monkey patches that function to return the deepcopy. This ensures that the changes on the entry
point cache performed during the test through the manager are undone at the end of the function scope.
.. note:: This fixture does not use the ``suppress_deprecations`` decorator on purpose, but instead adds it manually
inside the fixture's body. The reason is that otherwise all deprecations would be suppressed for the entire
scope of the fixture, including those raised by the code run in the test using the fixture, which is not
desirable.
"""
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
eps_copy = copy.deepcopy(plugins.entry_point.eps())
monkeypatch.setattr(plugins.entry_point, 'eps', lambda: eps_copy)
yield EntryPointManager()

0 comments on commit db0c254

Please sign in to comment.