Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent nodes without registered entry points from being stored #3886

Merged
merged 2 commits into from
Apr 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions aiida/orm/nodes/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,20 @@ def _validate(self):
# pylint: disable=no-self-use
return True

def validate_storability(self):
"""Verify that the current node is allowed to be stored.

:raises `aiida.common.exceptions.StoringNotAllowed`: if the node does not match all requirements for storing
"""
from aiida.plugins.entry_point import is_registered_entry_point

if not self._storable:
raise exceptions.StoringNotAllowed(self._unstorable_message)

if not is_registered_entry_point(self.__module__, self.__class__.__name__, groups=('aiida.node', 'aiida.data')):
msg = 'class `{}:{}` does not have registered entry point'.format(self.__module__, self.__class__.__name__)
raise exceptions.StoringNotAllowed(msg)

@classproperty
def class_node_type(cls):
"""Returns the node type of this node (sub) class."""
Expand Down Expand Up @@ -998,11 +1012,10 @@ def store(self, with_transaction=True, use_cache=None): # pylint: disable=argum
'the `use_cache` argument is deprecated and will be removed in `v2.0.0`', AiidaDeprecationWarning
)

if not self._storable:
raise exceptions.StoringNotAllowed(self._unstorable_message)

if not self.is_stored:

# Call `validate_storability` directly and not in `_validate` in case sub class forgets to call the super.
self.validate_storability()
self._validate()

# Verify that parents are already stored. Raises if this is not the case.
Expand Down
18 changes: 11 additions & 7 deletions aiida/orm/utils/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
import math
import numbers
import warnings
from collections.abc import Iterable, Mapping

from aiida.common import exceptions
Expand Down Expand Up @@ -70,7 +71,14 @@ def load_node_class(type_string):
entry_point_name = strip_prefix(base_path, 'nodes.')
return load_entry_point('aiida.node', entry_point_name)

raise exceptions.EntryPointError('unknown type string {}'.format(type_string))
# At this point we really have an anomalous type string. At some point, storing nodes with unresolvable type strings
# was allowed, for example by creating a sub class in a shell and then storing an instance. Attempting to load the
# node then would fail miserably. This is now no longer allowed, but we need a fallback for existing cases, which
# should be rare. We fallback on `Data` and not `Node` because bare node instances are also not storable and so the
# logic of the ORM is not well defined for a loaded instance of the base `Node` class.
warnings.warn('unknown type string `{}`, falling back onto `Data` class'.format(type_string)) # pylint: disable=no-member

return Data


def get_type_string_from_class(class_module, class_name):
Expand Down Expand Up @@ -247,13 +255,9 @@ def clean_builtin(val):


class AbstractNodeMeta(ABCMeta): # pylint: disable=too-few-public-methods
"""
Some python black magic to set correctly the logger also in subclasses.
"""

# pylint: disable=arguments-differ,protected-access,too-many-function-args
"""Some python black magic to set correctly the logger also in subclasses."""

def __new__(mcs, name, bases, namespace):
def __new__(mcs, name, bases, namespace): # pylint: disable=arguments-differ,protected-access,too-many-function-args
newcls = ABCMeta.__new__(mcs, name, bases, namespace)
newcls._logger = logging.getLogger('{}.{}'.format(namespace['__module__'], name))

Expand Down
33 changes: 29 additions & 4 deletions aiida/plugins/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def load_entry_point_from_string(entry_point_string):
group, name = parse_entry_point_string(entry_point_string)
return load_entry_point(group, name)


def load_entry_point(group, name):
"""
Load the class registered under the entry point for a given name and group
Expand Down Expand Up @@ -244,6 +245,7 @@ def get_entry_points(group):
"""
return [ep for ep in ENTRYPOINT_MANAGER.iter_entry_points(group=group)]


@functools.lru_cache(maxsize=None)
def get_entry_point(group, name):
"""
Expand All @@ -258,12 +260,12 @@ def get_entry_point(group, name):
entry_points = [ep for ep in get_entry_points(group) if ep.name == name]

if not entry_points:
raise MissingEntryPointError("Entry point '{}' not found in group '{}'.".format(name, group) +
'Try running `reentry scan` to update the entry point cache.')
raise MissingEntryPointError("Entry point '{}' not found in group '{}'. Try running `reentry scan` to update "
'the entry point cache.'.format(name, group))

if len(entry_points) > 1:
raise MultipleEntryPointError("Multiple entry points '{}' found in group '{}'. ".format(name, group) +
'Try running `reentry scan` to repopulate the entry point cache.')
raise MultipleEntryPointError("Multiple entry points '{}' found in group '{}'.Try running `reentry scan` to "
'repopulate the entry point cache.'.format(name, group))

return entry_points[0]

Expand Down Expand Up @@ -332,3 +334,26 @@ def is_valid_entry_point_string(entry_point_string):
return False

return group in entry_point_group_to_module_path_map


@functools.lru_cache(maxsize=None)
def is_registered_entry_point(class_module, class_name, groups=None):
"""Verify whether the class with the given module and class name is a registered entry point.

.. note:: this function only checks whether the class has a registered entry point. It does explicitly not verify
if the corresponding class is also importable. Use `load_entry_point` for this purpose instead.

:param class_module: the module of the class
:param class_name: the name of the class
:param groups: optionally consider only these entry point groups to look for the class
:return: boolean, True if the class is a registered entry point, False otherwise.
"""
if groups is None:
groups = list(entry_point_group_to_module_path_map.keys())

for group in groups:
for entry_point in ENTRYPOINT_MANAGER.iter_entry_points(group):
if class_module == entry_point.module_name and [class_name] == entry_point.attrs:
return True
else:
return False
13 changes: 6 additions & 7 deletions tests/cmdline/commands/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
###########################################################################
# pylint: disable=invalid-name,protected-access
"""Tests for `verdi database`."""

import enum

from click.testing import CliRunner

from aiida.backends.testbase import AiidaTestCase
from aiida.cmdline.commands import cmd_database
from aiida.common.links import LinkType
from aiida.orm import Data, Node, CalculationNode, WorkflowNode
from aiida.orm import Data, CalculationNode, WorkflowNode


class TestVerdiDatabasaIntegrity(AiidaTestCase):
Expand Down Expand Up @@ -162,11 +161,11 @@ def test_detect_invalid_nodes_unknown_node_type(self):
self.assertEqual(result.exit_code, 0)
self.assertClickResultNoException(result)

# Create a node with invalid type: a base Node type string is considered invalid
# Note that there is guard against storing base Nodes for this reason, which we temporarily disable
Node._storable = True
Node().store()
Node._storable = False
# Create a node with invalid type: since there are a lot of validation rules that prevent us from creating an
# invalid node type normally, we have to do it manually on the database model instance before storing
node = Data()
node.backend_entity.dbmodel.node_type = '__main__.SubClass.'
node.store()

result = self.cli_runner.invoke(cmd_database.detect_invalid_nodes, [])
self.assertNotEqual(result.exit_code, 0)
Expand Down
6 changes: 6 additions & 0 deletions tests/orm/utils/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Tests for the `Node` utils."""
import pytest

from aiida.backends.testbase import AiidaTestCase
from aiida.orm import Data
Expand All @@ -21,3 +22,8 @@ def test_load_node_class_fallback(self):
"""Verify that `load_node_class` will fall back to `Data` class if entry point cannot be loaded."""
loaded_class = load_node_class('data.some.non.existing.plugin.')
self.assertEqual(loaded_class, Data)

# For really unresolvable type strings, we fall back onto the `Data` class
with pytest.warns(UserWarning):
loaded_class = load_node_class('__main__.SubData.')
self.assertEqual(loaded_class, Data)
63 changes: 15 additions & 48 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,32 +24,28 @@


class TestNodeIsStorable(AiidaTestCase):
"""
Test if one can store specific Node subclasses, and that Node and
ProcessType are not storable, intead.
"""
"""Test that checks on storability of certain node sub classes work correctly."""

def test_storable_unstorable(self):
"""
Test storability of Nodes
"""
node = orm.Node()
def test_base_classes(self):
"""Test storability of `Node` base sub classes."""
with self.assertRaises(StoringNotAllowed):
node.store()
orm.Node().store()

process = orm.ProcessNode()
with self.assertRaises(StoringNotAllowed):
process.store()
orm.ProcessNode().store()

# These below should be allowed instead
data = orm.Data()
data.store()
# The following base classes are storable
orm.Data().store()
orm.CalculationNode().store()
orm.WorkflowNode().store()

calc = orm.CalculationNode()
calc.store()
def test_unregistered_sub_class(self):
"""Sub classes without a registered entry point are not storable."""
class SubData(orm.Data):
pass

work = orm.WorkflowNode()
work.store()
with self.assertRaises(StoringNotAllowed):
SubData().store()


class TestNodeCopyDeepcopy(AiidaTestCase):
Expand Down Expand Up @@ -1207,35 +1203,6 @@ def test_load_node(self):
with self.assertRaises(NotExistent):
orm.load_node(spec, sub_classes=(orm.ArrayData,))

def test_load_unknown_data_type(self):
"""
Test that the loader will choose a common data ancestor for an unknown data type.
For the case where, e.g., the user doesn't have the necessary plugin.
"""
from aiida.plugins import DataFactory

KpointsData = DataFactory('array.kpoints')
kpoint = KpointsData().store()

# compare if plugin exist
obj = orm.load_node(uuid=kpoint.uuid)
self.assertEqual(type(kpoint), type(obj))

class TestKpointsData(KpointsData):
pass

# change node type and save in database again
TestKpointsData().store()

# changed node should return data node as its plugin is not exist
obj = orm.load_node(uuid=kpoint.uuid)
self.assertEqual(type(kpoint), type(obj))

# for node
n1 = orm.Data().store()
obj = orm.load_node(n1.uuid)
self.assertEqual(type(n1), type(obj))


class TestSubNodesAndLinks(AiidaTestCase):

Expand Down