Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor de_export.py, extract _codegen.model_class #2319

Merged
merged 4 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 191 additions & 0 deletions python/sdist/amici/_codegen/model_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""Function for generating the ``amici::Model`` subclass for an amici model."""
from __future__ import annotations

from .cxx_functions import functions, multiobs_functions

from ..de_model_components import Event


def get_function_extern_declaration(fun: str, name: str, ode: bool) -> str:
"""
Constructs the extern function declaration for a given function

:param fun:
function name
:param name:
model name
:param ode:
whether to generate declaration for DAE or ODE

:return:
C++ function definition string
"""
f = functions[fun]
return f"extern {f.return_type} {fun}_{name}({f.arguments(ode)});"


def get_sunindex_extern_declaration(
fun: str, name: str, indextype: str
) -> str:
"""
Constructs the function declaration for an index function of a given
function

:param fun:
function name

:param name:
model name

:param indextype:
index function {'colptrs', 'rowvals'}

:return:
C++ function declaration string
"""
index_arg = ", int index" if fun in multiobs_functions else ""
return (
f"extern void {fun}_{indextype}_{name}"
f"(SUNMatrixWrapper &{indextype}{index_arg});"
)


def get_model_override_implementation(
fun: str, name: str, ode: bool, nobody: bool = False
) -> str:
"""
Constructs ``amici::Model::*`` override implementation for a given function

:param fun:
function name

:param name:
model name

:param nobody:
whether the function has a nontrivial implementation

:return:
C++ function implementation string
"""
func_info = functions[fun]
body = (
""
if nobody
else "\n{ind8}{maybe_return}{fun}_{name}({eval_signature});\n{ind4}".format(
ind4=" " * 4,
ind8=" " * 8,
maybe_return="" if func_info.return_type == "void" else "return ",
fun=fun,
name=name,
eval_signature=remove_argument_types(func_info.arguments(ode)),
)
)
return "{return_type} f{fun}({signature}) override {{{body}}}\n".format(
return_type=func_info.return_type,
fun=fun,
signature=func_info.arguments(ode),
body=body,
)


def get_sunindex_override_implementation(
fun: str, name: str, indextype: str, nobody: bool = False
) -> str:
"""
Constructs the ``amici::Model`` function implementation for an index
function of a given function

:param fun:
function name

:param name:
model name

:param indextype:
index function {'colptrs', 'rowvals'}

:param nobody:
whether the corresponding function has a nontrivial implementation

:return:
C++ function implementation string
"""
index_arg = ", int index" if fun in multiobs_functions else ""
index_arg_eval = ", index" if fun in multiobs_functions else ""

impl = "void f{fun}_{indextype}({signature}) override {{"

if nobody:
impl += "}}\n"
else:
impl += (
"\n{ind8}{fun}_{indextype}_{name}({eval_signature});\n{ind4}}}\n"
)

return impl.format(
ind4=" " * 4,
ind8=" " * 8,
fun=fun,
indextype=indextype,
name=name,
signature=f"SUNMatrixWrapper &{indextype}{index_arg}",
eval_signature=f"{indextype}{index_arg_eval}",
)


def remove_argument_types(signature: str) -> str:
"""
Strips argument types from a function signature

:param signature:
function signature

:return:
string that can be used to construct function calls with the same
variable names and ordering as in the function signature
"""
# remove * prefix for pointers (pointer must always be removed before
# values otherwise we will inadvertently dereference values,
# same applies for const specifications)
#
# always add whitespace after type definition for cosmetic reasons
known_types = [
"const realtype *",
"const double *",
"const realtype ",
"double *",
"realtype *",
"const int ",
"int ",
"bool ",
"SUNMatrixContent_Sparse ",
"gsl::span<const int>",
]

for type_str in known_types:
signature = signature.replace(type_str, "")

return signature


def get_state_independent_event_intializer(events: list[Event]) -> str:
"""Get initializer list for state independent events in amici::Model."""
map_time_to_event_idx = {}
for event_idx, event in enumerate(events):
if not event.triggers_at_fixed_timepoint():
continue
trigger_time = float(event.get_trigger_time())
try:
map_time_to_event_idx[trigger_time].append(event_idx)
except KeyError:
map_time_to_event_idx[trigger_time] = [event_idx]

Check warning on line 182 in python/sdist/amici/_codegen/model_class.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/_codegen/model_class.py#L176-L182

Added lines #L176 - L182 were not covered by tests

def vector_initializer(v):
"""std::vector initializer list with elements from `v`"""
return f"{{{', '.join(map(str, v))}}}"

Check warning on line 186 in python/sdist/amici/_codegen/model_class.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/_codegen/model_class.py#L186

Added line #L186 was not covered by tests

return ", ".join(
f"{{{trigger_time}, {vector_initializer(event_idxs)}}}"
for trigger_time, event_idxs in map_time_to_event_idx.items()
)
Loading
Loading