diff --git a/aiida/cmdline/commands/cmd_export.py b/aiida/cmdline/commands/cmd_export.py index 67f6ac9ef6..9baff6ad25 100644 --- a/aiida/cmdline/commands/cmd_export.py +++ b/aiida/cmdline/commands/cmd_export.py @@ -93,7 +93,8 @@ def create( their provenance, according to the rules outlined in the documentation. You can modify some of those rules using options of this command. """ - from aiida.tools.importexport import export, export_zip + from aiida.tools.importexport import export, ExportFileFormat + from aiida.tools.importexport.common.exceptions import ArchiveExportError entities = [] @@ -122,19 +123,18 @@ def create( } if archive_format == 'zip': - export_function = export_zip + export_format = ExportFileFormat.ZIP kwargs.update({'use_compression': True}) elif archive_format == 'zip-uncompressed': - export_function = export_zip + export_format = ExportFileFormat.ZIP kwargs.update({'use_compression': False}) elif archive_format == 'tar.gz': - export_function = export + export_format = ExportFileFormat.TAR_GZIPPED try: - export_function(entities, outfile=output_file, **kwargs) - - except IOError as exception: - echo.echo_critical('failed to write the export archive file: {}'.format(exception)) + export(entities, filename=output_file, file_format=export_format, **kwargs) + except ArchiveExportError as exception: + echo.echo_critical('failed to write the archive file. Exception: {}'.format(exception)) else: echo.echo_success('wrote the export archive file to {}'.format(output_file)) diff --git a/aiida/cmdline/commands/cmd_import.py b/aiida/cmdline/commands/cmd_import.py index 73054b182e..1a9d3e1f9d 100644 --- a/aiida/cmdline/commands/cmd_import.py +++ b/aiida/cmdline/commands/cmd_import.py @@ -12,6 +12,7 @@ from enum import Enum import traceback import urllib.request + import click from aiida.cmdline.commands.cmd_verdi import verdi @@ -34,6 +35,45 @@ class ExtrasImportCode(Enum): ask = 'kca' +def _echo_error( # pylint: disable=unused-argument + message, non_interactive, more_archives, raised_exception, **kwargs +): + """Utility function to help write an error message for ``verdi import`` + + :param message: Message following red-colored, bold "Error:". + :type message: str + :param non_interactive: Whether or not the user should be asked for input for any reason. + :type non_interactive: bool + :param more_archives: Whether or not there are more archives to import. + :type more_archives: bool + :param raised_exception: Exception raised during error. + :type raised_exception: `Exception` + """ + from aiida.tools.importexport import close_progress_bar, IMPORT_LOGGER + + # Close progress bar, if it exists + close_progress_bar(leave=False) + + IMPORT_LOGGER.debug('%s', traceback.format_exc()) + + exception = '{}: {}'.format(raised_exception.__class__.__name__, str(raised_exception)) + + echo.echo_error(message) + echo.echo(exception) + + if more_archives: + # There are more archives to go through + if non_interactive: + # Continue to next archive + pass + else: + # Ask if one should continue to next archive + click.confirm('Do you want to continue?', abort=True) + else: + # There are no more archives + click.Abort() + + def _try_import(migration_performed, file_to_import, archive, group, migration, non_interactive, **kwargs): """Utility function for `verdi import` to try to import archive @@ -66,8 +106,12 @@ def _try_import(migration_performed, file_to_import, archive, group, migration, except IncompatibleArchiveVersionError as exception: if migration_performed: # Migration has been performed, something is still wrong - crit_message = '{} has been migrated, but it still cannot be imported.\n{}'.format(archive, exception) - echo.echo_critical(crit_message) + _echo_error( + '{} has been migrated, but it still cannot be imported'.format(archive), + non_interactive=non_interactive, + raised_exception=exception, + **kwargs + ) else: # Migration has not yet been tried. if migration: @@ -85,18 +129,20 @@ def _try_import(migration_performed, file_to_import, archive, group, migration, else: # Abort echo.echo_critical(str(exception)) - except Exception: - echo.echo_error('an exception occurred while importing the archive {}'.format(archive)) - echo.echo(traceback.format_exc()) - if not non_interactive: - click.confirm('do you want to continue?', abort=True) + except Exception as exception: + _echo_error( + 'an exception occurred while importing the archive {}'.format(archive), + non_interactive=non_interactive, + raised_exception=exception, + **kwargs + ) else: echo.echo_success('imported archive {}'.format(archive)) return migrate_archive -def _migrate_archive(ctx, temp_folder, file_to_import, archive, non_interactive, **kwargs): # pylint: disable=unused-argument +def _migrate_archive(ctx, temp_folder, file_to_import, archive, non_interactive, more_archives, silent, **kwargs): # pylint: disable=unused-argument """Utility function for `verdi import` to migrate archive Invoke click command `verdi export migrate`, passing in the archive, outputting the migrated archive in a temporary SandboxFolder. @@ -107,6 +153,8 @@ def _migrate_archive(ctx, temp_folder, file_to_import, archive, non_interactive, :param file_to_import: Absolute path, including filename, of file to be migrated. :param archive: Filename of archive to be migrated, and later attempted imported. :param non_interactive: Whether or not the user should be asked for input for any reason. + :param more_archives: Whether or not there are more archives to be imported. + :param silent: Suppress console messages. :return: Absolute path to migrated archive within SandboxFolder. """ from aiida.cmdline.commands.cmd_export import migrate @@ -120,18 +168,19 @@ def _migrate_archive(ctx, temp_folder, file_to_import, archive, non_interactive, # Migration try: ctx.invoke( - migrate, input_file=file_to_import, output_file=temp_folder.get_abs_path(temp_out_file), silent=False + migrate, input_file=file_to_import, output_file=temp_folder.get_abs_path(temp_out_file), silent=silent ) - except Exception: - echo.echo_error( + except Exception as exception: + _echo_error( 'an exception occurred while migrating the archive {}.\n' - "Use 'verdi export migrate' to update this export file.".format(archive) + "Use 'verdi export migrate' to update this export file.".format(archive), + non_interactive=non_interactive, + more_archives=more_archives, + raised_exception=exception ) - echo.echo(traceback.format_exc()) - if not non_interactive: - click.confirm('do you want to continue?', abort=True) else: - echo.echo_success('archive migrated, proceeding with import') + # Success + echo.echo_info('proceeding with import') return temp_folder.get_abs_path(temp_out_file) @@ -197,7 +246,6 @@ def cmd_import( The archive can be specified by its relative or absolute file path, or its HTTP URL. """ - from aiida.common.folders import SandboxFolder from aiida.tools.importexport.common.utils import get_valid_import_links @@ -217,11 +265,13 @@ def cmd_import( try: echo.echo_info('retrieving archive URLS from {}'.format(webpage)) urls = get_valid_import_links(webpage) - except Exception: - echo.echo_error('an exception occurred while trying to discover archives at URL {}'.format(webpage)) - echo.echo(traceback.format_exc()) - if not non_interactive: - click.confirm('do you want to continue?', abort=True) + except Exception as exception: + _echo_error( + 'an exception occurred while trying to discover archives at URL {}'.format(webpage), + non_interactive=non_interactive, + more_archives=webpage != webpages[-1] or archives_file or archives_url, + raised_exception=exception + ) else: echo.echo_success('{} archive URLs discovered and added'.format(len(urls))) archives_url += urls @@ -239,7 +289,8 @@ def cmd_import( 'extras_mode_existing': ExtrasImportCode[extras_mode_existing].value, 'extras_mode_new': extras_mode_new, 'comment_mode': comment_mode, - 'non_interactive': non_interactive + 'non_interactive': non_interactive, + 'silent': False, } # Import local archives @@ -250,6 +301,7 @@ def cmd_import( # Initialization import_opts['archive'] = archive import_opts['file_to_import'] = import_opts['archive'] + import_opts['more_archives'] = archive != archives_file[-1] or archives_url # First attempt to import archive migrate_archive = _try_import(migration_performed=False, **import_opts) @@ -265,13 +317,14 @@ def cmd_import( # Initialization import_opts['archive'] = archive + import_opts['more_archives'] = archive != archives_url[-1] echo.echo_info('downloading archive {}'.format(archive)) try: response = urllib.request.urlopen(archive) except Exception as exception: - echo.echo_warning('downloading archive {} failed: {}'.format(archive, exception)) + _echo_error('downloading archive {} failed'.format(archive), raised_exception=exception, **import_opts) with SandboxFolder() as temp_folder: temp_file = 'importfile.tar.gz' diff --git a/aiida/cmdline/commands/cmd_restapi.py b/aiida/cmdline/commands/cmd_restapi.py index ca9cc45fb7..0ff7546b30 100644 --- a/aiida/cmdline/commands/cmd_restapi.py +++ b/aiida/cmdline/commands/cmd_restapi.py @@ -16,7 +16,7 @@ import click from aiida.cmdline.commands.cmd_verdi import verdi -from aiida.cmdline.params.options import HOSTNAME, PORT +from aiida.cmdline.params.options import HOSTNAME, PORT, DEBUG from aiida.restapi.common import config @@ -30,7 +30,7 @@ default=config.CLI_DEFAULTS['CONFIG_DIR'], help='Path to the configuration directory' ) -@click.option('--debug', 'debug', is_flag=True, default=config.APP_CONFIG['DEBUG'], help='Enable debugging') +@DEBUG(default=config.APP_CONFIG['DEBUG']) @click.option( '--wsgi-profile', is_flag=True, diff --git a/aiida/cmdline/params/options/__init__.py b/aiida/cmdline/params/options/__init__.py index 708930028f..0b6fe9645b 100644 --- a/aiida/cmdline/params/options/__init__.py +++ b/aiida/cmdline/params/options/__init__.py @@ -30,7 +30,7 @@ 'DESCRIPTION', 'INPUT_PLUGIN', 'CALC_JOB_STATE', 'PROCESS_STATE', 'PROCESS_LABEL', 'TYPE_STRING', 'EXIT_STATUS', 'FAILED', 'LIMIT', 'PROJECT', 'ORDER_BY', 'PAST_DAYS', 'OLDER_THAN', 'ALL', 'ALL_STATES', 'ALL_USERS', 'GROUP_CLEAR', 'RAW', 'HOSTNAME', 'TRANSPORT', 'SCHEDULER', 'USER', 'PORT', 'FREQUENCY', 'VERBOSE', 'TIMEOUT', - 'FORMULA_MODE', 'TRAJECTORY_INDEX', 'WITH_ELEMENTS', 'WITH_ELEMENTS_EXCLUSIVE' + 'FORMULA_MODE', 'TRAJECTORY_INDEX', 'WITH_ELEMENTS', 'WITH_ELEMENTS_EXCLUSIVE', 'DEBUG' ) TRAVERSAL_RULE_HELP_STRING = { @@ -522,3 +522,7 @@ def decorator(command): DICT_KEYS = OverridableOption( '-k', '--keys', type=click.STRING, cls=MultipleValueOption, help='Filter the output by one or more keys.' ) + +DEBUG = OverridableOption( + '--debug', is_flag=True, default=False, help='Show debug messages. Mostly relevant for developers.', hidden=True +) diff --git a/aiida/common/folders.py b/aiida/common/folders.py index 9eedcb51d7..79b6eb3db5 100644 --- a/aiida/common/folders.py +++ b/aiida/common/folders.py @@ -86,13 +86,13 @@ def get_subfolder(self, subfolder, create=False, reset_limit=False): Return a Folder object pointing to a subfolder. :param subfolder: a string with the relative path of the subfolder, - relative to the absolute path of this object. Note that - this may also contain '..' parts, - as far as this does not go beyond the folder_limit. + relative to the absolute path of this object. Note that + this may also contain '..' parts, + as far as this does not go beyond the folder_limit. :param create: if True, the new subfolder is created, if it does not exist. :param reset_limit: when doing ``b = a.get_subfolder('xxx', reset_limit=False)``, - the limit of b will be the same limit of a. - if True, the limit will be set to the boundaries of folder b. + the limit of b will be the same limit of a. + if True, the limit will be set to the boundaries of folder b. :Returns: a Folder object pointing to the subfolder. """ @@ -114,18 +114,16 @@ def get_subfolder(self, subfolder, create=False, reset_limit=False): return new_folder def get_content_list(self, pattern='*', only_paths=True): - """ - Return a list of files (and subfolders) in the folder, - matching a given pattern. + """Return a list of files (and subfolders) in the folder, matching a given pattern. Example: If you want to exclude files starting with a dot, you can call this method with ``pattern='[!.]*'`` :param pattern: a pattern for the file/folder names, using Unix filename - pattern matching (see Python standard module fnmatch). - By default, pattern is '*', matching all files and folders. + pattern matching (see Python standard module fnmatch). + By default, pattern is '*', matching all files and folders. :param only_paths: if False (default), return pairs (name, is_file). - if True, return only a flat list. + if True, return only a flat list. :Returns: a list of tuples of two elements, the first is the file name and @@ -140,8 +138,7 @@ def get_content_list(self, pattern='*', only_paths=True): return [(fname, not os.path.isdir(os.path.join(self.abspath, fname))) for fname in file_list] def create_symlink(self, src, name): - """ - Create a symlink inside the folder to the location 'src'. + """Create a symlink inside the folder to the location 'src'. :param src: the location to which the symlink must point. Can be either a relative or an absolute path. Should, however, @@ -155,8 +152,7 @@ def create_symlink(self, src, name): # For symlinks, permissions should not be set def insert_path(self, src, dest_name=None, overwrite=True): - """ - Copy a file to the folder. + """Copy a file to the folder. :param src: the source filename to copy :param dest_name: if None, the same basename of src is used. Otherwise, @@ -236,8 +232,7 @@ def create_file_from_filelike(self, filelike, filename, mode='wb', encoding=None return filepath def remove_path(self, filename): - """ - Remove a file or folder from the folder. + """Remove a file or folder from the folder. :param filename: the relative path name to remove """ @@ -251,8 +246,7 @@ def remove_path(self, filename): os.remove(dest_abs_path) def get_abs_path(self, relpath, check_existence=False): - """ - Return an absolute path for a file or folder in this folder. + """Return an absolute path for a file or folder in this folder. The advantage of using this method is that it checks that filename is a valid filename within this folder, @@ -352,24 +346,20 @@ def create(self): os.makedirs(self.abspath, mode=self.mode_dir) def replace_with_folder(self, srcdir, move=False, overwrite=False): - """ - This routine copies or moves the source folder 'srcdir' to the local - folder pointed by this Folder object. + """This routine copies or moves the source folder 'srcdir' to the local folder pointed to by this Folder. - :param srcdir: the source folder on the disk; this must be a string with - an absolute path - :param move: if True, the srcdir is moved to the repository. Otherwise, it - is only copied. + :param srcdir: the source folder on the disk; this must be an absolute path + :type srcdir: str + :param move: if True, the srcdir is moved to the repository. Otherwise, it is only copied. + :type move: bool :param overwrite: if True, the folder will be erased first. - if False, a IOError is raised if the folder already exists. - Whatever the value of this flag, parent directories will be - created, if needed. + if False, an IOError is raised if the folder already exists. + Whatever the value of this flag, parent directories will be created, if needed. + :type overwrite: bool - :Raises: - OSError or IOError: in case of problems accessing or writing - the files. - :Raises: - ValueError: if the section is not recognized. + :raises IOError: in case of problems accessing or writing the files. + :raises OSError: in case of problems accessing or writing the files (from ``shutil`` module). + :raises ValueError: if the section is not recognized. """ if not os.path.isabs(srcdir): raise ValueError('srcdir must be an absolute path') @@ -390,13 +380,11 @@ def replace_with_folder(self, srcdir, move=False, overwrite=False): # Set the mode also for the current dir, recursively for dirpath, _, filenames in os.walk(self.abspath, followlinks=False): - # dirpath should already be absolute, because I am passing - # an absolute path to os.walk + # dirpath should already be absolute, because I am passing an absolute path to os.walk os.chmod(dirpath, self.mode_dir) for filename in filenames: - # do not change permissions of symlinks (this would - # actually change permissions of the linked file/dir) - # Toc check whether this is a big speed loss + # do not change permissions of symlinks (this would actually change permissions of the linked file/dir) + # TODO check whether this is a big speed loss # pylint: disable=fixme full_file_path = os.path.join(dirpath, filename) if not os.path.islink(full_file_path): os.chmod(full_file_path, self.mode_file) diff --git a/aiida/common/lang.py b/aiida/common/lang.py index defeea73dc..cb24ea4ba5 100644 --- a/aiida/common/lang.py +++ b/aiida/common/lang.py @@ -30,15 +30,19 @@ def type_check(what, of_type, msg=None, allow_none=False): :param of_type: the type (or tuple of types) to compare to :param msg: if specified, allows to customize the message that is passed within the TypeError exception :param allow_none: boolean, if True will not raise if the passed `what` is `None` + + :return: `what` or `None` """ if allow_none and what is None: - return + return None if not isinstance(what, of_type): if msg is None: msg = "Got object of type '{}', expecting '{}'".format(type(what), of_type) raise TypeError(msg) + return what + def override_decorator(check=False): """Decorator to signal that a method from a base class is being overridden completely.""" diff --git a/aiida/common/log.py b/aiida/common/log.py index 9f208072ed..de51d10a1a 100644 --- a/aiida/common/log.py +++ b/aiida/common/log.py @@ -13,10 +13,11 @@ import logging import types from contextlib import contextmanager +from wrapt import decorator from aiida.manage.configuration import get_config_option -__all__ = ('AIIDA_LOGGER', 'override_log_level') +__all__ = ('AIIDA_LOGGER', 'override_log_level', 'override_log_formatter') # Custom logging level, intended specifically for informative log messages reported during WorkChains. # We want the level between INFO(20) and WARNING(30) such that it will be logged for the default loglevel, however @@ -209,3 +210,29 @@ def override_log_level(level=logging.CRITICAL): yield finally: logging.disable(level=logging.NOTSET) + + +def override_log_formatter(fmt: str): + """Temporarily use a different formatter for all handlers. + + NOTE: One can _only_ set `fmt` (not `datefmt` or `style`). + Be aware! This may fail if the number of handlers is changed within the decorated function/method. + """ + + @decorator + def wrapper(wrapped, instance, args, kwargs): # pylint: disable=unused-argument + temp_formatter = logging.Formatter(fmt=fmt) + + cached_formatters = [] + for handler in AIIDA_LOGGER.handlers: + cached_formatters.append(handler.formatter) + + try: + for handler in AIIDA_LOGGER.handlers: + handler.setFormatter(temp_formatter) + return wrapped(*args, **kwargs) + finally: + for index, handler in enumerate(AIIDA_LOGGER.handlers): + handler.setFormatter(cached_formatters[index]) + + return wrapper diff --git a/aiida/tools/importexport/common/__init__.py b/aiida/tools/importexport/common/__init__.py index 7cd409cc08..03a5b00be6 100644 --- a/aiida/tools/importexport/common/__init__.py +++ b/aiida/tools/importexport/common/__init__.py @@ -13,5 +13,6 @@ from .archive import * from .config import * from .exceptions import * +from .progress_bar import * -__all__ = (archive.__all__ + config.__all__ + exceptions.__all__) +__all__ = (archive.__all__ + config.__all__ + exceptions.__all__ + progress_bar.__all__) diff --git a/aiida/tools/importexport/common/archive.py b/aiida/tools/importexport/common/archive.py index 95798cd7c3..f330bff117 100644 --- a/aiida/tools/importexport/common/archive.py +++ b/aiida/tools/importexport/common/archive.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=too-many-branches """Utility functions and classes to interact with AiiDA export archives.""" import os @@ -19,8 +20,10 @@ from aiida.common import json from aiida.common.exceptions import ContentNotExistent, InvalidOperation from aiida.common.folders import SandboxFolder + from aiida.tools.importexport.common.config import NODES_EXPORT_SUBFOLDER from aiida.tools.importexport.common.exceptions import CorruptArchive +from aiida.tools.importexport.common.progress_bar import get_progress_bar, close_progress_bar __all__ = ('Archive', 'extract_zip', 'extract_tar', 'extract_tree') @@ -40,8 +43,9 @@ class Archive: FILENAME_DATA = 'data.json' FILENAME_METADATA = 'metadata.json' - def __init__(self, filepath): + def __init__(self, filepath, silent=True): self._filepath = filepath + self._silent = silent self._folder = None self._unpacked = False self._data = None @@ -79,9 +83,9 @@ def unpack(self): if os.path.isdir(self.filepath): extract_tree(self.filepath, self.folder) elif tarfile.is_tarfile(self.filepath): - extract_tar(self.filepath, self.folder, silent=True, nodes_export_subfolder=NODES_EXPORT_SUBFOLDER) + extract_tar(self.filepath, self.folder, silent=self._silent, nodes_export_subfolder=NODES_EXPORT_SUBFOLDER) elif zipfile.is_zipfile(self.filepath): - extract_zip(self.filepath, self.folder, silent=True, nodes_export_subfolder=NODES_EXPORT_SUBFOLDER) + extract_zip(self.filepath, self.folder, silent=self._silent, nodes_export_subfolder=NODES_EXPORT_SUBFOLDER) else: raise CorruptArchive('unrecognized archive format') @@ -218,9 +222,70 @@ def _read_json_file(self, filename): return json.load(fhandle) -def extract_zip(infile, folder, nodes_export_subfolder=None, silent=False): +def update_description(path, refresh: bool = False): + """Update description for a progress bar given path + + :param path: path of file or directory + :type path: str + :param refresh: Whether or not to refresh the progress bar with the new description. Default: False. + :type refresh: bool """ - Extract the nodes to be imported from a zip file. + (path, description) = os.path.split(path) + while description == '': + (path, description) = os.path.split(path) + description = 'EXTRACTING: {}'.format(description) + + progress_bar = get_progress_bar() + progress_bar.set_description_str(description, refresh=refresh) + + +def get_file_iterator(file_handle, folderpath, silent=True, **kwargs): # pylint: disable=unused-argument + """Go through JSON files and then return new file_iterator + + :param file_handle: A file handle returned from `with open() as file_handle:`. + :type file_handle: `tarfile.TarFile`, `zipfile.ZipFile` + + :param folderpath: Path to folder. + :type folderpath: str + + :param silent: suppress progress bar. + :type silent: bool + + :return: List of filenames in the archive, wrapped in the `tqdm` progress bar. + :rtype: `tqdm.tqdm` + """ + json_files = {'metadata.json', 'data.json'} + + if isinstance(file_handle, tarfile.TarFile): + file_format = 'tar' + elif isinstance(file_handle, zipfile.ZipFile): + file_format = 'zip' + else: + raise TypeError('Can only handle Tar or Zip files.') + + close_progress_bar(leave=False) + file_iterator = get_progress_bar(iterable=json_files, leave=False, disable=silent) + + for json_file in file_iterator: + update_description(json_file, file_iterator) + + try: + if file_format == 'tar': + file_handle.extract(path=folderpath, member=file_handle.getmember(json_file)) + else: + file_handle.extract(path=folderpath, member=json_file) + except KeyError: + raise CorruptArchive('required file `{}` is not included'.format(json_file)) + + close_progress_bar(leave=False) + if file_format == 'tar': + return get_progress_bar(iterable=file_handle.getmembers(), unit='files', leave=False, disable=silent) + # zip + return get_progress_bar(iterable=file_handle.namelist(), unit='files', leave=False, disable=silent) + + +def extract_zip(infile, folder, nodes_export_subfolder=None, **kwargs): + """Extract the nodes to be imported from a zip file. :param infile: file path :type infile: str @@ -231,7 +296,7 @@ def extract_zip(infile, folder, nodes_export_subfolder=None, silent=False): :param nodes_export_subfolder: name of the subfolder for AiiDA nodes :type nodes_export_subfolder: str - :param silent: suppress debug print + :param silent: suppress progress bar :type silent: bool :raises TypeError: if parameter types are not respected @@ -239,9 +304,6 @@ def extract_zip(infile, folder, nodes_export_subfolder=None, silent=False): incorrect formats """ # pylint: disable=fixme - if not silent: - print('READING DATA AND METADATA...') - if nodes_export_subfolder: if not isinstance(nodes_export_subfolder, str): raise TypeError('nodes_export_subfolder must be a string') @@ -252,33 +314,26 @@ def extract_zip(infile, folder, nodes_export_subfolder=None, silent=False): with zipfile.ZipFile(infile, 'r', allowZip64=True) as handle: if not handle.namelist(): - raise CorruptArchive('no files detected') - - try: - handle.extract(path=folder.abspath, member='metadata.json') - except KeyError: - raise CorruptArchive('required file `metadata.json` is not included') + raise CorruptArchive('no files detected in archive') - try: - handle.extract(path=folder.abspath, member='data.json') - except KeyError: - raise CorruptArchive('required file `data.json` is not included') + file_iterator = get_file_iterator(file_handle=handle, folderpath=folder.abspath, **kwargs) - if not silent: - print('EXTRACTING NODE DATA...') - - for membername in handle.namelist(): + for membername in file_iterator: # Check that we are only exporting nodes within the subfolder! # TODO: better check such that there are no .. in the # path; use probably the folder limit checks if not membername.startswith(nodes_export_subfolder + os.sep): continue + + update_description(membername, file_iterator) + handle.extract(path=folder.abspath, member=membername) except zipfile.BadZipfile: raise ValueError('The input file format for import is not valid (not a zip file)') + close_progress_bar(leave=False) -def extract_tar(infile, folder, nodes_export_subfolder=None, silent=False): +def extract_tar(infile, folder, nodes_export_subfolder=None, **kwargs): """ Extract the nodes to be imported from a (possibly zipped) tar file. @@ -291,7 +346,7 @@ def extract_tar(infile, folder, nodes_export_subfolder=None, silent=False): :param nodes_export_subfolder: name of the subfolder for AiiDA nodes :type nodes_export_subfolder: str - :param silent: suppress debug print + :param silent: suppress progress bar :type silent: bool :raises TypeError: if parameter types are not respected @@ -299,9 +354,6 @@ def extract_tar(infile, folder, nodes_export_subfolder=None, silent=False): incorrect formats """ # pylint: disable=fixme - if not silent: - print('READING DATA AND METADATA...') - if nodes_export_subfolder: if not isinstance(nodes_export_subfolder, str): raise TypeError('nodes_export_subfolder must be a string') @@ -311,20 +363,12 @@ def extract_tar(infile, folder, nodes_export_subfolder=None, silent=False): try: with tarfile.open(infile, 'r:*', format=tarfile.PAX_FORMAT) as handle: - try: - handle.extract(path=folder.abspath, member=handle.getmember('metadata.json')) - except KeyError: - raise CorruptArchive('required file `metadata.json` is not included') + if len(handle.getmembers()) == 1 and handle.getmembers()[0].size == 0: + raise CorruptArchive('no files detected in archive') - try: - handle.extract(path=folder.abspath, member=handle.getmember('data.json')) - except KeyError: - raise CorruptArchive('required file `data.json` is not included') + file_iterator = get_file_iterator(file_handle=handle, folderpath=folder.abspath, **kwargs) - if not silent: - print('EXTRACTING NODE DATA...') - - for member in handle.getmembers(): + for member in file_iterator: if member.isdev(): # safety: skip if character device, block device or FIFO print('WARNING, device found inside the import file: {}'.format(member.name), file=sys.stderr) @@ -332,16 +376,20 @@ def extract_tar(infile, folder, nodes_export_subfolder=None, silent=False): if member.issym() or member.islnk(): # safety: in export, I set dereference=True therefore # there should be no symbolic or hard links. - print('WARNING, link found inside the import file: {}'.format(member.name), file=sys.stderr) + print('WARNING, symlink found inside the import file: {}'.format(member.name), file=sys.stderr) continue # Check that we are only exporting nodes within the subfolder! # TODO: better check such that there are no .. in the # path; use probably the folder limit checks if not member.name.startswith(nodes_export_subfolder + os.sep): continue + + update_description(member.name, file_iterator) + handle.extract(path=folder.abspath, member=member) except tarfile.ReadError: - raise ValueError('The input file format for import is not valid (1)') + raise ValueError('The input file format for import is not valid (not a tar file)') + close_progress_bar(leave=False) def extract_tree(infile, folder): diff --git a/aiida/tools/importexport/common/config.py b/aiida/tools/importexport/common/config.py index 5f5a8e0751..3ef0b98de1 100644 --- a/aiida/tools/importexport/common/config.py +++ b/aiida/tools/importexport/common/config.py @@ -21,6 +21,9 @@ # The name of the subfolder in which the node files are stored NODES_EXPORT_SUBFOLDER = 'nodes' +# Progress bar +BAR_FORMAT = '{desc:40.40}{percentage:6.1f}%|{bar}| {n_fmt}/{total_fmt}' + # Giving names to the various entities. Attributes and links are not AiiDA # entities but we will refer to them as entities in the file (to simplify # references to them). diff --git a/aiida/tools/importexport/common/exceptions.py b/aiida/tools/importexport/common/exceptions.py index 3e641f1139..5db8fd1c0d 100644 --- a/aiida/tools/importexport/common/exceptions.py +++ b/aiida/tools/importexport/common/exceptions.py @@ -18,7 +18,7 @@ __all__ = ( 'ExportImportException', 'ArchiveExportError', 'ArchiveImportError', 'CorruptArchive', 'IncompatibleArchiveVersionError', 'ExportValidationError', 'ImportUniquenessError', 'ImportValidationError', - 'ArchiveMigrationError', 'MigrationValidationError', 'DanglingLinkError' + 'ArchiveMigrationError', 'MigrationValidationError', 'DanglingLinkError', 'ProgressBarError' ) @@ -67,3 +67,7 @@ class MigrationValidationError(ArchiveMigrationError): class DanglingLinkError(MigrationValidationError): """Raised when an export archive is detected to contain dangling links when importing.""" + + +class ProgressBarError(ExportImportException): + """Something is wrong with setting up the tqdm progress bar""" diff --git a/aiida/tools/importexport/common/progress_bar.py b/aiida/tools/importexport/common/progress_bar.py new file mode 100644 index 0000000000..f195a8bc12 --- /dev/null +++ b/aiida/tools/importexport/common/progress_bar.py @@ -0,0 +1,79 @@ +# pylint: disable=global-statement,too-many-branches +"""Have a single tqdm progress bar instance that should be handled using the functions in this module""" +from typing import Iterable + +from tqdm import tqdm + +from aiida.common.lang import type_check + +from aiida.tools.importexport.common.config import BAR_FORMAT +from aiida.tools.importexport.common.exceptions import ProgressBarError + +__all__ = ('get_progress_bar', 'close_progress_bar') + +PROGRESS_BAR = None + + +def get_progress_bar(iterable=None, total=None, leave=None, **kwargs): + """Set up, cache and return cached tqdm progress bar""" + global PROGRESS_BAR + + leave_default = False + + type_check(iterable, Iterable, allow_none=True) + type_check(total, int, allow_none=True) + type_check(leave, bool, allow_none=True) + + # iterable and total are mutually exclusive + if iterable is not None and total is not None: + if len(iterable) == total: + kwargs['iterable'] = iterable + else: + raise ProgressBarError('You can not set both "iterable" and "total" for the progress bar.') + elif iterable is None and total is None: + if PROGRESS_BAR is None: + kwargs['total'] = 1 + # Else pass: we guess the desired outcome is to retrieve the current progress bar + elif iterable is not None: + kwargs['iterable'] = iterable + elif total is not None: + kwargs['total'] = total + + if PROGRESS_BAR is None: + leave = leave if leave is not None else leave_default + PROGRESS_BAR = tqdm(bar_format=BAR_FORMAT, leave=leave, **kwargs) + elif 'iterable' in kwargs or 'total' in kwargs: + # Create a new progress bar + # We leave it up to the caller/creator to properly have set leave before we close the current progress bar + if leave is None: + leave = PROGRESS_BAR.leave if PROGRESS_BAR.leave is not None else leave_default + for attribute in ('desc', 'disable'): + if getattr(PROGRESS_BAR, attribute, None) is not None: + kwargs[attribute] = getattr(PROGRESS_BAR, attribute) + close_progress_bar() + PROGRESS_BAR = tqdm(bar_format=BAR_FORMAT, leave=leave, **kwargs) + else: + for attribute, value in kwargs.items(): + try: + setattr(PROGRESS_BAR, attribute, value) + except AttributeError: + raise ProgressBarError( + 'The given attribute {} either can not be set or does not exist for the progress bar.'. + format(attribute) + ) + + return PROGRESS_BAR + + +def close_progress_bar(leave=None): + """Close instantiated progress bar""" + global PROGRESS_BAR + + type_check(leave, bool, allow_none=True) + + if PROGRESS_BAR is not None: + if leave is not None: + PROGRESS_BAR.leave = leave + PROGRESS_BAR.close() + + PROGRESS_BAR = None diff --git a/aiida/tools/importexport/common/utils.py b/aiida/tools/importexport/common/utils.py index ae4843677f..1c5214f3e5 100644 --- a/aiida/tools/importexport/common/utils.py +++ b/aiida/tools/importexport/common/utils.py @@ -10,9 +10,9 @@ """ Utility functions for import/export of AiiDA entities """ # pylint: disable=inconsistent-return-statements,too-many-branches,too-many-return-statements # pylint: disable=too-many-nested-blocks,too-many-locals +from html.parser import HTMLParser import urllib.request import urllib.parse -from html.parser import HTMLParser from aiida.tools.importexport.common.config import ( NODE_ENTITY_NAME, GROUP_ENTITY_NAME, COMPUTER_ENTITY_NAME, USER_ENTITY_NAME, LOG_ENTITY_NAME, COMMENT_ENTITY_NAME diff --git a/aiida/tools/importexport/dbexport/__init__.py b/aiida/tools/importexport/dbexport/__init__.py index 28ad308711..6cf8019797 100644 --- a/aiida/tools/importexport/dbexport/__init__.py +++ b/aiida/tools/importexport/dbexport/__init__.py @@ -9,17 +9,20 @@ ########################################################################### # pylint: disable=fixme,too-many-branches,too-many-locals,too-many-statements,too-many-arguments """Provides export functionalities.""" - +import logging import os import tarfile import time from aiida import get_version, orm from aiida.common import json -from aiida.common.folders import RepositoryFolder +from aiida.common.exceptions import LicensingException +from aiida.common.folders import RepositoryFolder, SandboxFolder, Folder +from aiida.common.lang import type_check +from aiida.common.log import override_log_formatter, LOG_LEVEL_REPORT from aiida.orm.utils.repository import Repository -from aiida.tools.importexport.common import exceptions +from aiida.tools.importexport.common import exceptions, get_progress_bar, close_progress_bar from aiida.tools.importexport.common.config import EXPORT_VERSION, NODES_EXPORT_SUBFOLDER from aiida.tools.importexport.common.config import ( NODE_ENTITY_NAME, GROUP_ENTITY_NAME, COMPUTER_ENTITY_NAME, LOG_ENTITY_NAME, COMMENT_ENTITY_NAME @@ -29,31 +32,47 @@ ) from aiida.tools.importexport.common.utils import export_shard_uuid from aiida.tools.importexport.dbexport.utils import ( - check_licenses, fill_in_query, serialize_dict, check_process_nodes_sealed + check_licenses, fill_in_query, serialize_dict, check_process_nodes_sealed, summary, EXPORT_LOGGER, ExportFileFormat, + deprecated_parameters ) from .zip import ZipFolder -__all__ = ('export', 'export_zip') +__all__ = ('export', 'EXPORT_LOGGER', 'ExportFileFormat') -def export_zip(what, outfile='testzip', overwrite=False, silent=False, use_compression=True, **kwargs): - """Export in a zipped folder +def export( + entities=None, + filename=None, + file_format=ExportFileFormat.ZIP, + overwrite=False, + silent=False, + use_compression=True, + **kwargs +): + """Export AiiDA data + + .. deprecated:: 1.2.1 + Support for the parameters `what` and `outfile` will be removed in `v2.0.0`. + Please use `entities` and `filename` instead, respectively. + + :param entities: a list of entity instances; they can belong to different models/entities. + :type entities: list - :param what: a list of entity instances; they can belong to different models/entities. - :type what: list + :param filename: the filename (possibly including the absolute path) of the file on which to export. + :type filename: str - :param outfile: the filename (possibly including the absolute path) of the file on which to export. - :type outfile: str + :param file_format: See `ExportFileFormat` for complete list of valid values (default: 'zip'). + :type file_format: str, `ExportFileFormat` :param overwrite: if True, overwrite the output file without asking, if it exists. If False, raise an :py:class:`~aiida.tools.importexport.common.exceptions.ArchiveExportError` if the output file already exists. :type overwrite: bool - :param silent: suppress prints. + :param silent: suppress console prints and progress bar. :type silent: bool - :param use_compression: Whether or not to compress the zip file. + :param use_compression: Whether or not to compress the archive file (only valid for the zip file format). :type use_compression: bool :param allowed_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the @@ -66,11 +85,11 @@ def export_zip(what, outfile='testzip', overwrite=False, silent=False, use_compr otherwise. :type forbidden_licenses: list - :param include_comments: In-/exclude export of comments for given node(s) in ``what``. + :param include_comments: In-/exclude export of comments for given node(s) in ``entities``. Default: True, *include* comments in export (as well as relevant users). :type include_comments: bool - :param include_logs: In-/exclude export of logs for given node(s) in ``what``. + :param include_logs: In-/exclude export of logs for given node(s) in ``entities``. Default: True, *include* logs in export. :type include_logs: bool @@ -81,19 +100,194 @@ def export_zip(what, outfile='testzip', overwrite=False, silent=False, use_compr exporting. :raises `~aiida.common.exceptions.LicensingException`: if any node is licensed under forbidden license. """ - if not overwrite and os.path.exists(outfile): - raise exceptions.ArchiveExportError("the output file '{}' already exists".format(outfile)) + if file_format not in list(ExportFileFormat): + raise exceptions.ArchiveExportError( + 'Can only export in the formats: {}, please specify one for "file_format".'.format( + tuple(_.value for _ in ExportFileFormat) + ) + ) + + # Backwards-compatibility + entities = deprecated_parameters( + old={ + 'name': 'what', + 'value': kwargs.pop('what', None) + }, + new={ + 'name': 'entities', + 'value': entities + }, + ) + filename = deprecated_parameters( + old={ + 'name': 'outfile', + 'value': kwargs.pop('outfile', None) + }, + new={ + 'name': 'filename', + 'value': filename + }, + ) + + type_check(entities, (list, tuple, set), msg='`entities` must be specified and given as a list of AiiDA entities') + entities = list(entities) + if type_check(filename, str, allow_none=True) is None: + filename = 'export_data.aiida' + + if not overwrite and os.path.exists(filename): + raise exceptions.ArchiveExportError("The output file '{}' already exists".format(filename)) + + if silent: + logging.disable(level=logging.CRITICAL) + + if file_format == ExportFileFormat.TAR_GZIPPED: + file_format_verbose = 'Gzipped tarball (compressed)' + # Must be a zip then + elif use_compression: + file_format_verbose = 'Zip (compressed)' + else: + file_format_verbose = 'Zip (uncompressed)' + summary(file_format_verbose, filename, **kwargs) + + try: + if file_format == ExportFileFormat.TAR_GZIPPED: + times = export_tar(entities=entities, filename=filename, silent=silent, **kwargs) + else: # zip + times = export_zip( + entities=entities, filename=filename, use_compression=use_compression, silent=silent, **kwargs + ) + except (exceptions.ArchiveExportError, LicensingException) as exc: + if os.path.exists(filename): + os.remove(filename) + raise exc + + if len(times) == 2: + export_start, export_end = times + EXPORT_LOGGER.debug('Exported in %6.2g s.', export_end - export_start) + elif len(times) == 4: + export_start, export_end, compress_start, compress_end = times + EXPORT_LOGGER.debug( + 'Exported in %6.2g s, compressed in %6.2g s, total: %6.2g s.', export_end - export_start, + compress_end - compress_start, compress_end - export_start + ) + else: + EXPORT_LOGGER.debug('No information about the timing of the export.') + + # Reset logging level + if silent: + logging.disable(level=logging.NOTSET) - time_start = time.time() - with ZipFolder(outfile, mode='w', use_compression=use_compression) as folder: - export_tree(what, folder=folder, silent=silent, **kwargs) - if not silent: - print('File written in {:10.3g} s.'.format(time.time() - time_start)) +def export_zip(entities=None, filename=None, use_compression=True, **kwargs): + """Export in a zipped folder + + .. deprecated:: 1.2.1 + Support for the parameters `what` and `outfile` will be removed in `v2.0.0`. + Please use `entities` and `filename` instead, respectively. + :param entities: a list of entity instances; they can belong to different models/entities. + :type entities: list + + :param filename: the filename (possibly including the absolute path) of the file on which to export. + :type filename: str + + :param use_compression: Whether or not to compress the zip file. + :type use_compression: bool + """ + # Backwards-compatibility + entities = deprecated_parameters( + old={ + 'name': 'what', + 'value': kwargs.pop('what', None) + }, + new={ + 'name': 'entities', + 'value': entities + }, + ) + filename = deprecated_parameters( + old={ + 'name': 'outfile', + 'value': kwargs.pop('outfile', None) + }, + new={ + 'name': 'filename', + 'value': filename + }, + ) + + type_check(entities, (list, tuple, set), msg='`entities` must be specified and given as a list of AiiDA entities') + entities = list(entities) + + if type_check(filename, str, allow_none=True) is None: + filename = 'export_data.aiida' + + with ZipFolder(filename, mode='w', use_compression=use_compression) as folder: + time_start = time.time() + export_tree(entities=entities, folder=folder, **kwargs) + time_end = time.time() + + return (time_start, time_end) + + +def export_tar(entities=None, filename=None, **kwargs): + """Export the entries passed in the 'entities' list to a gzipped tar file. + + .. deprecated:: 1.2.1 + Support for the parameters `what` and `outfile` will be removed in `v2.0.0`. + Please use `entities` and `filename` instead, respectively. + + :param entities: a list of entity instances; they can belong to different models/entities. + :type entities: list + + :param filename: the filename (possibly including the absolute path) of the file on which to export. + :type filename: str + """ + # Backwards-compatibility + entities = deprecated_parameters( + old={ + 'name': 'what', + 'value': kwargs.pop('what', None) + }, + new={ + 'name': 'entities', + 'value': entities + }, + ) + filename = deprecated_parameters( + old={ + 'name': 'outfile', + 'value': kwargs.pop('outfile', None) + }, + new={ + 'name': 'filename', + 'value': filename + }, + ) + + type_check(entities, (list, tuple, set), msg='`entities` must be specified and given as a list of AiiDA entities') + entities = list(entities) + + if type_check(filename, str, allow_none=True) is None: + filename = 'export_data.aiida' + + with SandboxFolder() as folder: + time_export_start = time.time() + export_tree(entities=entities, folder=folder, **kwargs) + time_export_end = time.time() + + with tarfile.open(filename, 'w:gz', format=tarfile.PAX_FORMAT, dereference=True) as tar: + time_compress_start = time.time() + tar.add(folder.abspath, arcname='') + time_compress_end = time.time() + + return (time_export_start, time_export_end, time_compress_start, time_compress_end) + + +@override_log_formatter('%(message)s') def export_tree( - what, - folder, + entities=None, + folder=None, allowed_licenses=None, forbidden_licenses=None, silent=False, @@ -101,10 +295,13 @@ def export_tree( include_logs=True, **kwargs ): - """Export the entries passed in the 'what' list to a file tree. + """Export the entries passed in the 'entities' list to a file tree. + + .. deprecated:: 1.2.1 + Support for the parameter `what` will be removed in `v2.0.0`. Please use `entities` instead. - :param what: a list of entity instances; they can belong to different models/entities. - :type what: list + :param entities: a list of entity instances; they can belong to different models/entities. + :type entities: list :param folder: a temporary folder to build the archive before compression. :type folder: :py:class:`~aiida.common.folders.Folder` @@ -119,14 +316,14 @@ def export_tree( otherwise. :type forbidden_licenses: list - :param silent: suppress prints. + :param silent: suppress console prints and progress bar. :type silent: bool - :param include_comments: In-/exclude export of comments for given node(s) in ``what``. + :param include_comments: In-/exclude export of comments for given node(s) in ``entities``. Default: True, *include* comments in export (as well as relevant users). :type include_comments: bool - :param include_logs: In-/exclude export of logs for given node(s) in ``what``. + :param include_logs: In-/exclude export of logs for given node(s) in ``entities``. Default: True, *include* logs in export. :type include_logs: bool @@ -140,24 +337,46 @@ def export_tree( from collections import defaultdict from aiida.tools.graph.graph_traversers import get_nodes_export - if not silent: - print('STARTING EXPORT...') + if silent: + logging.disable(level=logging.CRITICAL) + + EXPORT_LOGGER.debug('STARTING EXPORT...') + + # Backwards-compatibility + entities = deprecated_parameters( + old={ + 'name': 'what', + 'value': kwargs.pop('what', None) + }, + new={ + 'name': 'entities', + 'value': entities + }, + ) + + type_check(entities, (list, tuple, set), msg='`entities` must be specified and given as a list of AiiDA entities') + entities = list(entities) + + type_check(folder, (Folder, ZipFolder), msg='`folder` must be specified and given as an AiiDA Folder entity') all_fields_info, unique_identifiers = get_all_fields_info() entities_starting_set = defaultdict(set) # The set that contains the nodes ids of the nodes that should be exported - given_data_entry_ids = set() - given_calculation_entry_ids = set() - given_group_entry_ids = set() - given_computer_entry_ids = set() - given_groups = set() + given_node_entry_ids = set() given_log_entry_ids = set() given_comment_entry_ids = set() + # Instantiate progress bar - go through list of `entities` + pbar_total = len(entities) + 1 if entities else 1 + progress_bar = get_progress_bar(total=pbar_total, leave=False, disable=silent) + progress_bar.set_description_str('Collecting chosen entities', refresh=False) + # I store a list of the actual dbnodes - for entry in what: + for entry in entities: + progress_bar.update() + # This returns the class name (as in imports). E.g. for a model node: # aiida.backends.djsite.db.models.DbNode # entry_class_string = get_class_string(entry) @@ -165,33 +384,26 @@ def export_tree( # entry_entity_name = schema_to_entity_names(entry_class_string) if issubclass(entry.__class__, orm.Group): entities_starting_set[GROUP_ENTITY_NAME].add(entry.uuid) - given_group_entry_ids.add(entry.id) - given_groups.add(entry) elif issubclass(entry.__class__, orm.Node): entities_starting_set[NODE_ENTITY_NAME].add(entry.uuid) - if issubclass(entry.__class__, orm.Data): - given_data_entry_ids.add(entry.pk) - elif issubclass(entry.__class__, orm.ProcessNode): - given_calculation_entry_ids.add(entry.pk) + given_node_entry_ids.add(entry.pk) elif issubclass(entry.__class__, orm.Computer): entities_starting_set[COMPUTER_ENTITY_NAME].add(entry.uuid) - given_computer_entry_ids.add(entry.pk) else: raise exceptions.ArchiveExportError( 'I was given {} ({}), which is not a Node, Computer, or Group instance'.format(entry, type(entry)) ) # Add all the nodes contained within the specified groups - if given_group_entry_ids: + if GROUP_ENTITY_NAME in entities_starting_set: - if not silent: - print('RETRIEVING NODES FROM GROUPS...') + progress_bar.set_description_str('Retrieving Nodes from Groups ...', refresh=True) # Use single query instead of given_group.nodes iterator for performance. qh_groups = orm.QueryBuilder().append( orm.Group, filters={ - 'id': { - 'in': given_group_entry_ids + 'uuid': { + 'in': entities_starting_set[GROUP_ENTITY_NAME] } }, tag='groups' ).queryhelp @@ -199,90 +411,89 @@ def export_tree( # Delete this import once the dbexport.zip module has been renamed from builtins import zip # pylint: disable=redefined-builtin - data_results = orm.QueryBuilder(**qh_groups).append(orm.Data, project=['id', 'uuid'], with_group='groups').all() - if data_results: - pks, uuids = map(list, zip(*data_results)) + node_results = orm.QueryBuilder(**qh_groups).append(orm.Node, project=['id', 'uuid'], with_group='groups').all() + if node_results: + pks, uuids = map(list, zip(*node_results)) entities_starting_set[NODE_ENTITY_NAME].update(uuids) - given_data_entry_ids.update(pks) - del data_results, pks, uuids + given_node_entry_ids.update(pks) + del node_results, pks, uuids - calc_results = orm.QueryBuilder(**qh_groups - ).append(orm.ProcessNode, project=['id', 'uuid'], with_group='groups').all() - if calc_results: - pks, uuids = map(list, zip(*calc_results)) - entities_starting_set[NODE_ENTITY_NAME].update(uuids) - given_calculation_entry_ids.update(pks) - del calc_results, pks, uuids - - for entity, entity_set in entities_starting_set.items(): - entities_starting_set[entity] = list(entity_set) + progress_bar.update() - # We will iteratively explore the AiiDA graph to find further nodes that - # should also be exported. + # We will iteratively explore the AiiDA graph to find further nodes that should also be exported. # At the same time, we will create the links_uuid list of dicts to be exported - if not silent: - print('RETRIEVING LINKED NODES AND STORING LINKS...') + progress_bar = get_progress_bar(total=1, disable=silent) + progress_bar.set_description_str('Getting provenance and storing links ...', refresh=True) - initial_nodes_ids = given_calculation_entry_ids.union(given_data_entry_ids) - traverse_output = get_nodes_export(starting_pks=initial_nodes_ids, get_links=True, **kwargs) - to_be_exported = traverse_output['nodes'] + traverse_output = get_nodes_export(starting_pks=given_node_entry_ids, get_links=True, **kwargs) + node_ids_to_be_exported = traverse_output['nodes'] graph_traversal_rules = traverse_output['rules'] - # I create a utility dictionary for mapping pk to uuid. - if traverse_output['nodes']: + # A utility dictionary for mapping PK to UUID. + if node_ids_to_be_exported: qbuilder = orm.QueryBuilder().append( orm.Node, project=('id', 'uuid'), filters={'id': { - 'in': traverse_output['nodes'] + 'in': node_ids_to_be_exported }}, ) - pk_2_uuid_dict = dict(qbuilder.all()) + node_pk_2_uuid_mapping = dict(qbuilder.all()) else: - pk_2_uuid_dict = {} + node_pk_2_uuid_mapping = {} # The set of tuples now has to be transformed to a list of dicts links_uuid = [{ - 'input': pk_2_uuid_dict[link.source_id], - 'output': pk_2_uuid_dict[link.target_id], + 'input': node_pk_2_uuid_mapping[link.source_id], + 'output': node_pk_2_uuid_mapping[link.target_id], 'label': link.link_label, 'type': link.link_type } for link in traverse_output['links']] + progress_bar.update() + + # Progress bar initialization - Entities + progress_bar = get_progress_bar(total=1, disable=silent) + progress_bar.set_description_str('Initializing export of all entities', refresh=True) + ## Universal "entities" attributed to all types of nodes # Logs - if include_logs and to_be_exported: + if include_logs and node_ids_to_be_exported: # Get related log(s) - universal for all nodes builder = orm.QueryBuilder() - builder.append(orm.Log, filters={'dbnode_id': {'in': to_be_exported}}, project='id') + builder.append(orm.Log, filters={'dbnode_id': {'in': node_ids_to_be_exported}}, project='uuid') res = set(builder.all(flat=True)) given_log_entry_ids.update(res) # Comments - if include_comments and to_be_exported: + if include_comments and node_ids_to_be_exported: # Get related log(s) - universal for all nodes builder = orm.QueryBuilder() - builder.append(orm.Comment, filters={'dbnode_id': {'in': to_be_exported}}, project='id') + builder.append(orm.Comment, filters={'dbnode_id': {'in': node_ids_to_be_exported}}, project='uuid') res = set(builder.all(flat=True)) given_comment_entry_ids.update(res) - # Here we get all the columns that we plan to project per entity that we - # would like to extract - given_entities = list() - if given_group_entry_ids: - given_entities.append(GROUP_ENTITY_NAME) - if to_be_exported: - given_entities.append(NODE_ENTITY_NAME) - if given_computer_entry_ids: - given_entities.append(COMPUTER_ENTITY_NAME) + # Here we get all the columns that we plan to project per entity that we would like to extract + given_entities = set(entities_starting_set.keys()) + if node_ids_to_be_exported: + given_entities.add(NODE_ENTITY_NAME) if given_log_entry_ids: - given_entities.append(LOG_ENTITY_NAME) + given_entities.add(LOG_ENTITY_NAME) if given_comment_entry_ids: - given_entities.append(COMMENT_ENTITY_NAME) + given_entities.add(COMMENT_ENTITY_NAME) + + progress_bar.update() + + if given_entities: + progress_bar = get_progress_bar(total=len(given_entities), disable=silent) + pbar_base_str = 'Preparing entities' entries_to_add = dict() for given_entity in given_entities: + progress_bar.set_description_str(pbar_base_str + ' - {}s'.format(given_entity), refresh=False) + progress_bar.update() + project_cols = ['id'] # The following gets a list of fields that we need, # e.g. user, mtime, uuid, computer @@ -298,22 +509,20 @@ def export_tree( project_cols.append(nprop) # Getting the ids that correspond to the right entity - if given_entity == GROUP_ENTITY_NAME: - entry_ids_to_add = given_group_entry_ids + entry_uuids_to_add = entities_starting_set.get(given_entity, set()) + if not entry_uuids_to_add: + if given_entity == LOG_ENTITY_NAME: + entry_uuids_to_add = given_log_entry_ids + elif given_entity == COMMENT_ENTITY_NAME: + entry_uuids_to_add = given_comment_entry_ids elif given_entity == NODE_ENTITY_NAME: - entry_ids_to_add = to_be_exported - elif given_entity == COMPUTER_ENTITY_NAME: - entry_ids_to_add = given_computer_entry_ids - elif given_entity == LOG_ENTITY_NAME: - entry_ids_to_add = given_log_entry_ids - elif given_entity == COMMENT_ENTITY_NAME: - entry_ids_to_add = given_comment_entry_ids + entry_uuids_to_add.update({node_pk_2_uuid_mapping[_] for _ in node_ids_to_be_exported}) builder = orm.QueryBuilder() builder.append( entity_names_to_entities[given_entity], - filters={'id': { - 'in': entry_ids_to_add + filters={'uuid': { + 'in': entry_uuids_to_add }}, project=project_cols, tag=given_entity, @@ -325,7 +534,11 @@ def export_tree( # Check the licenses of exported data. if allowed_licenses is not None or forbidden_licenses is not None: builder = orm.QueryBuilder() - builder.append(orm.Node, project=['id', 'attributes.source.license'], filters={'id': {'in': to_be_exported}}) + builder.append( + orm.Node, project=['id', 'attributes.source.license'], filters={'id': { + 'in': node_ids_to_be_exported + }} + ) # Skip those nodes where the license is not set (this is the standard behavior with Django) node_licenses = list((a, b) for [a, b] in builder.all() if b is not None) check_licenses(node_licenses, allowed_licenses, forbidden_licenses) @@ -333,100 +546,106 @@ def export_tree( ############################################################ ##### Start automatic recursive export data generation ##### ############################################################ - if not silent: - print('STORING DATABASE ENTRIES...') + EXPORT_LOGGER.debug('GATHERING DATABASE ENTRIES...') - export_data = dict() + if entries_to_add: + progress_bar = get_progress_bar(total=len(entries_to_add), disable=silent) + + export_data = defaultdict(dict) entity_separator = '_' for entity_name, partial_query in entries_to_add.items(): - foreign_fields = { - k: v - for k, v in all_fields_info[entity_name].items() - # all_fields_info[model_name].items() - if 'requires' in v - } + progress_bar.set_description_str('Exporting {}s'.format(entity_name), refresh=False) + progress_bar.update() + + foreign_fields = {k: v for k, v in all_fields_info[entity_name].items() if 'requires' in v} for value in foreign_fields.values(): ref_model_name = value['requires'] fill_in_query(partial_query, entity_name, ref_model_name, [entity_name], entity_separator) for temp_d in partial_query.iterdict(): - for k in temp_d.keys(): + for key in temp_d: # Get current entity - current_entity = k.split(entity_separator)[-1] + current_entity = key.split(entity_separator)[-1] # This is a empty result of an outer join. # It should not be taken into account. - if temp_d[k]['id'] is None: + if temp_d[key]['id'] is None: continue - temp_d2 = { - temp_d[k]['id']: + export_data[current_entity].update({ + temp_d[key]['id']: serialize_dict( - temp_d[k], remove_fields=['id'], rename_fields=model_fields_to_file_fields[current_entity] + temp_d[key], remove_fields=['id'], rename_fields=model_fields_to_file_fields[current_entity] ) - } - try: - export_data[current_entity].update(temp_d2) - except KeyError: - export_data[current_entity] = temp_d2 + }) + + # Close progress up until this point in order to print properly + close_progress_bar(leave=False) ####################################### # Manually manage attributes and extras ####################################### - # I use .get because there may be no nodes to export - all_nodes_pk = list() - if NODE_ENTITY_NAME in export_data: - all_nodes_pk.extend(export_data.get(NODE_ENTITY_NAME).keys()) - - if sum(len(model_data) for model_data in export_data.values()) == 0: - if not silent: - print('No nodes to store, exiting...') + # Pointer. Renaming, since Nodes have now technically been retrieved and "stored" + all_node_pks = node_ids_to_be_exported + + model_data = sum(len(model_data) for model_data in export_data.values()) + if not model_data: + EXPORT_LOGGER.log(msg='Nothing to store, exiting...', level=LOG_LEVEL_REPORT) return + EXPORT_LOGGER.log( + msg='Exporting a total of {} database entries, of which {} are Nodes.'.format(model_data, len(all_node_pks)), + level=LOG_LEVEL_REPORT + ) - if not silent: - print( - 'Exporting a total of {} db entries, of which {} nodes.'.format( - sum(len(model_data) for model_data in export_data.values()), len(all_nodes_pk) - ) - ) + # Instantiate new progress bar + progress_bar = get_progress_bar(total=1, leave=False, disable=silent) # ATTRIBUTES and EXTRAS - if not silent: - print('STORING NODE ATTRIBUTES AND EXTRAS...') + EXPORT_LOGGER.debug('GATHERING NODE ATTRIBUTES AND EXTRAS...') node_attributes = {} node_extras = {} - # A second QueryBuilder query to get the attributes and extras. See if this can be optimized - if all_nodes_pk: - all_nodes_query = orm.QueryBuilder() - all_nodes_query.append(orm.Node, filters={'id': {'in': all_nodes_pk}}, project=['id', 'attributes', 'extras']) - for res_pk, res_attributes, res_extras in all_nodes_query.iterall(): - node_attributes[str(res_pk)] = res_attributes - node_extras[str(res_pk)] = res_extras - - if not silent: - print('STORING GROUP ELEMENTS...') - groups_uuid = dict() - # If a group is in the exported date, we export the group/node correlation + # Another QueryBuilder query to get the attributes and extras. TODO: See if this can be optimized + if all_node_pks: + all_nodes_query = orm.QueryBuilder().append( + orm.Node, filters={'id': { + 'in': all_node_pks + }}, project=['id', 'attributes', 'extras'] + ) + + progress_bar = get_progress_bar(total=all_nodes_query.count(), disable=silent) + progress_bar.set_description_str('Exporting Attributes and Extras', refresh=False) + + for node_pk, attributes, extras in all_nodes_query.iterall(): + progress_bar.update() + + node_attributes[str(node_pk)] = attributes + node_extras[str(node_pk)] = extras + + EXPORT_LOGGER.debug('GATHERING GROUP ELEMENTS...') + groups_uuid = defaultdict(list) + # If a group is in the exported data, we export the group/node correlation if GROUP_ENTITY_NAME in export_data: - for curr_group in export_data[GROUP_ENTITY_NAME]: - group_uuid_qb = orm.QueryBuilder() - group_uuid_qb.append( - entity_names_to_entities[GROUP_ENTITY_NAME], - filters={'id': { - '==': curr_group - }}, - project=['uuid'], - tag='group' - ) - group_uuid_qb.append(entity_names_to_entities[NODE_ENTITY_NAME], project=['uuid'], with_group='group') - for res in group_uuid_qb.iterall(): - if str(res[0]) in groups_uuid: - groups_uuid[str(res[0])].append(str(res[1])) - else: - groups_uuid[str(res[0])] = [str(res[1])] + group_uuids_with_node_uuids = orm.QueryBuilder().append( + orm.Group, filters={ + 'id': { + 'in': export_data[GROUP_ENTITY_NAME] + } + }, project='uuid', tag='groups' + ).append(orm.Node, project='uuid', with_group='groups') + + # This part is _only_ for the progress bar + total_node_uuids_for_groups = group_uuids_with_node_uuids.count() + if total_node_uuids_for_groups: + progress_bar = get_progress_bar(total=total_node_uuids_for_groups, disable=silent) + progress_bar.set_description_str('Exporting Groups ...', refresh=False) + + for group_uuid, node_uuid in group_uuids_with_node_uuids.iterall(): + progress_bar.update() + + groups_uuid[group_uuid].append(node_uuid) ####################################### # Final check for unsealed ProcessNodes @@ -439,13 +658,12 @@ def export_tree( check_process_nodes_sealed(process_nodes) ###################################### - # Now I store + # Now collecting and storing ###################################### # subfolder inside the export package nodesubfolder = folder.get_subfolder(NODES_EXPORT_SUBFOLDER, create=True, reset_limit=True) - if not silent: - print('STORING DATA...') + EXPORT_LOGGER.debug('ADDING DATA TO EXPORT ARCHIVE...') data = { 'node_attributes': node_attributes, @@ -460,8 +678,9 @@ def export_tree( # fhandle.write(json.dumps(data, cls=UUIDEncoder)) fhandle.write(json.dumps(data)) - # Add proper signature to unique identifiers & all_fields_info - # Ignore if a key doesn't exist in any of the two dictionaries + # Turn sets into lists to be able to export them as JSON metadata. + for entity, entity_set in entities_starting_set.items(): + entities_starting_set[entity] = list(entity_set) metadata = { 'aiida_version': get_version(), @@ -479,18 +698,21 @@ def export_tree( with folder.open('metadata.json', 'w') as fhandle: fhandle.write(json.dumps(metadata)) - if silent is not True: - print('STORING REPOSITORY FILES...') + EXPORT_LOGGER.debug('ADDING REPOSITORY FILES TO EXPORT ARCHIVE...') # If there are no nodes, there are no repository files to store - if all_nodes_pk: - # Large speed increase by not getting the node itself and looping in memory in python, but just getting the uuid - uuid_query = orm.QueryBuilder() - uuid_query.append(orm.Node, filters={'id': {'in': all_nodes_pk}}, project=['uuid']) - for res in uuid_query.all(): - uuid = str(res[0]) + if all_node_pks: + all_node_uuids = {node_pk_2_uuid_mapping[_] for _ in all_node_pks} + + progress_bar = get_progress_bar(total=len(all_node_uuids), disable=silent) + pbar_base_str = 'Exporting repository - ' + + for uuid in all_node_uuids: sharded_uuid = export_shard_uuid(uuid) + progress_bar.set_description_str(pbar_base_str + 'UUID={}'.format(uuid.split('-')[0]), refresh=False) + progress_bar.update() + # Important to set create=False, otherwise creates twice a subfolder. Maybe this is a bug of insert_path? thisnodefolder = nodesubfolder.get_subfolder(sharded_uuid, create=False, reset_limit=True) @@ -504,75 +726,8 @@ def export_tree( # In this way, I copy the content of the folder, and not the folder itself thisnodefolder.insert_path(src=src.abspath, dest_name='.') + close_progress_bar(leave=False) -def export(what, outfile='export_data.aiida.tar.gz', overwrite=False, silent=False, **kwargs): - """Export the entries passed in the 'what' list to a file tree. - - :param what: a list of entity instances; they can belong to different models/entities. - :type what: list - - :param outfile: the filename (possibly including the absolute path) of the file on which to export. - :type outfile: str - - :param overwrite: if True, overwrite the output file without asking, if it exists. If False, raise an - :py:class:`~aiida.tools.importexport.common.exceptions.ArchiveExportError` if the output file already exists. - :type overwrite: bool - - :param silent: suppress prints. - :type silent: bool - - :param allowed_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the - list. If a function, then calls function for licenses of Data nodes expecting True if license is allowed, False - otherwise. - :type allowed_licenses: list - - :param forbidden_licenses: List or function. If a list, then checks whether all licenses of Data nodes are in the - list. If a function, then calls function for licenses of Data nodes expecting True if license is allowed, False - otherwise. - :type forbidden_licenses: list - - :param include_comments: In-/exclude export of comments for given node(s) in ``what``. - Default: True, *include* comments in export (as well as relevant users). - :type include_comments: bool - - :param include_logs: In-/exclude export of logs for given node(s) in ``what``. - Default: True, *include* logs in export. - :type include_logs: bool - - :param kwargs: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` what rule names - are toggleable and what the defaults are. - - :raises `~aiida.tools.importexport.common.exceptions.ArchiveExportError`: if there are any internal errors when - exporting. - :raises `~aiida.common.exceptions.LicensingException`: if any node is licensed under forbidden license. - """ - from aiida.common.folders import SandboxFolder - - if not overwrite and os.path.exists(outfile): - raise exceptions.ArchiveExportError("The output file '{}' already exists".format(outfile)) - - folder = SandboxFolder() - time_export_start = time.time() - export_tree(what, folder=folder, silent=silent, **kwargs) - - time_export_end = time.time() - - if not silent: - print('COMPRESSING...') - - time_compress_start = time.time() - with tarfile.open(outfile, 'w:gz', format=tarfile.PAX_FORMAT, dereference=True) as tar: - tar.add(folder.abspath, arcname='') - time_compress_end = time.time() - - if not silent: - filecr_time = time_export_end - time_export_start - filecomp_time = time_compress_end - time_compress_start - print( - 'Exported in {:6.2g}s, compressed in {:6.2g}s, total: {:6.2g}s.'.format( - filecr_time, filecomp_time, filecr_time + filecomp_time - ) - ) - - if not silent: - print('DONE.') + # Reset logging level + if silent: + logging.disable(level=logging.NOTSET) diff --git a/aiida/tools/importexport/dbexport/utils.py b/aiida/tools/importexport/dbexport/utils.py index af73467ae9..84306c218a 100644 --- a/aiida/tools/importexport/dbexport/utils.py +++ b/aiida/tools/importexport/dbexport/utils.py @@ -9,13 +9,26 @@ ########################################################################### """ Utility functions for export of AiiDA entities """ # pylint: disable=too-many-locals,too-many-branches,too-many-nested-blocks +from enum import Enum +import warnings from aiida.orm import QueryBuilder, ProcessNode +from aiida.common.log import AIIDA_LOGGER, LOG_LEVEL_REPORT, override_log_formatter +from aiida.common.warnings import AiidaDeprecationWarning + from aiida.tools.importexport.common import exceptions from aiida.tools.importexport.common.config import ( file_fields_to_model_fields, entity_names_to_entities, get_all_fields_info ) +EXPORT_LOGGER = AIIDA_LOGGER.getChild('export') + + +class ExportFileFormat(str, Enum): + """Export file formats""" + ZIP = 'zip' + TAR_GZIPPED = 'tar.gz' + def fill_in_query(partial_query, originating_entity_str, current_entity_str, tag_suffixes=None, entity_separator='_'): """ @@ -270,3 +283,54 @@ def check_process_nodes_sealed(nodes): 'All ProcessNodes must be sealed before they can be exported. ' 'Node(s) with PK(s): {} is/are not sealed.'.format(', '.join(str(pk) for pk in nodes - sealed_nodes)) ) + + +@override_log_formatter('%(message)s') +def summary(file_format, outfile, **kwargs): + """Print summary for export""" + from tabulate import tabulate + from aiida.tools.importexport.common.config import EXPORT_VERSION + + parameters = [['Archive', outfile], ['Format', file_format], ['Export version', EXPORT_VERSION]] + + result = '\n{}'.format(tabulate(parameters, headers=['EXPORT', ''])) + + include_comments = kwargs.get('include_comments', True) + include_logs = kwargs.get('include_logs', True) + input_forward = kwargs.get('input_forward', False) + create_reversed = kwargs.get('create_reversed', True) + return_reversed = kwargs.get('return_reversed', False) + call_reversed = kwargs.get('call_reversed', False) + + inclusions = [['Include Comments', include_comments], ['Include Logs', include_logs]] + result += '\n\n{}'.format(tabulate(inclusions, headers=['Inclusion rules', ''])) + + traversal_rules = [['Follow INPUT Links forwards', + input_forward], ['Follow CREATE Links backwards', create_reversed], + ['Follow RETURN Links backwards', return_reversed], + ['Follow CALL Links backwards', call_reversed]] + result += '\n\n{}\n'.format(tabulate(traversal_rules, headers=['Traversal rules', ''])) + + EXPORT_LOGGER.log(msg=result, level=LOG_LEVEL_REPORT) + + +def deprecated_parameters(old, new): + """Handle deprecated parameter (where it is replaced with another) + + :param old: The old, deprecated parameter as a dict with keys "name" and "value" + :type old: dict + + :param new: The new parameter as a dict with keys "name" and "value" + :type new: dict + + :return: New parameter's value (if not defined, then old parameter's value) + """ + if old.get('value', None) is not None: + if new.get('value', None) is not None: + message = '`{}` is deprecated, the supplied `{}` input will be used'.format(old['name'], new['name']) + else: + message = '`{}` is deprecated, please use `{}` instead'.format(old['name'], new['name']) + new['value'] = old['value'] + warnings.warn(message, AiidaDeprecationWarning) # pylint: disable=no-member + + return new['value'] diff --git a/aiida/tools/importexport/dbimport/__init__.py b/aiida/tools/importexport/dbimport/__init__.py index 0d27785bf1..5e5a44c20c 100644 --- a/aiida/tools/importexport/dbimport/__init__.py +++ b/aiida/tools/importexport/dbimport/__init__.py @@ -8,8 +8,9 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Provides import functionalities.""" +from aiida.tools.importexport.dbimport.utils import IMPORT_LOGGER -__all__ = ('import_data',) +__all__ = ('import_data', 'IMPORT_LOGGER') def import_data(in_path, group=None, silent=False, **kwargs): @@ -64,13 +65,17 @@ def import_data(in_path, group=None, silent=False, **kwargs): from aiida.backends import BACKEND_DJANGO, BACKEND_SQLA from aiida.tools.importexport.common.exceptions import ArchiveImportError - if configuration.PROFILE.database_backend == BACKEND_SQLA: + backend = configuration.PROFILE.database_backend + + if backend == BACKEND_SQLA: from aiida.tools.importexport.dbimport.backends.sqla import import_data_sqla + IMPORT_LOGGER.debug('Calling import function import_data_sqla for the %s backend.', backend) return import_data_sqla(in_path, group=group, silent=silent, **kwargs) - if configuration.PROFILE.database_backend == BACKEND_DJANGO: + if backend == BACKEND_DJANGO: from aiida.tools.importexport.dbimport.backends.django import import_data_dj + IMPORT_LOGGER.debug('Calling import function import_data_dj for the %s backend.', backend) return import_data_dj(in_path, group=group, silent=silent, **kwargs) # else - raise ArchiveImportError('Unknown backend: {}'.format(configuration.PROFILE.database_backend)) + raise ArchiveImportError('Unknown backend: {}'.format(backend)) diff --git a/aiida/tools/importexport/dbimport/backends/django/__init__.py b/aiida/tools/importexport/dbimport/backends/django/__init__.py index d19c8d4075..877b5f7119 100644 --- a/aiida/tools/importexport/dbimport/backends/django/__init__.py +++ b/aiida/tools/importexport/dbimport/backends/django/__init__.py @@ -7,10 +7,11 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=protected-access,fixme,inconsistent-return-statements,too-many-arguments,too-many-locals,too-many-statements,too-many-branches +# pylint: disable=protected-access,fixme,too-many-arguments,too-many-locals,too-many-statements,too-many-branches,too-many-nested-blocks """ Django-specific import of AiiDA entities """ from distutils.version import StrictVersion +import logging import os import tarfile import zipfile @@ -19,21 +20,26 @@ from aiida.common import timezone, json from aiida.common.folders import SandboxFolder, RepositoryFolder from aiida.common.links import LinkType, validate_link_label +from aiida.common.log import override_log_formatter from aiida.common.utils import grouper, get_object_from_string +from aiida.manage.configuration import get_config_option from aiida.orm.utils.repository import Repository from aiida.orm import QueryBuilder, Node, Group, ImportGroup -from aiida.tools.importexport.common import exceptions + +from aiida.tools.importexport.common import exceptions, get_progress_bar, close_progress_bar from aiida.tools.importexport.common.archive import extract_tree, extract_tar, extract_zip -from aiida.tools.importexport.common.config import DUPL_SUFFIX, EXPORT_VERSION, NODES_EXPORT_SUBFOLDER +from aiida.tools.importexport.common.config import DUPL_SUFFIX, EXPORT_VERSION, NODES_EXPORT_SUBFOLDER, BAR_FORMAT from aiida.tools.importexport.common.config import ( NODE_ENTITY_NAME, GROUP_ENTITY_NAME, COMPUTER_ENTITY_NAME, USER_ENTITY_NAME, LOG_ENTITY_NAME, COMMENT_ENTITY_NAME ) from aiida.tools.importexport.common.config import entity_names_to_signatures from aiida.tools.importexport.common.utils import export_shard_uuid -from aiida.tools.importexport.dbimport.backends.utils import deserialize_field, merge_comment, merge_extras -from aiida.manage.configuration import get_config_option +from aiida.tools.importexport.dbimport.utils import ( + deserialize_field, merge_comment, merge_extras, start_summary, result_summary, IMPORT_LOGGER +) +@override_log_formatter('%(message)s') def import_data_dj( in_path, group=None, @@ -41,7 +47,8 @@ def import_data_dj( extras_mode_existing='kcl', extras_mode_new='import', comment_mode='newest', - silent=False + silent=False, + **kwargs ): """Import exported AiiDA archive to the AiiDA database and repository. @@ -81,7 +88,7 @@ def import_data_dj( 'overwrite' (will overwrite existing Comments with the ones from the import file). :type comment_mode: str - :param silent: suppress prints. + :param silent: suppress progress bar and summary. :type silent: bool :return: New and existing Nodes and Links. @@ -114,6 +121,9 @@ def import_data_dj( elif not group.is_stored: group.store() + if silent: + logging.disable(level=logging.CRITICAL) + ################ # EXTRACT DATA # ################ @@ -123,17 +133,13 @@ def import_data_dj( extract_tree(in_path, folder) else: if tarfile.is_tarfile(in_path): - extract_tar(in_path, folder, silent=silent, nodes_export_subfolder=NODES_EXPORT_SUBFOLDER) + extract_tar(in_path, folder, silent=silent, nodes_export_subfolder=NODES_EXPORT_SUBFOLDER, **kwargs) elif zipfile.is_zipfile(in_path): - try: - extract_zip(in_path, folder, silent=silent, nodes_export_subfolder=NODES_EXPORT_SUBFOLDER) - except ValueError as exc: - print('The following problem occured while processing the provided file: {}'.format(exc)) - return + extract_zip(in_path, folder, silent=silent, nodes_export_subfolder=NODES_EXPORT_SUBFOLDER, **kwargs) else: raise exceptions.ImportValidationError( 'Unable to detect the input file format, it is neither a ' - '(possibly compressed) tar file, nor a zip file.' + 'tar file, nor a (possibly compressed) zip file.' ) if not folder.get_content_list(): @@ -163,6 +169,8 @@ def import_data_dj( raise exceptions.IncompatibleArchiveVersionError(msg) + start_summary(in_path, comment_mode, extras_mode_new, extras_mode_existing) + ########################################################################## # CREATE UUID REVERSE TABLES AND CHECK IF I HAVE ALL NODES FOR THE LINKS # ########################################################################## @@ -189,7 +197,6 @@ def import_data_dj( # DOUBLE-CHECK MODEL DEPENDENCIES # ################################### # The entity import order. It is defined by the database model relationships. - model_order = ( USER_ENTITY_NAME, COMPUTER_ENTITY_NAME, NODE_ENTITY_NAME, GROUP_ENTITY_NAME, LOG_ENTITY_NAME, COMMENT_ENTITY_NAME @@ -230,6 +237,7 @@ def import_data_dj( # IMPORT DATA # ############### # DO ALL WITH A TRANSACTION + # !!! EXCEPT: Creating final import Group containing all Nodes in archive # batch size for bulk create operations batch_size = get_config_option('db.batch_size') @@ -239,6 +247,16 @@ def import_data_dj( new_entries = {} existing_entries = {} + IMPORT_LOGGER.debug('GENERATING LIST OF DATA...') + + # Instantiate progress bar + progress_bar = get_progress_bar(total=1, leave=False, disable=silent) + pbar_base_str = 'Generating list of data - ' + + # Get total entities from data.json + # To be used with progress bar + number_of_entities = 0 + # I first generate the list of data for model_name in model_order: cls_signature = entity_names_to_signatures[model_name] @@ -254,41 +272,102 @@ def import_data_dj( # Not necessarily all models are exported if model_name in data['export_data']: + IMPORT_LOGGER.debug(' %s...', model_name) + + progress_bar.set_description_str(pbar_base_str + model_name, refresh=False) + number_of_entities += len(data['export_data'][model_name]) + # skip nodes that are already present in the DB if unique_identifier is not None: import_unique_ids = set(v[unique_identifier] for v in data['export_data'][model_name].values()) - relevant_db_entries_result = model.objects.filter( - **{'{}__in'.format(unique_identifier): import_unique_ids} - ) - # Note: uuids need to be converted to strings - relevant_db_entries = { - str(getattr(n, unique_identifier)): n for n in relevant_db_entries_result - } + relevant_db_entries = {} + if import_unique_ids: + relevant_db_entries_result = model.objects.filter( + **{'{}__in'.format(unique_identifier): import_unique_ids} + ) + + # Note: UUIDs need to be converted to strings + if relevant_db_entries_result.count(): + progress_bar = get_progress_bar( + total=relevant_db_entries_result.count(), disable=silent + ) + # Imitating QueryBuilder.iterall() with default settings + for object_ in relevant_db_entries_result.iterator(chunk_size=100): + progress_bar.update() + relevant_db_entries.update({str(getattr(object_, unique_identifier)): object_}) foreign_ids_reverse_mappings[model_name] = {k: v.pk for k, v in relevant_db_entries.items()} + + IMPORT_LOGGER.debug(' GOING THROUGH ARCHIVE...') + + imported_comp_names = set() for key, value in data['export_data'][model_name].items(): - if value[unique_identifier] in relevant_db_entries.keys(): + if model_name == GROUP_ENTITY_NAME: + # Check if there is already a group with the same name + dupl_counter = 0 + orig_label = value['label'] + while model.objects.filter(label=value['label']): + value['label'] = orig_label + DUPL_SUFFIX.format(dupl_counter) + dupl_counter += 1 + if dupl_counter == 100: + raise exceptions.ImportUniquenessError( + 'A group of that label ( {} ) already exists and I could not create a new ' + 'one'.format(orig_label) + ) + + elif model_name == COMPUTER_ENTITY_NAME: + # Check if there is already a computer with the same name in the database + dupl = ( + model.objects.filter(name=value['name']) or value['name'] in imported_comp_names + ) + orig_name = value['name'] + dupl_counter = 0 + while dupl: + # Rename the new computer + value['name'] = orig_name + DUPL_SUFFIX.format(dupl_counter) + dupl = ( + model.objects.filter(name=value['name']) or value['name'] in imported_comp_names + ) + dupl_counter += 1 + if dupl_counter == 100: + raise exceptions.ImportUniquenessError( + 'A computer of that name ( {} ) already exists and I could not create a ' + 'new one'.format(orig_name) + ) + + imported_comp_names.add(value['name']) + + if value[unique_identifier] in relevant_db_entries: # Already in DB existing_entries[model_name][key] = value else: # To be added new_entries[model_name][key] = value else: - new_entries[model_name] = data['export_data'][model_name].copy() + new_entries[model_name] = data['export_data'][model_name] - # Show Comment mode if not silent - if not silent: - print('Comment mode: {}'.format(comment_mode)) + # Reset for import + progress_bar = get_progress_bar(total=number_of_entities, disable=silent) # I import data from the given model for model_name in model_order: + # Progress bar initialization - Model + pbar_base_str = '{}s - '.format(model_name) + progress_bar.set_description_str(pbar_base_str + 'Initializing', refresh=True) + cls_signature = entity_names_to_signatures[model_name] model = get_object_from_string(cls_signature) fields_info = metadata['all_fields_info'].get(model_name, {}) unique_identifier = metadata['unique_identifiers'].get(model_name, None) # EXISTING ENTRIES + if existing_entries[model_name]: + # Progress bar update - Model + progress_bar.set_description_str( + pbar_base_str + '{} existing entries'.format(len(existing_entries[model_name])), refresh=True + ) + for import_entry_pk, entry_data in existing_entries[model_name].items(): unique_id = entry_data[unique_identifier] existing_entry_id = foreign_ids_reverse_mappings[model_name][unique_id] @@ -312,18 +391,24 @@ def import_data_dj( if model_name not in ret_dict: ret_dict[model_name] = {'new': [], 'existing': []} ret_dict[model_name]['existing'].append((import_entry_pk, existing_entry_id)) - if not silent: - print('existing %s: %s (%s->%s)' % (model_name, unique_id, import_entry_pk, existing_entry_id)) - # print(" `-> WARNING: NO DUPLICITY CHECK DONE!") - # CHECK ALSO FILES! + IMPORT_LOGGER.debug( + 'Existing %s: %s (%s->%s)', model_name, unique_id, import_entry_pk, existing_entry_id + ) + # print(' `-> WARNING: NO DUPLICITY CHECK DONE!') + # CHECK ALSO FILES! # Store all objects for this model in a list, and store them all in once at the end. objects_to_create = [] # This is needed later to associate the import entry with the new pk import_new_entry_pks = {} - imported_comp_names = set() # NEW ENTRIES + if new_entries[model_name]: + # Progress bar update - Model + progress_bar.set_description_str( + pbar_base_str + '{} new entries'.format(len(new_entries[model_name])), refresh=True + ) + for import_entry_pk, entry_data in new_entries[model_name].items(): unique_id = entry_data[unique_identifier] import_data = dict( @@ -336,54 +421,21 @@ def import_data_dj( ) for k, v in entry_data.items() ) - if model is models.DbGroup: - # Check if there is already a group with the same name - dupl_counter = 0 - orig_label = import_data['label'] - while model.objects.filter(label=import_data['label']): - import_data['label'] = orig_label + DUPL_SUFFIX.format(dupl_counter) - dupl_counter += 1 - if dupl_counter == 100: - raise exceptions.ImportUniquenessError( - 'A group of that label ( {} ) already exists and I could not create a new one' - ''.format(orig_label) - ) - - elif model is models.DbComputer: - # Check if there is already a computer with the same name in the database - dupl = ( - model.objects.filter(name=import_data['name']) or import_data['name'] in imported_comp_names - ) - orig_name = import_data['name'] - dupl_counter = 0 - while dupl: - # Rename the new computer - import_data['name'] = (orig_name + DUPL_SUFFIX.format(dupl_counter)) - dupl = ( - model.objects.filter(name=import_data['name']) or - import_data['name'] in imported_comp_names - ) - dupl_counter += 1 - if dupl_counter == 100: - raise exceptions.ImportUniquenessError( - 'A computer of that name ( {} ) already exists and I could not create a new one' - ''.format(orig_name) - ) - - imported_comp_names.add(import_data['name']) - objects_to_create.append(model(**import_data)) import_new_entry_pks[unique_id] = import_entry_pk if model_name == NODE_ENTITY_NAME: - if not silent: - print('STORING NEW NODE REPOSITORY FILES...') + IMPORT_LOGGER.debug('STORING NEW NODE REPOSITORY FILES...') # NEW NODES for object_ in objects_to_create: import_entry_uuid = object_.uuid import_entry_pk = import_new_entry_pks[import_entry_uuid] + # Progress bar initialization - Node + progress_bar.update() + pbar_node_base_str = pbar_base_str + 'UUID={} - '.format(import_entry_uuid.split('-')[0]) + # Before storing entries in the DB, I store the files (if these are nodes). # Note: only for new entries! subfolder = folder.get_subfolder( @@ -397,11 +449,12 @@ def import_data_dj( destdir = RepositoryFolder(section=Repository._section_name, uuid=import_entry_uuid) # Replace the folder, possibly destroying existing previous folders, and move the files # (faster if we are on the same filesystem, and in any case the source is a SandboxFolder) + progress_bar.set_description_str(pbar_node_base_str + 'Repository', refresh=True) destdir.replace_with_folder(subfolder.abspath, move=True, overwrite=True) # For DbNodes, we also have to store its attributes - if not silent: - print('STORING NEW NODE ATTRIBUTES...') + IMPORT_LOGGER.debug('STORING NEW NODE ATTRIBUTES...') + progress_bar.set_description_str(pbar_node_base_str + 'Attributes', refresh=True) # Get attributes from import file try: @@ -413,8 +466,8 @@ def import_data_dj( # For DbNodes, we also have to store its extras if extras_mode_new == 'import': - if not silent: - print('STORING NEW NODE EXTRAS...') + IMPORT_LOGGER.debug('STORING NEW NODE EXTRAS...') + progress_bar.set_description_str(pbar_node_base_str + 'Extras', refresh=True) # Get extras from import file try: @@ -431,8 +484,7 @@ def import_data_dj( # till here object_.extras = extras elif extras_mode_new == 'none': - if not silent: - print('SKIPPING NEW NODE EXTRAS...') + IMPORT_LOGGER.debug('SKIPPING NEW NODE EXTRAS...') else: raise exceptions.ImportValidationError( "Unknown extras_mode_new value: {}, should be either 'import' or 'none'" @@ -441,8 +493,7 @@ def import_data_dj( # EXISTING NODES (Extras) # For the existing nodes that are also in the imported list we also update their extras if necessary - if not silent: - print('UPDATING EXISTING NODE EXTRAS (mode: {})'.format(extras_mode_existing)) + IMPORT_LOGGER.debug('UPDATING EXISTING NODE EXTRAS...') import_existing_entry_pks = { entry_data[unique_identifier]: import_entry_pk @@ -452,24 +503,38 @@ def import_data_dj( import_entry_uuid = str(node.uuid) import_entry_pk = import_existing_entry_pks[import_entry_uuid] + # Progress bar initialization - Node + pbar_node_base_str = pbar_base_str + 'UUID={} - '.format(import_entry_uuid.split('-')[0]) + progress_bar.set_description_str(pbar_node_base_str + 'Extras', refresh=False) + progress_bar.update() + # Get extras from import file try: extras = data['node_extras'][str(import_entry_pk)] except KeyError: raise exceptions.CorruptArchive( - 'Unable to find extra info for ode with UUID={}'.format(import_entry_uuid) + 'Unable to find extra info for Node with UUID={}'.format(import_entry_uuid) ) + old_extras = node.extras.copy() # TODO: remove when aiida extras will be moved somewhere else # from here extras = {key: value for key, value in extras.items() if not key.startswith('_aiida_')} if node.node_type.endswith('code.Code.'): extras = {key: value for key, value in extras.items() if not key == 'hidden'} # till here - node.extras = merge_extras(node.extras, extras, extras_mode_existing) + new_extras = merge_extras(node.extras, extras, extras_mode_existing) + + if new_extras != old_extras: + # Already saving existing node here to update its extras + node.extras = new_extras + node.save() - # Already saving existing node here to update its extras - node.save() + else: + # Update progress bar with new non-Node entries + progress_bar.update(n=len(existing_entries[model_name]) + len(new_entries[model_name])) + + progress_bar.set_description_str(pbar_base_str + 'Storing', refresh=True) # If there is an mtime in the field, disable the automatic update # to keep the mtime that we have set here @@ -498,11 +563,9 @@ def import_data_dj( ret_dict[model_name] = {'new': [], 'existing': []} ret_dict[model_name]['new'].append((import_entry_pk, new_pk)) - if not silent: - print('NEW %s: %s (%s->%s)' % (model_name, unique_id, import_entry_pk, new_pk)) + IMPORT_LOGGER.debug('New %s: %s (%s->%s)' % (model_name, unique_id, import_entry_pk, new_pk)) - if not silent: - print('STORING NODE LINKS...') + IMPORT_LOGGER.debug('STORING NODE LINKS...') import_links = data['links_uuid'] links_to_store = [] @@ -527,8 +590,15 @@ def import_data_dj( LinkType.RETURN: (workflow_node_types, data_node_types, 'unique_pair', 'unique_triple'), } + if import_links: + progress_bar = get_progress_bar(total=len(import_links), disable=silent) + pbar_base_str = 'Links - ' + for link in import_links: # Check for dangling Links within the, supposed, self-consistent archive + progress_bar.set_description_str(pbar_base_str + 'label={}'.format(link['label']), refresh=False) + progress_bar.update() + try: in_id = foreign_ids_reverse_mappings[NODE_ENTITY_NAME][link['input']] out_id = foreign_ids_reverse_mappings[NODE_ENTITY_NAME][link['output']] @@ -627,20 +697,27 @@ def import_data_dj( # Store new links if links_to_store: - if not silent: - print(' ({} new links...)'.format(len(links_to_store))) + IMPORT_LOGGER.debug(' (%d new links...)', len(links_to_store)) models.DbLink.objects.bulk_create(links_to_store, batch_size=batch_size) else: - if not silent: - print(' (0 new links...)') + IMPORT_LOGGER.debug(' (0 new links...)') + + IMPORT_LOGGER.debug('STORING GROUP ELEMENTS...') - if not silent: - print('STORING GROUP ELEMENTS...') import_groups = data['groups_uuid'] + + if import_groups: + progress_bar = get_progress_bar(total=len(import_groups), disable=silent) + pbar_base_str = 'Groups - ' + for groupuuid, groupnodes in import_groups.items(): # TODO: cache these to avoid too many queries group_ = models.DbGroup.objects.get(uuid=groupuuid) + + progress_bar.set_description_str(pbar_base_str + 'label={}'.format(group_.label), refresh=False) + progress_bar.update() + nodes_to_store = [foreign_ids_reverse_mappings[NODE_ENTITY_NAME][node_uuid] for node_uuid in groupnodes] if nodes_to_store: group_.dbnodes.add(*nodes_to_store) @@ -676,17 +753,32 @@ def import_data_dj( group = ImportGroup(label=group_label).store() # Add all the nodes to the new group - # TODO: decide if we want to return the group label - nodes = QueryBuilder().append(Node, filters={'id': {'in': pks_for_group}}).all(flat=True) + builder = QueryBuilder().append(Node, filters={'id': {'in': pks_for_group}}) + + progress_bar = get_progress_bar(total=len(pks_for_group), disable=silent) + progress_bar.set_description_str('Creating import Group - Preprocessing', refresh=True) + first = True + + nodes = [] + for entry in builder.iterall(): + if first: + progress_bar.set_description_str('Creating import Group', refresh=False) + first = False + progress_bar.update() + nodes.append(entry[0]) group.add_nodes(nodes) - - if not silent: - print("IMPORTED NODES ARE GROUPED IN THE IMPORT GROUP LABELED '{}'".format(group.label)) + progress_bar.set_description_str('Done (cleaning up)', refresh=True) else: - if not silent: - print('NO NODES TO IMPORT, SO NO GROUP CREATED, IF IT DID NOT ALREADY EXIST') + IMPORT_LOGGER.debug('No Nodes to import, so no Group created, if it did not already exist') + + # Finalize Progress bar + close_progress_bar(leave=False) + + # Summarize import + result_summary(ret_dict, getattr(group, 'label', None)) - if not silent: - print('DONE.') + # Reset logging level + if silent: + logging.disable(level=logging.NOTSET) return ret_dict diff --git a/aiida/tools/importexport/dbimport/backends/sqla/__init__.py b/aiida/tools/importexport/dbimport/backends/sqla/__init__.py index 2e800b1361..134e4a4a60 100644 --- a/aiida/tools/importexport/dbimport/backends/sqla/__init__.py +++ b/aiida/tools/importexport/dbimport/backends/sqla/__init__.py @@ -11,6 +11,7 @@ """ SQLAlchemy-specific import of AiiDA entities """ from distutils.version import StrictVersion +import logging import os import tarfile import zipfile @@ -19,14 +20,15 @@ from aiida.common import timezone, json from aiida.common.folders import SandboxFolder, RepositoryFolder from aiida.common.links import LinkType +from aiida.common.log import override_log_formatter from aiida.common.utils import get_object_from_string from aiida.orm import QueryBuilder, Node, Group, ImportGroup from aiida.orm.utils.links import link_triple_exists, validate_link from aiida.orm.utils.repository import Repository -from aiida.tools.importexport.common import exceptions +from aiida.tools.importexport.common import exceptions, get_progress_bar, close_progress_bar from aiida.tools.importexport.common.archive import extract_tree, extract_tar, extract_zip -from aiida.tools.importexport.common.config import DUPL_SUFFIX, EXPORT_VERSION, NODES_EXPORT_SUBFOLDER +from aiida.tools.importexport.common.config import DUPL_SUFFIX, EXPORT_VERSION, NODES_EXPORT_SUBFOLDER, BAR_FORMAT from aiida.tools.importexport.common.config import ( NODE_ENTITY_NAME, GROUP_ENTITY_NAME, COMPUTER_ENTITY_NAME, USER_ENTITY_NAME, LOG_ENTITY_NAME, COMMENT_ENTITY_NAME ) @@ -35,10 +37,13 @@ entity_names_to_entities ) from aiida.tools.importexport.common.utils import export_shard_uuid -from aiida.tools.importexport.dbimport.backends.utils import deserialize_field, merge_comment, merge_extras +from aiida.tools.importexport.dbimport.utils import ( + deserialize_field, merge_comment, merge_extras, start_summary, result_summary, IMPORT_LOGGER +) from aiida.tools.importexport.dbimport.backends.sqla.utils import validate_uuid +@override_log_formatter('%(message)s') def import_data_sqla( in_path, group=None, @@ -46,7 +51,8 @@ def import_data_sqla( extras_mode_existing='kcl', extras_mode_new='import', comment_mode='newest', - silent=False + silent=False, + **kwargs ): """Import exported AiiDA archive to the AiiDA database and repository. @@ -86,7 +92,7 @@ def import_data_sqla( 'overwrite' (will overwrite existing Comments with the ones from the import file). :type comment_mode: str - :param silent: suppress prints. + :param silent: suppress progress bar and summary. :type silent: bool :return: New and existing Nodes and Links. @@ -119,6 +125,9 @@ def import_data_sqla( elif not group.is_stored: group.store() + if silent: + logging.disable(level=logging.CRITICAL) + ################ # EXTRACT DATA # ################ @@ -128,21 +137,23 @@ def import_data_sqla( extract_tree(in_path, folder) else: if tarfile.is_tarfile(in_path): - extract_tar(in_path, folder, silent=silent, nodes_export_subfolder=NODES_EXPORT_SUBFOLDER) + extract_tar(in_path, folder, silent=silent, nodes_export_subfolder=NODES_EXPORT_SUBFOLDER, **kwargs) elif zipfile.is_zipfile(in_path): - extract_zip(in_path, folder, silent=silent, nodes_export_subfolder=NODES_EXPORT_SUBFOLDER) + extract_zip(in_path, folder, silent=silent, nodes_export_subfolder=NODES_EXPORT_SUBFOLDER, **kwargs) else: raise exceptions.ImportValidationError( 'Unable to detect the input file format, it is neither a ' - '(possibly compressed) tar file, nor a zip file.' + 'tar file, nor a (possibly compressed) zip file.' ) if not folder.get_content_list(): raise exceptions.CorruptArchive('The provided file/folder ({}) is empty'.format(in_path)) try: + IMPORT_LOGGER.debug('CACHING metadata.json') with open(folder.get_abs_path('metadata.json'), encoding='utf8') as fhandle: metadata = json.load(fhandle) + IMPORT_LOGGER.debug('CACHING data.json') with open(folder.get_abs_path('data.json'), encoding='utf8') as fhandle: data = json.load(fhandle) except IOError as error: @@ -164,10 +175,14 @@ def import_data_sqla( raise exceptions.IncompatibleArchiveVersionError(msg) + start_summary(in_path, comment_mode, extras_mode_new, extras_mode_existing) + ################################################################### # CREATE UUID REVERSE TABLES AND CHECK IF # # I HAVE ALL NODES FOR THE LINKS # ################################################################### + IMPORT_LOGGER.debug('CHECKING IF NODES FROM LINKS ARE IN DB OR ARCHIVE...') + linked_nodes = set(chain.from_iterable((l['input'], l['output']) for l in data['links_uuid'])) group_nodes = set(chain.from_iterable(data['groups_uuid'].values())) @@ -192,25 +207,21 @@ def import_data_sqla( # DOUBLE-CHECK MODEL DEPENDENCIES # ################################### # The entity import order. It is defined by the database model relationships. - entity_sig_order = [ - entity_names_to_signatures[m] for m in ( - USER_ENTITY_NAME, COMPUTER_ENTITY_NAME, NODE_ENTITY_NAME, GROUP_ENTITY_NAME, LOG_ENTITY_NAME, - COMMENT_ENTITY_NAME - ) + entity_order = [ + USER_ENTITY_NAME, COMPUTER_ENTITY_NAME, NODE_ENTITY_NAME, GROUP_ENTITY_NAME, LOG_ENTITY_NAME, + COMMENT_ENTITY_NAME ] # I make a new list that contains the entity names: # eg: ['User', 'Computer', 'Node', 'Group'] - all_entity_names = [signatures_to_entity_names[entity_sig] for entity_sig in entity_sig_order] for import_field_name in metadata['all_fields_info']: - if import_field_name not in all_entity_names: + if import_field_name not in entity_order: raise exceptions.ImportValidationError( "You are trying to import an unknown model '{}'!".format(import_field_name) ) - for idx, entity_sig in enumerate(entity_sig_order): + for idx, entity_name in enumerate(entity_order): dependencies = [] - entity_name = signatures_to_entity_names[entity_sig] # for every field, I checked the dependencies given as value for key requires for field in metadata['all_fields_info'][entity_name].values(): try: @@ -219,9 +230,9 @@ def import_data_sqla( # (No ForeignKey) pass for dependency in dependencies: - if dependency not in all_entity_names[:idx]: + if dependency not in entity_order[:idx]: raise exceptions.ArchiveImportError( - 'Entity {} requires {} but would be loaded first; stopping...'.format(entity_sig, dependency) + 'Entity {} requires {} but would be loaded first; stopping...'.format(entity_name, dependency) ) ################################################### @@ -235,6 +246,7 @@ def import_data_sqla( # 2363: 'ef04aa5d-99e7-4bfd-95ef-fe412a6a3524', 2364: '1dc59576-af21-4d71-81c2-bac1fc82a84a'}, # 'User': {1: 'aiida@localhost'} # } + IMPORT_LOGGER.debug('CREATING PK-2-UUID/EMAIL MAPPING...') import_unique_ids_mappings = {} # Export data since v0.3 contains the keys entity_name for entity_name, import_data in data['export_data'].items(): @@ -257,9 +269,18 @@ def import_data_sqla( new_entries = {} existing_entries = {} + IMPORT_LOGGER.debug('GENERATING LIST OF DATA...') + + # Instantiate progress bar + progress_bar = get_progress_bar(total=1, leave=False, disable=silent) + pbar_base_str = 'Generating list of data - ' + + # Get total entities from data.json + # To be used with progress bar + number_of_entities = 0 + # I first generate the list of data - for entity_sig in entity_sig_order: - entity_name = signatures_to_entity_names[entity_sig] + for entity_name in entity_order: entity = entity_names_to_entities[entity_name] # I get the unique identifier, since v0.3 stored under entity_name unique_identifier = metadata['unique_identifiers'].get(entity_name, None) @@ -272,29 +293,32 @@ def import_data_sqla( # Not necessarily all models are exported if entity_name in data['export_data']: + IMPORT_LOGGER.debug(' %s...', entity_name) + + progress_bar.set_description_str(pbar_base_str + entity_name, refresh=False) + number_of_entities += len(data['export_data'][entity_name]) + if unique_identifier is not None: import_unique_ids = set(v[unique_identifier] for v in data['export_data'][entity_name].values()) - relevant_db_entries = dict() + relevant_db_entries = {} if import_unique_ids: builder = QueryBuilder() - builder.append( - entity, - filters={unique_identifier: { - 'in': import_unique_ids - }}, - project=['*'], - tag='res' - ) - relevant_db_entries = { - str(getattr(v[0], unique_identifier)): # str() to convert UUID() to string - v[0] for v in builder.all() - } + builder.append(entity, filters={unique_identifier: {'in': import_unique_ids}}, project='*') + + if builder.count(): + progress_bar = get_progress_bar(total=builder.count(), disable=silent) + for object_ in builder.iterall(): + progress_bar.update() + + relevant_db_entries.update({getattr(object_[0], unique_identifier): object_[0]}) foreign_ids_reverse_mappings[entity_name] = { k: v.pk for k, v in relevant_db_entries.items() } + IMPORT_LOGGER.debug(' GOING THROUGH ARCHIVE...') + imported_comp_names = set() for key, value in data['export_data'][entity_name].items(): if entity_name == GROUP_ENTITY_NAME: @@ -331,19 +355,19 @@ def import_data_sqla( '==': value['name'] }}, project=['*'], tag='res' ) - dupl = (builder.count() or value['name'] in imported_comp_names) + dupl = builder.count() or value['name'] in imported_comp_names dupl_counter = 0 orig_name = value['name'] while dupl: # Rename the new computer - value['name'] = (orig_name + DUPL_SUFFIX.format(dupl_counter)) + value['name'] = orig_name + DUPL_SUFFIX.format(dupl_counter) builder = QueryBuilder() builder.append( entity, filters={'name': { '==': value['name'] }}, project=['*'], tag='res' ) - dupl = (builder.count() or value['name'] in imported_comp_names) + dupl = builder.count() or value['name'] in imported_comp_names dupl_counter += 1 if dupl_counter == 100: raise exceptions.ImportUniquenessError( @@ -361,21 +385,33 @@ def import_data_sqla( # To be added new_entries[entity_name][key] = value else: - # Why the copy: - new_entries[entity_name] = data['export_data'][entity_name].copy() + new_entries[entity_name] = data['export_data'][entity_name] - # Show Comment mode if not silent - if not silent: - print('Comment mode: {}'.format(comment_mode)) + # Progress bar - reset for import + progress_bar = get_progress_bar(total=number_of_entities, disable=silent) + reset_progress_bar = {} # I import data from the given model - for entity_sig in entity_sig_order: - entity_name = signatures_to_entity_names[entity_sig] + for entity_name in entity_order: entity = entity_names_to_entities[entity_name] fields_info = metadata['all_fields_info'].get(entity_name, {}) unique_identifier = metadata['unique_identifiers'].get(entity_name, '') + # Progress bar initialization - Model + if reset_progress_bar: + progress_bar = get_progress_bar(total=reset_progress_bar['total'], disable=silent) + progress_bar.n = reset_progress_bar['n'] + reset_progress_bar = {} + pbar_base_str = '{}s - '.format(entity_name) + progress_bar.set_description_str(pbar_base_str + 'Initializing', refresh=True) + # EXISTING ENTRIES + if existing_entries[entity_name]: + # Progress bar update - Model + progress_bar.set_description_str( + pbar_base_str + '{} existing entries'.format(len(existing_entries[entity_name])), refresh=True + ) + for import_entry_pk, entry_data in existing_entries[entity_name].items(): unique_id = entry_data[unique_identifier] existing_entry_pk = foreign_ids_reverse_mappings[entity_name][unique_id] @@ -390,7 +426,7 @@ def import_data_sqla( ) # TODO COMPARE, AND COMPARE ATTRIBUTES - if entity_sig is entity_names_to_signatures[COMMENT_ENTITY_NAME]: + if entity_name == COMMENT_ENTITY_NAME: new_entry_uuid = merge_comment(import_data, comment_mode) if new_entry_uuid is not None: entry_data[unique_identifier] = new_entry_uuid @@ -399,8 +435,9 @@ def import_data_sqla( if entity_name not in ret_dict: ret_dict[entity_name] = {'new': [], 'existing': []} ret_dict[entity_name]['existing'].append((import_entry_pk, existing_entry_pk)) - if not silent: - print('existing %s: %s (%s->%s)' % (entity_sig, unique_id, import_entry_pk, existing_entry_pk)) + IMPORT_LOGGER.debug( + 'Existing %s: %s (%s->%s)', entity_name, unique_id, import_entry_pk, existing_entry_pk + ) # Store all objects for this model in a list, and store them # all in once at the end. @@ -411,6 +448,12 @@ def import_data_sqla( import_new_entry_pks = dict() # NEW ENTRIES + if new_entries[entity_name]: + # Progress bar update - Model + progress_bar.set_description_str( + pbar_base_str + '{} new entries'.format(len(new_entries[entity_name])), refresh=True + ) + for import_entry_pk, entry_data in new_entries[entity_name].items(): unique_id = entry_data[unique_identifier] import_data = dict( @@ -451,15 +494,18 @@ def import_data_sqla( objects_to_create.append(db_entity(**import_data)) import_new_entry_pks[unique_id] = import_entry_pk - if entity_sig == entity_names_to_signatures[NODE_ENTITY_NAME]: - if not silent: - print('STORING NEW NODE REPOSITORY FILES & ATTRIBUTES...') + if entity_name == NODE_ENTITY_NAME: + IMPORT_LOGGER.debug('STORING NEW NODE REPOSITORY FILES & ATTRIBUTES...') # NEW NODES for object_ in objects_to_create: import_entry_uuid = object_.uuid import_entry_pk = import_new_entry_pks[import_entry_uuid] + # Progress bar initialization - Node + progress_bar.update() + pbar_node_base_str = pbar_base_str + 'UUID={} - '.format(import_entry_uuid.split('-')[0]) + # Before storing entries in the DB, I store the files (if these are nodes). # Note: only for new entries! subfolder = folder.get_subfolder( @@ -473,9 +519,13 @@ def import_data_sqla( destdir = RepositoryFolder(section=Repository._section_name, uuid=import_entry_uuid) # Replace the folder, possibly destroying existing previous folders, and move the files # (faster if we are on the same filesystem, and in any case the source is a SandboxFolder) + progress_bar.set_description_str(pbar_node_base_str + 'Repository', refresh=True) destdir.replace_with_folder(subfolder.abspath, move=True, overwrite=True) # For Nodes, we also have to store Attributes! + IMPORT_LOGGER.debug('STORING NEW NODE ATTRIBUTES...') + progress_bar.set_description_str(pbar_node_base_str + 'Attributes', refresh=True) + # Get attributes from import file try: object_.attributes = data['node_attributes'][str(import_entry_pk)] @@ -485,10 +535,11 @@ def import_data_sqla( ) # For DbNodes, we also have to store extras - # Get extras from import file if extras_mode_new == 'import': - if not silent: - print('STORING NEW NODE EXTRAS...') + IMPORT_LOGGER.debug('STORING NEW NODE EXTRAS...') + progress_bar.set_description_str(pbar_node_base_str + 'Extras', refresh=True) + + # Get extras from import file try: extras = data['node_extras'][str(import_entry_pk)] except KeyError: @@ -503,8 +554,7 @@ def import_data_sqla( # till here object_.extras = extras elif extras_mode_new == 'none': - if not silent: - print('SKIPPING NEW NODE EXTRAS...') + IMPORT_LOGGER.debug('SKIPPING NEW NODE EXTRAS...') else: raise exceptions.ImportValidationError( "Unknown extras_mode_new value: {}, should be either 'import' or 'none'" @@ -512,8 +562,7 @@ def import_data_sqla( ) # EXISTING NODES (Extras) - if not silent: - print('UPDATING EXISTING NODE EXTRAS (mode: {})'.format(extras_mode_existing)) + IMPORT_LOGGER.debug('UPDATING EXISTING NODE EXTRAS...') import_existing_entry_pks = { entry_data[unique_identifier]: import_entry_pk @@ -523,6 +572,11 @@ def import_data_sqla( import_entry_uuid = str(node.uuid) import_entry_pk = import_existing_entry_pks[import_entry_uuid] + # Progress bar initialization - Node + pbar_node_base_str = pbar_base_str + 'UUID={} - '.format(import_entry_uuid.split('-')[0]) + progress_bar.set_description_str(pbar_node_base_str + 'Extras', refresh=False) + progress_bar.update() + # Get extras from import file try: extras = data['node_extras'][str(import_entry_pk)] @@ -531,15 +585,24 @@ def import_data_sqla( 'Unable to find extra info for Node with UUID={}'.format(import_entry_uuid) ) + old_extras = node.extras.copy() # TODO: remove when aiida extras will be moved somewhere else # from here extras = {key: value for key, value in extras.items() if not key.startswith('_aiida_')} if node.node_type.endswith('code.Code.'): extras = {key: value for key, value in extras.items() if not key == 'hidden'} # till here - node.extras = merge_extras(node.extras, extras, extras_mode_existing) - flag_modified(node, 'extras') - objects_to_update.append(node) + new_extras = merge_extras(node.extras, extras, extras_mode_existing) + if new_extras != old_extras: + node.extras = new_extras + flag_modified(node, 'extras') + objects_to_update.append(node) + + else: + # Update progress bar with new non-Node entries + progress_bar.update(n=len(existing_entries[entity_name]) + len(new_entries[entity_name])) + + progress_bar.set_description_str(pbar_base_str + 'Storing', refresh=True) # Store them all in once; However, the PK are not set in this way... if objects_to_create: @@ -549,19 +612,26 @@ def import_data_sqla( session.flush() + just_saved = {} if import_new_entry_pks.keys(): + reset_progress_bar = {'total': progress_bar.total, 'n': progress_bar.n} + progress_bar = get_progress_bar(total=len(import_new_entry_pks), disable=silent) + builder = QueryBuilder() builder.append( entity, filters={unique_identifier: { 'in': list(import_new_entry_pks.keys()) }}, - project=[unique_identifier, 'id'], - tag='res' + project=[unique_identifier, 'id'] ) - just_saved = {v[0]: v[1] for v in builder.all()} - else: - just_saved = dict() + + for entry in builder.iterall(): + progress_bar.update() + + just_saved.update({entry[0]: entry[1]}) + + progress_bar.set_description_str(pbar_base_str + 'Done!', refresh=True) # Now I have the PKs, print the info # Moreover, add newly created Nodes to foreign_ids_reverse_mappings @@ -575,16 +645,21 @@ def import_data_sqla( ret_dict[entity_name] = {'new': [], 'existing': []} ret_dict[entity_name]['new'].append((import_entry_pk, new_pk)) - if not silent: - print('NEW %s: %s (%s->%s)' % (entity_sig, unique_id, import_entry_pk, new_pk)) + IMPORT_LOGGER.debug('N %s: %s (%s->%s)', entity_name, unique_id, import_entry_pk, new_pk) - if not silent: - print('STORING NODE LINKS...') + IMPORT_LOGGER.debug('STORING NODE LINKS...') import_links = data['links_uuid'] + if import_links: + progress_bar = get_progress_bar(total=len(import_links), disable=silent) + pbar_base_str = 'Links - ' + for link in import_links: # Check for dangling Links within the, supposed, self-consistent archive + progress_bar.set_description_str(pbar_base_str + 'label={}'.format(link['label']), refresh=False) + progress_bar.update() + try: in_id = foreign_ids_reverse_mappings[NODE_ENTITY_NAME][link['input']] out_id = foreign_ids_reverse_mappings[NODE_ENTITY_NAME][link['output']] @@ -617,16 +692,24 @@ def import_data_sqla( ret_dict['Link'] = {'new': []} ret_dict['Link']['new'].append((in_id, out_id)) - if not silent: - print(' ({} new links...)'.format(len(ret_dict.get('Link', {}).get('new', [])))) + IMPORT_LOGGER.debug(' (%d new links...)', len(ret_dict.get('Link', {}).get('new', []))) + + IMPORT_LOGGER.debug('STORING GROUP ELEMENTS...') - if not silent: - print('STORING GROUP ELEMENTS...') import_groups = data['groups_uuid'] + + if import_groups: + progress_bar = get_progress_bar(total=len(import_groups), disable=silent) + pbar_base_str = 'Groups - ' + for groupuuid, groupnodes in import_groups.items(): # # TODO: cache these to avoid too many queries qb_group = QueryBuilder().append(Group, filters={'uuid': {'==': groupuuid}}) group_ = qb_group.first()[0] + + progress_bar.set_description_str(pbar_base_str + 'label={}'.format(group_.label), refresh=False) + progress_bar.update() + nodes_ids_to_add = [ foreign_ids_reverse_mappings[NODE_ENTITY_NAME][node_uuid] for node_uuid in groupnodes ] @@ -668,31 +751,45 @@ def import_data_sqla( session.add(group.backend_entity._dbmodel) # Adding nodes to group avoiding the SQLA ORM to increase speed - nodes = [ - entry[0].backend_entity - for entry in QueryBuilder().append(Node, filters={ - 'id': { - 'in': pks_for_group - } - }).all() - ] + builder = QueryBuilder().append(Node, filters={'id': {'in': pks_for_group}}) + + progress_bar = get_progress_bar(total=len(pks_for_group), disable=silent) + progress_bar.set_description_str('Creating import Group - Preprocessing', refresh=True) + first = True + + nodes = [] + for entry in builder.iterall(): + if first: + progress_bar.set_description_str('Creating import Group', refresh=False) + first = False + progress_bar.update() + nodes.append(entry[0].backend_entity) group.backend_entity.add_nodes(nodes, skip_orm=True) - if not silent: - print("IMPORTED NODES ARE GROUPED IN THE IMPORT GROUP LABELED '{}'".format(group.label)) + progress_bar.set_description_str('Done (cleaning up)', refresh=True) else: - if not silent: - print('NO NODES TO IMPORT, SO NO GROUP CREATED, IF IT DID NOT ALREADY EXIST') + IMPORT_LOGGER.debug('No Nodes to import, so no Group created, if it did not already exist') - if not silent: - print('COMMITTING EVERYTHING...') + IMPORT_LOGGER.debug('COMMITTING EVERYTHING...') session.commit() + + # Finalize Progress bar + close_progress_bar(leave=False) + + # Summarize import + result_summary(ret_dict, getattr(group, 'label', None)) + except: - if not silent: - print('Rolling back') + # Finalize Progress bar + close_progress_bar(leave=False) + + result_summary({}, None) + + IMPORT_LOGGER.debug('Rolling back') session.rollback() raise - if not silent: - print('DONE.') + # Reset logging level + if silent: + logging.disable(level=logging.NOTSET) return ret_dict diff --git a/aiida/tools/importexport/dbimport/backends/utils.py b/aiida/tools/importexport/dbimport/utils.py similarity index 86% rename from aiida/tools/importexport/dbimport/backends/utils.py rename to aiida/tools/importexport/dbimport/utils.py index cf56c76102..faaac81e1c 100644 --- a/aiida/tools/importexport/dbimport/backends/utils.py +++ b/aiida/tools/importexport/dbimport/utils.py @@ -9,12 +9,19 @@ ########################################################################### """ Utility functions for import of AiiDA entities """ # pylint: disable=inconsistent-return-statements,too-many-branches +import os + import click +from tabulate import tabulate -from aiida.orm import QueryBuilder, Comment +from aiida.common.log import AIIDA_LOGGER, LOG_LEVEL_REPORT from aiida.common.utils import get_new_uuid +from aiida.orm import QueryBuilder, Comment + from aiida.tools.importexport.common import exceptions +IMPORT_LOGGER = AIIDA_LOGGER.getChild('import') + def merge_comment(incoming_comment, comment_mode): """ Merge comment according comment_mode @@ -241,3 +248,42 @@ def deserialize_field(key, value, fields_info, import_unique_ids_mappings, forei # else return ('{}_id'.format(key), None) + + +def start_summary(archive, comment_mode, extras_mode_new, extras_mode_existing): + """Print starting summary for import""" + archive = os.path.basename(archive) + result = '\n{}'.format(tabulate([['Archive', archive]], headers=['IMPORT', ''])) + + parameters = [ + ['Comment rules', comment_mode], + ['New Node Extras rules', extras_mode_new], + ['Existing Node Extras rules', extras_mode_existing], + ] + result += '\n\n{}'.format(tabulate(parameters, headers=['Parameters', ''])) + + IMPORT_LOGGER.log(msg=result, level=LOG_LEVEL_REPORT) + + +def result_summary(results, import_group_label): + """Summarize import results""" + title = None + + if results or import_group_label: + parameters = [] + title = 'Summary' + + if import_group_label: + parameters.append(['Auto-import Group label', import_group_label]) + + for model in results: + value = [] + if results[model].get('new', None): + value.append('{} new'.format(len(results[model]['new']))) + if results[model].get('existing', None): + value.append('{} existing'.format(len(results[model]['existing']))) + + parameters.extend([[param, val] for param, val in zip(['{}(s)'.format(model)], value)]) + + if title: + IMPORT_LOGGER.log(msg='\n{}\n'.format(tabulate(parameters, headers=[title, ''])), level=LOG_LEVEL_REPORT) diff --git a/docs/requirements_for_rtd.txt b/docs/requirements_for_rtd.txt index f518609215..6035209bfd 100644 --- a/docs/requirements_for_rtd.txt +++ b/docs/requirements_for_rtd.txt @@ -53,6 +53,7 @@ sqlalchemy-utils~=0.34.2 sqlalchemy>=1.3.10,~=1.3 tabulate~=0.8.5 tornado<5.0 +tqdm~=4.45 tzlocal~=2.0 upf_to_json~=0.9.2 wrapt~=1.11.1 \ No newline at end of file diff --git a/docs/source/verdi/verdi_user_guide.rst b/docs/source/verdi/verdi_user_guide.rst index c11b280b24..2db1f58e04 100644 --- a/docs/source/verdi/verdi_user_guide.rst +++ b/docs/source/verdi/verdi_user_guide.rst @@ -699,7 +699,6 @@ Below is a list with all available subcommands. -H, --hostname HOSTNAME Hostname. -P, --port INTEGER Port number. -c, --config-dir PATH Path to the configuration directory - --debug Enable debugging --wsgi-profile Whether to enable WSGI profiler middleware for finding bottlenecks diff --git a/environment.yml b/environment.yml index e63cbefaa5..e9e7c7bd65 100644 --- a/environment.yml +++ b/environment.yml @@ -35,6 +35,7 @@ dependencies: - sqlalchemy>=1.3.10,~=1.3 - tabulate~=0.8.5 - tornado<5.0 +- tqdm~=4.45 - tzlocal~=2.0 - upf_to_json~=0.9.2 - wrapt~=1.11.1 diff --git a/requirements/requirements-py-3.5.txt b/requirements/requirements-py-3.5.txt index 96a42ecec3..6a0064d39f 100644 --- a/requirements/requirements-py-3.5.txt +++ b/requirements/requirements-py-3.5.txt @@ -140,6 +140,7 @@ terminado==0.8.3 testpath==0.4.4 topika==0.2.1 tornado==4.5.3 +tqdm==4.45.0 traitlets==4.3.3 tzlocal==2.0.0 upf-to-json==0.9.2 diff --git a/requirements/requirements-py-3.6.txt b/requirements/requirements-py-3.6.txt index c4e942b38e..5e4e7e04c7 100644 --- a/requirements/requirements-py-3.6.txt +++ b/requirements/requirements-py-3.6.txt @@ -139,6 +139,7 @@ terminado==0.8.3 testpath==0.4.4 topika==0.2.1 tornado==4.5.3 +tqdm==4.45.0 traitlets==4.3.3 tzlocal==2.0.0 upf-to-json==0.9.2 diff --git a/requirements/requirements-py-3.7.txt b/requirements/requirements-py-3.7.txt index 6a5fd19a03..2f11b6ad66 100644 --- a/requirements/requirements-py-3.7.txt +++ b/requirements/requirements-py-3.7.txt @@ -138,6 +138,7 @@ terminado==0.8.3 testpath==0.4.4 topika==0.2.1 tornado==4.5.3 +tqdm==4.45.0 traitlets==4.3.3 tzlocal==2.0.0 upf-to-json==0.9.2 diff --git a/requirements/requirements-py-3.8.txt b/requirements/requirements-py-3.8.txt index 4603457b4b..611c434ab8 100644 --- a/requirements/requirements-py-3.8.txt +++ b/requirements/requirements-py-3.8.txt @@ -137,6 +137,7 @@ terminado==0.8.3 testpath==0.4.4 topika==0.2.1 tornado==4.5.3 +tqdm==4.45.0 traitlets==4.3.3 tzlocal==2.0.0 upf-to-json==0.9.2 diff --git a/setup.json b/setup.json index 1cd5b6cd06..4eade9dfbf 100644 --- a/setup.json +++ b/setup.json @@ -51,6 +51,7 @@ "sqlalchemy~=1.3,>=1.3.10", "tabulate~=0.8.5", "tornado<5.0", + "tqdm~=4.45", "tzlocal~=2.0", "upf_to_json~=0.9.2", "wrapt~=1.11.1" diff --git a/tests/cmdline/commands/test_export.py b/tests/cmdline/commands/test_export.py index 441dcfae68..2683f745df 100644 --- a/tests/cmdline/commands/test_export.py +++ b/tests/cmdline/commands/test_export.py @@ -10,6 +10,7 @@ """Tests for `verdi export`.""" import errno import os +import shutil import tempfile import tarfile import traceback @@ -69,7 +70,7 @@ def setUpClass(cls, *args, **kwargs): @classmethod def tearDownClass(cls, *args, **kwargs): os.chdir(cls.old_cwd) - os.rmdir(cls.cwd) + shutil.rmtree(cls.cwd, ignore_errors=True) def setUp(self): self.cli_runner = CliRunner() @@ -263,7 +264,7 @@ def test_inspect(self): options = ['--version', filename_input] result = self.cli_runner.invoke(cmd_export.inspect, options) self.assertIsNone(result.exception, result.output) - self.assertEqual(result.output.strip(), version_number) + self.assertEqual(result.output.strip()[-len(version_number):], version_number) def test_inspect_empty_archive(self): """Test the functionality of `verdi export inspect` for an empty archive.""" diff --git a/tests/cmdline/commands/test_import.py b/tests/cmdline/commands/test_import.py index 6ac168ad8f..cad98c783a 100644 --- a/tests/cmdline/commands/test_import.py +++ b/tests/cmdline/commands/test_import.py @@ -8,10 +8,11 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for `verdi import`.""" - from click.testing import CliRunner from click.exceptions import BadParameter +import pytest + from aiida.backends.testbase import AiidaTestCase from aiida.cmdline.commands import cmd_import from aiida.orm import Group @@ -144,15 +145,20 @@ def test_import_make_new_group(self): self.assertFalse(new_group, msg='The Group should not have been created now, but instead when it was imported.') self.assertFalse(group.is_empty, msg='The Group should not be empty.') + @pytest.mark.skip('Due to summary being logged, this can not be checked against `results.output`.') def test_comment_mode(self): """Test toggling comment mode flag""" + import re archives = [get_archive_file(self.newest_archive, filepath=self.archive_path)] - for mode in {'newest', 'overwrite'}: + for mode in ['newest', 'overwrite']: options = ['--comment-mode', mode] + archives result = self.cli_runner.invoke(cmd_import.cmd_import, options) self.assertIsNone(result.exception, result.output) - self.assertIn('Comment mode: {}'.format(mode), result.output) + self.assertTrue( + any([re.fullmatch(r'Comment rules[\s]*{}'.format(mode), line) for line in result.output.split('\n')]), + msg='Mode: {}. Output: {}'.format(mode, result.output) + ) self.assertEqual(result.exit_code, 0, result.output) def test_import_old_local_archives(self): diff --git a/tests/orm/implementation/test_comments.py b/tests/orm/implementation/test_comments.py index 5203a958e8..7bacd6bc1b 100644 --- a/tests/orm/implementation/test_comments.py +++ b/tests/orm/implementation/test_comments.py @@ -96,7 +96,7 @@ def test_creation_with_time(self): Test creation of a BackendComment when passing the mtime and the ctime. The passed ctime and mtime should be respected since it is important for the correct import of nodes at the AiiDA import/export. """ - from aiida.tools.importexport.dbimport.backends.utils import deserialize_attributes + from aiida.tools.importexport.dbimport.utils import deserialize_attributes ctime = deserialize_attributes('2019-02-27T16:20:12.245738', 'date') mtime = deserialize_attributes('2019-02-27T16:27:14.798838', 'date') diff --git a/tests/orm/implementation/test_logs.py b/tests/orm/implementation/test_logs.py index d27f2a71d9..54057cfee4 100644 --- a/tests/orm/implementation/test_logs.py +++ b/tests/orm/implementation/test_logs.py @@ -90,7 +90,7 @@ def test_creation_with_static_time(self): Test creation of a BackendLog when passing the mtime and the ctime. The passed ctime and mtime should be respected since it is important for the correct import of nodes at the AiiDA import/export. """ - from aiida.tools.importexport.dbimport.backends.utils import deserialize_attributes + from aiida.tools.importexport.dbimport.utils import deserialize_attributes time = deserialize_attributes('2019-02-27T16:20:12.245738', 'date') diff --git a/tests/orm/implementation/test_nodes.py b/tests/orm/implementation/test_nodes.py index 87f629b08f..c6086d01f6 100644 --- a/tests/orm/implementation/test_nodes.py +++ b/tests/orm/implementation/test_nodes.py @@ -103,7 +103,7 @@ def test_creation_with_time(self): Test creation of a BackendNode when passing the mtime and the ctime. The passed ctime and mtime should be respected since it is important for the correct import of nodes at the AiiDA import/export. """ - from aiida.tools.importexport.dbimport.backends.utils import deserialize_attributes + from aiida.tools.importexport.dbimport.utils import deserialize_attributes ctime = deserialize_attributes('2019-02-27T16:20:12.245738', 'date') mtime = deserialize_attributes('2019-02-27T16:27:14.798838', 'date') diff --git a/tests/tools/importexport/orm/test_attributes.py b/tests/tools/importexport/orm/test_attributes.py index 34c1389f8f..d0ea7bf19d 100644 --- a/tests/tools/importexport/orm/test_attributes.py +++ b/tests/tools/importexport/orm/test_attributes.py @@ -48,7 +48,7 @@ def test_import_of_attributes(self, temp_dir): # Export self.export_file = os.path.join(temp_dir, 'export.aiida') - export([self.data], outfile=self.export_file, silent=True) + export([self.data], filename=self.export_file, silent=True) # Clean db self.reset_database() diff --git a/tests/tools/importexport/orm/test_calculations.py b/tests/tools/importexport/orm/test_calculations.py index 95d45b4528..74f1007bd4 100644 --- a/tests/tools/importexport/orm/test_calculations.py +++ b/tests/tools/importexport/orm/test_calculations.py @@ -53,8 +53,8 @@ def max_(**kwargs): # These are the uuids that shouldn't be exported since it's a selection. not_wanted_uuids = [v.uuid for v in (b, c, d)] # At this point we export the generated data - filename1 = os.path.join(temp_dir, 'export1.tar.gz') - export([res], outfile=filename1, silent=True, return_backward=True) + filename1 = os.path.join(temp_dir, 'export1.aiida') + export([res], filename=filename1, silent=True, return_backward=True) self.clean_db() self.insert_data() import_data(filename1, silent=True) @@ -90,8 +90,8 @@ def test_workcalculation(self, temp_dir): slave.seal() uuids_values = [(v.uuid, v.value) for v in (output_1,)] - filename1 = os.path.join(temp_dir, 'export1.tar.gz') - export([output_1], outfile=filename1, silent=True) + filename1 = os.path.join(temp_dir, 'export1.aiida') + export([output_1], filename=filename1, silent=True) self.clean_db() self.insert_data() import_data(filename1, silent=True) diff --git a/tests/tools/importexport/orm/test_codes.py b/tests/tools/importexport/orm/test_codes.py index d8f173107b..2e26e4f392 100644 --- a/tests/tools/importexport/orm/test_codes.py +++ b/tests/tools/importexport/orm/test_codes.py @@ -46,8 +46,8 @@ def test_that_solo_code_is_exported_correctly(self, temp_dir): code_uuid = code.uuid - export_file = os.path.join(temp_dir, 'export.tar.gz') - export([code], outfile=export_file, silent=True) + export_file = os.path.join(temp_dir, 'export.aiida') + export([code], filename=export_file, silent=True) self.reset_database() @@ -82,8 +82,8 @@ def test_input_code(self, temp_dir): export_links = get_all_node_links() - export_file = os.path.join(temp_dir, 'export.tar.gz') - export([calc], outfile=export_file, silent=True) + export_file = os.path.join(temp_dir, 'export.aiida') + export([calc], filename=export_file, silent=True) self.reset_database() @@ -119,8 +119,8 @@ def test_solo_code(self, temp_dir): code_uuid = code.uuid - export_file = os.path.join(temp_dir, 'export.tar.gz') - export([code], outfile=export_file, silent=True) + export_file = os.path.join(temp_dir, 'export.aiida') + export([code], filename=export_file, silent=True) self.clean_db() self.insert_data() diff --git a/tests/tools/importexport/orm/test_comments.py b/tests/tools/importexport/orm/test_comments.py index fa623a352e..57e3034f7e 100644 --- a/tests/tools/importexport/orm/test_comments.py +++ b/tests/tools/importexport/orm/test_comments.py @@ -48,8 +48,8 @@ def test_multiple_imports_for_single_node(self, temp_dir): comment_uuids = [c.uuid for c in [comment_one, comment_two]] # Export as "EXISTING" DB - export_file_existing = os.path.join(temp_dir, 'export_EXISTING.tar.gz') - export([node], outfile=export_file_existing, silent=True) + export_file_existing = os.path.join(temp_dir, 'export_EXISTING.aiida') + export([node], filename=export_file_existing, silent=True) # Add 2 more Comments and save UUIDs prior to export comment_three = orm.Comment(node, user, self.comments[2]).store() @@ -57,8 +57,8 @@ def test_multiple_imports_for_single_node(self, temp_dir): comment_uuids += [c.uuid for c in [comment_three, comment_four]] # Export as "FULL" DB - export_file_full = os.path.join(temp_dir, 'export_FULL.tar.gz') - export([node], outfile=export_file_full, silent=True) + export_file_full = os.path.join(temp_dir, 'export_FULL.aiida') + export([node], filename=export_file_full, silent=True) # Clean database and reimport "EXISTING" DB self.reset_database() @@ -124,8 +124,8 @@ def test_exclude_comments_flag(self, temp_dir): self.assertEqual(node.user.email, users_email[0]) # Export nodes, excluding comments - export_file = os.path.join(temp_dir, 'export.tar.gz') - export([node], outfile=export_file, silent=True, include_comments=False) + export_file = os.path.join(temp_dir, 'export.aiida') + export([node], filename=export_file, silent=True, include_comments=False) # Clean database and reimport exported file self.reset_database() @@ -168,8 +168,8 @@ def test_calc_and_data_nodes_with_comments(self, temp_dir): data_comments_uuid = [c.uuid for c in [comment_three, comment_four]] # Export nodes - export_file = os.path.join(temp_dir, 'export.tar.gz') - export([calc_node, data_node], outfile=export_file, silent=True) + export_file = os.path.join(temp_dir, 'export.aiida') + export([calc_node, data_node], filename=export_file, silent=True) # Clean database and reimport exported file self.reset_database() @@ -219,8 +219,8 @@ def test_multiple_user_comments_single_node(self, temp_dir): user_two_comments_uuid = [str(c.uuid) for c in [comment_three, comment_four]] # Export node, along with comments and users recursively - export_file = os.path.join(temp_dir, 'export.tar.gz') - export([node], outfile=export_file, silent=True) + export_file = os.path.join(temp_dir, 'export.aiida') + export([node], filename=export_file, silent=True) # Clean database and reimport exported file self.reset_database() @@ -307,8 +307,8 @@ def test_mtime_of_imported_comments(self, temp_dir): calc_mtime = builder[0][1] # Export, reset database and reimport - export_file = os.path.join(temp_dir, 'export.tar.gz') - export([calc], outfile=export_file, silent=True) + export_file = os.path.join(temp_dir, 'export.aiida') + export([calc], filename=export_file, silent=True) self.reset_database() import_data(export_file, silent=True) @@ -357,8 +357,8 @@ def test_import_arg_comment_mode(self, temp_dir): cmt_uuid = cmt.uuid # Export calc and comment - export_file = os.path.join(temp_dir, 'export_file.tar.gz') - export([calc], outfile=export_file, silent=True) + export_file = os.path.join(temp_dir, 'export_file.aiida') + export([calc], filename=export_file, silent=True) # Update comment cmt.set_content(self.comments[1]) @@ -370,8 +370,8 @@ def test_import_arg_comment_mode(self, temp_dir): self.assertEqual(export_comments.all()[0][1], self.comments[1]) # Export calc and UPDATED comment - export_file_updated = os.path.join(temp_dir, 'export_file_updated.tar.gz') - export([calc], outfile=export_file_updated, silent=True) + export_file_updated = os.path.join(temp_dir, 'export_file_updated.aiida') + export([calc], filename=export_file_updated, silent=True) # Reimport exported 'old' calc and comment import_data(export_file, silent=True, comment_mode='newest') @@ -474,7 +474,7 @@ def test_reimport_of_comments_for_single_node(self, temp_dir): # Export "EXISTING" DB export_file_existing = os.path.join(temp_dir, export_filenames['EXISTING']) - export([calc], outfile=export_file_existing, silent=True) + export([calc], filename=export_file_existing, silent=True) # Add remaining Comments for comment in self.comments[1:]: @@ -494,7 +494,7 @@ def test_reimport_of_comments_for_single_node(self, temp_dir): # Export "FULL" DB export_file_full = os.path.join(temp_dir, export_filenames['FULL']) - export([calc], outfile=export_file_full, silent=True) + export([calc], filename=export_file_full, silent=True) # Clean database self.reset_database() @@ -533,7 +533,7 @@ def test_reimport_of_comments_for_single_node(self, temp_dir): # Export "NEW" DB export_file_new = os.path.join(temp_dir, export_filenames['NEW']) - export([calc], outfile=export_file_new, silent=True) + export([calc], filename=export_file_new, silent=True) # Clean database self.reset_database() diff --git a/tests/tools/importexport/orm/test_computers.py b/tests/tools/importexport/orm/test_computers.py index 939ee4b224..e32ed98bb6 100644 --- a/tests/tools/importexport/orm/test_computers.py +++ b/tests/tools/importexport/orm/test_computers.py @@ -64,12 +64,12 @@ def test_same_computer_import(self, temp_dir): comp_uuid = str(comp.uuid) # Export the first job calculation - filename1 = os.path.join(temp_dir, 'export1.tar.gz') - export([calc1], outfile=filename1, silent=True) + filename1 = os.path.join(temp_dir, 'export1.aiida') + export([calc1], filename=filename1, silent=True) # Export the second job calculation - filename2 = os.path.join(temp_dir, 'export2.tar.gz') - export([calc2], outfile=filename2, silent=True) + filename2 = os.path.join(temp_dir, 'export2.aiida') + export([calc2], filename=filename2, silent=True) # Clean the local database self.clean_db() @@ -152,8 +152,8 @@ def test_same_computer_different_name_import(self, temp_dir): comp1_name = str(comp1.name) # Export the first job calculation - filename1 = os.path.join(temp_dir, 'export1.tar.gz') - export([calc1], outfile=filename1, silent=True) + filename1 = os.path.join(temp_dir, 'export1.aiida') + export([calc1], filename=filename1, silent=True) # Rename the computer comp1.set_name(comp1_name + '_updated') @@ -168,8 +168,8 @@ def test_same_computer_different_name_import(self, temp_dir): calc2.seal() # Export the second job calculation - filename2 = os.path.join(temp_dir, 'export2.tar.gz') - export([calc2], outfile=filename2, silent=True) + filename2 = os.path.join(temp_dir, 'export2.aiida') + export([calc2], filename=filename2, silent=True) # Clean the local database self.clean_db() @@ -235,8 +235,8 @@ def test_different_computer_same_name_import(self, temp_dir): calc1.seal() # Export the first job calculation - filename1 = os.path.join(temp_dir, 'export1.tar.gz') - export([calc1], outfile=filename1, silent=True) + filename1 = os.path.join(temp_dir, 'export1.aiida') + export([calc1], filename=filename1, silent=True) # Reset the database self.clean_db() @@ -255,8 +255,8 @@ def test_different_computer_same_name_import(self, temp_dir): calc2.seal() # Export the second job calculation - filename2 = os.path.join(temp_dir, 'export2.tar.gz') - export([calc2], outfile=filename2, silent=True) + filename2 = os.path.join(temp_dir, 'export2.aiida') + export([calc2], filename=filename2, silent=True) # Reset the database self.clean_db() @@ -275,8 +275,8 @@ def test_different_computer_same_name_import(self, temp_dir): calc3.seal() # Export the third job calculation - filename3 = os.path.join(temp_dir, 'export3.tar.gz') - export([calc3], outfile=filename3, silent=True) + filename3 = os.path.join(temp_dir, 'export3.aiida') + export([calc3], filename=filename3, silent=True) # Clean the local database self.clean_db() @@ -332,8 +332,8 @@ def test_import_of_computer_json_params(self, temp_dir): calc1.seal() # Export the first job calculation - filename1 = os.path.join(temp_dir, 'export1.tar.gz') - export([calc1], outfile=filename1, silent=True) + filename1 = os.path.join(temp_dir, 'export1.aiida') + export([calc1], filename=filename1, silent=True) # Clean the local database self.clean_db() diff --git a/tests/tools/importexport/orm/test_extras.py b/tests/tools/importexport/orm/test_extras.py index 5da3a532ea..3a612038c3 100644 --- a/tests/tools/importexport/orm/test_extras.py +++ b/tests/tools/importexport/orm/test_extras.py @@ -33,7 +33,7 @@ def setUpClass(cls, *args, **kwargs): data.set_extra_many({'b': 2, 'c': 3}) cls.tmp_folder = tempfile.mkdtemp() cls.export_file = os.path.join(cls.tmp_folder, 'export.aiida') - export([data], outfile=cls.export_file, silent=True) + export([data], filename=cls.export_file, silent=True) @classmethod def tearDownClass(cls, *args, **kwargs): @@ -89,7 +89,8 @@ def test_absence_of_extras(self): self.imported_node.get_extra('c') def test_extras_import_mode_keep_existing(self): - """Check if old extras are not modified in case of name collision""" + """Check if old extras are not modified in case of name collision + (keep original, create new, leave original)""" self.import_extras() imported_node = self.modify_extras(mode_existing='kcl') @@ -99,7 +100,8 @@ def test_extras_import_mode_keep_existing(self): self.assertEqual(imported_node.get_extra('c'), 3) def test_extras_import_mode_update_existing(self): - """Check if old extras are modified in case of name collision""" + """Check if old extras are modified in case of name collision + (keep original, create new, update original)""" self.import_extras() imported_node = self.modify_extras(mode_existing='kcu') @@ -109,7 +111,8 @@ def test_extras_import_mode_update_existing(self): self.assertEqual(imported_node.get_extra('c'), 3) def test_extras_import_mode_mirror(self): - """Check if old extras are fully overwritten by the imported ones""" + """Check if old extras are fully overwritten by the imported ones + (not keep original, create new, update original)""" self.import_extras() imported_node = self.modify_extras(mode_existing='ncu') @@ -122,7 +125,8 @@ def test_extras_import_mode_mirror(self): self.assertEqual(imported_node.get_extra('c'), 3) def test_extras_import_mode_none(self): - """Check if old extras are fully overwritten by the imported ones""" + """Check if old extras are fully overwritten by the imported ones + (keep original, not create new, leave original)""" self.import_extras() imported_node = self.modify_extras(mode_existing='knl') @@ -134,7 +138,8 @@ def test_extras_import_mode_none(self): imported_node.get_extra('c') def test_extras_import_mode_strange(self): - """Check a mode that is probably does not make much sense but is still available""" + """Check a mode that probably does not make much sense but is still available + (keep original, create new, delete)""" self.import_extras() imported_node = self.modify_extras(mode_existing='kcd') @@ -146,7 +151,7 @@ def test_extras_import_mode_strange(self): imported_node.get_extra('b') def test_extras_import_mode_correct(self): - """Test all possible import modes except 'ask' """ + """Test all possible import modes except 'ask'""" self.import_extras() for mode1 in ['k', 'n']: # keep or not keep old extras for mode2 in ['n', 'c']: # create or not create new extras diff --git a/tests/tools/importexport/orm/test_groups.py b/tests/tools/importexport/orm/test_groups.py index 2010c7b167..b2462a2135 100644 --- a/tests/tools/importexport/orm/test_groups.py +++ b/tests/tools/importexport/orm/test_groups.py @@ -62,8 +62,8 @@ def test_nodes_in_group(self, temp_dir): gr1_uuid = gr1.uuid # At this point we export the generated data - filename1 = os.path.join(temp_dir, 'export1.tar.gz') - export([sd1, jc1, gr1], outfile=filename1, silent=True) + filename1 = os.path.join(temp_dir, 'export1.aiida') + export([sd1, jc1, gr1], filename=filename1, silent=True) n_uuids = [sd1.uuid, jc1.uuid] self.clean_db() self.insert_data() @@ -100,8 +100,8 @@ def test_group_export(self, temp_dir): group_uuid = group.uuid # At this point we export the generated data - filename = os.path.join(temp_dir, 'export.tar.gz') - export([group], outfile=filename, silent=True) + filename = os.path.join(temp_dir, 'export.aiida') + export([group], filename=filename, silent=True) n_uuids = [sd1.uuid] self.clean_db() self.insert_data() @@ -142,8 +142,8 @@ def test_group_import_existing(self, temp_dir): group.add_nodes([sd1]) # At this point we export the generated data - filename = os.path.join(temp_dir, 'export1.tar.gz') - export([group], outfile=filename, silent=True) + filename = os.path.join(temp_dir, 'export1.aiida') + export([group], filename=filename, silent=True) self.clean_db() self.insert_data() @@ -184,7 +184,7 @@ def test_import_to_group(self, temp_dir): # Export Nodes filename = os.path.join(temp_dir, 'export.aiida') - export([data1, data2], outfile=filename, silent=True) + export([data1, data2], filename=filename, silent=True) self.reset_database() # Create Group, do not store diff --git a/tests/tools/importexport/orm/test_links.py b/tests/tools/importexport/orm/test_links.py index c2e6872e85..d636b8c710 100644 --- a/tests/tools/importexport/orm/test_links.py +++ b/tests/tools/importexport/orm/test_links.py @@ -45,8 +45,8 @@ def test_links_to_unknown_nodes(self, temp_dir): struct.store() struct_uuid = struct.uuid - filename = os.path.join(temp_dir, 'export.tar.gz') - export([struct], outfile=filename, silent=True) + filename = os.path.join(temp_dir, 'export.aiida') + export([struct], filename=filename, file_format='tar.gz', silent=True) unpack = SandboxFolder() with tarfile.open(filename, 'r:gz', format=tarfile.PAX_FORMAT) as tar: @@ -93,8 +93,8 @@ def test_input_and_create_links(self, temp_dir): node_work.seal() export_links = get_all_node_links() - export_file = os.path.join(temp_dir, 'export.tar.gz') - export([node_output], outfile=export_file, silent=True) + export_file = os.path.join(temp_dir, 'export.aiida') + export([node_output], filename=export_file, silent=True) self.reset_database() @@ -264,8 +264,8 @@ def test_complex_workflow_graph_links(self, temp_dir): ) export_links = builder.all() - export_file = os.path.join(temp_dir, 'export.tar.gz') - export(graph_nodes, outfile=export_file, silent=True) + export_file = os.path.join(temp_dir, 'export.aiida') + export(graph_nodes, filename=export_file, silent=True) self.reset_database() @@ -285,8 +285,8 @@ def test_complex_workflow_graph_export_sets(self, temp_dir): _, (export_node, export_target) = self.construct_complex_graph(export_conf) export_target_uuids = set(_.uuid for _ in export_target) - export_file = os.path.join(temp_dir, 'export.tar.gz') - export([export_node], outfile=export_file, silent=True, overwrite=True) + export_file = os.path.join(temp_dir, 'export.aiida') + export([export_node], filename=export_file, silent=True, overwrite=True) export_node_str = str(export_node) self.reset_database() @@ -350,8 +350,8 @@ def test_high_level_workflow_links(self, temp_dir): export_links = builder.all() - export_file = os.path.join(temp_dir, 'export.tar.gz') - export(graph_nodes, outfile=export_file, silent=True, overwrite=True) + export_file = os.path.join(temp_dir, 'export.aiida') + export(graph_nodes, filename=export_file, silent=True, overwrite=True) self.reset_database() @@ -377,7 +377,7 @@ def prepare_link_flags_export(nodes_to_export, test_data): for export_file, rule_changes, expected_nodes in test_data.values(): traversal_rules.update(rule_changes) - export(nodes_to_export[0], outfile=export_file, silent=True, **traversal_rules) + export(nodes_to_export[0], filename=export_file, silent=True, **traversal_rules) for node_type in nodes_to_export[1]: if node_type in expected_nodes: @@ -597,8 +597,8 @@ def test_double_return_links_for_workflows(self, temp_dir): uuids_wanted = set(_.uuid for _ in (work1, data_out, data_in, work2)) links_wanted = get_all_node_links() - export_file = os.path.join(temp_dir, 'export.tar.gz') - export([data_out, work1, work2, data_in], outfile=export_file, silent=True) + export_file = os.path.join(temp_dir, 'export.aiida') + export([data_out, work1, work2, data_in], filename=export_file, silent=True) self.reset_database() @@ -627,8 +627,8 @@ def test_dangling_link_to_existing_db_node(self, temp_dir): calc.seal() calc_uuid = calc.uuid - filename = os.path.join(temp_dir, 'export.tar.gz') - export([struct], outfile=filename, silent=True) + filename = os.path.join(temp_dir, 'export.aiida') + export([struct], filename=filename, file_format='tar.gz', silent=True) unpack = SandboxFolder() with tarfile.open(filename, 'r:gz', format=tarfile.PAX_FORMAT) as tar: @@ -693,8 +693,8 @@ def test_multiple_post_return_links(self, temp_dir): # pylint: disable=too-many data_provenance = os.path.join(temp_dir, 'data.aiida') all_provenance = os.path.join(temp_dir, 'all.aiida') - export([data], outfile=data_provenance, silent=True, return_backward=False) - export([data], outfile=all_provenance, silent=True, return_backward=True) + export([data], filename=data_provenance, silent=True, return_backward=False) + export([data], filename=all_provenance, silent=True, return_backward=True) self.reset_database() diff --git a/tests/tools/importexport/orm/test_logs.py b/tests/tools/importexport/orm/test_logs.py index 1a8d4f6e06..bbe34dc53f 100644 --- a/tests/tools/importexport/orm/test_logs.py +++ b/tests/tools/importexport/orm/test_logs.py @@ -53,8 +53,8 @@ def test_critical_log_msg_and_metadata(self, temp_dir): # Store Log metadata log_metadata = orm.Log.objects.get(dbnode_id=calc.id).metadata - export_file = os.path.join(temp_dir, 'export.tar.gz') - export([calc], outfile=export_file, silent=True) + export_file = os.path.join(temp_dir, 'export.aiida') + export([calc], filename=export_file, silent=True) self.reset_database() @@ -84,8 +84,8 @@ def test_exclude_logs_flag(self, temp_dir): calc_uuid = calc.uuid # Export, excluding logs - export_file = os.path.join(temp_dir, 'export.tar.gz') - export([calc], outfile=export_file, silent=True, include_logs=False) + export_file = os.path.join(temp_dir, 'export.aiida') + export([calc], filename=export_file, silent=True, include_logs=False) # Clean database and reimport exported data self.reset_database() @@ -121,8 +121,8 @@ def test_export_of_imported_logs(self, temp_dir): log_uuid = str(log_uuid[0][0]) # Export - export_file = os.path.join(temp_dir, 'export.tar.gz') - export([calc], outfile=export_file, silent=True) + export_file = os.path.join(temp_dir, 'export.aiida') + export([calc], filename=export_file, silent=True) # Clean database and reimport exported data self.reset_database() @@ -142,8 +142,8 @@ def test_export_of_imported_logs(self, temp_dir): # Re-export calc = orm.load_node(import_calcs[0][0]) - re_export_file = os.path.join(temp_dir, 're_export.tar.gz') - export([calc], outfile=re_export_file, silent=True) + re_export_file = os.path.join(temp_dir, 're_export.aiida') + export([calc], filename=re_export_file, silent=True) # Clean database and reimport exported data self.reset_database() @@ -175,8 +175,8 @@ def test_multiple_imports_for_single_node(self, temp_dir): log_uuid_existing = str(log_uuid_existing[0][0]) # Export as "EXISTING" DB - export_file_existing = os.path.join(temp_dir, 'export_EXISTING.tar.gz') - export([node], outfile=export_file_existing, silent=True) + export_file_existing = os.path.join(temp_dir, 'export_EXISTING.aiida') + export([node], filename=export_file_existing, silent=True) # Add 2 more Logs and save UUIDs for all three Logs prior to export node.logger.critical(log_msgs[1]) @@ -185,8 +185,8 @@ def test_multiple_imports_for_single_node(self, temp_dir): log_uuids_full = [str(log[0]) for log in log_uuids_full] # Export as "FULL" DB - export_file_full = os.path.join(temp_dir, 'export_FULL.tar.gz') - export([node], outfile=export_file_full, silent=True) + export_file_full = os.path.join(temp_dir, 'export_FULL.aiida') + export([node], filename=export_file_full, silent=True) # Clean database and reimport "EXISTING" DB self.reset_database() @@ -290,7 +290,7 @@ def test_reimport_of_logs_for_single_node(self, temp_dir): # Export "EXISTING" DB export_file_existing = os.path.join(temp_dir, export_filenames['EXISTING']) - export([calc], outfile=export_file_existing, silent=True) + export([calc], filename=export_file_existing, silent=True) # Add remaining Log messages for log_msg in log_msgs[1:]: @@ -310,7 +310,7 @@ def test_reimport_of_logs_for_single_node(self, temp_dir): # Export "FULL" DB export_file_full = os.path.join(temp_dir, export_filenames['FULL']) - export([calc], outfile=export_file_full, silent=True) + export([calc], filename=export_file_full, silent=True) # Clean database self.reset_database() @@ -348,7 +348,7 @@ def test_reimport_of_logs_for_single_node(self, temp_dir): # Export "NEW" DB export_file_new = os.path.join(temp_dir, export_filenames['NEW']) - export([calc], outfile=export_file_new, silent=True) + export([calc], filename=export_file_new, silent=True) # Clean database self.reset_database() diff --git a/tests/tools/importexport/orm/test_users.py b/tests/tools/importexport/orm/test_users.py index d1658b13c5..47466caf33 100644 --- a/tests/tools/importexport/orm/test_users.py +++ b/tests/tools/importexport/orm/test_users.py @@ -80,9 +80,9 @@ def test_nodes_belonging_to_different_users(self, temp_dir): uuids_u1 = [sd1.uuid, jc1.uuid, sd2.uuid] uuids_u2 = [jc2.uuid, sd3.uuid] - filename = os.path.join(temp_dir, 'export.tar.gz') + filename = os.path.join(temp_dir, 'export.aiida') - export([sd3], outfile=filename, silent=True) + export([sd3], filename=filename, silent=True) self.clean_db() self.create_user() import_data(filename, silent=True) @@ -137,8 +137,8 @@ def test_non_default_user_nodes(self, temp_dir): # pylint: disable=too-many-sta sd2_uuid = sd2.uuid # At this point we export the generated data - filename1 = os.path.join(temp_dir, 'export1.tar.gz') - export([sd2], outfile=filename1, silent=True) + filename1 = os.path.join(temp_dir, 'export1.aiidaz') + export([sd2], filename=filename1, silent=True) uuids1 = [sd1.uuid, jc1.uuid, sd2.uuid] self.clean_db() self.insert_data() @@ -170,8 +170,8 @@ def test_non_default_user_nodes(self, temp_dir): # pylint: disable=too-many-sta # if they can be imported correctly. uuids2 = [jc2.uuid, sd3.uuid] - filename2 = os.path.join(temp_dir, 'export2.tar.gz') - export([sd3], outfile=filename2, silent=True) + filename2 = os.path.join(temp_dir, 'export2.aiida') + export([sd3], filename=filename2, silent=True) self.clean_db() self.insert_data() import_data(filename2, silent=True) diff --git a/tests/tools/importexport/test_complex.py b/tests/tools/importexport/test_complex.py index bbdb243bc0..08a11ef145 100644 --- a/tests/tools/importexport/test_complex.py +++ b/tests/tools/importexport/test_complex.py @@ -88,8 +88,8 @@ def test_complex_graph_import_export(self, temp_dir): fd1.uuid: fd1.label } - filename = os.path.join(temp_dir, 'export.tar.gz') - export([fd1], outfile=filename, silent=True) + filename = os.path.join(temp_dir, 'export.aiida') + export([fd1], filename=filename, silent=True) self.clean_db() self.create_user() @@ -157,12 +157,12 @@ def get_hash_from_db_content(grouplabel): size = 10 grouplabel = 'test-group' - nparr = np.random.random((4, 3, 2)) + nparr = np.random.random((4, 3, 2)) # pylint: disable=no-member trial_dict = {} # give some integers: trial_dict.update({str(k): np.random.randint(100) for k in range(10)}) # give some floats: - trial_dict.update({str(k): np.random.random() for k in range(10, 20)}) + trial_dict.update({str(k): np.random.random() for k in range(10, 20)}) # pylint: disable=no-member # give some booleans: trial_dict.update({str(k): bool(np.random.randint(1)) for k in range(20, 30)}) # give some text: @@ -196,12 +196,12 @@ def get_hash_from_db_content(grouplabel): # I export and reimport 3 times in a row: for i in range(3): # Always new filename: - filename = os.path.join(temp_dir, 'export-{}.zip'.format(i)) + filename = os.path.join(temp_dir, 'export-{}.aiida'.format(i)) # Loading the group from the string group = orm.Group.get(label=grouplabel) # exporting based on all members of the group # this also checks if group memberships are preserved! - export([group] + list(group.nodes), outfile=filename, silent=True) + export([group] + list(group.nodes), filename=filename, silent=True) # cleaning the DB! self.clean_db() self.create_user() diff --git a/tests/tools/importexport/test_deprecation.py b/tests/tools/importexport/test_deprecation.py new file mode 100644 index 0000000000..d3ea749c89 --- /dev/null +++ b/tests/tools/importexport/test_deprecation.py @@ -0,0 +1,74 @@ +"""Test deprecated parts still work and emit deprecations warnings""" +# pylint: disable=invalid-name +import os + +import pytest + +from aiida.common.warnings import AiidaDeprecationWarning +from aiida.tools.importexport import dbexport + +pytestmark = pytest.mark.usefixtures('clear_database_before_test') + + +def test_export_functions(temp_dir): + """Check `what` and `outfile` in export(), export_tar() and export_zip()""" + what = [] + outfile = os.path.join(temp_dir, 'deprecated.aiida') + + for export_function in (dbexport.export, dbexport.export_tar, dbexport.export_zip): + if os.path.exists(outfile): + os.remove(outfile) + with pytest.warns(AiidaDeprecationWarning, match='`what` is deprecated, please use `entities` instead'): + export_function(what=what, filename=outfile) + + if os.path.exists(outfile): + os.remove(outfile) + with pytest.warns( + AiidaDeprecationWarning, match='`what` is deprecated, the supplied `entities` input will be used' + ): + export_function(entities=what, what=what, filename=outfile) + + if os.path.exists(outfile): + os.remove(outfile) + with pytest.warns( + AiidaDeprecationWarning, + match='`outfile` is deprecated, please use `filename` instead', + ): + export_function(what, outfile=outfile) + + if os.path.exists(outfile): + os.remove(outfile) + with pytest.warns( + AiidaDeprecationWarning, match='`outfile` is deprecated, the supplied `filename` input will be used' + ): + export_function(what, filename=outfile, outfile=outfile) + + if os.path.exists(outfile): + os.remove(outfile) + with pytest.raises(TypeError, match='`entities` must be specified'): + export_function(filename=outfile) + + +def test_export_tree(): + """Check `what` in export_tree()""" + from aiida.common.folders import SandboxFolder + + what = [] + + with SandboxFolder() as folder: + with pytest.warns(AiidaDeprecationWarning, match='`what` is deprecated, please use `entities` instead'): + dbexport.export_tree(what=what, folder=folder) + + folder.erase(create_empty_folder=True) + with pytest.warns( + AiidaDeprecationWarning, match='`what` is deprecated, the supplied `entities` input will be used' + ): + dbexport.export_tree(entities=what, what=what, folder=folder) + + folder.erase(create_empty_folder=True) + with pytest.raises(TypeError, match='`entities` must be specified'): + dbexport.export_tree(folder=folder) + + folder.erase(create_empty_folder=True) + with pytest.raises(TypeError, match='`folder` must be specified'): + dbexport.export_tree(entities=what) diff --git a/tests/tools/importexport/test_prov_redesign.py b/tests/tools/importexport/test_prov_redesign.py index 5ef849c51c..85b945caa1 100644 --- a/tests/tools/importexport/test_prov_redesign.py +++ b/tests/tools/importexport/test_prov_redesign.py @@ -58,8 +58,8 @@ def test_base_data_type_change(self, temp_dir): export_nodes.append(list_node) # Export nodes - filename = os.path.join(temp_dir, 'export.tar.gz') - export(export_nodes, outfile=filename, silent=True) + filename = os.path.join(temp_dir, 'export.aiida') + export(export_nodes, filename=filename, silent=True) # Clean the database self.reset_database() @@ -107,8 +107,8 @@ def test_node_process_type(self, temp_dir): self.assertEqual(node.process_type, node_process_type) # Export nodes - filename = os.path.join(temp_dir, 'export.tar.gz') - export([node], outfile=filename, silent=True) + filename = os.path.join(temp_dir, 'export.aiida') + export([node], filename=filename, silent=True) # Clean the database and reimport data self.reset_database() @@ -150,8 +150,8 @@ def test_code_type_change(self, temp_dir): self.assertEqual(code_type, 'data.code.Code.') # Export node - filename = os.path.join(temp_dir, 'export.tar.gz') - export([code], outfile=filename, silent=True) + filename = os.path.join(temp_dir, 'export.aiida') + export([code], filename=filename, silent=True) # Clean the database and reimport self.reset_database() @@ -232,8 +232,8 @@ def test_group_name_and_type_change(self, temp_dir): self.assertListEqual(groups_type_string, ['core', 'core.upf']) # Export node - filename = os.path.join(temp_dir, 'export.tar.gz') - export([group_user, group_upf], outfile=filename, silent=True) + filename = os.path.join(temp_dir, 'export.aiida') + export([group_user, group_upf], filename=filename, silent=True) # Clean the database and reimport self.reset_database() diff --git a/tests/tools/importexport/test_simple.py b/tests/tools/importexport/test_simple.py index 646afbc373..5f8994a92b 100644 --- a/tests/tools/importexport/test_simple.py +++ b/tests/tools/importexport/test_simple.py @@ -37,14 +37,14 @@ def test_base_data_nodes(self, temp_dir): """Test ex-/import of Base Data nodes""" # producing values for each base type values = ('Hello', 6, -1.2399834e12, False) # , ["Bla", 1, 1e-10]) - filename = os.path.join(temp_dir, 'export.tar.gz') + filename = os.path.join(temp_dir, 'export.aiida') # producing nodes: nodes = [cls(val).store() for val, cls in zip(values, (orm.Str, orm.Int, orm.Float, orm.Bool))] # my uuid - list to reload the node: uuids = [n.uuid for n in nodes] # exporting the nodes: - export(nodes, outfile=filename, silent=True) + export(nodes, filename=filename, silent=True) # cleaning: self.clean_db() self.create_user() @@ -79,9 +79,9 @@ def test_calc_of_structuredata(self, temp_dir): for k in node.attributes.keys(): attrs[node.uuid][k] = node.get_attribute(k) - filename = os.path.join(temp_dir, 'export.tar.gz') + filename = os.path.join(temp_dir, 'export.aiida') - export([calc], outfile=filename, silent=True) + export([calc], filename=filename, silent=True) self.clean_db() self.create_user() @@ -104,8 +104,8 @@ def test_check_for_export_format_version(self): struct = orm.StructureData() struct.store() - filename = os.path.join(export_file_tmp_folder, 'export.tar.gz') - export([struct], outfile=filename, silent=True) + filename = os.path.join(export_file_tmp_folder, 'export.aiida') + export([struct], filename=filename, file_format='tar.gz', silent=True) with tarfile.open(filename, 'r:gz', format=tarfile.PAX_FORMAT) as tar: tar.extractall(unpack_tmp_folder) diff --git a/tests/tools/importexport/test_specific_import.py b/tests/tools/importexport/test_specific_import.py index 28cfb7bd26..9ee506250f 100644 --- a/tests/tools/importexport/test_specific_import.py +++ b/tests/tools/importexport/test_specific_import.py @@ -66,7 +66,7 @@ def test_simple_import(self): with tempfile.NamedTemporaryFile() as handle: nodes = [parameters] - export(nodes, outfile=handle.name, overwrite=True, silent=True) + export(nodes, filename=handle.name, overwrite=True, silent=True) # Check that we have the expected number of nodes in the database self.assertEqual(orm.QueryBuilder().append(orm.Node).count(), len(nodes)) @@ -127,7 +127,7 @@ def test_cycle_structure_data(self): with tempfile.NamedTemporaryFile() as handle: nodes = [structure, child_calculation, parent_process, remote_folder] - export(nodes, outfile=handle.name, overwrite=True, silent=True) + export(nodes, filename=handle.name, overwrite=True, silent=True) # Check that we have the expected number of nodes in the database self.assertEqual(orm.QueryBuilder().append(orm.Node).count(), len(nodes)) @@ -198,9 +198,9 @@ def test_missing_node_repo_folder_export(self, temp_dir): ) # Try to export, check it raises and check the raise message - filename = os.path.join(temp_dir, 'export.tar.gz') + filename = os.path.join(temp_dir, 'export.aiida') with self.assertRaises(exceptions.ArchiveExportError) as exc: - export([node], outfile=filename, silent=True) + export([node], filename=filename, silent=True) self.assertIn( 'Unable to find the repository folder for Node with UUID={}'.format(node_uuid), str(exc.exception) @@ -233,8 +233,8 @@ def test_missing_node_repo_folder_import(self, temp_dir): ) # Export and reset db - filename = os.path.join(temp_dir, 'export.tar.gz') - export([node], outfile=filename, silent=True) + filename = os.path.join(temp_dir, 'export.aiida') + export([node], filename=filename, file_format='tar.gz', silent=True) self.reset_database() # Untar export file, remove repository folder, re-tar @@ -256,7 +256,7 @@ def test_missing_node_repo_folder_import(self, temp_dir): msg="The Node's repository folder should now have been removed in the export file" ) - filename_corrupt = os.path.join(temp_dir, 'export_corrupt.tar.gz') + filename_corrupt = os.path.join(temp_dir, 'export_corrupt.aiida') with tarfile.open(filename_corrupt, 'w:gz', format=tarfile.PAX_FORMAT, dereference=True) as tar: tar.add(folder.abspath, arcname='') @@ -273,7 +273,7 @@ def test_missing_node_repo_folder_import(self, temp_dir): def test_empty_repo_folder_export(self, temp_dir): """Check a Node's empty repository folder is exported properly""" from aiida.common.folders import Folder - from aiida.tools.importexport.dbexport import export_zip, export_tree + from aiida.tools.importexport.dbexport import export_tree node = orm.Dict().store() node_uuid = node.uuid @@ -302,8 +302,8 @@ def test_empty_repo_folder_export(self, temp_dir): } export_tree([node], folder=Folder(archive_variants['archive folder']), silent=True) - export([node], outfile=archive_variants['tar archive'], silent=True) - export_zip([node], outfile=archive_variants['zip archive'], silent=True) + export([node], filename=archive_variants['tar archive'], file_format='tar.gz', silent=True) + export([node], filename=archive_variants['zip archive'], file_format='zip', silent=True) for variant, filename in archive_variants.items(): self.reset_database()