From 93a91e22be75352ed46a64f6c2c18d72adcfb601 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Tue, 13 Dec 2022 14:43:21 +0100 Subject: [PATCH 1/2] =?UTF-8?q?=F0=9F=91=8C=20IMPROVE:=20Add=20dataclass?= =?UTF-8?q?=20serialisation=20to=20context?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- aiida/orm/utils/serialize.py | 25 +++++++++++++++++++++++++ tests/orm/utils/test_serialize.py | 17 +++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/aiida/orm/utils/serialize.py b/aiida/orm/utils/serialize.py index 3ddd1a9ae0..a5cb5a1bec 100644 --- a/aiida/orm/utils/serialize.py +++ b/aiida/orm/utils/serialize.py @@ -16,8 +16,10 @@ """ from __future__ import annotations +from dataclasses import asdict, is_dataclass from enum import Enum from functools import partial +import inspect from typing import Any, Protocol, Type, overload from plumpy import Bundle, get_object_loader # type: ignore[attr-defined] @@ -28,6 +30,7 @@ from aiida.common import AttributeDict _ENUM_TAG = '!enum' +_DATACLASS_TAG = '!dataclass' _NODE_TAG = '!aiida_node' _GROUP_TAG = '!aiida_group' _COMPUTER_TAG = '!aiida_computer' @@ -51,6 +54,25 @@ def enum_constructor(loader: yaml.Loader, serialized: yaml.Node) -> Enum: return enum +def represent_dataclass(dumper: yaml.Dumper, obj: Any) -> yaml.MappingNode: + """Represent an arbitrary dataclass in yaml.""" + loader = get_object_loader() + data = { + '__type__': loader.identify_object(obj.__class__), + '__fields__': asdict(obj), + } + return dumper.represent_mapping(_DATACLASS_TAG, data) + + +def dataclass_constructor(loader: yaml.Loader, serialized: yaml.Node) -> Any: + """Construct a dataclass from the serialized representation.""" + deserialized = loader.construct_mapping(serialized, deep=True) # type: ignore[arg-type] + identifier = deserialized['__type__'] + cls = get_object_loader().load_object(identifier) + data = deserialized['__fields__'] + return cls(**data) + + def represent_node(dumper: yaml.Dumper, node: orm.Node) -> yaml.ScalarNode: """Represent a node in yaml.""" if not node.is_stored: @@ -136,6 +158,8 @@ def represent_data(self, data): return represent_computer(self, data) if isinstance(data, orm.Group): return represent_group(self, data) + if is_dataclass(data) and not inspect.isclass(data): + return represent_dataclass(self, data) return super().represent_data(data) @@ -163,6 +187,7 @@ class AiiDALoader(yaml.Loader): yaml.add_constructor(_GROUP_TAG, group_constructor, Loader=AiiDALoader) yaml.add_constructor(_COMPUTER_TAG, computer_constructor, Loader=AiiDALoader) yaml.add_constructor(_ENUM_TAG, enum_constructor, Loader=AiiDALoader) +yaml.add_constructor(_DATACLASS_TAG, dataclass_constructor, Loader=AiiDALoader) @overload diff --git a/tests/orm/utils/test_serialize.py b/tests/orm/utils/test_serialize.py index 557154168d..a419dcd6af 100644 --- a/tests/orm/utils/test_serialize.py +++ b/tests/orm/utils/test_serialize.py @@ -8,6 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for the :mod:`aiida.orm.utils.serialize` module.""" +from dataclasses import dataclass import types import uuid @@ -162,3 +163,19 @@ def test_enum(): deserialized = serialize.deserialize_unsafe(serialized) assert deserialized == enum + + +@dataclass +class DataClass: + """A dataclass for testing.""" + my_value: int + + +def test_dataclass(): + """Test serialization and deserialization of a ``dataclass``.""" + obj = DataClass(1) + serialized = serialize.serialize(obj) + assert isinstance(serialized, str) + + deserialized = serialize.deserialize_unsafe(serialized) + assert deserialized == DataClass(1) From f236f54afe866ecc2540da02ef5a8ae290340b3f Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Wed, 14 Dec 2022 00:24:00 +0100 Subject: [PATCH 2/2] Update tests/orm/utils/test_serialize.py Co-authored-by: Sebastiaan Huber --- tests/orm/utils/test_serialize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/orm/utils/test_serialize.py b/tests/orm/utils/test_serialize.py index a419dcd6af..74802a5b1b 100644 --- a/tests/orm/utils/test_serialize.py +++ b/tests/orm/utils/test_serialize.py @@ -178,4 +178,4 @@ def test_dataclass(): assert isinstance(serialized, str) deserialized = serialize.deserialize_unsafe(serialized) - assert deserialized == DataClass(1) + assert deserialized == obj