Skip to content

Commit

Permalink
SBML import: Allow hardcoding of numerical values (#2134)
Browse files Browse the repository at this point in the history
Allows selecting parameters whose values are to be hard-coded (#1192).

So far, restricted to parameters that aren't targets of rule or initial assignments. 
This can be extended later: lifting those restrictions on parameters, allow hard-coding Species with constant=True, ...
  • Loading branch information
dweindl authored Jul 4, 2023
1 parent b8edbcd commit 752b0e5
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 12 deletions.
76 changes: 65 additions & 11 deletions python/sdist/amici/sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,18 @@
import warnings
import xml.etree.ElementTree as ET
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
)

import libsbml as sbml
import numpy as np
Expand Down Expand Up @@ -281,6 +292,7 @@ def sbml2amici(
cache_simplify: bool = False,
log_as_log10: bool = True,
generate_sensitivity_code: bool = True,
hardcode_symbols: Sequence[str] = None,
) -> None:
"""
Generate and compile AMICI C++ files for the model provided to the
Expand Down Expand Up @@ -385,6 +397,12 @@ def sbml2amici(
:param generate_sensitivity_code:
If ``False``, the code required for sensitivity computation will
not be generated
:param hardcode_symbols:
List of SBML entitiy IDs that are to be hardcoded in the generated model.
Their values cannot be changed anymore after model import.
Currently only parameters that are not targets of rules or
initial assignments are supported.
"""
set_log_level(logger, verbose)

Expand All @@ -401,6 +419,7 @@ def sbml2amici(
simplify=simplify,
cache_simplify=cache_simplify,
log_as_log10=log_as_log10,
hardcode_symbols=hardcode_symbols,
)

exporter = DEExporter(
Expand Down Expand Up @@ -437,13 +456,21 @@ def _build_ode_model(
simplify: Optional[Callable] = _default_simplify,
cache_simplify: bool = False,
log_as_log10: bool = True,
hardcode_symbols: Sequence[str] = None,
) -> DEModel:
"""Generate an ODEModel from this SBML model.
See :py:func:`sbml2amici` for parameters.
"""
constant_parameters = list(constant_parameters) if constant_parameters else []

hardcode_symbols = set(hardcode_symbols) if hardcode_symbols else {}
if invalid := (set(constant_parameters) & set(hardcode_symbols)):
raise ValueError(
"The following parameters were selected as both constant "
f"and hard-coded which is not allowed: {invalid}"
)

if sigmas is None:
sigmas = {}

Expand All @@ -460,7 +487,9 @@ def _build_ode_model(
self.sbml_parser_settings.setParseLog(
sbml.L3P_PARSE_LOG_AS_LOG10 if log_as_log10 else sbml.L3P_PARSE_LOG_AS_LN
)
self._process_sbml(constant_parameters)
self._process_sbml(
constant_parameters=constant_parameters, hardcode_symbols=hardcode_symbols
)

if (
self.symbols.get(SymbolId.EVENT, False)
Expand Down Expand Up @@ -496,18 +525,26 @@ def _build_ode_model(
return ode_model

@log_execution_time("importing SBML", logger)
def _process_sbml(self, constant_parameters: List[str] = None) -> None:
def _process_sbml(
self,
constant_parameters: List[str] = None,
hardcode_symbols: Sequence[str] = None,
) -> None:
"""
Read parameters, species, reactions, and so on from SBML model
:param constant_parameters:
SBML Ids identifying constant parameters
:param hardcode_parameters:
Parameter IDs to be replaced by their values in the generated model.
"""
if not self._discard_annotations:
self._process_annotations()
self.check_support()
self._gather_locals()
self._process_parameters(constant_parameters)
self._gather_locals(hardcode_symbols=hardcode_symbols)
self._process_parameters(
constant_parameters=constant_parameters, hardcode_symbols=hardcode_symbols
)
self._process_compartments()
self._process_species()
self._process_reactions()
Expand Down Expand Up @@ -639,18 +676,18 @@ def check_event_support(self) -> None:
)

@log_execution_time("gathering local SBML symbols", logger)
def _gather_locals(self) -> None:
def _gather_locals(self, hardcode_symbols: Sequence[str] = None) -> None:
"""
Populate self.local_symbols with all model entities.
This is later used during sympifications to avoid sympy builtins
shadowing model entities as well as to avoid possibly costly
symbolic substitutions
"""
self._gather_base_locals()
self._gather_base_locals(hardcode_symbols=hardcode_symbols)
self._gather_dependent_locals()

def _gather_base_locals(self):
def _gather_base_locals(self, hardcode_symbols: Sequence[str] = None) -> None:
"""
Populate self.local_symbols with pure symbol definitions that do not
depend on any other symbol.
Expand All @@ -677,8 +714,20 @@ def _gather_base_locals(self):
):
if not c.isSetId():
continue

self.add_local_symbol(c.getId(), _get_identifier_symbol(c))
if c.getId() in hardcode_symbols:
if c.getConstant() is not True:
# disallow anything that can be changed by rules/reaction/events
raise ValueError(
f"Cannot hardcode non-constant symbol `{c.getId()}`."
)
if self.sbml.getInitialAssignment(c.getId()):
raise NotImplementedError(
f"Cannot hardcode symbol `{c.getId()}` "
"that is an initial assignment target."
)
self.add_local_symbol(c.getId(), sp.Float(c.getValue()))
else:
self.add_local_symbol(c.getId(), _get_identifier_symbol(c))

for x_ref in _get_list_of_species_references(self.sbml):
if not x_ref.isSetId():
Expand Down Expand Up @@ -940,7 +989,11 @@ def _process_annotations(self) -> None:
self.sbml.removeParameter(parameter_id)

@log_execution_time("processing SBML parameters", logger)
def _process_parameters(self, constant_parameters: List[str] = None) -> None:
def _process_parameters(
self,
constant_parameters: List[str] = None,
hardcode_symbols: Sequence[str] = None,
) -> None:
"""
Get parameter information from SBML model.
Expand Down Expand Up @@ -983,6 +1036,7 @@ def _process_parameters(self, constant_parameters: List[str] = None) -> None:
if parameter.getId() not in constant_parameters
and self._get_element_initial_assignment(parameter.getId()) is None
and not self.is_assignment_rule_target(parameter)
and parameter.getId() not in hardcode_symbols
]

loop_settings = {
Expand Down
33 changes: 32 additions & 1 deletion python/tests/test_sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def simple_sbml_model():
model.addSpecies(s1)
p1 = model.createParameter()
p1.setId("p1")
p1.setValue(0.0)
p1.setValue(2.0)
model.addParameter(p1)

return document, model
Expand Down Expand Up @@ -662,3 +662,34 @@ def test_code_gen_uses_lhs_symbol_ids():
)
dwdx = Path(tmpdir, "dwdx.cpp").read_text()
assert "dobservable_x1_dx1 = " in dwdx


def test_hardcode_parameters(simple_sbml_model):
"""Test model generation works for model without observables"""
sbml_doc, sbml_model = simple_sbml_model
sbml_importer = SbmlImporter(sbml_source=sbml_model, from_file=False)
r = sbml_model.createRateRule()
r.setVariable("S1")
r.setFormula("p1")
assert sbml_model.getParameter("p1").getValue() != 0

ode_model = sbml_importer._build_ode_model()
assert str(ode_model.parameters()) == "[p1]"
assert ode_model.differential_states()[0].get_dt().name == "p1"

ode_model = sbml_importer._build_ode_model(
constant_parameters=[],
hardcode_symbols=["p1"],
)
assert str(ode_model.parameters()) == "[]"
assert (
ode_model.differential_states()[0].get_dt()
== sbml_model.getParameter("p1").getValue()
)

with pytest.raises(ValueError):
sbml_importer._build_ode_model(
# mutually exclusive
constant_parameters=["p1"],
hardcode_symbols=["p1"],
)

0 comments on commit 752b0e5

Please sign in to comment.