Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/develop' into status2str
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl committed Sep 10, 2022
2 parents 69a0640 + 81f8b43 commit 39e4127
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 14 deletions.
7 changes: 4 additions & 3 deletions python/amici/cxxcodeprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _get_sym_lines_symbols(

def format_regular_line(symbol, math, index):
return (
f'{indent}{symbol} = {self.doprint(math)};'
f'{indent}{self.doprint(symbol)} = {self.doprint(math)};'
f' // {variable}[{index}]'.replace('\n', '\n' + indent)
)

Expand Down Expand Up @@ -136,10 +136,11 @@ def format_regular_line(symbol, math, index):
})
symbol_to_idx = {sym: idx for idx, sym in enumerate(symbols)}

def format_line(symbol):
def format_line(symbol: sp.Symbol):
math = expr_dict[symbol]
if str(symbol).startswith(cse_sym_prefix):
return f'{indent}const realtype {symbol} '\
return f'{indent}const realtype ' \
f'{self.doprint(symbol)} ' \
f'= {self.doprint(math)};'
elif math not in [0, 0.0]:
return format_regular_line(
Expand Down
9 changes: 9 additions & 0 deletions python/amici/ode_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2978,6 +2978,15 @@ def _get_function_body(
iterator = 'iy'
lines.extend(get_switch_statement(iterator, cases, 1))

elif function in self.model.sym_names() \
and function not in non_unique_id_symbols:
if function in sparse_functions:
symbols = self.model.sparsesym(function)
else:
symbols = self.model.sym(function)
lines += self.model._code_printer._get_sym_lines_symbols(
symbols, equations, function, 4)

else:
lines += self.model._code_printer._get_sym_lines_array(
equations, function, 4)
Expand Down
56 changes: 45 additions & 11 deletions python/tests/test_sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
from amici.testing import TemporaryDirectoryWinSafe as TemporaryDirectory, \
skip_on_valgrind

EXAMPLES_DIR = Path(__file__).parent / '..' / 'examples'
STEADYSTATE_MODEL_FILE = (EXAMPLES_DIR / 'example_steadystate'
/ 'model_steadystate_scaled.xml')


@pytest.fixture
def simple_sbml_model():
Expand Down Expand Up @@ -163,9 +167,7 @@ def test_sbml2amici_observable_dependent_error(observable_dependent_error_model)

@pytest.fixture(scope='session')
def model_steadystate_module():
sbml_file = os.path.join(os.path.dirname(__file__), '..',
'examples', 'example_steadystate',
'model_steadystate_scaled.xml')
sbml_file = STEADYSTATE_MODEL_FILE
sbml_importer = amici.SbmlImporter(sbml_file)

observables = amici.assignmentRules2observables(
Expand All @@ -192,8 +194,7 @@ def model_steadystate_module():

@pytest.fixture(scope='session')
def model_units_module():
sbml_file = Path(__file__).parent / '..' / 'examples' \
/ 'example_units' / 'model_units.xml'
sbml_file = EXAMPLES_DIR / 'example_units' / 'model_units.xml'
module_name = 'test_model_units'

sbml_importer = amici.SbmlImporter(sbml_file)
Expand Down Expand Up @@ -339,9 +340,7 @@ def test_solver_reuse(model_steadystate_module):
def model_test_likelihoods():
"""Test model for various likelihood functions."""
# load sbml model
sbml_file = os.path.join(os.path.dirname(__file__), '..',
'examples', 'example_steadystate',
'model_steadystate_scaled.xml')
sbml_file = STEADYSTATE_MODEL_FILE
sbml_importer = amici.SbmlImporter(sbml_file)

# define observables
Expand Down Expand Up @@ -438,9 +437,7 @@ def test_likelihoods(model_test_likelihoods):
@skip_on_valgrind
def test_likelihoods_error():
"""Test whether wrong inputs lead to expected errors."""
sbml_file = os.path.join(os.path.dirname(__file__), '..',
'examples', 'example_steadystate',
'model_steadystate_scaled.xml')
sbml_file = STEADYSTATE_MODEL_FILE
sbml_importer = amici.SbmlImporter(sbml_file)

# define observables
Expand Down Expand Up @@ -571,3 +568,40 @@ def _test_set_parameters_by_dict(model_module):
assert model.getParameterByName(change_par_name) == new_par_val
model.setParameterByName(change_par_name, old_par_val)
assert model.getParameters() == old_parameter_values


@pytest.mark.parametrize("extract_cse", [True, False])
def test_code_gen_uses_cse(extract_cse):
"""Check that code generation honors AMICI_EXTRACT_CSE"""
old_environ = os.environ.copy()
try:
os.environ["AMICI_EXTRACT_CSE"] = str(extract_cse)
sbml_importer = amici.SbmlImporter(STEADYSTATE_MODEL_FILE)
model_name = "test_code_gen_uses_cse"
with TemporaryDirectory() as tmpdir:
sbml_importer.sbml2amici(
model_name=model_name,
compile=False,
generate_sensitivity_code=False,
output_dir = tmpdir
)
xdot = Path(tmpdir, f'{model_name}_xdot.cpp').read_text()
assert ("__amici_cse_0 = " in xdot) == extract_cse
finally:
os.environ = old_environ


def test_code_gen_uses_lhs_symbol_ids():
"""Check that code generation uses symbol IDs instead of plain array
indices"""
sbml_importer = amici.SbmlImporter(STEADYSTATE_MODEL_FILE)
model_name = "test_code_gen_uses_lhs_symbol_ids"
with TemporaryDirectory() as tmpdir:
sbml_importer.sbml2amici(
model_name=model_name,
compile=False,
generate_sensitivity_code=False,
output_dir=tmpdir
)
dwdx = Path(tmpdir, f'{model_name}_dwdx.cpp').read_text()
assert "dobservable_x1_dx1 = " in dwdx

0 comments on commit 39e4127

Please sign in to comment.