Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add event modulation to peri_event.py #56

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions neuro_py/io/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from neuro_py.behavior.kinematics import get_speed
from neuro_py.process.intervals import find_interval, in_intervals
from neuro_py.process.peri_event import get_participation


def loadXML(basepath: str) -> Union[Tuple[int, int, int, Dict[int, list]], None]:
Expand Down Expand Up @@ -1025,6 +1024,7 @@ def load_barrage_events(
Union[pd.DataFrame, nel.EpochArray]
DataFrame with barrage events.
"""
from neuro_py.process.peri_event import count_in_interval

# locate barrage file
filename = os.path.join(basepath, os.path.basename(basepath) + ".HSEn2.events.mat")
Expand Down Expand Up @@ -1061,7 +1061,7 @@ def load_barrage_events(
# load ca2 pyr cells
st, _ = load_spikes(basepath, putativeCellType="Pyr", brainRegion="CA2")
# bin spikes into barrages
bst = get_participation(st.data, df["start"].values, df["stop"].values)
bst = count_in_interval(st.data, df["start"].values, df["stop"].values)
# keep only barrages with some activity
df = df[np.sum(bst > 0, axis=0) > 0].reset_index(drop=True)

Expand Down
2 changes: 2 additions & 0 deletions neuro_py/process/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ __all__ = (
"average_diagonal",
"remove_inactive_cells",
"remove_inactive_cells_pre_task_post",
"swr_modulated_units"
)

from . import batch_analysis
Expand Down Expand Up @@ -92,6 +93,7 @@ from .peri_event import (
nearest_event_delay,
peth_matrix,
relative_times,
swr_modulated_units
)
from .precession_utils import (
acf_power,
Expand Down
249 changes: 247 additions & 2 deletions neuro_py/process/peri_event.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import warnings
from typing import List, Optional, Tuple, Union

import bottleneck as bn
import numpy as np
import pandas as pd
Expand All @@ -10,7 +9,8 @@
from scipy import stats
from scipy.linalg import toeplitz
from scipy.ndimage import gaussian_filter1d

from neuro_py.io.loading import load_spikes, load_ripples_events
from tqdm import tqdm
from neuro_py.process.intervals import in_intervals, split_epoch_by_width


Expand Down Expand Up @@ -1278,3 +1278,248 @@ def event_spiking_threshold(
sns.despine()

return valid_events


def swr_modulated_units(
basepath,
bin_width=0.002,
n_bins=350,
n_shuffles=500,
min_spikes=50,
alpha=0.05,
recording_length=None,
pre_window=(-0.5, -0.1), # baseline window (e.g. -500 ms to -100 ms)
post_window=(0.0, 0.2), # SWR window (e.g. 0 to +200 ms)
sd_threshold=1.0,
return_psth=False
):
"""
Analyze SWR-modulation with an integrated approach:
1) Load spikes & ripples from basepath.
2) Filter out neurons with < min_spikes in ripple intervals.
3) Compute PSTH from -0.5s to +0.2s (encompassing baseline & post).
4) Shuffle spike trains (global circular shift) n_shuffles times => baseline PSTH distribution.
5) Identify modulated vs. not_modulated via mean-squared difference (MSD).
6) Classify modulated cells as excited or inhibited by comparing post_window vs. pre_window.
7) Compute latency = first crossing of ±1 SD from baseline after 0 ms.

By default, returns only cell_metrics (DataFrame). If return_psth=True,
returns a tuple: (cell_metrics, psth_full).

Parameters
----------
basepath : str
Path to data directory containing spikes/ripples.
bin_width : float
PSTH bin size in seconds (default=0.002 => 2 ms).
n_bins : int
Total bins for PSTH (default=350 => covers 0.7s if bin_width=2 ms).
n_shuffles : int
Number of global circular shifts (default=500).
min_spikes : int
Minimum # of spikes in ripple intervals (default=50).
alpha : float
Significance level for MSD test (default=0.05).
recording_length : float or None
If None, inferred from max spike time. Otherwise, pass a known total length.
pre_window : tuple
Baseline window (start, stop), e.g. (-0.5, -0.1).
post_window : tuple
Post-SWR window (start, stop), e.g. (0.0, 0.2).
sd_threshold : float
# of SD for latency crossing from baseline (default=1.0).
return_psth : bool
If False (default), return only cell_metrics (pd.DataFrame).
If True, return (cell_metrics, psth_full) as a tuple.

Returns
-------
pd.DataFrame (if return_psth=False)
A DataFrame describing each neuron's SWR modulation.
(pd.DataFrame, pd.DataFrame) (if return_psth=True)
A tuple where the first element is the same DataFrame as above,
and the second element is the PSTH covering [-0.5, +0.2].
"""

print(f"Loading data from: {basepath}")
# 1) Load spikes & ripples
spike_data, cell_metrics = load_spikes(basepath)
ripple_df = load_ripples_events(basepath, return_epoch_array=False)

# Infer recording length if needed
if recording_length is None:
all_spike_times = np.concatenate(spike_data.data)
recording_length = all_spike_times.max()
print(f"Inferred total recording length = {recording_length:.3f} s")

# Convert ripple times to float arrays
ripple_starts = ripple_df['start'].astype(float).values
ripple_stops = ripple_df['stop'].astype(float).values

# Initialize or reset columns in cell_metrics
cell_metrics["swr_is_modulated"] = False
cell_metrics["swr_modulation_type"] = "insufficient_spikes"
cell_metrics["swr_modulation_pval"] = np.nan
cell_metrics["swr_msd"] = np.nan
cell_metrics["swr_msd_threshold"] = np.nan

cell_metrics["swr_exc_inh_type"] = "insufficient_spikes"
cell_metrics["swr_exc_inh_latency"] = np.nan

# 2) Filter units by # spikes in ripple intervals
keep_mask = []
for unit_idx in range(spike_data.n_units):
st = spike_data.data[unit_idx]
count_in_ripples = 0
for (t0, t1) in zip(ripple_starts, ripple_stops):
count_in_ripples += np.sum((st >= t0) & (st <= t1))
keep_mask.append(count_in_ripples >= min_spikes)

keep_mask = np.array(keep_mask)
valid_unit_indices = np.where(keep_mask)[0]
print(f"Found {len(valid_unit_indices)} valid units (≥ {min_spikes} spikes) out of {spike_data.n_units}.")

# If none pass threshold, return early
if len(valid_unit_indices) == 0:
print("No units pass the min_spikes threshold.")
if return_psth:
return cell_metrics, pd.DataFrame()
else:
return cell_metrics

# Build a list of float arrays for valid units
valid_spike_list = []
for idx in valid_unit_indices:
valid_spike_list.append(spike_data.data[idx].astype(float))

# 3) Compute PSTH from -0.5 to +0.2
print(f"Computing PSTH from {pre_window[0]} to {post_window[1]} for valid units...")
psth_full = compute_psth(
spikes=valid_spike_list,
event=ripple_starts,
bin_width=bin_width,
n_bins=n_bins,
window=[pre_window[0], post_window[1]]
)
time_bins = psth_full.index.values

# Helper to do global shift
def global_shift_spike_train(spike_times, total_len):
shift = np.random.uniform(0, total_len)
shifted = (spike_times + shift) % total_len
return np.sort(shifted)

# We'll store partial results
results = {}

# 4) For each valid unit, do shuffles => measure significance => classify modulated
for col_idx, unit_idx in enumerate(tqdm(valid_unit_indices, desc="Units")):
st_float = spike_data.data[unit_idx].astype(float)
real_psth_unit = psth_full.iloc[:, col_idx].values

# Shuffle PSTHs
shuffle_psths = []
for _ in tqdm(range(n_shuffles), desc=f"Shuffles (unit {unit_idx})", leave=False):
shifted_spikes = global_shift_spike_train(st_float, recording_length)
shift_psth = compute_psth(
spikes=[shifted_spikes],
event=ripple_starts,
bin_width=bin_width,
n_bins=n_bins,
window=[pre_window[0], post_window[1]]
)
shuffle_psths.append(shift_psth.iloc[:, 0].values)

shuffle_matrix = np.array(shuffle_psths) # shape: [n_shuffles, time_bins]
baseline_psth = shuffle_matrix.mean(axis=0)

# Compare real vs baseline in [0..+0.2] to get MSD
mask_0to200 = (time_bins >= 0) & (time_bins < post_window[1])
real_segment = real_psth_unit[mask_0to200]
baseline_segment = baseline_psth[mask_0to200]

actual_msd = np.mean((real_segment - baseline_segment)**2)

shuffle_msds = []
for i in range(n_shuffles):
sh_seg = shuffle_matrix[i, mask_0to200]
msd_i = np.mean((sh_seg - baseline_segment)**2)
shuffle_msds.append(msd_i)
shuffle_msds = np.array(shuffle_msds)

msd_threshold = np.percentile(shuffle_msds, 100*(1 - alpha))
is_modulated = (actual_msd >= msd_threshold)

rank = np.sum(actual_msd > shuffle_msds)
pval = 1.0 - (rank / float(n_shuffles))

results[unit_idx] = {
"is_mod": is_modulated,
"pval": pval,
"msd": actual_msd,
"msd_thresh": msd_threshold
}

# 5) Update cell_metrics: modulated / not_modulated
for unit_idx, vals in results.items():
cell_metrics.loc[unit_idx, "swr_is_modulated"] = vals["is_mod"]
if vals["is_mod"]:
cell_metrics.loc[unit_idx, "swr_modulation_type"] = "modulated"
else:
cell_metrics.loc[unit_idx, "swr_modulation_type"] = "not_modulated"

cell_metrics.loc[unit_idx, "swr_modulation_pval"] = vals["pval"]
cell_metrics.loc[unit_idx, "swr_msd"] = vals["msd"]
cell_metrics.loc[unit_idx, "swr_msd_threshold"] = vals["msd_thresh"]

# 6) For modulated cells, define excited vs inhibited + compute latency
pre_mask = (time_bins >= pre_window[0]) & (time_bins < pre_window[1])
post_mask = (time_bins >= post_window[0]) & (time_bins < post_window[1])

for col_idx, unit_idx in enumerate(valid_unit_indices):
mod_type = cell_metrics.loc[unit_idx, "swr_modulation_type"]
if mod_type != "modulated":
continue

psth_unit = psth_full.iloc[:, col_idx]
pre_rate = psth_unit[pre_mask].mean()
post_rate = psth_unit[post_mask].mean()

if post_rate > pre_rate:
direction = "excited"
else:
direction = "inhibited"

cell_metrics.loc[unit_idx, "swr_exc_inh_type"] = direction

# Latency: first crossing ± sd_threshold * baseline_std after 0 ms
baseline_mean = pre_rate
baseline_std = psth_unit[pre_mask].std(ddof=1)
if baseline_std == 0 or np.isnan(baseline_std):
continue

if direction == "excited":
threshold = baseline_mean + sd_threshold * baseline_std
crossing_mask = (psth_unit >= threshold)
else:
threshold = baseline_mean - sd_threshold * baseline_std
crossing_mask = (psth_unit <= threshold)

after_zero_idx = np.where(time_bins >= 0)[0]
if len(after_zero_idx) == 0:
continue

candidate_vals = crossing_mask.iloc[after_zero_idx].values
if np.any(candidate_vals):
first_rel = np.where(candidate_vals)[0][0]
first_idx = after_zero_idx[first_rel]
latency_sec = time_bins[first_idx]
cell_metrics.loc[unit_idx, "swr_exc_inh_latency"] = latency_sec

print("Done analyzing SWR modulation!")

# 7) Return only cell_metrics (default), or a tuple if return_psth=True
if return_psth:
return cell_metrics, psth_full
else:
return cell_metrics