Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: raise when detecting multiple entry points #5531

Merged
merged 4 commits into from
May 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion aiida/plugins/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,9 @@ def get_entry_point(group: str, name: str) -> EntryPoint:
found = eps().select(group=group, name=name)
if name not in found.names:
raise MissingEntryPointError(f"Entry point '{name}' not found in group '{group}'")
if len(found.names) > 1:
# If multiple entry points are found and they have different values we raise, otherwise if they all
# correspond to the same value, we simply return one of them
if len(found) > 1 and len(set(ep.value for ep in found)) != 1:
raise MultipleEntryPointError(f"Multiple entry points '{name}' found in group '{group}': {found}")
return found[name]

Expand Down
54 changes: 53 additions & 1 deletion tests/plugins/test_entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
# pylint: disable=redefined-outer-name
"""Tests for the :mod:`~aiida.plugins.entry_point` module."""
import pytest

from aiida.common.exceptions import MissingEntryPointError, MultipleEntryPointError
from aiida.common.warnings import AiidaDeprecationWarning
from aiida.plugins.entry_point import get_entry_point, validate_registered_entry_points
from aiida.plugins import entry_point
from aiida.plugins.entry_point import EntryPoint as EP
from aiida.plugins.entry_point import EntryPoints, get_entry_point, validate_registered_entry_points


def test_validate_registered_entry_points():
Expand Down Expand Up @@ -43,3 +47,51 @@ def test_get_entry_point_deprecated(group, name):

with pytest.warns(AiidaDeprecationWarning, match=warning):
get_entry_point(group, name)


@pytest.fixture
def eps(request):
"""Mocked version of :func:`aiida.plugins.entry_point.eps`.

The mocked function returns a dummy class whose ``select`` method returns a fixed list of entry points that are
passed in via the ``request`` parameter with which ``pytest`` invokes the fixture.
"""

class MockEntryPoints:

@staticmethod
def select(group, name): # pylint: disable=unused-argument
return EntryPoints(request.param)

return MockEntryPoints


@pytest.mark.parametrize(
'eps, name, exception', (
((EP(name='ep', group='gr', value='x'),), None, None),
((EP(name='ep', group='gr', value='x'),), 'non-existing', MissingEntryPointError),
((EP(name='ep', group='gr', value='x'), EP(name='ep', group='gr', value='y')), None, MultipleEntryPointError),
((EP(name='ep', group='gr', value='x'), EP(name='ep', group='gr', value='x')), None, None),
),
indirect=['eps']
)
def test_get_entry_point(eps, name, exception, monkeypatch):
"""Test the ``get_entry_point`` method.

Test four different cases:

* Requested entry point exists and no duplicates -> no exception
* Requested entry points does not exist -> MissingEntryPointError
* Requested entry point has two matches by name but hits have different values -> MultipleEntryPointError
* Requested entry point has two matches by name but hits have same values -> no exception

"""
monkeypatch.setattr(entry_point, 'eps', eps)

name = name or 'ep' # Try to load the entry point with name ``ep`` unless the fixture provides one

if exception:
with pytest.raises(exception):
get_entry_point(group='gr', name=name)
else:
get_entry_point(group='gr', name=name)