Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make construction of AttributeDict recursive #3005

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 40 additions & 22 deletions aiida/backends/tests/common/test_extendeddicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_copy(self):
"""Test copying."""
dictionary_01 = extendeddicts.AttributeDict()
dictionary_01.alpha = 'a'
dictionary_02 = dictionary_01.copy()
dictionary_02 = copy.copy(dictionary_01)
dictionary_02.alpha = 'b'
self.assertEqual(dictionary_01.alpha, 'a')
self.assertEqual(dictionary_02.alpha, 'b')
Expand All @@ -106,7 +106,7 @@ def test_delete_after_copy(self):
dictionary_01 = extendeddicts.AttributeDict()
dictionary_01.alpha = 'a'
dictionary_01.beta = 'b'
dictionary_02 = dictionary_01.copy()
dictionary_02 = copy.copy(dictionary_01)
del dictionary_01.alpha
del dictionary_01['beta']
with self.assertRaises(AttributeError):
Expand All @@ -123,7 +123,7 @@ def test_shallowcopy1(self):
dictionary_01 = extendeddicts.AttributeDict()
dictionary_01.alpha = [1, 2, 3]
dictionary_01.beta = 3
dictionary_02 = dictionary_01.copy()
dictionary_02 = copy.copy(dictionary_01)
dictionary_02.alpha[0] = 4
dictionary_02.beta = 5
self.assertEqual(dictionary_01.alpha, [4, 2, 3]) # copy does a shallow copy
Expand All @@ -135,13 +135,7 @@ def test_shallowcopy2(self):
"""Test shallow copying."""
dictionary_01 = extendeddicts.AttributeDict()
dictionary_01.alpha = {'a': 'b', 'c': 'd'}
# dictionary_02 = copy.deepcopy(dictionary_01)
dictionary_02 = dictionary_01.copy()
# doesn't work like this, would work as dictionary_02['x']['a']
# i think that it is because deepcopy on dict actually creates a
# copy only if the data is changed; but for a nested dict,
# dictionary_02.alpha returns a dict wrapped in our class and this looses all the
# information on what should be updated when changed.
dictionary_02 = copy.copy(dictionary_01)
dictionary_02.alpha['a'] = 'ggg'
self.assertEqual(dictionary_01.alpha['a'], 'ggg') # copy does a shallow copy
self.assertEqual(dictionary_02.alpha['a'], 'ggg')
Expand All @@ -155,20 +149,44 @@ def test_deepcopy1(self):
self.assertEqual(dictionary_01.alpha, [1, 2, 3])
self.assertEqual(dictionary_02.alpha, [4, 2, 3])

def test_shallowcopy3(self):
"""Test shallow copying."""
dictionary_01 = extendeddicts.AttributeDict()
dictionary_01.alpha = {'a': 'b', 'c': 'd'}
dictionary_02 = copy.deepcopy(dictionary_01)
dictionary_02.alpha['a'] = 'ggg'
self.assertEqual(dictionary_01.alpha['a'], 'b') # copy does a shallow copy
self.assertEqual(dictionary_02.alpha['a'], 'ggg')


class TestAttributeDictNested(unittest.TestCase):
"""
Test the functionality of nested AttributeDict classes.
"""
"""Test the functionality of nested AttributeDict classes."""

def test_shallow_copy(self):
"""Test shallow copying using either the copy method of the dict class or the copy module."""
nested = {'x': 1, 'y': 2, 'sub': {'namespace': {'a': 1}, 'b': 'string'}}
dictionary = extendeddicts.AttributeDict(nested)
copied_by_method = dictionary.copy()
copied_by_module = copy.copy(dictionary)

dictionary.x = 400
dictionary.sub.namespace.b = 'other_string'

# The shallow copied dictionaries should be different objects
self.assertTrue(dictionary is not copied_by_method)
self.assertTrue(dictionary is not copied_by_module)

# However, nested dictionaries should be the same
self.assertTrue(dictionary.sub is copied_by_method['sub'])
self.assertTrue(dictionary.sub is copied_by_module['sub'])

# The top-level values should not have changed, because they have been deep copied
self.assertEqual(copied_by_method['x'], 1)
self.assertEqual(copied_by_module['x'], 1)

# The nested value should have also changed for the shallow copies
self.assertEqual(copied_by_method['sub']['namespace']['b'], 'other_string')
self.assertEqual(copied_by_module['sub']['namespace']['b'], 'other_string')

def test_recursive_attribute_dict(self):
"""Test that all nested dictionaries are also recursively turned into AttributeDict instances."""
nested = {'x': 1, 'y': 2, 'sub': {'namespace': {'a': 1}, 'b': 'string'}}
dictionary = extendeddicts.AttributeDict(nested)
self.assertIsInstance(dictionary, extendeddicts.AttributeDict)
self.assertIsInstance(dictionary.sub, extendeddicts.AttributeDict)
self.assertIsInstance(dictionary.sub.namespace, extendeddicts.AttributeDict)
self.assertEqual(dictionary.sub.namespace.a, nested['sub']['namespace']['a'])

def test_nested(self):
"""Test nested dictionary."""
Expand Down
51 changes: 25 additions & 26 deletions aiida/common/extendeddicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from __future__ import print_function
from __future__ import absolute_import

from collections import Mapping

from . import exceptions

__all__ = ('AttributeDict', 'FixedFieldsAttributeDict', 'DefaultFieldsAttributeDict')
Expand All @@ -28,16 +30,26 @@ class AttributeDict(dict):
used.
"""

def __init__(self, dictionary=None):
"""Recursively turn the `dict` and all its nested dictionaries into `AttributeDict` instance."""
super(AttributeDict, self).__init__()
if dictionary is None:
dictionary = {}

for key, value in dictionary.items():
if isinstance(value, Mapping):
self[key] = AttributeDict(value)
else:
self[key] = value

def __repr__(self):
"""
Representation of the object.
"""
"""Representation of the object."""
return "%s(%s)" % (self.__class__.__name__, dict.__repr__(self))

def __getattr__(self, attr):
"""
Read a key as an attribute. Raise AttributeError on missing key.
Called only for attributes that do not exist.
"""Read a key as an attribute.

:raises AttributeError: if the attribute does not correspond to an existing key.
"""
try:
return self[attr]
Expand All @@ -46,35 +58,26 @@ def __getattr__(self, attr):
raise AttributeError(errmsg)

def __setattr__(self, attr, value):
"""
Set a key as an attribute.
"""
"""Set a key as an attribute."""
try:
self[attr] = value
except KeyError:
raise AttributeError("AttributeError: '{}' is not a valid attribute of the object "
"'{}'".format(attr, self.__class__.__name__))

def __delattr__(self, attr):
"""
Delete a key as an attribute. Raise AttributeError on missing key.
"""Delete a key as an attribute.

:raises AttributeError: if the attribute does not correspond to an existing key.
"""
try:
del self[attr]
except KeyError:
errmsg = "'{}' object has no attribute '{}'".format(self.__class__.__name__, attr)
raise AttributeError(errmsg)

def copy(self):
"""
Shallow copy.
"""
return self.__class__(self)

def __deepcopy__(self, memo=None):
"""
Support deepcopy.
"""
"""Deep copy."""
from copy import deepcopy

if memo is None:
Expand All @@ -83,15 +86,11 @@ def __deepcopy__(self, memo=None):
return self.__class__(retval)

def __getstate__(self):
"""
Needed for pickling this class.
"""
"""Needed for pickling this class."""
return self.__dict__.copy()

def __setstate__(self, dictionary):
"""
Needed for pickling this class.
"""
"""Needed for pickling this class."""
self.__dict__.update(dictionary)

def __dir__(self):
Expand Down
5 changes: 2 additions & 3 deletions aiida/orm/nodes/process/calculation/calcjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,9 @@ def _validate(self):
except TypeError as exception:
raise exceptions.ValidationError('invalid parser specified: {}'.format(exception))

computer = self.computer
scheduler = computer.get_scheduler()
resources = self.get_option('resources')
def_cpus_machine = computer.get_default_mpiprocs_per_machine()
scheduler = self.computer.get_scheduler() # pylint: disable=no-member
def_cpus_machine = self.computer.get_default_mpiprocs_per_machine() # pylint: disable=no-member

if def_cpus_machine is not None:
resources['default_mpiprocs_per_machine'] = def_cpus_machine
Expand Down
1 change: 1 addition & 0 deletions docs/source/nitpick-exceptions
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ py:class unittest.case.TestCase
py:class unittest.runner.TextTestRunner
py:class unittest2.case.TestCase
py:meth unittest.TestLoader.discover
py:meth copy.copy
py:class ABCMeta
py:class exceptions.Exception
py:class exceptions.ValueError
Expand Down