From 829512e67a19cc24343d70172293c2125c9df5f2 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Tue, 27 Feb 2024 13:50:54 +0100 Subject: [PATCH 1/3] Refactor de_export.py, extract C++ function info to cxx_functions.py Move everything related to information on C++ model functions to a separate module. Related to #2306. No changes in functionality. --- python/sdist/amici/_codegen/cxx_functions.py | 385 ++++++++++++++++++ python/sdist/amici/de_export.py | 388 +------------------ 2 files changed, 399 insertions(+), 374 deletions(-) create mode 100644 python/sdist/amici/_codegen/cxx_functions.py diff --git a/python/sdist/amici/_codegen/cxx_functions.py b/python/sdist/amici/_codegen/cxx_functions.py new file mode 100644 index 0000000000..25a8af3c2c --- /dev/null +++ b/python/sdist/amici/_codegen/cxx_functions.py @@ -0,0 +1,385 @@ +"""Info about C++ functions in the generated model code.""" +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass +class _FunctionInfo: + """Information on a model-specific generated C++ function + + :ivar ode_arguments: argument list of the ODE function. + input variables should be ``const``. + :ivar dae_arguments: argument list of the DAE function, if different from + ODE function. input variables should be ``const``. + :ivar return_type: the return type of the function + :ivar assume_pow_positivity: + identifies the functions on which ``assume_pow_positivity`` will have + an effect when specified during model generation. generally these are + functions that are used for solving the ODE, where negative values may + negatively affect convergence of the integration algorithm + :ivar sparse: + specifies whether the result of this function will be stored in sparse + format. sparse format means that the function will only return an + array of nonzero values and not a full matrix. + :ivar generate_body: + indicates whether a model-specific implementation is to be generated + :ivar body: + the actual function body. will be filled later + """ + + ode_arguments: str = "" + dae_arguments: str = "" + return_type: str = "void" + assume_pow_positivity: bool = False + sparse: bool = False + generate_body: bool = True + body: str = "" + + def arguments(self, ode: bool = True) -> str: + """Get the arguments for the ODE or DAE function""" + if ode or not self.dae_arguments: + return self.ode_arguments + return self.dae_arguments + + +# Information on a model-specific generated C++ function +# prototype for generated C++ functions, keys are the names of functions +functions = { + "Jy": _FunctionInfo( + "realtype *Jy, const int iy, const realtype *p, " + "const realtype *k, const realtype *y, const realtype *sigmay, " + "const realtype *my" + ), + "dJydsigma": _FunctionInfo( + "realtype *dJydsigma, const int iy, const realtype *p, " + "const realtype *k, const realtype *y, const realtype *sigmay, " + "const realtype *my" + ), + "dJydy": _FunctionInfo( + "realtype *dJydy, const int iy, const realtype *p, " + "const realtype *k, const realtype *y, " + "const realtype *sigmay, const realtype *my", + sparse=True, + ), + "Jz": _FunctionInfo( + "realtype *Jz, const int iz, const realtype *p, const realtype *k, " + "const realtype *z, const realtype *sigmaz, const realtype *mz" + ), + "dJzdsigma": _FunctionInfo( + "realtype *dJzdsigma, const int iz, const realtype *p, " + "const realtype *k, const realtype *z, const realtype *sigmaz, " + "const realtype *mz" + ), + "dJzdz": _FunctionInfo( + "realtype *dJzdz, const int iz, const realtype *p, " + "const realtype *k, const realtype *z, const realtype *sigmaz, " + "const double *mz", + ), + "Jrz": _FunctionInfo( + "realtype *Jrz, const int iz, const realtype *p, " + "const realtype *k, const realtype *rz, const realtype *sigmaz" + ), + "dJrzdsigma": _FunctionInfo( + "realtype *dJrzdsigma, const int iz, const realtype *p, " + "const realtype *k, const realtype *rz, const realtype *sigmaz" + ), + "dJrzdz": _FunctionInfo( + "realtype *dJrzdz, const int iz, const realtype *p, " + "const realtype *k, const realtype *rz, const realtype *sigmaz", + ), + "root": _FunctionInfo( + "realtype *root, const realtype t, const realtype *x, " + "const realtype *p, const realtype *k, const realtype *h, " + "const realtype *tcl" + ), + "dwdp": _FunctionInfo( + "realtype *dwdp, const realtype t, const realtype *x, " + "const realtype *p, const realtype *k, const realtype *h, " + "const realtype *w, const realtype *tcl, const realtype *dtcldp, " + "const realtype *spl, const realtype *sspl, bool include_static", + assume_pow_positivity=True, + sparse=True, + ), + "dwdx": _FunctionInfo( + "realtype *dwdx, const realtype t, const realtype *x, " + "const realtype *p, const realtype *k, const realtype *h, " + "const realtype *w, const realtype *tcl, const realtype *spl, " + "bool include_static", + assume_pow_positivity=True, + sparse=True, + ), + "create_splines": _FunctionInfo( + "const realtype *p, const realtype *k", + return_type="std::vector", + ), + "spl": _FunctionInfo(generate_body=False), + "sspl": _FunctionInfo(generate_body=False), + "spline_values": _FunctionInfo( + "const realtype *p, const realtype *k", generate_body=False + ), + "spline_slopes": _FunctionInfo( + "const realtype *p, const realtype *k", generate_body=False + ), + "dspline_valuesdp": _FunctionInfo( + "realtype *dspline_valuesdp, const realtype *p, const realtype *k, " + "const int ip" + ), + "dspline_slopesdp": _FunctionInfo( + "realtype *dspline_slopesdp, const realtype *p, const realtype *k, " + "const int ip" + ), + "dwdw": _FunctionInfo( + "realtype *dwdw, const realtype t, const realtype *x, " + "const realtype *p, const realtype *k, const realtype *h, " + "const realtype *w, const realtype *tcl, bool include_static", + assume_pow_positivity=True, + sparse=True, + ), + "dxdotdw": _FunctionInfo( + "realtype *dxdotdw, const realtype t, const realtype *x, " + "const realtype *p, const realtype *k, const realtype *h, " + "const realtype *w", + "realtype *dxdotdw, const realtype t, const realtype *x, " + "const realtype *p, const realtype *k, const realtype *h, " + "const realtype *dx, const realtype *w", + assume_pow_positivity=True, + sparse=True, + ), + "dxdotdx_explicit": _FunctionInfo( + "realtype *dxdotdx_explicit, const realtype t, " + "const realtype *x, const realtype *p, const realtype *k, " + "const realtype *h, const realtype *w", + "realtype *dxdotdx_explicit, const realtype t, " + "const realtype *x, const realtype *p, const realtype *k, " + "const realtype *h, const realtype *dx, const realtype *w", + assume_pow_positivity=True, + sparse=True, + ), + "dxdotdp_explicit": _FunctionInfo( + "realtype *dxdotdp_explicit, const realtype t, " + "const realtype *x, const realtype *p, const realtype *k, " + "const realtype *h, const realtype *w", + "realtype *dxdotdp_explicit, const realtype t, " + "const realtype *x, const realtype *p, const realtype *k, " + "const realtype *h, const realtype *dx, const realtype *w", + assume_pow_positivity=True, + sparse=True, + ), + "dydx": _FunctionInfo( + "realtype *dydx, const realtype t, const realtype *x, " + "const realtype *p, const realtype *k, const realtype *h, " + "const realtype *w, const realtype *dwdx", + ), + "dydp": _FunctionInfo( + "realtype *dydp, const realtype t, const realtype *x, " + "const realtype *p, const realtype *k, const realtype *h, " + "const int ip, const realtype *w, const realtype *tcl, " + "const realtype *dtcldp, const realtype *spl, const realtype *sspl" + ), + "dzdx": _FunctionInfo( + "realtype *dzdx, const int ie, const realtype t, " + "const realtype *x, const realtype *p, const realtype *k, " + "const realtype *h", + ), + "dzdp": _FunctionInfo( + "realtype *dzdp, const int ie, const realtype t, " + "const realtype *x, const realtype *p, const realtype *k, " + "const realtype *h, const int ip", + ), + "drzdx": _FunctionInfo( + "realtype *drzdx, const int ie, const realtype t, " + "const realtype *x, const realtype *p, const realtype *k, " + "const realtype *h", + ), + "drzdp": _FunctionInfo( + "realtype *drzdp, const int ie, const realtype t, " + "const realtype *x, const realtype *p, const realtype *k, " + "const realtype *h, const int ip", + ), + "dsigmaydy": _FunctionInfo( + "realtype *dsigmaydy, const realtype t, const realtype *p, " + "const realtype *k, const realtype *y" + ), + "dsigmaydp": _FunctionInfo( + "realtype *dsigmaydp, const realtype t, const realtype *p, " + "const realtype *k, const realtype *y, const int ip", + ), + "sigmay": _FunctionInfo( + "realtype *sigmay, const realtype t, const realtype *p, " + "const realtype *k, const realtype *y", + ), + "dsigmazdp": _FunctionInfo( + "realtype *dsigmazdp, const realtype t, const realtype *p," + " const realtype *k, const int ip", + ), + "sigmaz": _FunctionInfo( + "realtype *sigmaz, const realtype t, const realtype *p, " + "const realtype *k", + ), + "sroot": _FunctionInfo( + "realtype *stau, const realtype t, const realtype *x, " + "const realtype *p, const realtype *k, const realtype *h, " + "const realtype *sx, const int ip, const int ie, " + "const realtype *tcl", + generate_body=False, + ), + "drootdt": _FunctionInfo(generate_body=False), + "drootdt_total": _FunctionInfo(generate_body=False), + "drootdp": _FunctionInfo(generate_body=False), + "drootdx": _FunctionInfo(generate_body=False), + "stau": _FunctionInfo( + "realtype *stau, const realtype t, const realtype *x, " + "const realtype *p, const realtype *k, const realtype *h, " + "const realtype *tcl, const realtype *sx, const int ip, " + "const int ie" + ), + "deltax": _FunctionInfo( + "double *deltax, const realtype t, const realtype *x, " + "const realtype *p, const realtype *k, const realtype *h, " + "const int ie, const realtype *xdot, const realtype *xdot_old" + ), + "ddeltaxdx": _FunctionInfo(generate_body=False), + "ddeltaxdt": _FunctionInfo(generate_body=False), + "ddeltaxdp": _FunctionInfo(generate_body=False), + "deltasx": _FunctionInfo( + "realtype *deltasx, const realtype t, const realtype *x, " + "const realtype *p, const realtype *k, const realtype *h, " + "const realtype *w, const int ip, const int ie, " + "const realtype *xdot, const realtype *xdot_old, " + "const realtype *sx, const realtype *stau, const realtype *tcl" + ), + "w": _FunctionInfo( + "realtype *w, const realtype t, const realtype *x, " + "const realtype *p, const realtype *k, " + "const realtype *h, const realtype *tcl, const realtype *spl, " + "bool include_static", + assume_pow_positivity=True, + ), + "x0": _FunctionInfo( + "realtype *x0, const realtype t, const realtype *p, " + "const realtype *k" + ), + "x0_fixedParameters": _FunctionInfo( + "realtype *x0_fixedParameters, const realtype t, " + "const realtype *p, const realtype *k, " + "gsl::span reinitialization_state_idxs", + ), + "sx0": _FunctionInfo( + "realtype *sx0, const realtype t, const realtype *x, " + "const realtype *p, const realtype *k, const int ip", + ), + "sx0_fixedParameters": _FunctionInfo( + "realtype *sx0_fixedParameters, const realtype t, " + "const realtype *x0, const realtype *p, const realtype *k, " + "const int ip, gsl::span reinitialization_state_idxs", + ), + "xdot": _FunctionInfo( + "realtype *xdot, const realtype t, const realtype *x, " + "const realtype *p, const realtype *k, const realtype *h, " + "const realtype *w", + "realtype *xdot, const realtype t, const realtype *x, " + "const realtype *p, const realtype *k, const realtype *h, " + "const realtype *dx, const realtype *w", + assume_pow_positivity=True, + ), + "xdot_old": _FunctionInfo(generate_body=False), + "y": _FunctionInfo( + "realtype *y, const realtype t, const realtype *x, " + "const realtype *p, const realtype *k, " + "const realtype *h, const realtype *w", + ), + "x_rdata": _FunctionInfo( + "realtype *x_rdata, const realtype *x, const realtype *tcl, " + "const realtype *p, const realtype *k" + ), + "total_cl": _FunctionInfo( + "realtype *total_cl, const realtype *x_rdata, " + "const realtype *p, const realtype *k" + ), + "dtotal_cldp": _FunctionInfo( + "realtype *dtotal_cldp, const realtype *x_rdata, " + "const realtype *p, const realtype *k, const int ip" + ), + "dtotal_cldx_rdata": _FunctionInfo( + "realtype *dtotal_cldx_rdata, const realtype *x_rdata, " + "const realtype *p, const realtype *k, const realtype *tcl", + sparse=True, + ), + "x_solver": _FunctionInfo("realtype *x_solver, const realtype *x_rdata"), + "dx_rdatadx_solver": _FunctionInfo( + "realtype *dx_rdatadx_solver, const realtype *x, " + "const realtype *tcl, const realtype *p, const realtype *k", + sparse=True, + ), + "dx_rdatadp": _FunctionInfo( + "realtype *dx_rdatadp, const realtype *x, " + "const realtype *tcl, const realtype *p, const realtype *k, " + "const int ip" + ), + "dx_rdatadtcl": _FunctionInfo( + "realtype *dx_rdatadtcl, const realtype *x, " + "const realtype *tcl, const realtype *p, const realtype *k", + sparse=True, + ), + "z": _FunctionInfo( + "realtype *z, const int ie, const realtype t, const realtype *x, " + "const realtype *p, const realtype *k, const realtype *h" + ), + "rz": _FunctionInfo( + "realtype *rz, const int ie, const realtype t, const realtype *x, " + "const realtype *p, const realtype *k, const realtype *h" + ), +} + +#: list of sparse functions +sparse_functions = [ + func_name for func_name, func_info in functions.items() if func_info.sparse +] + +#: list of nobody functions +nobody_functions = [ + func_name + for func_name, func_info in functions.items() + if not func_info.generate_body +] + +#: list of sensitivity functions +sensi_functions = [ + func_name + for func_name, func_info in functions.items() + if "const int ip" in func_info.arguments() +] + +#: list of sparse sensitivity functions +sparse_sensi_functions = [ + func_name + for func_name, func_info in functions.items() + if "const int ip" not in func_info.arguments() + and func_name.endswith("dp") + or func_name.endswith("dp_explicit") +] + +#: list of event functions +event_functions = [ + func_name + for func_name, func_info in functions.items() + if "const int ie" in func_info.arguments() + and "const int ip" not in func_info.arguments() +] + +#: list of event sensitivity functions +event_sensi_functions = [ + func_name + for func_name, func_info in functions.items() + if "const int ie" in func_info.arguments() + and "const int ip" in func_info.arguments() +] + +#: list of multiobs functions +multiobs_functions = [ + func_name + for func_name, func_info in functions.items() + if "const int iy" in func_info.arguments() + or "const int iz" in func_info.arguments() +] diff --git a/python/sdist/amici/de_export.py b/python/sdist/amici/de_export.py index b853dd6ad0..2131f9dd9e 100644 --- a/python/sdist/amici/de_export.py +++ b/python/sdist/amici/de_export.py @@ -17,7 +17,6 @@ import os import re import shutil -from dataclasses import dataclass from itertools import chain from pathlib import Path from typing import ( @@ -39,6 +38,17 @@ amiciSwigPath, splines, ) +from ._codegen.cxx_functions import ( + _FunctionInfo, + functions, + sparse_functions, + nobody_functions, + sensi_functions, + sparse_sensi_functions, + event_functions, + event_sensi_functions, + multiobs_functions, +) from ._codegen.template import apply_template from .cxxcodeprinter import ( AmiciCxxCodePrinter, @@ -86,380 +96,10 @@ DERIVATIVE_PATTERN = re.compile(r"^d(x_rdata|xdot|\w+?)d(\w+?)(?:_explicit)?$") -@dataclass -class _FunctionInfo: - """Information on a model-specific generated C++ function - - :ivar ode_arguments: argument list of the ODE function. input variables should be - ``const``. - :ivar dae_arguments: argument list of the DAE function, if different from ODE - function. input variables should be ``const``. - :ivar return_type: the return type of the function - :ivar assume_pow_positivity: - identifies the functions on which ``assume_pow_positivity`` will have - an effect when specified during model generation. generally these are - functions that are used for solving the ODE, where negative values may - negatively affect convergence of the integration algorithm - :ivar sparse: - specifies whether the result of this function will be stored in sparse - format. sparse format means that the function will only return an - array of nonzero values and not a full matrix. - :ivar generate_body: - indicates whether a model-specific implementation is to be generated - :ivar body: - the actual function body. will be filled later - """ - - ode_arguments: str = "" - dae_arguments: str = "" - return_type: str = "void" - assume_pow_positivity: bool = False - sparse: bool = False - generate_body: bool = True - body: str = "" - - def arguments(self, ode: bool = True) -> str: - """Get the arguments for the ODE or DAE function""" - if ode or not self.dae_arguments: - return self.ode_arguments - return self.dae_arguments - - -# Information on a model-specific generated C++ function -# prototype for generated C++ functions, keys are the names of functions -functions = { - "Jy": _FunctionInfo( - "realtype *Jy, const int iy, const realtype *p, " - "const realtype *k, const realtype *y, const realtype *sigmay, " - "const realtype *my" - ), - "dJydsigma": _FunctionInfo( - "realtype *dJydsigma, const int iy, const realtype *p, " - "const realtype *k, const realtype *y, const realtype *sigmay, " - "const realtype *my" - ), - "dJydy": _FunctionInfo( - "realtype *dJydy, const int iy, const realtype *p, " - "const realtype *k, const realtype *y, " - "const realtype *sigmay, const realtype *my", - sparse=True, - ), - "Jz": _FunctionInfo( - "realtype *Jz, const int iz, const realtype *p, const realtype *k, " - "const realtype *z, const realtype *sigmaz, const realtype *mz" - ), - "dJzdsigma": _FunctionInfo( - "realtype *dJzdsigma, const int iz, const realtype *p, " - "const realtype *k, const realtype *z, const realtype *sigmaz, " - "const realtype *mz" - ), - "dJzdz": _FunctionInfo( - "realtype *dJzdz, const int iz, const realtype *p, " - "const realtype *k, const realtype *z, const realtype *sigmaz, " - "const double *mz", - ), - "Jrz": _FunctionInfo( - "realtype *Jrz, const int iz, const realtype *p, " - "const realtype *k, const realtype *rz, const realtype *sigmaz" - ), - "dJrzdsigma": _FunctionInfo( - "realtype *dJrzdsigma, const int iz, const realtype *p, " - "const realtype *k, const realtype *rz, const realtype *sigmaz" - ), - "dJrzdz": _FunctionInfo( - "realtype *dJrzdz, const int iz, const realtype *p, " - "const realtype *k, const realtype *rz, const realtype *sigmaz", - ), - "root": _FunctionInfo( - "realtype *root, const realtype t, const realtype *x, " - "const realtype *p, const realtype *k, const realtype *h, " - "const realtype *tcl" - ), - "dwdp": _FunctionInfo( - "realtype *dwdp, const realtype t, const realtype *x, " - "const realtype *p, const realtype *k, const realtype *h, " - "const realtype *w, const realtype *tcl, const realtype *dtcldp, " - "const realtype *spl, const realtype *sspl, bool include_static", - assume_pow_positivity=True, - sparse=True, - ), - "dwdx": _FunctionInfo( - "realtype *dwdx, const realtype t, const realtype *x, " - "const realtype *p, const realtype *k, const realtype *h, " - "const realtype *w, const realtype *tcl, const realtype *spl, " - "bool include_static", - assume_pow_positivity=True, - sparse=True, - ), - "create_splines": _FunctionInfo( - "const realtype *p, const realtype *k", - return_type="std::vector", - ), - "spl": _FunctionInfo(generate_body=False), - "sspl": _FunctionInfo(generate_body=False), - "spline_values": _FunctionInfo( - "const realtype *p, const realtype *k", generate_body=False - ), - "spline_slopes": _FunctionInfo( - "const realtype *p, const realtype *k", generate_body=False - ), - "dspline_valuesdp": _FunctionInfo( - "realtype *dspline_valuesdp, const realtype *p, const realtype *k, const int ip" - ), - "dspline_slopesdp": _FunctionInfo( - "realtype *dspline_slopesdp, const realtype *p, const realtype *k, const int ip" - ), - "dwdw": _FunctionInfo( - "realtype *dwdw, const realtype t, const realtype *x, " - "const realtype *p, const realtype *k, const realtype *h, " - "const realtype *w, const realtype *tcl, bool include_static", - assume_pow_positivity=True, - sparse=True, - ), - "dxdotdw": _FunctionInfo( - "realtype *dxdotdw, const realtype t, const realtype *x, " - "const realtype *p, const realtype *k, const realtype *h, " - "const realtype *w", - "realtype *dxdotdw, const realtype t, const realtype *x, " - "const realtype *p, const realtype *k, const realtype *h, " - "const realtype *dx, const realtype *w", - assume_pow_positivity=True, - sparse=True, - ), - "dxdotdx_explicit": _FunctionInfo( - "realtype *dxdotdx_explicit, const realtype t, " - "const realtype *x, const realtype *p, const realtype *k, " - "const realtype *h, const realtype *w", - "realtype *dxdotdx_explicit, const realtype t, " - "const realtype *x, const realtype *p, const realtype *k, " - "const realtype *h, const realtype *dx, const realtype *w", - assume_pow_positivity=True, - sparse=True, - ), - "dxdotdp_explicit": _FunctionInfo( - "realtype *dxdotdp_explicit, const realtype t, " - "const realtype *x, const realtype *p, const realtype *k, " - "const realtype *h, const realtype *w", - "realtype *dxdotdp_explicit, const realtype t, " - "const realtype *x, const realtype *p, const realtype *k, " - "const realtype *h, const realtype *dx, const realtype *w", - assume_pow_positivity=True, - sparse=True, - ), - "dydx": _FunctionInfo( - "realtype *dydx, const realtype t, const realtype *x, " - "const realtype *p, const realtype *k, const realtype *h, " - "const realtype *w, const realtype *dwdx", - ), - "dydp": _FunctionInfo( - "realtype *dydp, const realtype t, const realtype *x, " - "const realtype *p, const realtype *k, const realtype *h, " - "const int ip, const realtype *w, const realtype *tcl, " - "const realtype *dtcldp, const realtype *spl, const realtype *sspl" - ), - "dzdx": _FunctionInfo( - "realtype *dzdx, const int ie, const realtype t, " - "const realtype *x, const realtype *p, const realtype *k, " - "const realtype *h", - ), - "dzdp": _FunctionInfo( - "realtype *dzdp, const int ie, const realtype t, " - "const realtype *x, const realtype *p, const realtype *k, " - "const realtype *h, const int ip", - ), - "drzdx": _FunctionInfo( - "realtype *drzdx, const int ie, const realtype t, " - "const realtype *x, const realtype *p, const realtype *k, " - "const realtype *h", - ), - "drzdp": _FunctionInfo( - "realtype *drzdp, const int ie, const realtype t, " - "const realtype *x, const realtype *p, const realtype *k, " - "const realtype *h, const int ip", - ), - "dsigmaydy": _FunctionInfo( - "realtype *dsigmaydy, const realtype t, const realtype *p, " - "const realtype *k, const realtype *y" - ), - "dsigmaydp": _FunctionInfo( - "realtype *dsigmaydp, const realtype t, const realtype *p, " - "const realtype *k, const realtype *y, const int ip", - ), - "sigmay": _FunctionInfo( - "realtype *sigmay, const realtype t, const realtype *p, " - "const realtype *k, const realtype *y", - ), - "dsigmazdp": _FunctionInfo( - "realtype *dsigmazdp, const realtype t, const realtype *p," - " const realtype *k, const int ip", - ), - "sigmaz": _FunctionInfo( - "realtype *sigmaz, const realtype t, const realtype *p, " - "const realtype *k", - ), - "sroot": _FunctionInfo( - "realtype *stau, const realtype t, const realtype *x, " - "const realtype *p, const realtype *k, const realtype *h, " - "const realtype *sx, const int ip, const int ie, " - "const realtype *tcl", - generate_body=False, - ), - "drootdt": _FunctionInfo(generate_body=False), - "drootdt_total": _FunctionInfo(generate_body=False), - "drootdp": _FunctionInfo(generate_body=False), - "drootdx": _FunctionInfo(generate_body=False), - "stau": _FunctionInfo( - "realtype *stau, const realtype t, const realtype *x, " - "const realtype *p, const realtype *k, const realtype *h, " - "const realtype *tcl, const realtype *sx, const int ip, " - "const int ie" - ), - "deltax": _FunctionInfo( - "double *deltax, const realtype t, const realtype *x, " - "const realtype *p, const realtype *k, const realtype *h, " - "const int ie, const realtype *xdot, const realtype *xdot_old" - ), - "ddeltaxdx": _FunctionInfo(generate_body=False), - "ddeltaxdt": _FunctionInfo(generate_body=False), - "ddeltaxdp": _FunctionInfo(generate_body=False), - "deltasx": _FunctionInfo( - "realtype *deltasx, const realtype t, const realtype *x, " - "const realtype *p, const realtype *k, const realtype *h, " - "const realtype *w, const int ip, const int ie, " - "const realtype *xdot, const realtype *xdot_old, " - "const realtype *sx, const realtype *stau, const realtype *tcl" - ), - "w": _FunctionInfo( - "realtype *w, const realtype t, const realtype *x, " - "const realtype *p, const realtype *k, " - "const realtype *h, const realtype *tcl, const realtype *spl, " - "bool include_static", - assume_pow_positivity=True, - ), - "x0": _FunctionInfo( - "realtype *x0, const realtype t, const realtype *p, " - "const realtype *k" - ), - "x0_fixedParameters": _FunctionInfo( - "realtype *x0_fixedParameters, const realtype t, " - "const realtype *p, const realtype *k, " - "gsl::span reinitialization_state_idxs", - ), - "sx0": _FunctionInfo( - "realtype *sx0, const realtype t, const realtype *x, " - "const realtype *p, const realtype *k, const int ip", - ), - "sx0_fixedParameters": _FunctionInfo( - "realtype *sx0_fixedParameters, const realtype t, " - "const realtype *x0, const realtype *p, const realtype *k, " - "const int ip, gsl::span reinitialization_state_idxs", - ), - "xdot": _FunctionInfo( - "realtype *xdot, const realtype t, const realtype *x, " - "const realtype *p, const realtype *k, const realtype *h, " - "const realtype *w", - "realtype *xdot, const realtype t, const realtype *x, " - "const realtype *p, const realtype *k, const realtype *h, " - "const realtype *dx, const realtype *w", - assume_pow_positivity=True, - ), - "xdot_old": _FunctionInfo(generate_body=False), - "y": _FunctionInfo( - "realtype *y, const realtype t, const realtype *x, " - "const realtype *p, const realtype *k, " - "const realtype *h, const realtype *w", - ), - "x_rdata": _FunctionInfo( - "realtype *x_rdata, const realtype *x, const realtype *tcl, " - "const realtype *p, const realtype *k" - ), - "total_cl": _FunctionInfo( - "realtype *total_cl, const realtype *x_rdata, " - "const realtype *p, const realtype *k" - ), - "dtotal_cldp": _FunctionInfo( - "realtype *dtotal_cldp, const realtype *x_rdata, " - "const realtype *p, const realtype *k, const int ip" - ), - "dtotal_cldx_rdata": _FunctionInfo( - "realtype *dtotal_cldx_rdata, const realtype *x_rdata, " - "const realtype *p, const realtype *k, const realtype *tcl", - sparse=True, - ), - "x_solver": _FunctionInfo("realtype *x_solver, const realtype *x_rdata"), - "dx_rdatadx_solver": _FunctionInfo( - "realtype *dx_rdatadx_solver, const realtype *x, " - "const realtype *tcl, const realtype *p, const realtype *k", - sparse=True, - ), - "dx_rdatadp": _FunctionInfo( - "realtype *dx_rdatadp, const realtype *x, " - "const realtype *tcl, const realtype *p, const realtype *k, " - "const int ip" - ), - "dx_rdatadtcl": _FunctionInfo( - "realtype *dx_rdatadtcl, const realtype *x, " - "const realtype *tcl, const realtype *p, const realtype *k", - sparse=True, - ), - "z": _FunctionInfo( - "realtype *z, const int ie, const realtype t, const realtype *x, " - "const realtype *p, const realtype *k, const realtype *h" - ), - "rz": _FunctionInfo( - "realtype *rz, const int ie, const realtype t, const realtype *x, " - "const realtype *p, const realtype *k, const realtype *h" - ), -} - -# list of sparse functions -sparse_functions = [ - func_name for func_name, func_info in functions.items() if func_info.sparse -] -# list of nobody functions -nobody_functions = [ - func_name - for func_name, func_info in functions.items() - if not func_info.generate_body -] -# list of sensitivity functions -sensi_functions = [ - func_name - for func_name, func_info in functions.items() - if "const int ip" in func_info.arguments() -] -# list of sensitivity functions -sparse_sensi_functions = [ - func_name - for func_name, func_info in functions.items() - if "const int ip" not in func_info.arguments() - and func_name.endswith("dp") - or func_name.endswith("dp_explicit") -] -# list of event functions -event_functions = [ - func_name - for func_name, func_info in functions.items() - if "const int ie" in func_info.arguments() - and "const int ip" not in func_info.arguments() -] -event_sensi_functions = [ - func_name - for func_name, func_info in functions.items() - if "const int ie" in func_info.arguments() - and "const int ip" in func_info.arguments() -] -# list of multiobs functions -multiobs_functions = [ - func_name - for func_name, func_info in functions.items() - if "const int iy" in func_info.arguments() - or "const int iz" in func_info.arguments() -] -# list of equations that have ids which may not be unique +#: list of equations that have ids which may not be unique non_unique_id_symbols = ["x_rdata", "y"] -# custom c++ function replacements +#: custom c++ function replacements CUSTOM_FUNCTIONS = [ { "sympy": "polygamma", @@ -471,7 +111,7 @@ def arguments(self, ode: bool = True) -> str: {"sympy": "DiracDelta", "c++": "amici::dirac"}, ] -# python log manager +#: python log manager logger = get_logger(__name__, logging.ERROR) From 09d94f9bc0cbd9fd60ac9b76b83dcf7a11b68925 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Tue, 27 Feb 2024 13:59:36 +0100 Subject: [PATCH 2/3] Refactor de_export.py, extract _codegen.model_class Move functionality for generating the ``amici::Model`` subclass code to a separate file. Related to #2306 --- python/sdist/amici/_codegen/model_class.py | 191 ++++++++++++++++++++ python/sdist/amici/de_export.py | 195 ++------------------- 2 files changed, 201 insertions(+), 185 deletions(-) create mode 100644 python/sdist/amici/_codegen/model_class.py diff --git a/python/sdist/amici/_codegen/model_class.py b/python/sdist/amici/_codegen/model_class.py new file mode 100644 index 0000000000..d6c1bcdc81 --- /dev/null +++ b/python/sdist/amici/_codegen/model_class.py @@ -0,0 +1,191 @@ +"""Function for generating the ``amici::Model`` subclass for an amici model.""" +from __future__ import annotations + +from .cxx_functions import functions, multiobs_functions + +from ..de_model import Event + + +def get_function_extern_declaration(fun: str, name: str, ode: bool) -> str: + """ + Constructs the extern function declaration for a given function + + :param fun: + function name + :param name: + model name + :param ode: + whether to generate declaration for DAE or ODE + + :return: + C++ function definition string + """ + f = functions[fun] + return f"extern {f.return_type} {fun}_{name}({f.arguments(ode)});" + + +def get_sunindex_extern_declaration( + fun: str, name: str, indextype: str +) -> str: + """ + Constructs the function declaration for an index function of a given + function + + :param fun: + function name + + :param name: + model name + + :param indextype: + index function {'colptrs', 'rowvals'} + + :return: + C++ function declaration string + """ + index_arg = ", int index" if fun in multiobs_functions else "" + return ( + f"extern void {fun}_{indextype}_{name}" + f"(SUNMatrixWrapper &{indextype}{index_arg});" + ) + + +def get_model_override_implementation( + fun: str, name: str, ode: bool, nobody: bool = False +) -> str: + """ + Constructs ``amici::Model::*`` override implementation for a given function + + :param fun: + function name + + :param name: + model name + + :param nobody: + whether the function has a nontrivial implementation + + :return: + C++ function implementation string + """ + func_info = functions[fun] + body = ( + "" + if nobody + else "\n{ind8}{maybe_return}{fun}_{name}({eval_signature});\n{ind4}".format( + ind4=" " * 4, + ind8=" " * 8, + maybe_return="" if func_info.return_type == "void" else "return ", + fun=fun, + name=name, + eval_signature=remove_argument_types(func_info.arguments(ode)), + ) + ) + return "{return_type} f{fun}({signature}) override {{{body}}}\n".format( + return_type=func_info.return_type, + fun=fun, + signature=func_info.arguments(ode), + body=body, + ) + + +def get_sunindex_override_implementation( + fun: str, name: str, indextype: str, nobody: bool = False +) -> str: + """ + Constructs the ``amici::Model`` function implementation for an index + function of a given function + + :param fun: + function name + + :param name: + model name + + :param indextype: + index function {'colptrs', 'rowvals'} + + :param nobody: + whether the corresponding function has a nontrivial implementation + + :return: + C++ function implementation string + """ + index_arg = ", int index" if fun in multiobs_functions else "" + index_arg_eval = ", index" if fun in multiobs_functions else "" + + impl = "void f{fun}_{indextype}({signature}) override {{" + + if nobody: + impl += "}}\n" + else: + impl += ( + "\n{ind8}{fun}_{indextype}_{name}({eval_signature});\n{ind4}}}\n" + ) + + return impl.format( + ind4=" " * 4, + ind8=" " * 8, + fun=fun, + indextype=indextype, + name=name, + signature=f"SUNMatrixWrapper &{indextype}{index_arg}", + eval_signature=f"{indextype}{index_arg_eval}", + ) + + +def remove_argument_types(signature: str) -> str: + """ + Strips argument types from a function signature + + :param signature: + function signature + + :return: + string that can be used to construct function calls with the same + variable names and ordering as in the function signature + """ + # remove * prefix for pointers (pointer must always be removed before + # values otherwise we will inadvertently dereference values, + # same applies for const specifications) + # + # always add whitespace after type definition for cosmetic reasons + known_types = [ + "const realtype *", + "const double *", + "const realtype ", + "double *", + "realtype *", + "const int ", + "int ", + "bool ", + "SUNMatrixContent_Sparse ", + "gsl::span", + ] + + for type_str in known_types: + signature = signature.replace(type_str, "") + + return signature + + +def get_state_independent_event_intializer(events: list[Event]) -> str: + """Get initializer list for state independent events in amici::Model.""" + map_time_to_event_idx = {} + for event_idx, event in enumerate(events): + if not event.triggers_at_fixed_timepoint(): + continue + trigger_time = float(event.get_trigger_time()) + try: + map_time_to_event_idx[trigger_time].append(event_idx) + except KeyError: + map_time_to_event_idx[trigger_time] = [event_idx] + + def vector_initializer(v): + """std::vector initializer list with elements from `v`""" + return f"{{{', '.join(map(str, v))}}}" + + return ", ".join( + f"{{{trigger_time}, {vector_initializer(event_idxs)}}}" + for trigger_time, event_idxs in map_time_to_event_idx.items() + ) diff --git a/python/sdist/amici/de_export.py b/python/sdist/amici/de_export.py index 2131f9dd9e..d5c6c511c4 100644 --- a/python/sdist/amici/de_export.py +++ b/python/sdist/amici/de_export.py @@ -49,6 +49,13 @@ event_sensi_functions, multiobs_functions, ) +from ._codegen.model_class import ( + get_function_extern_declaration, + get_sunindex_extern_declaration, + get_model_override_implementation, + get_sunindex_override_implementation, + get_state_independent_event_intializer, +) from ._codegen.template import apply_template from .cxxcodeprinter import ( AmiciCxxCodePrinter, @@ -3330,7 +3337,9 @@ def _write_model_header_cpp(self) -> None: ) ), "Z2EVENT": ", ".join(map(str, self.model._z2event)), - "STATE_INDEPENDENT_EVENTS": self._get_state_independent_event_intializer(), + "STATE_INDEPENDENT_EVENTS": get_state_independent_event_intializer( + self.model.events() + ), "ID": ", ".join( str(float(isinstance(s, DifferentialState))) for s in self.model.states() @@ -3464,27 +3473,6 @@ def _get_symbol_id_initializer_list(self, name: str) -> str: for idx, symbol in enumerate(self.model.sym(name)) ) - def _get_state_independent_event_intializer(self) -> str: - """Get initializer list for state independent events in amici::Model.""" - map_time_to_event_idx = {} - for event_idx, event in enumerate(self.model.events()): - if not event.triggers_at_fixed_timepoint(): - continue - trigger_time = float(event.get_trigger_time()) - try: - map_time_to_event_idx[trigger_time].append(event_idx) - except KeyError: - map_time_to_event_idx[trigger_time] = [event_idx] - - def vector_initializer(v): - """std::vector initializer list with elements from `v`""" - return f"{{{', '.join(map(str, v))}}}" - - return ", ".join( - f"{{{trigger_time}, {vector_initializer(event_idxs)}}}" - for trigger_time, event_idxs in map_time_to_event_idx.items() - ) - def _write_c_make_file(self): """Write CMake ``CMakeLists.txt`` file for this model.""" sources = "\n".join( @@ -3585,169 +3573,6 @@ def set_name(self, model_name: str) -> None: self.model_name = model_name -def get_function_extern_declaration(fun: str, name: str, ode: bool) -> str: - """ - Constructs the extern function declaration for a given function - - :param fun: - function name - :param name: - model name - :param ode: - whether to generate declaration for DAE or ODE - - :return: - C++ function definition string - """ - f = functions[fun] - return f"extern {f.return_type} {fun}_{name}({f.arguments(ode)});" - - -def get_sunindex_extern_declaration( - fun: str, name: str, indextype: str -) -> str: - """ - Constructs the function declaration for an index function of a given - function - - :param fun: - function name - - :param name: - model name - - :param indextype: - index function {'colptrs', 'rowvals'} - - :return: - C++ function declaration string - """ - index_arg = ", int index" if fun in multiobs_functions else "" - return ( - f"extern void {fun}_{indextype}_{name}" - f"(SUNMatrixWrapper &{indextype}{index_arg});" - ) - - -def get_model_override_implementation( - fun: str, name: str, ode: bool, nobody: bool = False -) -> str: - """ - Constructs ``amici::Model::*`` override implementation for a given function - - :param fun: - function name - - :param name: - model name - - :param nobody: - whether the function has a nontrivial implementation - - :return: - C++ function implementation string - """ - func_info = functions[fun] - body = ( - "" - if nobody - else "\n{ind8}{maybe_return}{fun}_{name}({eval_signature});\n{ind4}".format( - ind4=" " * 4, - ind8=" " * 8, - maybe_return="" if func_info.return_type == "void" else "return ", - fun=fun, - name=name, - eval_signature=remove_argument_types(func_info.arguments(ode)), - ) - ) - return "{return_type} f{fun}({signature}) override {{{body}}}\n".format( - return_type=func_info.return_type, - fun=fun, - signature=func_info.arguments(ode), - body=body, - ) - - -def get_sunindex_override_implementation( - fun: str, name: str, indextype: str, nobody: bool = False -) -> str: - """ - Constructs the ``amici::Model`` function implementation for an index - function of a given function - - :param fun: - function name - - :param name: - model name - - :param indextype: - index function {'colptrs', 'rowvals'} - - :param nobody: - whether the corresponding function has a nontrivial implementation - - :return: - C++ function implementation string - """ - index_arg = ", int index" if fun in multiobs_functions else "" - index_arg_eval = ", index" if fun in multiobs_functions else "" - - impl = "void f{fun}_{indextype}({signature}) override {{" - - if nobody: - impl += "}}\n" - else: - impl += ( - "\n{ind8}{fun}_{indextype}_{name}({eval_signature});\n{ind4}}}\n" - ) - - return impl.format( - ind4=" " * 4, - ind8=" " * 8, - fun=fun, - indextype=indextype, - name=name, - signature=f"SUNMatrixWrapper &{indextype}{index_arg}", - eval_signature=f"{indextype}{index_arg_eval}", - ) - - -def remove_argument_types(signature: str) -> str: - """ - Strips argument types from a function signature - - :param signature: - function signature - - :return: - string that can be used to construct function calls with the same - variable names and ordering as in the function signature - """ - # remove * prefix for pointers (pointer must always be removed before - # values otherwise we will inadvertently dereference values, - # same applies for const specifications) - # - # always add whitespace after type definition for cosmetic reasons - known_types = [ - "const realtype *", - "const double *", - "const realtype ", - "double *", - "realtype *", - "const int ", - "int ", - "bool ", - "SUNMatrixContent_Sparse ", - "gsl::span", - ] - - for type_str in known_types: - signature = signature.replace(type_str, "") - - return signature - - def is_valid_identifier(x: str) -> bool: """ Check whether `x` is a valid identifier for conditions, parameters, From 86433eb54072e6dc6c7e04e186cfbea610d2754c Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Tue, 27 Feb 2024 18:09:29 +0100 Subject: [PATCH 3/3] .. --- python/sdist/amici/_codegen/model_class.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sdist/amici/_codegen/model_class.py b/python/sdist/amici/_codegen/model_class.py index d6c1bcdc81..e6366c1dfd 100644 --- a/python/sdist/amici/_codegen/model_class.py +++ b/python/sdist/amici/_codegen/model_class.py @@ -3,7 +3,7 @@ from .cxx_functions import functions, multiobs_functions -from ..de_model import Event +from ..de_model_components import Event def get_function_extern_declaration(fun: str, name: str, ode: bool) -> str: