Skip to content

Commit

Permalink
Merge pull request #301 from gyorilab/test-tm-json
Browse files Browse the repository at this point in the history
Fix TemplateModel json I/O
  • Loading branch information
bgyori authored Feb 29, 2024
2 parents 44e1612 + f660fa0 commit b186555
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 1 deletion.
8 changes: 7 additions & 1 deletion mira/metamodel/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,13 @@ def from_json(cls, data, rate_symbols=None) -> "Template":
# Handle concepts
for concept_key in stmt_cls.concept_keys:
if concept_key in data:
data[concept_key] = Concept.from_json(data[concept_key])
concept_data = data[concept_key]
# Handle lists of concepts for e.g. controllers in
# GroupedControlledConversion
if isinstance(concept_data, list):
data[concept_key] = [Concept.from_json(c) for c in concept_data]
else:
data[concept_key] = Concept.from_json(data[concept_key])

return stmt_cls(**{k: v for k, v in data.items()
if k not in {'rate_law', 'type'}},
Expand Down
54 changes: 54 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Test for metamodel io operations."""
import tempfile

from mira.metamodel import *

# List all templates classes
template_cls_list = Template.__subclasses__()

# Create Concepts for testing
controller1 = Concept(name="controller1")
controller2 = Concept(name="controller2")
subject = Concept(name="subject1")
outcome = Concept(name="outcome1")


def _check_roundtrip(tm: TemplateModel):
# Test json serialization and deserialization

# Get a temporary file
with tempfile.NamedTemporaryFile(suffix=".json") as temp_file:
# Write the model to the file
model_to_json_file(tm, temp_file.name)
# Read the model from the file
tm2 = model_from_json_file(temp_file.name)
# Check that the models are the same
for t1, t2 in zip(tm.templates, tm2.templates):
assert t1.is_equal_to(t2)


def test_templates():
failed = []
for templ_cls in template_cls_list:
# Create the template
template = templ_cls(
# As long as the BaseModel is allowed to accept unused
# arguments, we can pass all the arguments to the constructor.
# The unused arguments will be ignored.
controller=controller1,
subject=subject,
outcome=outcome,
controllers=[controller1, controller2]
)
# Create the template model
tm = TemplateModel(templates=[template])
# Check the roundtrip
try:
_check_roundtrip(tm)
except Exception as e:
failed.append((templ_cls, str(e)))
if failed:
print(f"{len(failed)} roundtrips failed")
for f in failed:
print(f)
raise AssertionError(f"{len(failed)} roundtrips failed")

0 comments on commit b186555

Please sign in to comment.