Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add nan safe log&divide #2611

Merged
merged 3 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions python/sdist/amici/jax.template.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# ruff: noqa: F401, F821, F841
import jax.numpy as jnp
from interpax import interp1d

from amici.jax.model import JAXModel
from amici.jax.model import JAXModel, safe_log, safe_div


class JAXModel_TPL_MODEL_NAME(JAXModel):
Expand All @@ -11,7 +12,6 @@ def __init__(self):
super().__init__()

def _xdot(self, t, x, args):

pk, tcl = args

TPL_X_SYMS = x
Expand All @@ -24,7 +24,6 @@ def _xdot(self, t, x, args):
return TPL_XDOT_RET

def _w(self, t, x, pk, tcl):

TPL_X_SYMS = x
TPL_PK_SYMS = pk
TPL_TCL_SYMS = tcl
Expand All @@ -34,23 +33,20 @@ def _w(self, t, x, pk, tcl):
return TPL_W_RET

def _x0(self, pk):

TPL_PK_SYMS = pk

TPL_X0_EQ

return TPL_X0_RET

def _x_solver(self, x):

TPL_X_RDATA_SYMS = x

TPL_X_SOLVER_EQ

return TPL_X_SOLVER_RET

def _x_rdata(self, x, tcl):

TPL_X_SYMS = x
TPL_TCL_SYMS = tcl

Expand All @@ -59,7 +55,6 @@ def _x_rdata(self, x, tcl):
return TPL_X_RDATA_RET

def _tcl(self, x, pk):

TPL_X_RDATA_SYMS = x
TPL_PK_SYMS = pk

Expand All @@ -68,7 +63,6 @@ def _tcl(self, x, pk):
return TPL_TOTAL_CL_RET

def _y(self, t, x, pk, tcl):

TPL_X_SYMS = x
TPL_PK_SYMS = pk
TPL_W_SYMS = self._w(t, x, pk, tcl)
Expand All @@ -86,7 +80,6 @@ def _sigmay(self, y, pk):

return TPL_SIGMAY_RET


def _nllh(self, t, x, pk, tcl, my, iy):
y = self._y(t, x, pk, tcl)
TPL_Y_SYMS = y
Expand Down
36 changes: 36 additions & 0 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,3 +557,39 @@
stats_dyn=stats_dyn,
stats_posteq=stats_posteq,
)


def safe_log(x: jnp.float_) -> jnp.float_:
"""
Safe logarithm that returns `jnp.log(jnp.finfo(jnp.float_).eps)` for x <= 0.

:param x:
input
:return:
logarithm of x
"""
# see https://docs.kidger.site/equinox/api/debug/, need double jnp.where to guard
# against nans in forward & backward passes
safe_x = jnp.where(

Check warning on line 573 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L573

Added line #L573 was not covered by tests
x > jnp.finfo(jnp.float_).eps, x, jnp.finfo(jnp.float_).eps
)
return jnp.where(

Check warning on line 576 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L576

Added line #L576 was not covered by tests
x > 0, jnp.log(safe_x), jnp.log(jnp.finfo(jnp.float_).eps)
)


def safe_div(x: jnp.float_, y: jnp.float_) -> jnp.float_:
"""
Safe division that returns `x/jnp.finfo(jnp.float_).eps` for `y == 0`.

:param x:
numerator
:param y:
denominator
:return:
x / y
"""
# see https://docs.kidger.site/equinox/api/debug/, need double jnp.where to guard
# against nans in forward & backward passes
safe_y = jnp.where(y != 0, y, jnp.finfo(jnp.float_).eps)
return jnp.where(y != 0, x / safe_y, x / jnp.finfo(jnp.float_).eps)

Check warning on line 595 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L594-L595

Added lines #L594 - L595 were not covered by tests
9 changes: 9 additions & 0 deletions python/sdist/amici/jaxcodeprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ def _print_AmiciSpline(self, expr: sp.Expr) -> str:
# FIXME: untested, where are spline nodes coming from anyways?
return f'interp1d(time, {self.doprint(expr.args[2:])}, kind="cubic")'

def _print_log(self, expr: sp.Expr) -> str:
return f"safe_log({self.doprint(expr.args[0])})"

def _print_Mul(self, expr: sp.Expr) -> str:
numer, denom = expr.as_numer_denom()
if denom == 1:
return super()._print_Mul(expr)
return f"safe_div({self.doprint(numer)}, {self.doprint(denom)})"

def _get_sym_lines(
self,
symbols: sp.Matrix | Iterable[str],
Expand Down
20 changes: 7 additions & 13 deletions tests/benchmark-models/test_petab_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,14 +299,8 @@ def test_jax_llh(benchmark_problem):

np.random.seed(cur_settings.rng_seed)

problems_for_gradient_check_jax = list(
set(problems_for_gradient_check) - {"Laske_PLOSComputBiol2019"}
# Laske has nan values in gradient due to nan values in observables that are not used in the likelihood
# but are problematic during backpropagation
)

problem_parameters = None
if problem_id in problems_for_gradient_check_jax:
if problem_id in problems_for_gradient_check:
point = petab_problem.x_nominal_free_scaled
for _ in range(20):
amici_solver.setSensitivityMethod(amici.SensitivityMethod.adjoint)
Expand Down Expand Up @@ -352,12 +346,12 @@ def test_jax_llh(benchmark_problem):
[problem_parameters[pid] for pid in jax_problem.parameter_ids]
),
)
if problem_id in problems_for_gradient_check_jax:
(llh_jax, _), sllh_jax = eqx.filter_jit(
eqx.filter_value_and_grad(run_simulations, has_aux=True)
if problem_id in problems_for_gradient_check:
(llh_jax, _), sllh_jax = eqx.filter_value_and_grad(
run_simulations, has_aux=True
)(jax_problem, simulation_conditions)
else:
llh_jax, _ = beartype(eqx.filter_jit(run_simulations))(
llh_jax, _ = beartype(run_simulations)(
jax_problem, simulation_conditions
)

Expand All @@ -369,14 +363,14 @@ def test_jax_llh(benchmark_problem):
err_msg=f"LLH mismatch for {problem_id}",
)

if problem_id in problems_for_gradient_check_jax:
if problem_id in problems_for_gradient_check:
sllh_amici = r_amici[SLLH]
np.testing.assert_allclose(
sllh_jax.parameters,
np.array([sllh_amici[pid] for pid in jax_problem.parameter_ids]),
rtol=1e-2,
atol=1e-2,
err_msg=f"SLLH mismatch for {problem_id}",
err_msg=f"SLLH mismatch for {problem_id}, {dict(zip(jax_problem.parameter_ids, sllh_jax.parameters))}",
)


Expand Down
Loading