Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TemplateModel json I/O #301

Merged
merged 3 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion mira/metamodel/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,13 @@ def from_json(cls, data, rate_symbols=None) -> "Template":
# Handle concepts
for concept_key in stmt_cls.concept_keys:
if concept_key in data:
data[concept_key] = Concept.from_json(data[concept_key])
concept_data = data[concept_key]
# Handle lists of concepts for e.g. controllers in
# GroupedControlledConversion
if isinstance(concept_data, list):
data[concept_key] = [Concept.from_json(c) for c in concept_data]
else:
data[concept_key] = Concept.from_json(data[concept_key])

return stmt_cls(**{k: v for k, v in data.items()
if k not in {'rate_law', 'type'}},
Expand Down
54 changes: 54 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Test for metamodel io operations."""
import tempfile

from mira.metamodel import *

# List all templates classes
template_cls_list = Template.__subclasses__()

# Create Concepts for testing
controller1 = Concept(name="controller1")
controller2 = Concept(name="controller2")
subject = Concept(name="subject1")
outcome = Concept(name="outcome1")


def _check_roundtrip(tm: TemplateModel):
# Test json serialization and deserialization

# Get a temporary file
with tempfile.NamedTemporaryFile(suffix=".json") as temp_file:
# Write the model to the file
model_to_json_file(tm, temp_file.name)
# Read the model from the file
tm2 = model_from_json_file(temp_file.name)
# Check that the models are the same
for t1, t2 in zip(tm.templates, tm2.templates):
assert t1.is_equal_to(t2)


def test_templates():
failed = []
for templ_cls in template_cls_list:
# Create the template
template = templ_cls(
# As long as the BaseModel is allowed to accept unused
# arguments, we can pass all the arguments to the constructor.
# The unused arguments will be ignored.
controller=controller1,
subject=subject,
outcome=outcome,
controllers=[controller1, controller2]
)
# Create the template model
tm = TemplateModel(templates=[template])
# Check the roundtrip
try:
_check_roundtrip(tm)
except Exception as e:
failed.append((templ_cls, str(e)))
if failed:
print(f"{len(failed)} roundtrips failed")
for f in failed:
print(f)
raise AssertionError(f"{len(failed)} roundtrips failed")
Loading