Skip to content

Commit

Permalink
CLI: use factories in PluginParamType whenever possible
Browse files Browse the repository at this point in the history
The `get_entry_point_from_string` method of the `PluginParamType`
parameter type, tries to determine the entry point group and name from
the value passed to the parameter and when matched, attempts to load the
corresponding entry point. It was doing so by directly calling the
`get_entry_point` method from the `aiida.plugins` module.

Since the plugins that ship with `aiida-core` were recently updated to
have them properly prefixed with `core.`, the old unprefixed entry point
names are now deprecated. When used through the factories, the legacy
entry point names are automatically detected and converted to the new
one, with a deprecation warning being printed. However, the command line
didn't have this functionality, since the `PluginParamType` was not
going through the factories.

Here, for the entry point groups that have a factory, the plugin param
type is updated to use the factories, hence also automatically profiting
from the deprecation pathway, allowing users to keep using the old entry
point names for a while.

The factories had to be modified slightly in order to make this work.
Since the `PluginParamType` has the argument `load` which determines
whether the matched entry point should be loaded or not, it should be
able to pass this through to the factories, since this was always
loading the entry point by default. For this reason, the factories now
also have the `load` keyword argument. When set to false, the entry
point itself is returned, instead of the resource that it points to. It
is set to `True` by default to maintain backwards compatibility.
  • Loading branch information
sphuber committed Aug 31, 2021
1 parent 16daad0 commit 9cc5624
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 68 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ repos:
aiida/orm/nodes/node.py|
aiida/orm/nodes/process/.*py|
aiida/plugins/entry_point.py|
aiida/plugins/factories.py|
aiida/repository/.*py|
aiida/tools/graph/graph_traversers.py|
aiida/tools/groups/paths.py|
Expand Down
5 changes: 4 additions & 1 deletion aiida/cmdline/params/options/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,10 @@ def set_log_level(ctx, __, value):
)

INPUT_PLUGIN = OverridableOption(
'-P', '--input-plugin', type=types.PluginParamType(group='calculations'), help='Calculation input plugin string.'
'-P',
'--input-plugin',
type=types.PluginParamType(group='calculations', load=False),
help='Calculation input plugin string.'
)

CALC_JOB_STATE = OverridableOption(
Expand Down
34 changes: 29 additions & 5 deletions aiida/cmdline/params/types/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Click parameter type for AiiDA Plugins."""
import functools

import click

from aiida.cmdline.utils import decorators
from aiida.common import exceptions
from aiida.plugins import factories
from aiida.plugins.entry_point import ENTRY_POINT_STRING_SEPARATOR, ENTRY_POINT_GROUP_PREFIX, EntryPointFormat
from aiida.plugins.entry_point import format_entry_point_string, get_entry_point_string_format
from aiida.plugins.entry_point import get_entry_point, get_entry_points, get_entry_point_groups

from .strings import EntryPointType

__all__ = ('PluginParamType',)
Expand All @@ -40,6 +43,18 @@ class PluginParamType(EntryPointType):
"""
name = 'plugin'

_factory_mapping = {
'aiida.calculations': factories.CalculationFactory,
'aiida.data': factories.DataFactory,
'aiida.groups': factories.GroupFactory,
'aiida.parsers': factories.ParserFactory,
'aiida.schedulers': factories.SchedulerFactory,
'aiida.transports': factories.TransportFactory,
'aiida.tools.dbimporters': factories.DbImporterFactory,
'aiida.tools.data.orbitals': factories.OrbitalFactory,
'aiida.workflows': factories.WorkflowFactory,
}

def __init__(self, group=None, load=False, *args, **kwargs):
"""
Validate that group is either a string or a tuple of valid entry point groups, or if it
Expand Down Expand Up @@ -172,6 +187,11 @@ def get_entry_point_from_string(self, entry_point_string):
if group not in self.groups:
raise ValueError('entry point group {} is not supported by this parameter')

elif entry_point_format == EntryPointFormat.MINIMAL and len(self.groups) == 1:

name = entry_point_string
group = self.groups[0]

elif entry_point_format == EntryPointFormat.MINIMAL:

name = entry_point_string
Expand All @@ -187,19 +207,23 @@ def get_entry_point_from_string(self, entry_point_string):
"entry point '{}' is not valid for any of the allowed "
'entry point groups: {}'.format(name, ' '.join(self.groups))
)
else:
group = matching_groups[0]

group = matching_groups[0]

else:
ValueError(f'invalid entry point string format: {entry_point_string}')

# If there is a factory for the entry point group, use that, otherwise use ``get_entry_point``
try:
entry_point = get_entry_point(group, name)
get_entry_point_partial = functools.partial(self._factory_mapping[group], load=False)
except KeyError:
get_entry_point_partial = functools.partial(get_entry_point, group)

try:
return get_entry_point_partial(name)
except exceptions.EntryPointError as exception:
raise ValueError(exception)

return entry_point

@decorators.with_dbenv()
def convert(self, value, param, ctx):
"""
Expand Down
8 changes: 4 additions & 4 deletions aiida/orm/utils/builders/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
import enum
import os

import importlib_metadata

from aiida.cmdline.utils.decorators import with_dbenv
from aiida.cmdline.params.types.plugin import PluginParamType
from aiida.common.utils import ErrorAccumulator


Expand Down Expand Up @@ -63,7 +64,7 @@ def new(self):

code.label = self._get_and_count('label', used)
code.description = self._get_and_count('description', used)
code.set_input_plugin_name(self._get_and_count('input_plugin', used).name)
code.set_input_plugin_name(self._get_and_count('input_plugin', used))
code.set_prepend_text(self._get_and_count('prepend_text', used))
code.set_append_text(self._get_and_count('append_text', used))

Expand Down Expand Up @@ -154,8 +155,7 @@ def _set_code_attr(self, key, value):
Checks compatibility with other code attributes.
"""
# store only string of input plugin
if key == 'input_plugin' and isinstance(value, PluginParamType):
if key == 'input_plugin' and isinstance(value, importlib_metadata.EntryPoint):
value = value.name

backup = self._code_spec.copy()
Expand Down
Loading

0 comments on commit 9cc5624

Please sign in to comment.