Skip to content

Commit

Permalink
Add plot_jacobian (#1930)
Browse files Browse the repository at this point in the history

Co-authored-by: Fabian Fröhlich <fabian@schaluck.com>
  • Loading branch information
dweindl and FFroehlich authored Jan 17, 2023
1 parent 8c58eb3 commit 3cb0a32
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions python/sdist/amici/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
--------
Plotting related functions
"""
from . import ReturnDataView, Model
from typing import Iterable, Optional

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from matplotlib.axes import Axes
from typing import Optional, Iterable

from . import Model, ReturnDataView

def plotStateTrajectories(

def plot_state_trajectories(
rdata: ReturnDataView,
state_indices: Optional[Iterable[int]] = None,
ax: Optional[Axes] = None,
Expand Down Expand Up @@ -50,7 +53,7 @@ def plotStateTrajectories(
ax.set_title('State trajectories')


def plotObservableTrajectories(
def plot_observable_trajectories(
rdata: ReturnDataView,
observable_indices: Optional[Iterable[int]] = None,
ax: Optional[Axes] = None,
Expand Down Expand Up @@ -88,3 +91,18 @@ def plotObservableTrajectories(
ax.set_ylabel('$y(t)$')
ax.legend()
ax.set_title('Observable trajectories')


def plot_jacobian(rdata: ReturnDataView):
"""Plot Jacobian as heatmap."""
df = pd.DataFrame(
data=rdata.J,
index=rdata._swigptr.state_ids,
columns=rdata._swigptr.state_ids
)
sns.heatmap(df, center=0.0)
plt.title("Jacobian")

# backwards compatibility
plotStateTrajectories = plot_state_trajectories
plotObservableTrajectories = plot_observable_trajectories

0 comments on commit 3cb0a32

Please sign in to comment.