Skip to content

Commit

Permalink
Improve testing
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisjsewell committed Oct 16, 2021
1 parent d5d39fd commit 264c61b
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 91 deletions.
44 changes: 19 additions & 25 deletions aiida/orm/implementation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
###########################################################################
"""Generic backend related objects"""
import abc
from typing import TYPE_CHECKING, ContextManager, Generic, List, Sequence, TypeVar
from typing import TYPE_CHECKING, Any, ContextManager, List, Sequence, TypeVar

if TYPE_CHECKING:
from sqlalchemy.orm.session import Session
Expand All @@ -31,7 +31,7 @@
TransactionType = TypeVar('TransactionType')


class Backend(abc.ABC, Generic[TransactionType]):
class Backend(abc.ABC):
"""The public interface that defines a backend factory that creates backend specific concrete objects."""

@abc.abstractmethod
Expand Down Expand Up @@ -85,7 +85,7 @@ def get_session(self) -> 'Session':
"""

@abc.abstractmethod
def transaction(self) -> ContextManager[TransactionType]:
def transaction(self) -> ContextManager[Any]:
"""
Get a context manager that can be used as a transaction context for a series of backend operations.
If there is an exception within the context then the changes will be rolled back and the state will
Expand All @@ -94,32 +94,18 @@ def transaction(self) -> ContextManager[TransactionType]:
:return: a context manager to group database operations
"""

@property
@abc.abstractmethod
def delete_nodes_and_connections(self, pks_to_delete: Sequence[int], transaction: TransactionType):
"""Delete all nodes corresponding to pks in the input.
This method is intended to be used within a transaction context.
:param pks_to_delete: a sequence of node pks to delete
:param transact: the returned instance from entering transaction context
"""
def in_transaction(self) -> bool:
"""Return whether a transaction is currently active."""

@abc.abstractmethod
def bulk_insert(
self,
entity_type: 'EntityTypes',
rows: List[dict],
transaction: TransactionType,
allow_defaults: bool = False
) -> List[int]:
def bulk_insert(self, entity_type: 'EntityTypes', rows: List[dict], allow_defaults: bool = False) -> List[int]:
"""Insert a list of entities into the database, directly into a backend transaction.
This method is intended to be used within a transaction context.
:param entity_type: The type of the entity
:param data: A list of dictionaries, containing all fields of the backend model,
except the `id` field (a.k.a primary key), which will be generated dynamically
:param transaction: the returned object of the ``self.transaction`` context
:param allow_defaults: If ``False``, assert that each row contains all fields (except primary key(s)),
otherwise, allow default values for missing fields.
Expand All @@ -129,15 +115,23 @@ def bulk_insert(
"""

@abc.abstractmethod
def bulk_update(self, entity_type: 'EntityTypes', rows: List[dict], transaction: TransactionType) -> None:
def bulk_update(self, entity_type: 'EntityTypes', rows: List[dict]) -> None:
"""Update a list of entities in the database, directly with a backend transaction.
This method is intended to be used within a transaction context.
:param entity_type: The type of the entity
:param data: A list of dictionaries, containing fields of the backend model to update,
and the `id` field (a.k.a primary key)
:param transaction: the returned object of the ``self.transaction`` context
:raises: ``IntegrityError`` if the keys in a row are not a subset of the columns in the table
"""

@abc.abstractmethod
def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]):
"""Delete all nodes corresponding to pks in the input and any links to/from them.
This method is intended to be used within a transaction context.
:param pks_to_delete: a sequence of node pks to delete
:raises: ``AssertionError`` if a transaction is not active
"""
22 changes: 12 additions & 10 deletions aiida/orm/implementation/django/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""Django implementation of `aiida.orm.implementation.backends.Backend`."""
from contextlib import contextmanager
import functools
from typing import ContextManager, List, Sequence
from typing import Any, ContextManager, List, Sequence

# pylint: disable=import-error,no-name-in-module
from django.apps import apps
Expand All @@ -28,7 +28,7 @@
__all__ = ('DjangoBackend',)


class DjangoBackend(SqlBackend[None, models.Model]):
class DjangoBackend(SqlBackend[models.Model]):
"""Django implementation of `aiida.orm.implementation.backends.Backend`."""

def __init__(self):
Expand Down Expand Up @@ -89,10 +89,14 @@ def get_session():
return get_scoped_session()

@staticmethod
def transaction() -> ContextManager[None]:
def transaction() -> ContextManager[Any]:
"""Open a transaction to be used as a context manager."""
return django_transaction.atomic()

@property
def in_transaction(self) -> bool:
return not django_transaction.get_autocommit()

@staticmethod
@functools.lru_cache(maxsize=18)
def _get_model_from_entity(entity_type: EntityTypes, with_pk: bool):
Expand All @@ -118,11 +122,7 @@ def _get_model_from_entity(entity_type: EntityTypes, with_pk: bool):
keys = {key for key, col in mapper.c.items() if with_pk or col not in mapper.primary_key}
return model, keys

def bulk_insert(self,
entity_type: EntityTypes,
rows: List[dict],
transaction: None,
allow_defaults: bool = False) -> List[int]:
def bulk_insert(self, entity_type: EntityTypes, rows: List[dict], allow_defaults: bool = False) -> List[int]:
model, keys = self._get_model_from_entity(entity_type, False)
if allow_defaults:
for row in rows:
Expand All @@ -141,7 +141,7 @@ def bulk_insert(self,
model.objects.bulk_create(objects)
return [obj.id for obj in objects]

def bulk_update(self, entity_type: EntityTypes, rows: List[dict], transaction: None) -> None:
def bulk_update(self, entity_type: EntityTypes, rows: List[dict]) -> None:
model, keys = self._get_model_from_entity(entity_type, True)
id_entries = {}
fields = None
Expand All @@ -166,7 +166,9 @@ def bulk_update(self, entity_type: EntityTypes, rows: List[dict], transaction: N
objects.append(obj)
model.objects.bulk_update(objects, fields)

def delete_nodes_and_connections(self, pks_to_delete: Sequence[int], transaction: None) -> None:
def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]) -> None:
if not self.in_transaction:
raise AssertionError('Cannot delete nodes outside a transaction')
# Delete all links pointing to or from a given node
dbm.DbLink.objects.filter(models.Q(input__in=pks_to_delete) | models.Q(output__in=pks_to_delete)).delete()
# now delete nodes
Expand Down
2 changes: 1 addition & 1 deletion aiida/orm/implementation/sql/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
ModelType = typing.TypeVar('ModelType') # pylint: disable=invalid-name


class SqlBackend(typing.Generic[backends.TransactionType, ModelType], backends.Backend[backends.TransactionType]):
class SqlBackend(typing.Generic[ModelType], backends.Backend):
"""
A class for SQL based backends. Assumptions are that:
* there is an ORM
Expand Down
50 changes: 24 additions & 26 deletions aiida/orm/implementation/sqlalchemy/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
__all__ = ('SqlaBackend',)


class SqlaBackend(SqlBackend[Session, base.Base]):
class SqlaBackend(SqlBackend[base.Base]):
"""SqlAlchemy implementation of `aiida.orm.implementation.backends.Backend`."""

def __init__(self):
Expand Down Expand Up @@ -100,6 +100,10 @@ def transaction(self) -> Iterator[Session]:
with session.begin_nested():
yield session

@property
def in_transaction(self) -> bool:
return self.get_session().in_nested_transaction()

@staticmethod
@functools.lru_cache(maxsize=18)
def _get_mapper_from_entity(entity_type: EntityTypes, with_pk: bool):
Expand Down Expand Up @@ -131,13 +135,7 @@ def _get_mapper_from_entity(entity_type: EntityTypes, with_pk: bool):
keys = {key for key, col in mapper.c.items() if with_pk or col not in mapper.primary_key}
return mapper, keys

def bulk_insert(
self,
entity_type: EntityTypes,
rows: List[dict],
transaction: Session,
allow_defaults: bool = False
) -> List[int]:
def bulk_insert(self, entity_type: EntityTypes, rows: List[dict], allow_defaults: bool = False) -> List[int]:
mapper, keys = self._get_mapper_from_entity(entity_type, False)
if not rows:
return []
Expand All @@ -155,10 +153,10 @@ def bulk_insert(
# note for postgresql+psycopg2 we could also use `save_all` + `flush` with minimal performance degradation, see
# https://docs.sqlalchemy.org/en/14/changelog/migration_14.html#orm-batch-inserts-with-psycopg2-now-batch-statements-with-returning-in-most-cases
# by contrast, in sqlite, bulk_insert is faster: https://docs.sqlalchemy.org/en/14/faq/performance.html
transaction.bulk_insert_mappings(mapper, rows, render_nulls=True, return_defaults=True)
self.get_session().bulk_insert_mappings(mapper, rows, render_nulls=True, return_defaults=True)
return [row['id'] for row in rows]

def bulk_update(self, entity_type: EntityTypes, rows: List[dict], transaction: Session) -> None: # pylint: disable=no-self-use
def bulk_update(self, entity_type: EntityTypes, rows: List[dict]) -> None: # pylint: disable=no-self-use
mapper, keys = self._get_mapper_from_entity(entity_type, True)
if not rows:
return None
Expand All @@ -167,26 +165,26 @@ def bulk_update(self, entity_type: EntityTypes, rows: List[dict], transaction: S
raise IntegrityError(f"'id' field not given for {entity_type}: {set(row)}")
if not keys.issuperset(row):
raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}')
transaction.bulk_update_mappings(mapper, rows)
self.get_session().bulk_update_mappings(mapper, rows)

def delete_nodes_and_connections(self, pks_to_delete: Sequence[int], transaction: Session) -> None: # pylint: disable=no-self-use
def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]) -> None: # pylint: disable=no-self-use
# pylint: disable=no-value-for-parameter
from aiida.backends.sqlalchemy.models.group import table_groups_nodes
from aiida.backends.sqlalchemy.models.group import DbGroupNode
from aiida.backends.sqlalchemy.models.node import DbLink, DbNode

# I am first making a statement to delete the membership of these nodes to groups.
# Since table_groups_nodes is a sqlalchemy.schema.Table, I am using expression language to compile
# a stmt to be executed by the session. It works, but it's not nice that two different ways are used!
# Can this be changed?
stmt = table_groups_nodes.delete().where(table_groups_nodes.c.dbnode_id.in_(list(pks_to_delete)))
transaction.execute(stmt)
# First delete links, then the Nodes, since we are not cascading deletions.
# Here I delete the links coming out of the nodes marked for deletion.
transaction.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
# Here I delete the links pointing to the nodes marked for deletion.
transaction.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
# Now I am deleting the nodes
transaction.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
if not self.in_transaction:
raise AssertionError('Cannot delete nodes outside a transaction')

session = self.get_session()
# Delete the membership of these nodes to groups.
session.query(DbGroupNode).filter(DbGroupNode.dbnode_id.in_(list(pks_to_delete))
).delete(synchronize_session='fetch')
# Delete the links coming out of the nodes marked for deletion.
session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
# Delete the links pointing to the nodes marked for deletion.
session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
# Delete the actual nodes
session.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')

# Below are abstract methods inherited from `aiida.orm.implementation.sql.backends.SqlBackend`

Expand Down
4 changes: 2 additions & 2 deletions aiida/tools/graph/deletions.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def _missing_callback(_pks: Iterable[int]):
return (pks_set_to_delete, True)

DELETE_LOGGER.report('Starting node deletion...')
with backend.transaction() as transaction:
backend.delete_nodes_and_connections(pks_set_to_delete, transaction)
with backend.transaction():
backend.delete_nodes_and_connections(pks_set_to_delete)
DELETE_LOGGER.report('Deletion of nodes completed.')

return (pks_set_to_delete, True)
Expand Down
Loading

0 comments on commit 264c61b

Please sign in to comment.