Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

👌 IMPROVE: constructor of base data types #5165

Merged
merged 2 commits into from
Dec 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions aiida/orm/nodes/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,15 @@ def to_aiida_type(value):
class BaseType(Data):
"""`Data` sub class to be used as a base for data containers that represent base python data types."""

def __init__(self, *args, **kwargs):
def __init__(self, value=None, **kwargs):
try:
getattr(self, '_type')
except AttributeError:
raise RuntimeError('Derived class must define the `_type` class member')

super().__init__(**kwargs)

try:
value = args[0]
except IndexError:
value = self._type() # pylint: disable=no-member

self.value = value
self.value = value or self._type() # pylint: disable=no-member

@property
def value(self):
Expand Down
12 changes: 6 additions & 6 deletions aiida/orm/nodes/data/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,17 @@ class Dict(Data):
Finally, all dictionary mutations will be forbidden once the node is stored.
"""

def __init__(self, **kwargs):
"""Store a dictionary as a `Node` instance.
def __init__(self, value=None, **kwargs):
"""Initialise a ``Dict`` node instance.
Comment on lines +49 to +50
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just want to make sure this will not break any backwards compatibility, correct? We can still use d = orm.Dict(dict={}).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I see a test case for this particular. 👍


Usual rules for attribute names apply, in particular, keys cannot start with an underscore, or a `ValueError`
Usual rules for attribute names apply, in particular, keys cannot start with an underscore, or a ``ValueError``
will be raised.

Initial attributes can be changed, deleted or added as long as the node is not stored.

:param dict: the dictionary to set
:param value: dictionary to initialise the ``Dict`` node from
"""
dictionary = kwargs.pop('dict', None)
dictionary = value or kwargs.pop('dict', None)
super().__init__(**kwargs)
if dictionary:
self.set_dict(dictionary)
Expand Down Expand Up @@ -135,4 +135,4 @@ def dict(self):

@to_aiida_type.register(dict)
def _(value):
return Dict(dict=value)
return Dict(value)
16 changes: 12 additions & 4 deletions aiida/orm/nodes/data/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ class List(Data, MutableSequence):

_LIST_KEY = 'list'

def __init__(self, **kwargs):
data = kwargs.pop('list', [])
def __init__(self, value=None, **kwargs):
"""Initialise a ``List`` node instance.

:param value: list to initialise the ``List`` node from
"""
data = value or kwargs.pop('list', [])
super().__init__(**kwargs)
self.set_list(data)

Expand Down Expand Up @@ -75,7 +79,11 @@ def insert(self, i, value): # pylint: disable=arguments-renamed
self.set_list(data)

def remove(self, value):
del self[value]
data = self.get_list()
item = data.remove(value)
if not self._using_list_reference():
self.set_list(data)
return item

def pop(self, **kwargs): # pylint: disable=arguments-differ
"""Remove and return item at index (default last)."""
Expand Down Expand Up @@ -123,7 +131,7 @@ def set_list(self, data):
"""
if not isinstance(data, list):
raise TypeError('Must supply list type')
self.set_attribute(self._LIST_KEY, data)
self.set_attribute(self._LIST_KEY, data.copy())

def _using_list_reference(self):
"""
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
207 changes: 207 additions & 0 deletions tests/orm/nodes/data/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
# pylint: disable=invalid-name
"""Tests for :class:`aiida.orm.nodes.data.base.BaseType` classes."""

import operator

import pytest

from aiida.orm import Bool, Float, Int, NumericType, Str, load_node


@pytest.mark.usefixtures('clear_database_before_test')
@pytest.mark.parametrize(
'node_type, default, value', [
(Bool, False, True),
(Int, 0, 5),
(Float, 0.0, 5.5),
(Str, '', 'a'),
]
)
def test_create(node_type, default, value):
"""Test the creation of the ``BaseType`` nodes."""

node = node_type()
assert node.value == default

node = node_type(value)
assert node.value == value


@pytest.mark.usefixtures('clear_database_before_test')
@pytest.mark.parametrize('node_type', [Bool, Float, Int, Str])
def test_store_load(node_type):
"""Test ``BaseType`` node storing and loading."""
node = node_type()
node.store()
loaded = load_node(node.pk)
assert node.value == loaded.value


@pytest.mark.usefixtures('clear_database_before_test')
def test_modulo():
"""Test ``Int`` modulus operation."""
term_a = Int(12)
term_b = Int(10)

assert term_a % term_b == 2
assert isinstance(term_a % term_b, NumericType)
assert term_a % 10 == 2
assert isinstance(term_a % 10, NumericType)
assert 12 % term_b == 2
assert isinstance(12 % term_b, NumericType)


@pytest.mark.usefixtures('clear_database_before_test')
@pytest.mark.parametrize('node_type, a, b', [
(Int, 3, 5),
(Float, 1.2, 5.5),
])
def test_add(node_type, a, b):
"""Test addition for ``Int`` and ``Float`` nodes."""
node_a = node_type(a)
node_b = node_type(b)

result = node_a + node_b
assert isinstance(result, node_type)
assert result.value == a + b

# Node and native (both ways)
result = node_a + b
assert isinstance(result, node_type)
assert result.value == a + b

result = a + node_b
assert isinstance(result, node_type)
assert result.value == a + b

# Inplace
result = node_type(a)
result += node_b
assert isinstance(result, node_type)
assert result.value == a + b


@pytest.mark.usefixtures('clear_database_before_test')
@pytest.mark.parametrize('node_type, a, b', [
(Int, 3, 5),
(Float, 1.2, 5.5),
])
def test_multiplication(node_type, a, b):
"""Test floats multiplication."""
node_a = node_type(a)
node_b = node_type(b)

# Check multiplication
result = node_a * node_b
assert isinstance(result, node_type)
assert result.value == a * b

# Check multiplication Node and native (both ways)
result = node_a * b
assert isinstance(result, node_type)
assert result.value == a * b

result = a * node_b
assert isinstance(result, node_type)
assert result.value == a * b

# Inplace
result = node_type(a)
result *= node_b
assert isinstance(result, node_type)
assert result.value == a * b


@pytest.mark.usefixtures('clear_database_before_test')
@pytest.mark.parametrize('node_type, a, b', [
(Int, 3, 5),
(Float, 1.2, 5.5),
])
@pytest.mark.usefixtures('clear_database_before_test')
def test_division(node_type, a, b):
"""Test the ``BaseType`` normal division operator."""
node_a = node_type(a)
node_b = node_type(b)

result = node_a / node_b
assert result == a / b
assert isinstance(result, Float) # Should be a `Float` for both node types


@pytest.mark.usefixtures('clear_database_before_test')
@pytest.mark.parametrize('node_type, a, b', [
(Int, 3, 5),
(Float, 1.2, 5.5),
])
@pytest.mark.usefixtures('clear_database_before_test')
def test_division_integer(node_type, a, b):
"""Test the ``Int`` integer division operator."""
node_a = node_type(a)
node_b = node_type(b)

result = node_a // node_b
assert result == a // b
assert isinstance(result, node_type)


@pytest.mark.usefixtures('clear_database_before_test')
@pytest.mark.parametrize('node_type, base, power', [
(Int, 5, 2),
(Float, 3.5, 3),
])
def test_power(node_type, base, power):
"""Test power operator."""
node_base = node_type(base)
node_power = node_type(power)

result = node_base**node_power
assert result == base**power
assert isinstance(result, node_type)


@pytest.mark.usefixtures('clear_database_before_test')
@pytest.mark.parametrize('node_type, a, b', [
(Int, 5, 2),
(Float, 3.5, 3),
])
def test_modulus(node_type, a, b):
"""Test modulus operator."""
node_a = node_type(a)
node_b = node_type(b)

assert node_a % node_b == a % b
assert isinstance(node_a % node_b, node_type)

assert node_a % b == a % b
assert isinstance(node_a % b, node_type)

assert a % node_b == a % b
assert isinstance(a % node_b, node_type)


@pytest.mark.usefixtures('clear_database_before_test')
@pytest.mark.parametrize(
'opera', [
operator.add, operator.mul, operator.pow, operator.lt, operator.le, operator.gt, operator.ge, operator.iadd,
operator.imul
]
)
def test_operator(opera):
"""Test operations between Int and Float objects."""
node_a = Float(2.2)
node_b = Int(3)

for node_x, node_y in [(node_a, node_b), (node_b, node_a)]:
res = opera(node_x, node_y)
c_val = opera(node_x.value, node_y.value)
assert res._type == type(c_val) # pylint: disable=protected-access
assert res == opera(node_x.value, node_y.value)
File renamed without changes.
File renamed without changes.
25 changes: 16 additions & 9 deletions tests/orm/data/test_dict.py → tests/orm/nodes/data/test_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,29 @@ def dictionary():
@pytest.mark.usefixtures('clear_database_before_test')
def test_keys(dictionary):
"""Test the ``keys`` method."""
node = Dict(dict=dictionary)
node = Dict(dictionary)
assert sorted(node.keys()) == sorted(dictionary.keys())


@pytest.mark.usefixtures('clear_database_before_test')
def test_get_dict(dictionary):
"""Test the ``get_dict`` method."""
node = Dict(dict=dictionary)
node = Dict(dictionary)
assert node.get_dict() == dictionary


@pytest.mark.usefixtures('clear_database_before_test')
def test_dict_property(dictionary):
"""Test the ``dict`` property."""
node = Dict(dict=dictionary)
node = Dict(dictionary)
assert node.dict.value == dictionary['value']
assert node.dict.nested == dictionary['nested']


@pytest.mark.usefixtures('clear_database_before_test')
def test_get_item(dictionary):
"""Test the ``__getitem__`` method."""
node = Dict(dict=dictionary)
node = Dict(dictionary)
assert node['value'] == dictionary['value']
assert node['nested'] == dictionary['nested']

Expand All @@ -56,7 +56,7 @@ def test_set_item(dictionary):
* ``__setitem__`` directly on the node
* ``__setattr__`` through the ``AttributeManager`` returned by the ``dict`` property
"""
node = Dict(dict=dictionary)
node = Dict(dictionary)

node['value'] = 2
assert node['value'] == 2
Expand All @@ -72,7 +72,7 @@ def test_correct_raises(dictionary):
* ``node['inexistent']`` should raise ``KeyError``
* ``node.dict.inexistent`` should raise ``AttributeError``
"""
node = Dict(dict=dictionary)
node = Dict(dictionary)

with pytest.raises(KeyError):
_ = node['inexistent_key']
Expand All @@ -89,8 +89,8 @@ def test_eq(dictionary):
compare equal to another node that has the same content. This is a hot issue and is being discussed in the following
ticket: https://github.com/aiidateam/aiida-core/issues/1917
"""
node = Dict(dict=dictionary)
clone = Dict(dict=dictionary)
node = Dict(dictionary)
clone = Dict(dictionary)

assert node is node # pylint: disable=comparison-with-itself
assert node == dictionary
Expand All @@ -101,8 +101,15 @@ def test_eq(dictionary):
# wouldn't happen unless, by accident, two different nodes get the same UUID, the probability of which is minimal.
# Note that we have to set the UUID directly through the database model instance of the backend entity, since it is
# forbidden to change it through the front-end or backend entity instance, for good reasons.
other = Dict(dict={})
other = Dict({})
other.backend_entity._dbmodel.uuid = node.uuid # pylint: disable=protected-access
assert other.uuid == node.uuid
assert other.dict != node.dict
assert node == other


@pytest.mark.usefixtures('clear_database_before_test')
def test_initialise_with_dict_kwarg(dictionary):
"""Test that the ``Dict`` node can be initialized with the ``dict`` keyword argument for backwards compatibility."""
node = Dict(dict=dictionary)
assert sorted(node.keys()) == sorted(dictionary.keys())
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading