diff --git a/.github/system_tests/test_daemon.py b/.github/system_tests/test_daemon.py index 631d7a1784..e91112095f 100644 --- a/.github/system_tests/test_daemon.py +++ b/.github/system_tests/test_daemon.py @@ -9,11 +9,14 @@ ########################################################################### # pylint: disable=no-name-in-module """Tests to run with a running daemon.""" +import os +import shutil import subprocess import sys +import tempfile import time -from aiida.common import exceptions +from aiida.common import exceptions, StashMode from aiida.engine import run, submit from aiida.engine.daemon.client import get_daemon_client from aiida.engine.persistence import ObjectLoader @@ -415,6 +418,24 @@ def launch_all(): print('Running the `MultiplyAddWorkChain`') run_multiply_add_workchain() + # Testing the stashing functionality + process, inputs, expected_result = create_calculation_process(code=code_doubler, inputval=1) + with tempfile.TemporaryDirectory() as tmpdir: + + # Delete the temporary directory to test that the stashing functionality will create it if necessary + shutil.rmtree(tmpdir, ignore_errors=True) + + source_list = ['output.txt', 'triple_value.tmp'] + inputs['metadata']['options']['stash'] = {'target_base': tmpdir, 'source_list': source_list} + _, node = run.get_node(process, **inputs) + assert node.is_finished_ok + assert 'remote_stash' in node.outputs + remote_stash = node.outputs.remote_stash + assert remote_stash.stash_mode == StashMode.COPY + assert remote_stash.target_basepath.startswith(tmpdir) + assert sorted(remote_stash.source_list) == sorted(source_list) + assert sorted(p for p in os.listdir(remote_stash.target_basepath)) == sorted(source_list) + # Submitting the calcfunction through the launchers print('Submitting calcfunction to the daemon') proc, expected_result = launch_calcfunction(inputval=1) diff --git a/aiida/common/datastructures.py b/aiida/common/datastructures.py index 166c7d88d9..271cdaec48 100644 --- a/aiida/common/datastructures.py +++ b/aiida/common/datastructures.py @@ -12,7 +12,13 @@ from .extendeddicts import DefaultFieldsAttributeDict -__all__ = ('CalcJobState', 'CalcInfo', 'CodeInfo', 'CodeRunMode') +__all__ = ('StashMode', 'CalcJobState', 'CalcInfo', 'CodeInfo', 'CodeRunMode') + + +class StashMode(Enum): + """Mode to use when stashing files from the working directory of a completed calculation job for safekeeping.""" + + COPY = 'copy' class CalcJobState(Enum): @@ -21,6 +27,7 @@ class CalcJobState(Enum): UPLOADING = 'uploading' SUBMITTING = 'submitting' WITHSCHEDULER = 'withscheduler' + STASHING = 'stashing' RETRIEVING = 'retrieving' PARSING = 'parsing' diff --git a/aiida/engine/daemon/execmanager.py b/aiida/engine/daemon/execmanager.py index fa086e4d94..02dc638a99 100644 --- a/aiida/engine/daemon/execmanager.py +++ b/aiida/engine/daemon/execmanager.py @@ -332,6 +332,65 @@ def submit_calculation(calculation: CalcJobNode, transport: Transport) -> str: return job_id +def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None: + """Stash files from the working directory of a completed calculation to a permanent remote folder. + + After a calculation has been completed, optionally stash files from the work directory to a storage location on the + same remote machine. This is useful if one wants to keep certain files from a completed calculation to be removed + from the scratch directory, because they are necessary for restarts, but that are too heavy to retrieve. + Instructions of which files to copy where are retrieved from the `stash.source_list` option. + + :param calculation: the calculation job node. + :param transport: an already opened transport. + """ + from aiida.common.datastructures import StashMode + from aiida.orm import RemoteStashFolderData + + logger_extra = get_dblogger_extra(calculation) + + stash_options = calculation.get_option('stash') + stash_mode = stash_options.get('mode', StashMode.COPY.value) + source_list = stash_options.get('source_list', []) + + if not source_list: + return + + if stash_mode != StashMode.COPY.value: + EXEC_LOGGER.warning(f'stashing mode {stash_mode} is not implemented yet.') + return + + cls = RemoteStashFolderData + + EXEC_LOGGER.debug(f'stashing files for calculation<{calculation.pk}>: {source_list}', extra=logger_extra) + + uuid = calculation.uuid + target_basepath = os.path.join(stash_options['target_base'], uuid[:2], uuid[2:4], uuid[4:]) + + for source_filename in source_list: + + source_filepath = os.path.join(calculation.get_remote_workdir(), source_filename) + target_filepath = os.path.join(target_basepath, source_filename) + + # If the source file is in a (nested) directory, create those directories first in the target directory + target_dirname = os.path.dirname(target_filepath) + transport.makedirs(target_dirname, ignore_existing=True) + + try: + transport.copy(source_filepath, target_filepath) + except (IOError, ValueError) as exception: + EXEC_LOGGER.warning(f'failed to stash {source_filepath} to {target_filepath}: {exception}') + else: + EXEC_LOGGER.debug(f'stashed {source_filepath} to {target_filepath}') + + remote_stash = cls( + computer=calculation.computer, + target_basepath=target_basepath, + stash_mode=StashMode(stash_mode), + source_list=source_list, + ).store() + remote_stash.add_incoming(calculation, link_type=LinkType.CREATE, link_label='remote_stash') + + def retrieve_calculation(calculation: CalcJobNode, transport: Transport, retrieved_temporary_folder: str) -> None: """Retrieve all the files of a completed job calculation using the given transport. diff --git a/aiida/engine/processes/calcjobs/calcjob.py b/aiida/engine/processes/calcjobs/calcjob.py index b0bd6bf174..f13a65a965 100644 --- a/aiida/engine/processes/calcjobs/calcjob.py +++ b/aiida/engine/processes/calcjobs/calcjob.py @@ -97,6 +97,33 @@ def validate_calc_job(inputs: Any, ctx: PortNamespace) -> Optional[str]: # pyli return None +def validate_stash_options(stash_options: Any, _: Any) -> Optional[str]: + """Validate the ``stash`` options.""" + from aiida.common.datastructures import StashMode + + target_base = stash_options.get('target_base', None) + source_list = stash_options.get('source_list', None) + stash_mode = stash_options.get('mode', StashMode.COPY.value) + + if not isinstance(target_base, str) or not os.path.isabs(target_base): + return f'`metadata.options.stash.target_base` should be an absolute filepath, got: {target_base}' + + if ( + not isinstance(source_list, (list, tuple)) or + any(not isinstance(src, str) or os.path.isabs(src) for src in source_list) + ): + port = 'metadata.options.stash.source_list' + return f'`{port}` should be a list or tuple of relative filepaths, got: {source_list}' + + try: + StashMode(stash_mode) + except ValueError: + port = 'metadata.options.stash.mode' + return f'`{port}` should be a member of aiida.common.datastructures.StashMode, got: {stash_mode}' + + return None + + def validate_parser(parser_name: Any, _: Any) -> Optional[str]: """Validate the parser. @@ -104,11 +131,10 @@ def validate_parser(parser_name: Any, _: Any) -> Optional[str]: """ from aiida.plugins import ParserFactory - if parser_name is not plumpy.ports.UNSPECIFIED: - try: - ParserFactory(parser_name) - except exceptions.EntryPointError as exception: - return f'invalid parser specified: {exception}' + try: + ParserFactory(parser_name) + except exceptions.EntryPointError as exception: + return f'invalid parser specified: {exception}' return None @@ -118,9 +144,6 @@ def validate_additional_retrieve_list(additional_retrieve_list: Any, _: Any) -> :return: string with error message in case the input is invalid. """ - if additional_retrieve_list is plumpy.ports.UNSPECIFIED: - return None - if any(not isinstance(value, str) or os.path.isabs(value) for value in additional_retrieve_list): return f'`additional_retrieve_list` should only contain relative filepaths but got: {additional_retrieve_list}' @@ -216,9 +239,21 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] spec.input('metadata.options.additional_retrieve_list', required=False, valid_type=(list, tuple), validator=validate_additional_retrieve_list, help='List of relative file paths that should be retrieved in addition to what the plugin specifies.') + spec.input_namespace('metadata.options.stash', required=False, populate_defaults=False, + validator=validate_stash_options, + help='Optional directives to stash files after the calculation job has completed.') + spec.input('metadata.options.stash.target_base', valid_type=str, required=False, + help='The base location to where the files should be stashd. For example, for the `copy` stash mode, this ' + 'should be an absolute filepath on the remote computer.') + spec.input('metadata.options.stash.source_list', valid_type=(tuple, list), required=False, + help='Sequence of relative filepaths representing files in the remote directory that should be stashed.') + spec.input('metadata.options.stash.stash_mode', valid_type=str, required=False, + help='Mode with which to perform the stashing, should be value of `aiida.common.datastructures.StashMode.') spec.output('remote_folder', valid_type=orm.RemoteData, help='Input files necessary to run the process will be stored in this folder node.') + spec.output('remote_stash', valid_type=orm.RemoteStashData, required=False, + help='Contents of the `stash.source_list` option are stored in this remote folder after job completion.') spec.output(cls.link_label_retrieved, valid_type=orm.FolderData, pass_to_parser=True, help='Files that are retrieved by the daemon will be stored in this node. By default the stdout and stderr ' 'of the scheduler will be added, but one can add more by specifying them in `CalcInfo.retrieve_list`.') @@ -653,29 +688,29 @@ def presubmit(self, folder: Folder) -> CalcInfo: local_copy_list = calc_info.local_copy_list try: validate_list_of_string_tuples(local_copy_list, tuple_length=3) - except ValidationError as exc: + except ValidationError as exception: raise PluginInternalError( - f'[presubmission of calc {this_pk}] local_copy_list format problem: {exc}' - ) from exc + f'[presubmission of calc {this_pk}] local_copy_list format problem: {exception}' + ) from exception remote_copy_list = calc_info.remote_copy_list try: validate_list_of_string_tuples(remote_copy_list, tuple_length=3) - except ValidationError as exc: + except ValidationError as exception: raise PluginInternalError( - f'[presubmission of calc {this_pk}] remote_copy_list format problem: {exc}' - ) from exc + f'[presubmission of calc {this_pk}] remote_copy_list format problem: {exception}' + ) from exception for (remote_computer_uuid, _, dest_rel_path) in remote_copy_list: try: Computer.objects.get(uuid=remote_computer_uuid) # pylint: disable=unused-variable - except exceptions.NotExistent as exc: + except exceptions.NotExistent as exception: raise PluginInternalError( '[presubmission of calc {}] ' 'The remote copy requires a computer with UUID={}' 'but no such computer was found in the ' 'database'.format(this_pk, remote_computer_uuid) - ) from exc + ) from exception if os.path.isabs(dest_rel_path): raise PluginInternalError( '[presubmission of calc {}] ' diff --git a/aiida/engine/processes/calcjobs/tasks.py b/aiida/engine/processes/calcjobs/tasks.py index 8e57fb5db5..95fb4b0f8e 100644 --- a/aiida/engine/processes/calcjobs/tasks.py +++ b/aiida/engine/processes/calcjobs/tasks.py @@ -37,6 +37,7 @@ SUBMIT_COMMAND = 'submit' UPDATE_COMMAND = 'update' RETRIEVE_COMMAND = 'retrieve' +STASH_COMMAND = 'stash' KILL_COMMAND = 'kill' RETRY_INTERVAL_OPTION = 'transport.task_retry_initial_interval' @@ -100,9 +101,9 @@ async def do_upload(): raise except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): raise - except Exception: + except Exception as exception: logger.warning(f'uploading CalcJob<{node.pk}> failed') - raise TransportTaskException(f'upload_calculation failed {max_attempts} times consecutively') + raise TransportTaskException(f'upload_calculation failed {max_attempts} times consecutively') from exception else: logger.info(f'uploading CalcJob<{node.pk}> successful') node.set_state(CalcJobState.SUBMITTING) @@ -146,9 +147,9 @@ async def do_submit(): ) except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): # pylint: disable=try-except-raise raise - except Exception: + except Exception as exception: logger.warning(f'submitting CalcJob<{node.pk}> failed') - raise TransportTaskException(f'submit_calculation failed {max_attempts} times consecutively') + raise TransportTaskException(f'submit_calculation failed {max_attempts} times consecutively') from exception else: logger.info(f'submitting CalcJob<{node.pk}> successful') node.set_state(CalcJobState.WITHSCHEDULER) @@ -171,8 +172,10 @@ async def task_update_job(node: CalcJobNode, job_manager, cancellable: Interrupt :type cancellable: :class:`aiida.engine.utils.InterruptableFuture` :return: True if the tasks was successfully completed, False otherwise """ - if node.get_state() == CalcJobState.RETRIEVING: - logger.warning(f'CalcJob<{node.pk}> already marked as RETRIEVING, skipping task_update_job') + state = node.get_state() + + if state in [CalcJobState.RETRIEVING, CalcJobState.STASHING]: + logger.warning(f'CalcJob<{node.pk}> already marked as `{state}`, skipping task_update_job') return True initial_interval = get_config_option(RETRY_INTERVAL_OPTION) @@ -205,13 +208,13 @@ async def do_update(): ) except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): # pylint: disable=try-except-raise raise - except Exception: + except Exception as exception: logger.warning(f'updating CalcJob<{node.pk}> failed') - raise TransportTaskException(f'update_calculation failed {max_attempts} times consecutively') + raise TransportTaskException(f'update_calculation failed {max_attempts} times consecutively') from exception else: logger.info(f'updating CalcJob<{node.pk}> successful') if job_done: - node.set_state(CalcJobState.RETRIEVING) + node.set_state(CalcJobState.STASHING) return job_done @@ -271,15 +274,65 @@ async def do_retrieve(): ) except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): # pylint: disable=try-except-raise raise - except Exception: + except Exception as exception: logger.warning(f'retrieving CalcJob<{node.pk}> failed') - raise TransportTaskException(f'retrieve_calculation failed {max_attempts} times consecutively') + raise TransportTaskException(f'retrieve_calculation failed {max_attempts} times consecutively') from exception else: node.set_state(CalcJobState.PARSING) logger.info(f'retrieving CalcJob<{node.pk}> successful') return result +async def task_stash_job(node: CalcJobNode, transport_queue: TransportQueue, cancellable: InterruptableFuture): + """Transport task that will optionally stash files of a completed job calculation on the remote. + + The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager + function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will + retry after an interval that increases exponentially with the number of retries, for a maximum number of retries. + If all retries fail, the task will raise a TransportTaskException + + :param node: the node that represents the job calculation + :param transport_queue: the TransportQueue from which to request a Transport + :param cancellable: the cancelled flag that will be queried to determine whether the task was cancelled + :type cancellable: :class:`aiida.engine.utils.InterruptableFuture` + :raises: Return if the tasks was successfully completed + :raises: TransportTaskException if after the maximum number of retries the transport task still excepted + """ + if node.get_state() == CalcJobState.RETRIEVING: + logger.warning(f'calculation<{node.pk}> already marked as RETRIEVING, skipping task_stash_job') + return + + initial_interval = get_config_option(RETRY_INTERVAL_OPTION) + max_attempts = get_config_option(MAX_ATTEMPTS_OPTION) + + authinfo = node.get_authinfo() + + async def do_stash(): + with transport_queue.request_transport(authinfo) as request: + transport = await cancellable.with_interrupt(request) + + logger.info(f'stashing calculation<{node.pk}>') + return execmanager.stash_calculation(node, transport) + + try: + await exponential_backoff_retry( + do_stash, + initial_interval, + max_attempts, + logger=node.logger, + ignore_exceptions=plumpy.process_states.Interruption + ) + except plumpy.process_states.Interruption: + raise + except Exception as exception: + logger.warning(f'stashing calculation<{node.pk}> failed') + raise TransportTaskException(f'stash_calculation failed {max_attempts} times consecutively') from exception + else: + node.set_state(CalcJobState.RETRIEVING) + logger.info(f'stashing calculation<{node.pk}> successful') + return + + async def task_kill_job(node: CalcJobNode, transport_queue: TransportQueue, cancellable: InterruptableFuture): """Transport task that will attempt to kill a job calculation. @@ -313,9 +366,9 @@ async def do_kill(): result = await exponential_backoff_retry(do_kill, initial_interval, max_attempts, logger=node.logger) except plumpy.process_states.Interruption: raise - except Exception: + except Exception as exception: logger.warning(f'killing CalcJob<{node.pk}> failed') - raise TransportTaskException(f'kill_calculation failed {max_attempts} times consecutively') + raise TransportTaskException(f'kill_calculation failed {max_attempts} times consecutively') from exception else: logger.info(f'killing CalcJob<{node.pk}> successful') node.set_scheduler_state(JobState.DONE) @@ -353,11 +406,11 @@ def load_instance_state(self, saved_state, load_context): async def execute(self) -> plumpy.process_states.State: # type: ignore[override] # pylint: disable=invalid-overridden-method """Override the execute coroutine of the base `Waiting` state.""" - # pylint: disable=too-many-branches, too-many-statements + # pylint: disable=too-many-branches,too-many-statements node = self.process.node transport_queue = self.process.runner.transport - command = self.data result: plumpy.process_states.State = self + command = self.data process_status = f'Waiting for transport task: {command}' @@ -376,7 +429,7 @@ async def execute(self) -> plumpy.process_states.State: # type: ignore[override await self._launch_task(task_submit_job, node, transport_queue) result = self.update() - elif self.data == UPDATE_COMMAND: + elif command == UPDATE_COMMAND: job_done = False while not job_done: @@ -386,11 +439,18 @@ async def execute(self) -> plumpy.process_states.State: # type: ignore[override node.set_process_status(process_status) job_done = await self._launch_task(task_update_job, node, self.process.runner.job_manager) + if node.get_option('stash') is not None: + result = self.stash() + else: + result = self.retrieve() + + elif command == STASH_COMMAND: + node.set_process_status(process_status) + await self._launch_task(task_stash_job, node, transport_queue) result = self.retrieve() - elif self.data == RETRIEVE_COMMAND: + elif command == RETRIEVE_COMMAND: node.set_process_status(process_status) - # Create a temporary folder that has to be deleted by JobProcess.retrieved after successful parsing temp_folder = tempfile.mkdtemp() await self._launch_task(task_retrieve_job, node, transport_queue, temp_folder) result = self.parse(temp_folder) @@ -453,6 +513,11 @@ def retrieve(self) -> 'Waiting': ProcessState.WAITING, None, msg=msg, data=RETRIEVE_COMMAND ) # type: ignore[return-value] + def stash(self): + """Return the `Waiting` state that will `stash` the `CalcJob`.""" + msg = 'Waiting to stash' + return self.create_state(ProcessState.WAITING, None, msg=msg, data=STASH_COMMAND) + def parse(self, retrieved_temporary_folder: str) -> plumpy.process_states.Running: """Return the `Running` state that will parse the `CalcJob`. diff --git a/aiida/orm/nodes/data/__init__.py b/aiida/orm/nodes/data/__init__.py index 9734c08d2d..8ed0d10aa4 100644 --- a/aiida/orm/nodes/data/__init__.py +++ b/aiida/orm/nodes/data/__init__.py @@ -8,7 +8,6 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module with `Node` sub classes for data structures.""" - from .array import ArrayData, BandsData, KpointsData, ProjectionData, TrajectoryData, XyData from .base import BaseType, to_aiida_type from .bool import Bool @@ -22,7 +21,7 @@ from .list import List from .numeric import NumericType from .orbital import OrbitalData -from .remote import RemoteData +from .remote import RemoteData, RemoteStashData, RemoteStashFolderData from .singlefile import SinglefileData from .str import Str from .structure import StructureData @@ -30,6 +29,6 @@ __all__ = ( 'Data', 'BaseType', 'ArrayData', 'BandsData', 'KpointsData', 'ProjectionData', 'TrajectoryData', 'XyData', 'Bool', - 'CifData', 'Code', 'Float', 'FolderData', 'Int', 'List', 'OrbitalData', 'Dict', 'RemoteData', 'SinglefileData', - 'Str', 'StructureData', 'UpfData', 'NumericType', 'to_aiida_type' + 'CifData', 'Code', 'Float', 'FolderData', 'Int', 'List', 'OrbitalData', 'Dict', 'RemoteData', 'RemoteStashData', + 'RemoteStashFolderData', 'SinglefileData', 'Str', 'StructureData', 'UpfData', 'NumericType', 'to_aiida_type' ) diff --git a/aiida/orm/nodes/data/remote/__init__.py b/aiida/orm/nodes/data/remote/__init__.py new file mode 100644 index 0000000000..2f88d7edbc --- /dev/null +++ b/aiida/orm/nodes/data/remote/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +"""Module with data plugins that represent remote resources and so effectively are symbolic links.""" +from .base import RemoteData +from .stash import RemoteStashData, RemoteStashFolderData + +__all__ = ('RemoteData', 'RemoteStashData', 'RemoteStashFolderData') diff --git a/aiida/orm/nodes/data/remote.py b/aiida/orm/nodes/data/remote/base.py similarity index 95% rename from aiida/orm/nodes/data/remote.py rename to aiida/orm/nodes/data/remote/base.py index ba8b8e52e0..b293e2e6b9 100644 --- a/aiida/orm/nodes/data/remote.py +++ b/aiida/orm/nodes/data/remote/base.py @@ -11,7 +11,7 @@ import os from aiida.orm import AuthInfo -from .data import Data +from ..data import Data __all__ = ('RemoteData',) @@ -79,7 +79,7 @@ def getfile(self, relpath, destpath): full_path, self.computer.label # pylint: disable=no-member ) - ) + ) from exception raise def listdir(self, relpath='.'): @@ -102,7 +102,7 @@ def listdir(self, relpath='.'): format(full_path, self.computer.label) # pylint: disable=no-member ) exc.errno = exception.errno - raise exc + raise exc from exception else: raise @@ -115,7 +115,7 @@ def listdir(self, relpath='.'): format(full_path, self.computer.label) # pylint: disable=no-member ) exc.errno = exception.errno - raise exc + raise exc from exception else: raise @@ -139,7 +139,7 @@ def listdir_withattributes(self, path='.'): format(full_path, self.computer.label) # pylint: disable=no-member ) exc.errno = exception.errno - raise exc + raise exc from exception else: raise @@ -152,7 +152,7 @@ def listdir_withattributes(self, path='.'): format(full_path, self.computer.label) # pylint: disable=no-member ) exc.errno = exception.errno - raise exc + raise exc from exception else: raise @@ -176,8 +176,8 @@ def _validate(self): try: self.get_remote_path() - except AttributeError: - raise ValidationError("attribute 'remote_path' not set.") + except AttributeError as exception: + raise ValidationError("attribute 'remote_path' not set.") from exception computer = self.computer if computer is None: diff --git a/aiida/orm/nodes/data/remote/stash/__init__.py b/aiida/orm/nodes/data/remote/stash/__init__.py new file mode 100644 index 0000000000..f744240cfc --- /dev/null +++ b/aiida/orm/nodes/data/remote/stash/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +"""Module with data plugins that represent files of completed calculations jobs that have been stashed.""" +from .base import RemoteStashData +from .folder import RemoteStashFolderData + +__all__ = ('RemoteStashData', 'RemoteStashFolderData') diff --git a/aiida/orm/nodes/data/remote/stash/base.py b/aiida/orm/nodes/data/remote/stash/base.py new file mode 100644 index 0000000000..f904643bab --- /dev/null +++ b/aiida/orm/nodes/data/remote/stash/base.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +"""Data plugin that models an archived folder on a remote computer.""" +from aiida.common.datastructures import StashMode +from aiida.common.lang import type_check +from ...data import Data + +__all__ = ('RemoteStashData',) + + +class RemoteStashData(Data): + """Data plugin that models an archived folder on a remote computer. + + A stashed folder is essentially an instance of ``RemoteData`` that has been archived. Archiving in this context can + simply mean copying the content of the folder to another location on the same or another filesystem as long as it is + on the same machine. In addition, the folder may have been compressed into a single file for efficiency or even + written to tape. The ``stash_mode`` attribute will distinguish how the folder was stashed which will allow the + implementation to also `unstash` it and transform it back into a ``RemoteData`` such that it can be used as an input + for new ``CalcJobs``. + + This class is a non-storable base class that merely registers the ``stash_mode`` attribute. Only its subclasses, + that actually implement a certain stash mode, can be instantiated and therefore stored. The reason for this design + is that because the behavior of the class can change significantly based on the mode employed to stash the files and + implementing all these variants in the same class will lead to an unintuitive interface where certain properties or + methods of the class will only be available or function properly based on the ``stash_mode``. + """ + + _storable = False + + def __init__(self, stash_mode: StashMode, **kwargs): + """Construct a new instance + + :param stash_mode: the stashing mode with which the data was stashed on the remote. + """ + super().__init__(**kwargs) + self.stash_mode = stash_mode + + @property + def stash_mode(self) -> StashMode: + """Return the mode with which the data was stashed on the remote. + + :return: the stash mode. + """ + return StashMode(self.get_attribute('stash_mode')) + + @stash_mode.setter + def stash_mode(self, value: StashMode): + """Set the mode with which the data was stashed on the remote. + + :param value: the stash mode. + """ + type_check(value, StashMode) + self.set_attribute('stash_mode', value.value) diff --git a/aiida/orm/nodes/data/remote/stash/folder.py b/aiida/orm/nodes/data/remote/stash/folder.py new file mode 100644 index 0000000000..7d7c00b2fc --- /dev/null +++ b/aiida/orm/nodes/data/remote/stash/folder.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +"""Data plugin that models a stashed folder on a remote computer.""" +import typing + +from aiida.common.datastructures import StashMode +from aiida.common.lang import type_check +from .base import RemoteStashData + +__all__ = ('RemoteStashFolderData',) + + +class RemoteStashFolderData(RemoteStashData): + """Data plugin that models a folder with files of a completed calculation job that has been stashed through a copy. + + This data plugin can and should be used to stash files if and only if the stash mode is `StashMode.COPY`. + """ + + _storable = True + + def __init__(self, stash_mode: StashMode, target_basepath: str, source_list: typing.List, **kwargs): + """Construct a new instance + + :param stash_mode: the stashing mode with which the data was stashed on the remote. + :param target_basepath: the target basepath. + :param source_list: the list of source files. + """ + super().__init__(stash_mode, **kwargs) + self.target_basepath = target_basepath + self.source_list = source_list + + if stash_mode != StashMode.COPY: + raise ValueError('`RemoteStashFolderData` can only be used with `stash_mode == StashMode.COPY`.') + + @property + def target_basepath(self) -> str: + """Return the target basepath. + + :return: the target basepath. + """ + return self.get_attribute('target_basepath') + + @target_basepath.setter + def target_basepath(self, value: str): + """Set the target basepath. + + :param value: the target basepath. + """ + type_check(value, str) + self.set_attribute('target_basepath', value) + + @property + def source_list(self) -> typing.Union[typing.List, typing.Tuple]: + """Return the list of source files that were stashed. + + :return: the list of source files. + """ + return self.get_attribute('source_list') + + @source_list.setter + def source_list(self, value: typing.Union[typing.List, typing.Tuple]): + """Set the list of source files that were stashed. + + :param value: the list of source files. + """ + type_check(value, (list, tuple)) + self.set_attribute('source_list', value) diff --git a/setup.json b/setup.json index 1c62482f47..d7cac88646 100644 --- a/setup.json +++ b/setup.json @@ -163,7 +163,9 @@ "list = aiida.orm.nodes.data.list:List", "numeric = aiida.orm.nodes.data.numeric:NumericType", "orbital = aiida.orm.nodes.data.orbital:OrbitalData", - "remote = aiida.orm.nodes.data.remote:RemoteData", + "remote = aiida.orm.nodes.data.remote.base:RemoteData", + "remote.stash = aiida.orm.nodes.data.remote.stash.base:RemoteStashData", + "remote.stash.folder = aiida.orm.nodes.data.remote.stash.folder:RemoteStashFolderData", "singlefile = aiida.orm.nodes.data.singlefile:SinglefileData", "str = aiida.orm.nodes.data.str:Str", "structure = aiida.orm.nodes.data.structure:StructureData", diff --git a/tests/engine/processes/test_builder.py b/tests/engine/processes/test_builder.py index aa7ad19b0b..239bc4984d 100644 --- a/tests/engine/processes/test_builder.py +++ b/tests/engine/processes/test_builder.py @@ -28,29 +28,29 @@ def test_access_methods(): builder = ProcessBuilder(ArithmeticAddCalculation) builder['x'] = node_numb - assert dict(builder) == {'metadata': {'options': {}}, 'x': node_numb} + assert dict(builder) == {'metadata': {'options': {'stash': {}}}, 'x': node_numb} del builder['x'] - assert dict(builder) == {'metadata': {'options': {}}} + assert dict(builder) == {'metadata': {'options': {'stash': {}}}} with pytest.raises(ValueError): builder['x'] = node_dict builder['x'] = node_numb - assert dict(builder) == {'metadata': {'options': {}}, 'x': node_numb} + assert dict(builder) == {'metadata': {'options': {'stash': {}}}, 'x': node_numb} # AS ATTRIBUTES del builder builder = ProcessBuilder(ArithmeticAddCalculation) builder.x = node_numb - assert dict(builder) == {'metadata': {'options': {}}, 'x': node_numb} + assert dict(builder) == {'metadata': {'options': {'stash': {}}}, 'x': node_numb} del builder.x - assert dict(builder) == {'metadata': {'options': {}}} + assert dict(builder) == {'metadata': {'options': {'stash': {}}}} with pytest.raises(ValueError): builder.x = node_dict builder.x = node_numb - assert dict(builder) == {'metadata': {'options': {}}, 'x': node_numb} + assert dict(builder) == {'metadata': {'options': {'stash': {}}}, 'x': node_numb} diff --git a/tests/engine/test_calc_job.py b/tests/engine/test_calc_job.py index 6b67541b80..75a611d8b9 100644 --- a/tests/engine/test_calc_job.py +++ b/tests/engine/test_calc_job.py @@ -19,9 +19,10 @@ from aiida import orm from aiida.backends.testbase import AiidaTestCase -from aiida.common import exceptions, LinkType, CalcJobState +from aiida.common import exceptions, LinkType, CalcJobState, StashMode from aiida.engine import launch, CalcJob, Process, ExitCode from aiida.engine.processes.ports import PortNamespace +from aiida.engine.processes.calcjobs.calcjob import validate_stash_options from aiida.plugins import CalculationFactory ArithmeticAddCalculation = CalculationFactory('arithmetic.add') # pylint: disable=invalid-name @@ -616,3 +617,41 @@ def test_additional_retrieve_list(generate_process, fixture_sandbox): with pytest.raises(ValueError, match=r'`additional_retrieve_list` should only contain relative filepaths.*'): process = generate_process({'metadata': {'options': {'additional_retrieve_list': ['/abs/path']}}}) + + +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.parametrize(('stash_options', 'expected'), ( + ({ + 'target_base': None + }, '`metadata.options.stash.target_base` should be'), + ({ + 'target_base': 'relative/path' + }, '`metadata.options.stash.target_base` should be'), + ({ + 'target_base': '/path' + }, '`metadata.options.stash.source_list` should be'), + ({ + 'target_base': '/path', + 'source_list': ['/abspath'] + }, '`metadata.options.stash.source_list` should be'), + ({ + 'target_base': '/path', + 'source_list': ['rel/path'], + 'mode': 'test' + }, '`metadata.options.stash.mode` should be'), + ({ + 'target_base': '/path', + 'source_list': ['rel/path'] + }, None), + ({ + 'target_base': '/path', + 'source_list': ['rel/path'], + 'mode': StashMode.COPY.value + }, None), +)) +def test_validate_stash_options(stash_options, expected): + """Test the ``validate_stash_options`` function.""" + if expected is None: + assert validate_stash_options(stash_options, None) is expected + else: + assert expected in validate_stash_options(stash_options, None) diff --git a/tests/orm/data/test_remote_stash.py b/tests/orm/data/test_remote_stash.py new file mode 100644 index 0000000000..45318ca1b3 --- /dev/null +++ b/tests/orm/data/test_remote_stash.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Tests for the :mod:`aiida.orm.nodes.data.remote.stash` module.""" +import pytest + +from aiida.common.datastructures import StashMode +from aiida.common.exceptions import StoringNotAllowed +from aiida.orm import RemoteStashData, RemoteStashFolderData + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_base_class(): + """Verify that base class cannot be stored.""" + node = RemoteStashData(stash_mode=StashMode.COPY) + + with pytest.raises(StoringNotAllowed): + node.store() + + +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.parametrize('store', (False, True)) +def test_constructor(store): + """Test the constructor and storing functionality.""" + stash_mode = StashMode.COPY + target_basepath = '/absolute/path' + source_list = ['relative/folder', 'relative/file'] + + data = RemoteStashFolderData(stash_mode, target_basepath, source_list) + + assert data.stash_mode == stash_mode + assert data.target_basepath == target_basepath + assert data.source_list == source_list + + if store: + data.store() + assert data.is_stored + assert data.stash_mode == stash_mode + assert data.target_basepath == target_basepath + assert data.source_list == source_list + + +@pytest.mark.usefixtures('clear_database_before_test') +@pytest.mark.parametrize( + 'argument, value', ( + ('stash_mode', 'copy'), + ('target_basepath', ['list']), + ('source_list', 'relative/path'), + ('source_list', ('/absolute/path')), + ) +) +def test_constructor_invalid(argument, value): + """Test the constructor for invalid argument types.""" + kwargs = { + 'stash_mode': StashMode.COPY, + 'target_basepath': '/absolute/path', + 'source_list': ('relative/folder', 'relative/file'), + } + + with pytest.raises(TypeError): + kwargs[argument] = value + RemoteStashFolderData(**kwargs)