Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

linear-positions make private methods #58

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions neuro_py/behavior/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@ __all__ = [
"get_velocity",
"get_speed",
"linearize_position",
"find_laps",
"peakdetz",
"find_good_laps",
"get_linear_track_lap_epochs",
"find_good_lap_epochs",
"NodePicker",
Expand All @@ -37,11 +34,8 @@ from .get_trials import (
from .kinematics import get_speed, get_velocity
from .linear_positions import (
find_good_lap_epochs,
find_good_laps,
find_laps,
get_linear_track_lap_epochs,
linearize_position,
peakdetz,
)
from .linearization_pipeline import NodePicker
from .well_traversal_classification import (
Expand Down
157 changes: 82 additions & 75 deletions neuro_py/behavior/linear_positions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def linearize_position(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.nda
return x, y


def find_laps(
def __find_laps(
Vts: np.ndarray,
Vdata: np.ndarray,
newLapThreshold: float = 15,
Expand Down Expand Up @@ -91,6 +91,10 @@ def find_laps(
start index, and direction.
"""

# Handle empty input
if len(Vdata) == 0 or len(Vts) == 0:
return pd.DataFrame(columns=["start_ts", "pos", "start_idx", "direction"])

TL = np.abs(np.nanmax(Vdata) - np.nanmin(Vdata)) # % track length
th1 = (
np.nanmin(Vdata) + TL * newLapThreshold / 100
Expand Down Expand Up @@ -142,7 +146,8 @@ def find_laps(

# % fix direction of first lap which was unknown above
# % make first lap direction opposite of second lap's direction (laps alternate!)
laps.loc[0, "direction"] = -laps.iloc[1].direction
if len(laps) > 1:
laps.loc[0, "direction"] = -laps.iloc[1].direction

# % make sure laps cross the halfway point
middle = np.nanmedian(np.arange(np.nanmin(Vdata), np.nanmax(Vdata)))
Expand All @@ -161,7 +166,7 @@ def find_laps(
break

if good_laps:
laps = find_good_laps(
laps = __find_good_laps(
Vts,
Vdata,
laps,
Expand All @@ -173,11 +178,8 @@ def find_laps(
return laps


def peakdetz(
v: np.ndarray,
delta: float,
lookformax: int = 1,
backwards: int = 0
def __peakdetz(
v: np.ndarray, delta: float, lookformax: int = 1, backwards: int = 0
) -> Tuple[list[Tuple[int, float]], list[Tuple[int, float]]]:
"""
Detect peaks in a vector.
Expand All @@ -196,7 +198,7 @@ def peakdetz(
Returns
-------
tuple[list[tuple[int, float]], list[tuple[int, float]]]
A tuple containing the maxima and minima found in the input vector. Each list contains tuples of
A tuple containing the maxima and minima found in the input vector. Each list contains tuples of
the form (index, value).
"""

Expand Down Expand Up @@ -237,44 +239,36 @@ def peakdetz(
mnpos = ii

if lookformax:
try:
idx = mx - delta > mintab[-1]
except Exception:
idx = mx - delta > mintab

if (this < mx - delta) | ((ii == last - 1) & (len(mintab) > 0) & idx):
if this < mx - delta:
maxtab.append((mxpos, mx))
mn = this
mnpos = ii
lookformax = 0
else:
try:
idx = mx - delta < maxtab[-1]
except Exception:
idx = mx - delta < maxtab
if (this > mn + delta) | ((ii == last - 1) & (len(maxtab) > 0) & idx):
if this > mn + delta:
mintab.append((mnpos, mn))
mx = this
mxpos = ii
lookformax = 1

if (len(maxtab) == 0) & (len(mintab) == 0):
if lookformax:
if mx - mn > delta:
maxtab = [mxpos, mx]
else:
if mx - mn > delta:
mintab = [mnpos, mn]
# Handle the last extremum
if lookformax:
if mx - mn > delta:
maxtab.append((mxpos, mx))
else:
if mx - mn > delta:
mintab.append((mnpos, mn))

return maxtab, mintab


def find_good_laps(
ts: np.ndarray,
V_rest: np.ndarray,
laps: pd.DataFrame,
edgethresh: float = 0.1,
completeprop: float = 0.2,
posbins: int = 50
def __find_good_laps(
ts: np.ndarray,
V_rest: np.ndarray,
laps: pd.DataFrame,
edgethresh: float = 0.1,
completeprop: float = 0.2,
posbins: int = 50,
) -> pd.DataFrame:
"""
Find and eliminate laps that have too many NaNs or laps where the rat turns around in the middle.
Expand All @@ -299,6 +293,9 @@ def find_good_laps(
pd.DataFrame
Updated DataFrame with bad laps removed.
"""
# Handle all bad laps
if laps["pos"].isna().all():
return pd.DataFrame(columns=laps.columns)

if (
edgethresh > 1
Expand Down Expand Up @@ -331,36 +328,37 @@ def find_good_laps(
else:
endoflap = laps.iloc[lap + 1].start_ts

v = V_rest[
np.where(ts == laps.iloc[lap].start_ts)[0][0] : np.where(ts == endoflap)[0][
0
]
]
t = ts[
np.where(ts == laps.iloc[lap].start_ts)[0][0] : np.where(ts == endoflap)[0][
0
]
]
# Find the start and end indices for the current lap
start_idx = np.where(ts == laps.iloc[lap].start_ts)[0]
end_idx = np.where(ts == endoflap)[0]

# Skip if no matching timestamps are found
if len(start_idx) == 0 or len(end_idx) == 0:
lap += 1
continue

v = V_rest[start_idx[0] : end_idx[0]]
t = ts[start_idx[0] : end_idx[0]]

# % find turn around points during this lap
lookformax = laps.iloc[lap].direction == 1
peak, trough = peakdetz(v, delta, lookformax, 0)
peak, trough = __peakdetz(v, delta, lookformax, 0)

if lookformax:
# % find the direct path from bottomend to topend (or mark lap for
# % deleting if the turn around points are not in those ranges)
if len(trough) > 0:
# % find the last trough in range of bottomend (start of lap)
gt = len(trough)
while (gt > 0) & (trough(gt, 2) >= 2 * delta + bottomend):
while (gt > 0) & (trough[gt - 1][1] >= bottomend + 2 * delta):
gt = gt - 1

# % assign the next peak after that trough as the end of the lap
# % (or mark lap for deleting, if that peak is not at topend)
if gt == 0:
if peak[1, 2] > topend - 2 * delta:
t = t[0 : peak[0]]
v = v[0 : peak[0]]
if len(peak) > 0 and peak[0][1] > topend - 2 * delta:
t = t[0 : peak[0][0]]
v = v[0 : peak[0][0]]
else:
# % this marks the lap for deleting
t = t[0:5]
Expand All @@ -372,16 +370,16 @@ def find_good_laps(
t = t[0:2]
v = v[0:2]
else:
t = t[trough[gt, 1] : peak[gt + 1, 1]]
v = v[trough[gt, 1] : peak[gt + 1, 1]]
t = t[trough[gt - 1][0] : peak[gt][0]]
v = v[trough[gt - 1][0] : peak[gt][0]]

else:
# % make sure peak exists and is in range of topend
if len(peak) == 0:
if len(t) > 2:
t = t[0:2]
v = v[0:2]
elif peak[1] < topend - 2 * delta:
elif len(peak) > 0 and peak[0][1] < topend - 2 * delta:
# % this marks the lap for deleting
if len(t) > 5:
t = t[0:5]
Expand All @@ -392,14 +390,15 @@ def find_good_laps(
if len(peak) > 0:
# % find the last peak in range of topend (start of lap)
gt = len(peak)
while (gt > 0) & (peak[gt, 2] <= topend - 2 * delta):
while (gt > 0) & (peak[gt - 1][1] <= topend - 2 * delta):
gt = gt - 1

# % assign the next trough after that peak as the end of the lap
# % (or mark lap for deleting, if that trough is not at bottomend)
if gt == 0:
if trough(1, 2) < bottomend + 2 * delta:
t = t[1 : trough[0]]
v = v[1 : trough[0]]
if len(trough) > 0 and trough[0][1] < bottomend + 2 * delta:
t = t[1 : trough[0][0]]
v = v[1 : trough[0][0]]
else:
# % this marks the lap for deleting
t = t[0:5]
Expand All @@ -411,20 +410,20 @@ def find_good_laps(
v = v[0:2]
gt = 0
else:
t = t[peak[gt, 1] : trough[gt + 1, 1]]
v = v[peak[gt, 1] : trough[gt + 1, 1]]
t = t[peak[gt - 1][0] : trough[gt][0]]
v = v[peak[gt - 1][0] : trough[gt][0]]
else: # % if ~isempty(peak)
# % make sure trough exists and is in range of bottomend
if len(trough) == 0:
if len(t) > 2:
t = t[0:2]
v = v[0:2]

elif trough[1] > bottomend + 2 * delta:
elif len(trough) > 0 and trough[0][1] > bottomend + 2 * delta:
# % this marks the lap for deleting
if len(t) > 5:
t = t[0:5]
v = v[0:5]

vcovered, _ = np.histogram(v, bins=bins)

if len(v) < 3:
Expand Down Expand Up @@ -478,11 +477,11 @@ def get_linear_track_lap_epochs(
good_laps: bool = False,
edgethresh: float = 0.1,
completeprop: float = 0.2,
posbins: int = 50
posbins: int = 50,
) -> Tuple[nel.EpochArray, nel.EpochArray]:
"""
Identifies lap epochs on a linear track and classifies them into outbound and inbound directions.
Parameters:
----------
ts : np.ndarray
Expand Down Expand Up @@ -512,7 +511,7 @@ def get_linear_track_lap_epochs(
- This function calls `find_laps` to determine the lap structure, then segregates epochs into outbound and inbound directions.
- The EpochArray objects represent the start and stop timestamps for each identified lap.
"""
laps = find_laps(
laps = __find_laps(
np.array(ts),
np.array(x),
newLapThreshold=newLapThreshold,
Expand All @@ -522,6 +521,10 @@ def get_linear_track_lap_epochs(
posbins=posbins,
)

# Handle no laps
if len(laps) == 0:
return nel.EpochArray(), nel.EpochArray()

outbound_start = []
outbound_stop = []
inbound_start = []
Expand All @@ -543,15 +546,15 @@ def get_linear_track_lap_epochs(


def find_good_lap_epochs(
pos: nel.AnalogSignalArray,
dir_epoch: nel.EpochArray,
thres: float = 0.5,
binsize: int = 6,
min_laps: int = 10
pos: nel.AnalogSignalArray,
dir_epoch: nel.EpochArray,
thres: float = 0.5,
binsize: int = 6,
min_laps: int = 10,
) -> nel.EpochArray:
"""
Find good laps in behavior data for replay analysis.
Find good laps in behavior data
Parameters
----------
pos : nelpy.AnalogSignalArray
Expand All @@ -564,20 +567,24 @@ def find_good_lap_epochs(
Size of the bins for calculating occupancy, by default 6.
min_laps : int, optional
Minimum number of laps required to consider laps as 'good', by default 10.
Returns
-------
nelpy.EpochArray
An EpochArray containing the good laps based on the occupancy threshold.
Returns an empty EpochArray if no good laps are found or if the number
Returns an empty EpochArray if no good laps are found or if the number
of laps is less than `min_laps`.
Notes
-----
The function calculates the percent occupancy over position bins per lap,
and identifies laps that meet the occupancy threshold criteria. The laps
The function calculates the percent occupancy over position bins per lap,
and identifies laps that meet the occupancy threshold criteria. The laps
that meet this condition are returned as an EpochArray.
"""
# Ensure the input data is valid
if pos.isempty or dir_epoch.isempty:
return nel.EpochArray()

# make bin edges to calc occupancy
x_edges = np.arange(np.nanmin(pos.data[0]), np.nanmax(pos.data[0]), binsize)
# initialize occupancy matrix (position x time)
Expand Down
Loading