diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fff8155541..810c81fd80 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -75,6 +75,10 @@ repos: aiida/manage/manager.py| aiida/manage/database/delete/nodes.py| aiida/orm/querybuilder.py| + aiida/orm/implementation/backends.py| + aiida/orm/implementation/sql/backends.py| + aiida/orm/implementation/django/backend.py| + aiida/orm/implementation/sqlalchemy/backend.py| aiida/orm/implementation/querybuilder.py| aiida/orm/implementation/sqlalchemy/querybuilder/.*py| aiida/orm/nodes/data/jsonable.py| diff --git a/aiida/backends/djsite/utils.py b/aiida/backends/djsite/utils.py deleted file mode 100644 index 74bfb56269..0000000000 --- a/aiida/backends/djsite/utils.py +++ /dev/null @@ -1,28 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Utility functions specific to the Django backend.""" - - -def delete_nodes_and_connections_django(pks_to_delete): # pylint: disable=invalid-name - """Delete all nodes corresponding to pks in the input. - - :param pks_to_delete: A list, tuple or set of pks that should be deleted. - """ - # pylint: disable=no-member,import-error,no-name-in-module - from django.db import transaction - from django.db.models import Q - - from aiida.backends.djsite.db import models - with transaction.atomic(): - # This is fixed in pylint-django>=2, but this supports only py3 - # Delete all links pointing to or from a given node - models.DbLink.objects.filter(Q(input__in=pks_to_delete) | Q(output__in=pks_to_delete)).delete() - # now delete nodes - models.DbNode.objects.filter(pk__in=pks_to_delete).delete() diff --git a/aiida/backends/sqlalchemy/models/group.py b/aiida/backends/sqlalchemy/models/group.py index f943e7a519..1fdb898987 100644 --- a/aiida/backends/sqlalchemy/models/group.py +++ b/aiida/backends/sqlalchemy/models/group.py @@ -31,6 +31,11 @@ ) +class DbGroupNode(Base): + """Class to store group to nodes relation using SQLA backend.""" + __table__ = table_groups_nodes + + class DbGroup(Base): """Class to store groups using SQLA backend.""" diff --git a/aiida/backends/sqlalchemy/utils.py b/aiida/backends/sqlalchemy/utils.py index a8d76265ef..780df99bf3 100644 --- a/aiida/backends/sqlalchemy/utils.py +++ b/aiida/backends/sqlalchemy/utils.py @@ -11,34 +11,6 @@ """Utility functions specific to the SqlAlchemy backend.""" -def delete_nodes_and_connections_sqla(pks_to_delete): # pylint: disable=invalid-name - """ - Delete all nodes corresponding to pks in the input. - :param pks_to_delete: A list, tuple or set of pks that should be deleted. - """ - # pylint: disable=no-value-for-parameter - from aiida.backends.sqlalchemy.models.group import table_groups_nodes - from aiida.backends.sqlalchemy.models.node import DbLink, DbNode - from aiida.manage.manager import get_manager - - backend = get_manager().get_backend() - - with backend.transaction() as session: - # I am first making a statement to delete the membership of these nodes to groups. - # Since table_groups_nodes is a sqlalchemy.schema.Table, I am using expression language to compile - # a stmt to be executed by the session. It works, but it's not nice that two different ways are used! - # Can this be changed? - stmt = table_groups_nodes.delete().where(table_groups_nodes.c.dbnode_id.in_(list(pks_to_delete))) - session.execute(stmt) - # First delete links, then the Nodes, since we are not cascading deletions. - # Here I delete the links coming out of the nodes marked for deletion. - session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') - # Here I delete the links pointing to the nodes marked for deletion. - session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') - # Now I am deleting the nodes - session.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') - - def flag_modified(instance, key): """Wrapper around `sqlalchemy.orm.attributes.flag_modified` to correctly dereference utils.ModelWrapper diff --git a/aiida/backends/utils.py b/aiida/backends/utils.py index 234412e1f1..d73be4674d 100644 --- a/aiida/backends/utils.py +++ b/aiida/backends/utils.py @@ -8,9 +8,6 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Backend-agnostic utility functions""" -from aiida.backends import BACKEND_DJANGO, BACKEND_SQLA -from aiida.manage import configuration - AIIDA_ATTRIBUTE_SEP = '.' @@ -47,15 +44,3 @@ def create_scoped_session_factory(engine, **kwargs): """Create scoped SQLAlchemy session factory""" from sqlalchemy.orm import scoped_session, sessionmaker return scoped_session(sessionmaker(bind=engine, future=True, **kwargs)) - - -def delete_nodes_and_connections(pks): - """Backend-agnostic function to delete Nodes and connections""" - if configuration.PROFILE.database_backend == BACKEND_DJANGO: - from aiida.backends.djsite.utils import delete_nodes_and_connections_django as delete_nodes_backend - elif configuration.PROFILE.database_backend == BACKEND_SQLA: - from aiida.backends.sqlalchemy.utils import delete_nodes_and_connections_sqla as delete_nodes_backend - else: - raise Exception(f'unknown backend {configuration.PROFILE.database_backend}') - - delete_nodes_backend(pks) diff --git a/aiida/orm/entities.py b/aiida/orm/entities.py index f5019e2a99..19477a8671 100644 --- a/aiida/orm/entities.py +++ b/aiida/orm/entities.py @@ -10,6 +10,7 @@ """Module for all common top level AiiDA entity classes and methods""" import abc import copy +from enum import Enum import typing from plumpy.base.utils import call_with_super_check, super_check @@ -25,6 +26,19 @@ _NO_DEFAULT = tuple() +class EntityTypes(Enum): + """Enum for referring to ORM entities in a backend-agnostic manner.""" + AUTHINFO = 'authinfo' + COMMENT = 'comment' + COMPUTER = 'computer' + GROUP = 'group' + LOG = 'log' + NODE = 'node' + USER = 'user' + LINK = 'link' + GROUP_NODE = 'group_node' + + class Collection(typing.Generic[EntityType]): """Container class that represents the collection of objects of a particular type.""" diff --git a/aiida/orm/implementation/backends.py b/aiida/orm/implementation/backends.py index a0d43a7b43..b1273661d9 100644 --- a/aiida/orm/implementation/backends.py +++ b/aiida/orm/implementation/backends.py @@ -9,11 +9,12 @@ ########################################################################### """Generic backend related objects""" import abc -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, ContextManager, List, Sequence, TypeVar if TYPE_CHECKING: from sqlalchemy.orm.session import Session + from aiida.orm.entities import EntityTypes from aiida.orm.implementation import ( BackendAuthInfoCollection, BackendCommentCollection, @@ -27,12 +28,14 @@ __all__ = ('Backend',) +TransactionType = TypeVar('TransactionType') + class Backend(abc.ABC): """The public interface that defines a backend factory that creates backend specific concrete objects.""" @abc.abstractmethod - def migrate(self): + def migrate(self) -> None: """Migrate the database to the latest schema generation or version.""" @property @@ -65,17 +68,24 @@ def logs(self) -> 'BackendLogCollection': def nodes(self) -> 'BackendNodeCollection': """Return the collection of nodes""" + @property + @abc.abstractmethod + def users(self) -> 'BackendUserCollection': + """Return the collection of users""" + @abc.abstractmethod def query(self) -> 'BackendQueryBuilder': """Return an instance of a query builder implementation for this backend""" - @property @abc.abstractmethod - def users(self) -> 'BackendUserCollection': - """Return the collection of users""" + def get_session(self) -> 'Session': + """Return a database session that can be used by the `QueryBuilder` to perform its query. + + :return: an instance of :class:`sqlalchemy.orm.session.Session` + """ @abc.abstractmethod - def transaction(self): + def transaction(self) -> ContextManager[Any]: """ Get a context manager that can be used as a transaction context for a series of backend operations. If there is an exception within the context then the changes will be rolled back and the state will @@ -84,9 +94,44 @@ def transaction(self): :return: a context manager to group database operations """ + @property @abc.abstractmethod - def get_session(self) -> 'Session': - """Return a database session that can be used by the `QueryBuilder` to perform its query. + def in_transaction(self) -> bool: + """Return whether a transaction is currently active.""" - :return: an instance of :class:`sqlalchemy.orm.session.Session` + @abc.abstractmethod + def bulk_insert(self, entity_type: 'EntityTypes', rows: List[dict], allow_defaults: bool = False) -> List[int]: + """Insert a list of entities into the database, directly into a backend transaction. + + :param entity_type: The type of the entity + :param data: A list of dictionaries, containing all fields of the backend model, + except the `id` field (a.k.a primary key), which will be generated dynamically + :param allow_defaults: If ``False``, assert that each row contains all fields (except primary key(s)), + otherwise, allow default values for missing fields. + + :raises: ``IntegrityError`` if the keys in a row are not a subset of the columns in the table + + :returns: The list of generated primary keys for the entities + """ + + @abc.abstractmethod + def bulk_update(self, entity_type: 'EntityTypes', rows: List[dict]) -> None: + """Update a list of entities in the database, directly with a backend transaction. + + :param entity_type: The type of the entity + :param data: A list of dictionaries, containing fields of the backend model to update, + and the `id` field (a.k.a primary key) + + :raises: ``IntegrityError`` if the keys in a row are not a subset of the columns in the table + """ + + @abc.abstractmethod + def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]): + """Delete all nodes corresponding to pks in the input and any links to/from them. + + This method is intended to be used within a transaction context. + + :param pks_to_delete: a sequence of node pks to delete + + :raises: ``AssertionError`` if a transaction is not active """ diff --git a/aiida/orm/implementation/django/backend.py b/aiida/orm/implementation/django/backend.py index b6056bda35..915000c170 100644 --- a/aiida/orm/implementation/django/backend.py +++ b/aiida/orm/implementation/django/backend.py @@ -9,11 +9,18 @@ ########################################################################### """Django implementation of `aiida.orm.implementation.backends.Backend`.""" from contextlib import contextmanager +import functools +from typing import Any, ContextManager, List, Sequence # pylint: disable=import-error,no-name-in-module -from django.db import models, transaction +from django.apps import apps +from django.db import models +from django.db import transaction as django_transaction +from aiida.backends.djsite.db import models as dbm from aiida.backends.djsite.manager import DjangoBackendManager +from aiida.common.exceptions import IntegrityError +from aiida.orm.entities import EntityTypes from . import authinfos, comments, computers, convert, groups, logs, nodes, querybuilder, users from ..sql.backends import SqlBackend @@ -69,11 +76,6 @@ def query(self): def users(self): return self._users - @staticmethod - def transaction(): - """Open a transaction to be used as a context manager.""" - return transaction.atomic() - @staticmethod def get_session(): """Return a database session that can be used by the `QueryBuilder` to perform its query. @@ -86,6 +88,92 @@ def get_session(): from aiida.backends.djsite import get_scoped_session return get_scoped_session() + @staticmethod + def transaction() -> ContextManager[Any]: + """Open a transaction to be used as a context manager.""" + return django_transaction.atomic() + + @property + def in_transaction(self) -> bool: + return not django_transaction.get_autocommit() + + @staticmethod + @functools.lru_cache(maxsize=18) + def _get_model_from_entity(entity_type: EntityTypes, with_pk: bool): + """Return the Django model and fields corresponding to the given entity. + + :param with_pk: if True, the fields returned will include the primary key + """ + from sqlalchemy import inspect + + model = { + EntityTypes.AUTHINFO: dbm.DbAuthInfo, + EntityTypes.COMMENT: dbm.DbComment, + EntityTypes.COMPUTER: dbm.DbComputer, + EntityTypes.GROUP: dbm.DbGroup, + EntityTypes.LOG: dbm.DbLog, + EntityTypes.NODE: dbm.DbNode, + EntityTypes.USER: dbm.DbUser, + EntityTypes.LINK: dbm.DbLink, + EntityTypes.GROUP_NODE: + {model._meta.db_table: model for model in apps.get_models(include_auto_created=True)}['db_dbgroup_dbnodes'] + }[entity_type] + mapper = inspect(model.sa).mapper # here aldjemy provides us the SQLAlchemy model + keys = {key for key, col in mapper.c.items() if with_pk or col not in mapper.primary_key} + return model, keys + + def bulk_insert(self, entity_type: EntityTypes, rows: List[dict], allow_defaults: bool = False) -> List[int]: + model, keys = self._get_model_from_entity(entity_type, False) + if allow_defaults: + for row in rows: + if not keys.issuperset(row): + raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}') + else: + for row in rows: + if set(row) != keys: + raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} != {keys}') + objects = [model(**row) for row in rows] + # if there is an mtime field, disable the automatic update, so as not to change it + if entity_type in (EntityTypes.NODE, EntityTypes.COMMENT): + with dbm.suppress_auto_now([(model, ['mtime'])]): + model.objects.bulk_create(objects) + else: + model.objects.bulk_create(objects) + return [obj.id for obj in objects] + + def bulk_update(self, entity_type: EntityTypes, rows: List[dict]) -> None: + model, keys = self._get_model_from_entity(entity_type, True) + id_entries = {} + fields = None + for row in rows: + if not keys.issuperset(row): + raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}') + try: + id_entries[row['id']] = {k: v for k, v in row.items() if k != 'id'} + fields = fields or list(id_entries[row['id']]) + assert fields == list(id_entries[row['id']]) + except KeyError: + raise IntegrityError(f"'id' field not given for {entity_type}: {set(row)}") + except AssertionError: + # this is handled in sqlalchemy, but would require more complex logic here + raise NotImplementedError(f'Cannot bulk update {entity_type} with different fields') + if fields is None: + return + objects = [] + for pk, obj in model.objects.in_bulk(list(id_entries), field_name='id').items(): + for name, value in id_entries[pk].items(): + setattr(obj, name, value) + objects.append(obj) + model.objects.bulk_update(objects, fields) + + def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]) -> None: + if not self.in_transaction: + raise AssertionError('Cannot delete nodes and links outside a transaction') + # Delete all links pointing to or from a given node + dbm.DbLink.objects.filter(models.Q(input__in=pks_to_delete) | models.Q(output__in=pks_to_delete)).delete() + # now delete nodes + dbm.DbNode.objects.filter(pk__in=pks_to_delete).delete() + # Below are abstract methods inherited from `aiida.orm.implementation.sql.backends.SqlBackend` def get_backend_entity(self, model): @@ -100,7 +188,7 @@ def cursor(self): :rtype: :class:`psycopg2.extensions.cursor` """ try: - yield self.get_connection().cursor() + yield self._get_connection().cursor() finally: pass @@ -117,7 +205,7 @@ def execute_raw(self, query): return results @staticmethod - def get_connection(): + def _get_connection(): """ Get the Django connection diff --git a/aiida/orm/implementation/sql/backends.py b/aiida/orm/implementation/sql/backends.py index 2bb21f22af..1423ce5d22 100644 --- a/aiida/orm/implementation/sql/backends.py +++ b/aiida/orm/implementation/sql/backends.py @@ -11,11 +11,11 @@ import abc import typing -from .. import backends +from .. import backends, entities __all__ = ('SqlBackend',) -# The template type for the base ORM model type +# The template type for the base sqlalchemy/django ORM model type ModelType = typing.TypeVar('ModelType') # pylint: disable=invalid-name @@ -30,13 +30,12 @@ class SqlBackend(typing.Generic[ModelType], backends.Backend): """ @abc.abstractmethod - def get_backend_entity(self, model): + def get_backend_entity(self, model: ModelType) -> entities.BackendEntity: """ Return the backend entity that corresponds to the given Model instance :param model: the ORM model instance to promote to a backend instance :return: the backend entity corresponding to the given model - :rtype: :class:`aiida.orm.implementation.entities.BackendEntity` """ @abc.abstractmethod diff --git a/aiida/orm/implementation/sqlalchemy/backend.py b/aiida/orm/implementation/sqlalchemy/backend.py index 64a7109bf9..01a34c125b 100644 --- a/aiida/orm/implementation/sqlalchemy/backend.py +++ b/aiida/orm/implementation/sqlalchemy/backend.py @@ -8,10 +8,17 @@ # For further information please visit http://www.aiida.net # ########################################################################### """SqlAlchemy implementation of `aiida.orm.implementation.backends.Backend`.""" -from contextlib import contextmanager +# pylint: disable=missing-function-docstring +from contextlib import contextmanager, nullcontext +import functools +from typing import Iterator, List, Sequence + +from sqlalchemy.orm import Session from aiida.backends.sqlalchemy.manager import SqlaBackendManager from aiida.backends.sqlalchemy.models import base +from aiida.common.exceptions import IntegrityError +from aiida.orm.entities import EntityTypes from . import authinfos, comments, computers, convert, groups, logs, nodes, querybuilder, users from ..sql.backends import SqlBackend @@ -67,8 +74,17 @@ def query(self): def users(self): return self._users + @staticmethod + def get_session() -> Session: + """Return a database session that can be used by the `QueryBuilder` to perform its query. + + :return: an instance of :class:`sqlalchemy.orm.session.Session` + """ + from aiida.backends.sqlalchemy import get_scoped_session + return get_scoped_session() + @contextmanager - def transaction(self): + def transaction(self) -> Iterator[Session]: """Open a transaction to be used as a context manager. If there is an exception within the context then the changes will be rolled back and the state will be as before @@ -78,46 +94,117 @@ def transaction(self): if session.in_transaction(): with session.begin_nested(): yield session + session.commit() else: with session.begin(): with session.begin_nested(): yield session + @property + def in_transaction(self) -> bool: + return self.get_session().in_nested_transaction() + @staticmethod - def get_session(): - """Return a database session that can be used by the `QueryBuilder` to perform its query. + @functools.lru_cache(maxsize=18) + def _get_mapper_from_entity(entity_type: EntityTypes, with_pk: bool): + """Return the Sqlalchemy mapper and fields corresponding to the given entity. - :return: an instance of :class:`sqlalchemy.orm.session.Session` + :param with_pk: if True, the fields returned will include the primary key """ - from aiida.backends.sqlalchemy import get_scoped_session - return get_scoped_session() + from sqlalchemy import inspect + + from aiida.backends.sqlalchemy.models.authinfo import DbAuthInfo + from aiida.backends.sqlalchemy.models.comment import DbComment + from aiida.backends.sqlalchemy.models.computer import DbComputer + from aiida.backends.sqlalchemy.models.group import DbGroup, DbGroupNode + from aiida.backends.sqlalchemy.models.log import DbLog + from aiida.backends.sqlalchemy.models.node import DbLink, DbNode + from aiida.backends.sqlalchemy.models.user import DbUser + model = { + EntityTypes.AUTHINFO: DbAuthInfo, + EntityTypes.COMMENT: DbComment, + EntityTypes.COMPUTER: DbComputer, + EntityTypes.GROUP: DbGroup, + EntityTypes.LOG: DbLog, + EntityTypes.NODE: DbNode, + EntityTypes.USER: DbUser, + EntityTypes.LINK: DbLink, + EntityTypes.GROUP_NODE: DbGroupNode, + }[entity_type] + mapper = inspect(model).mapper + keys = {key for key, col in mapper.c.items() if with_pk or col not in mapper.primary_key} + return mapper, keys + + def bulk_insert(self, entity_type: EntityTypes, rows: List[dict], allow_defaults: bool = False) -> List[int]: + mapper, keys = self._get_mapper_from_entity(entity_type, False) + if not rows: + return [] + if entity_type in (EntityTypes.COMPUTER, EntityTypes.LOG): + for row in rows: + row['_metadata'] = row.pop('metadata') + if allow_defaults: + for row in rows: + if not keys.issuperset(row): + raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}') + else: + for row in rows: + if set(row) != keys: + raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} != {keys}') + # note for postgresql+psycopg2 we could also use `save_all` + `flush` with minimal performance degradation, see + # https://docs.sqlalchemy.org/en/14/changelog/migration_14.html#orm-batch-inserts-with-psycopg2-now-batch-statements-with-returning-in-most-cases + # by contrast, in sqlite, bulk_insert is faster: https://docs.sqlalchemy.org/en/14/faq/performance.html + session = self.get_session() + with (nullcontext() if self.in_transaction else self.transaction()): # type: ignore[attr-defined] + session.bulk_insert_mappings(mapper, rows, render_nulls=True, return_defaults=True) + return [row['id'] for row in rows] + + def bulk_update(self, entity_type: EntityTypes, rows: List[dict]) -> None: # pylint: disable=no-self-use + mapper, keys = self._get_mapper_from_entity(entity_type, True) + if not rows: + return None + for row in rows: + if 'id' not in row: + raise IntegrityError(f"'id' field not given for {entity_type}: {set(row)}") + if not keys.issuperset(row): + raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}') + session = self.get_session() + with (nullcontext() if self.in_transaction else self.transaction()): # type: ignore[attr-defined] + session.bulk_update_mappings(mapper, rows) + + def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]) -> None: # pylint: disable=no-self-use + # pylint: disable=no-value-for-parameter + from aiida.backends.sqlalchemy.models.group import DbGroupNode + from aiida.backends.sqlalchemy.models.node import DbLink, DbNode + + if not self.in_transaction: + raise AssertionError('Cannot delete nodes and links outside a transaction') + + session = self.get_session() + # Delete the membership of these nodes to groups. + session.query(DbGroupNode).filter(DbGroupNode.dbnode_id.in_(list(pks_to_delete)) + ).delete(synchronize_session='fetch') + # Delete the links coming out of the nodes marked for deletion. + session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') + # Delete the links pointing to the nodes marked for deletion. + session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') + # Delete the actual nodes + session.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') # Below are abstract methods inherited from `aiida.orm.implementation.sql.backends.SqlBackend` def get_backend_entity(self, model): - """Return a `BackendEntity` instance from a `DbModel` instance.""" return convert.get_backend_entity(model, self) @contextmanager def cursor(self): - """Return a psycopg cursor to be used in a context manager. - - :return: a psycopg cursor - :rtype: :class:`psycopg2.extensions.cursor` - """ from aiida.backends import sqlalchemy as sa try: connection = sa.ENGINE.raw_connection() yield connection.cursor() finally: - self.get_connection().close() + self._get_connection().close() def execute_raw(self, query): - """Execute a raw SQL statement and return the result. - - :param query: a string containing a raw SQL statement - :return: the result of the query - """ from sqlalchemy import text from sqlalchemy.exc import ResourceClosedError # pylint: disable=import-error,no-name-in-module @@ -132,7 +219,7 @@ def execute_raw(self, query): return results @staticmethod - def get_connection(): + def _get_connection(): """Get the SQLA database connection :return: the SQLA database connection diff --git a/aiida/tools/graph/deletions.py b/aiida/tools/graph/deletions.py index 57f785e9c2..d14d9c7dd5 100644 --- a/aiida/tools/graph/deletions.py +++ b/aiida/tools/graph/deletions.py @@ -11,8 +11,8 @@ import logging from typing import Callable, Iterable, Set, Tuple, Union -from aiida.backends.utils import delete_nodes_and_connections from aiida.common.log import AIIDA_LOGGER +from aiida.manage.manager import get_manager from aiida.orm import Group, Node, QueryBuilder from aiida.tools.graph.graph_traversers import get_nodes_delete @@ -21,9 +21,12 @@ DELETE_LOGGER = AIIDA_LOGGER.getChild('delete') -def delete_nodes(pks: Iterable[int], - dry_run: Union[bool, Callable[[Set[int]], bool]] = True, - **traversal_rules: bool) -> Tuple[Set[int], bool]: +def delete_nodes( + pks: Iterable[int], + dry_run: Union[bool, Callable[[Set[int]], bool]] = True, + backend=None, + **traversal_rules: bool +) -> Tuple[Set[int], bool]: """Delete nodes given a list of "starting" PKs. This command will delete not only the specified nodes, but also the ones that are @@ -60,6 +63,7 @@ def delete_nodes(pks: Iterable[int], :returns: (pks to delete, whether they were deleted) """ + backend = backend or get_manager().get_backend() # pylint: disable=too-many-arguments,too-many-branches,too-many-locals,too-many-statements @@ -99,7 +103,8 @@ def _missing_callback(_pks: Iterable[int]): return (pks_set_to_delete, True) DELETE_LOGGER.report('Starting node deletion...') - delete_nodes_and_connections(pks_set_to_delete) + with backend.transaction(): + backend.delete_nodes_and_connections(pks_set_to_delete) DELETE_LOGGER.report('Deletion of nodes completed.') return (pks_set_to_delete, True) diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index 44cf0958fc..b4670ae5f3 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -19,8 +19,12 @@ py:class builtins.str py:class builtins.dict # typing +py:class AbstractContextManager py:class asyncio.events.AbstractEventLoop py:class EntityType +py:class EntityTypes +py:class ModelType +py:class TransactionType py:class ReturnType py:class function py:class IO @@ -36,6 +40,7 @@ py:class aiida.engine.runners.ResultAndNode py:class aiida.engine.runners.ResultAndPk py:class aiida.engine.processes.workchains.workchain.WorkChainSpec py:class aiida.manage.manager.Manager +py:class aiida.orm.entities.EntityTypes py:class aiida.orm.nodes.node.WarnWhenNotEntered py:class aiida.orm.implementation.querybuilder.QueryDictType py:class aiida.orm.querybuilder.Classifier diff --git a/tests/orm/implementation/test_backend.py b/tests/orm/implementation/test_backend.py index b7b89e7fe5..82d3b6f72b 100644 --- a/tests/orm/implementation/test_backend.py +++ b/tests/orm/implementation/test_backend.py @@ -8,14 +8,23 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Unit tests for the ORM Backend class.""" +import pytest + from aiida import orm -from aiida.backends.testbase import AiidaTestCase from aiida.common import exceptions +from aiida.common.links import LinkType +from aiida.orm.entities import EntityTypes -class TestBackend(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test') +class TestBackend: """Test backend.""" + @pytest.fixture(autouse=True) + def init_test(self, backend): + """Set up the backend.""" + self.backend = backend # pylint: disable=attribute-defined-outside-init + def test_transaction_nesting(self): """Test that transaction nesting works.""" user = orm.User('initial@email.com').store() @@ -24,12 +33,12 @@ def test_transaction_nesting(self): try: with self.backend.transaction(): user.email = 'failure@email.com' - self.assertEqual(user.email, 'failure@email.com') + assert user.email == 'failure@email.com' raise RuntimeError except RuntimeError: pass - self.assertEqual(user.email, 'pre-failure@email.com') - self.assertEqual(user.email, 'pre-failure@email.com') + assert user.email == 'pre-failure@email.com' + assert user.email == 'pre-failure@email.com' def test_transaction(self): """Test that transaction nesting works.""" @@ -38,13 +47,14 @@ def test_transaction(self): try: with self.backend.transaction(): + assert self.backend.in_transaction user1.email = 'broken1@email.com' user2.email = 'broken2@email.com' raise RuntimeError except RuntimeError: pass - self.assertEqual(user1.email, 'user1@email.com') - self.assertEqual(user2.email, 'user2@email.com') + assert user1.email == 'user1@email.com' + assert user2.email == 'user2@email.com' def test_store_in_transaction(self): """Test that storing inside a transaction is correctly dealt with.""" @@ -62,5 +72,102 @@ def test_store_in_transaction(self): except RuntimeError: pass - with self.assertRaises(exceptions.NotExistent): + with pytest.raises(exceptions.NotExistent): orm.User.objects.get(email='user_store_fail@email.com') + + def test_bulk_insert(self): + """Test that bulk insert works.""" + rows = [{'email': 'user1@email.com'}, {'email': 'user2@email.com'}] + # should fail if all fields are not given and allow_defaults=False + with pytest.raises(exceptions.IntegrityError, match='Incorrect fields'): + self.backend.bulk_insert(EntityTypes.USER, rows) + pks = self.backend.bulk_insert(EntityTypes.USER, rows, allow_defaults=True) + assert len(pks) == len(rows) + for pk, row in zip(pks, rows): + assert isinstance(pk, int) + user = orm.User.objects.get(id=pk) + assert user.email == row['email'] + + def test_bulk_insert_in_transaction(self): + """Test that bulk insert in a cancelled transaction is not committed.""" + rows = [{'email': 'user1@email.com'}, {'email': 'user2@email.com'}] + try: + with self.backend.transaction(): + self.backend.bulk_insert(EntityTypes.USER, rows, allow_defaults=True) + raise RuntimeError + except RuntimeError: + pass + for row in rows: + with pytest.raises(exceptions.NotExistent): + orm.User.objects.get(email=row['email']) + + def test_bulk_update(self): + """Test that bulk update works.""" + users = [orm.User(f'user{i}@email.com').store() for i in range(3)] + # should raise if the 'id' field is not present + with pytest.raises(exceptions.IntegrityError, match="'id' field not given"): + self.backend.bulk_update(EntityTypes.USER, [{'email': 'other'}]) + # should raise if a non-existent field is present + with pytest.raises(exceptions.IntegrityError, match='Incorrect fields'): + self.backend.bulk_update(EntityTypes.USER, [{'id': users[0].pk, 'x': 'other'}]) + self.backend.bulk_update( + EntityTypes.USER, [{ + 'id': users[0].pk, + 'email': 'other0' + }, { + 'id': users[1].pk, + 'email': 'other1' + }] + ) + assert users[0].email == 'other0' + assert users[1].email == 'other1' + assert users[2].email == 'user2@email.com' + + def test_bulk_update_in_transaction(self): + """Test that bulk update in a cancelled transaction is not committed.""" + users = [orm.User(f'user{i}@email.com').store() for i in range(3)] + try: + with self.backend.transaction(): + self.backend.bulk_update( + EntityTypes.USER, [{ + 'id': users[0].pk, + 'email': 'other0' + }, { + 'id': users[1].pk, + 'email': 'other1' + }] + ) + raise RuntimeError + except RuntimeError: + pass + for i, user in enumerate(users): + assert user.email == f'user{i}@email.com' + + def test_delete_nodes_and_connections(self): + """Delete all nodes and connections.""" + # create node, link and add to group + node = orm.Data() + calc_node = orm.CalcFunctionNode().store() + node.add_incoming(calc_node, link_type=LinkType.CREATE, link_label='link') + node.store() + node_pk = node.pk + group = orm.Group('name').store() + group.add_nodes([node]) + + # checks before deletion + orm.Node.objects.get(id=node_pk) + assert len(calc_node.get_outgoing().all()) == 1 + assert len(group.nodes) == 1 + + # cannot call outside a transaction + with pytest.raises(AssertionError): + self.backend.delete_nodes_and_connections([node_pk]) + + with self.backend.transaction(): + self.backend.delete_nodes_and_connections([node_pk]) + + # checks after deletion + with pytest.raises(exceptions.NotExistent): + orm.Node.objects.get(id=node_pk) + assert len(calc_node.get_outgoing().all()) == 0 + assert len(group.nodes) == 0 diff --git a/tests/orm/node/test_node.py b/tests/orm/node/test_node.py index 1533fa7c24..20d2d12c62 100644 --- a/tests/orm/node/test_node.py +++ b/tests/orm/node/test_node.py @@ -16,7 +16,8 @@ import pytest -from aiida.common import LinkType, exceptions +from aiida.common import LinkType, exceptions, timezone +from aiida.manage.manager import get_manager from aiida.orm import CalculationNode, Computer, Data, Log, Node, User, WorkflowNode, load_node from aiida.orm.utils.links import LinkTriple @@ -800,10 +801,9 @@ class TestNodeDelete: # pylint: disable=no-member,no-self-use @pytest.mark.usefixtures('clear_database_before_test') - def test_delete_through_utility_method(self): - """Test deletion works correctly through the `aiida.backends.utils.delete_nodes_and_connections`.""" - from aiida.backends.utils import delete_nodes_and_connections - from aiida.common import timezone + def test_delete_through_backend(self): + """Test deletion works correctly through the backend.""" + backend = get_manager().get_backend() data_one = Data().store() data_two = Data().store() @@ -820,7 +820,8 @@ def test_delete_through_utility_method(self): assert len(Log.objects.get_logs_for(data_two)) == 1 assert Log.objects.get_logs_for(data_two)[0].pk == log_two.pk - delete_nodes_and_connections([data_two.pk]) + with backend.transaction(): + backend.delete_nodes_and_connections([data_two.pk]) assert len(Log.objects.get_logs_for(data_one)) == 1 assert Log.objects.get_logs_for(data_one)[0].pk == log_one.pk @@ -829,8 +830,6 @@ def test_delete_through_utility_method(self): @pytest.mark.usefixtures('clear_database_before_test') def test_delete_collection_logs(self): """Test deletion works correctly through objects collection.""" - from aiida.common import timezone - data_one = Data().store() data_two = Data().store()