From 3cb0a32ab48d3baf2e36651aa890aa1749e42d91 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Tue, 17 Jan 2023 12:10:01 +0100 Subject: [PATCH] Add `plot_jacobian` (#1930) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Fabian Fröhlich --- python/sdist/amici/plotting.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/python/sdist/amici/plotting.py b/python/sdist/amici/plotting.py index d2917de9fe..bdbffdb7be 100644 --- a/python/sdist/amici/plotting.py +++ b/python/sdist/amici/plotting.py @@ -3,14 +3,17 @@ -------- Plotting related functions """ -from . import ReturnDataView, Model +from typing import Iterable, Optional import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns from matplotlib.axes import Axes -from typing import Optional, Iterable +from . import Model, ReturnDataView -def plotStateTrajectories( + +def plot_state_trajectories( rdata: ReturnDataView, state_indices: Optional[Iterable[int]] = None, ax: Optional[Axes] = None, @@ -50,7 +53,7 @@ def plotStateTrajectories( ax.set_title('State trajectories') -def plotObservableTrajectories( +def plot_observable_trajectories( rdata: ReturnDataView, observable_indices: Optional[Iterable[int]] = None, ax: Optional[Axes] = None, @@ -88,3 +91,18 @@ def plotObservableTrajectories( ax.set_ylabel('$y(t)$') ax.legend() ax.set_title('Observable trajectories') + + +def plot_jacobian(rdata: ReturnDataView): + """Plot Jacobian as heatmap.""" + df = pd.DataFrame( + data=rdata.J, + index=rdata._swigptr.state_ids, + columns=rdata._swigptr.state_ids + ) + sns.heatmap(df, center=0.0) + plt.title("Jacobian") + +# backwards compatibility +plotStateTrajectories = plot_state_trajectories +plotObservableTrajectories = plot_observable_trajectories