Skip to content

Commit

Permalink
fixing issue with using indices with non-ragged tilt-series, adding a…
Browse files Browse the repository at this point in the history
…ssociated tests
  • Loading branch information
michal-g committed May 5, 2024
1 parent 1ccf011 commit d131c5a
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 8 deletions.
19 changes: 13 additions & 6 deletions cryodrgn/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,29 +190,36 @@ def __getitem__(self, index):
return images, tilt_indices, index

@classmethod
def parse_particle_tilt(cls, tiltstar):
def parse_particle_tilt(
cls, tiltstar: str
) -> tuple[list[np.ndarray], dict[np.int64, int]]:
# Parse unique particles from _rlnGroupName
s = starfile.Starfile.load(tiltstar)
group_name = list(s.df["_rlnGroupName"])
particles = OrderedDict()

for ii, gn in enumerate(group_name):
if gn not in particles:
particles[gn] = []
particles[gn].append(ii)
particles = np.array(
[np.asarray(pp, dtype=int) for pp in particles.values()], dtype=object
)

particles = [np.asarray(pp, dtype=int) for pp in particles.values()]
particles_to_tilts = particles
tilts_to_particles = {}

for i, j in enumerate(particles):
for jj in j:
tilts_to_particles[jj] = i

return particles_to_tilts, tilts_to_particles

@classmethod
def particles_to_tilts(cls, particles_to_tilts, particles):
tilts = [particles_to_tilts[i] for i in particles]
def particles_to_tilts(
cls, particles_to_tilts: list[np.ndarray], particles: np.ndarray
) -> np.ndarray:
tilts = [particles_to_tilts[int(i)] for i in particles]
tilts = np.concatenate(tilts)

return tilts

@classmethod
Expand Down
Binary file modified testing/data/ind4.pkl
Binary file not shown.
Binary file added testing/data/ind5.pkl
Binary file not shown.
108 changes: 106 additions & 2 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
import argparse
import os.path
import pickle
import nbformat
from nbconvert.preprocessors import ExecutePreprocessor

Expand All @@ -18,6 +19,9 @@
abinit_homo,
abinit_het,
)
from cryodrgn.commands_utils import filter_star
from cryodrgn.dataset import TiltSeriesData
import cryodrgn.utils

DATA_FOLDER = os.path.join(os.path.dirname(__file__), "..", "testing", "data")

Expand Down Expand Up @@ -312,8 +316,8 @@ def test_train_model(self, outdir, star_particles, indices_file):
)
@pytest.mark.parametrize(
"indices_file",
[None, os.path.join(DATA_FOLDER, "ind4.pkl")],
ids=("no-ind", "ind4"),
[os.path.join(DATA_FOLDER, "ind4.pkl"), os.path.join(DATA_FOLDER, "ind5.pkl")],
ids=("ind4", "ind5"),
)
class TestStarFixedHetero:
"""Run reconstructions using particles from a .star file as input."""
Expand Down Expand Up @@ -351,6 +355,65 @@ def test_train_model(self, outdir, star_particles, indices_file):
args = train_vae.add_args(argparse.ArgumentParser()).parse_args(args)
train_vae.main(args)

def test_filter_command(self, outdir, star_particles, indices_file):
# filter the tilt-series particles
args = [
star_particles,
"--ind",
indices_file,
"--et",
"-o",
os.path.join(outdir, "filtered_sta_testing_bin8.star"),
]
parser = argparse.ArgumentParser()
filter_star.add_args(parser)
filter_star.main(parser.parse_args(args))

# need to filter poses and CTFs manually due to tilt indices
pt, tp = TiltSeriesData.parse_particle_tilt(star_particles)
indices = cryodrgn.utils.load_pkl(indices_file)
new_indices = indices[:3]
tilt_indices = TiltSeriesData.particles_to_tilts(pt, indices)

rot, trans = cryodrgn.utils.load_pkl(self.poses_file)
rot, trans = rot[tilt_indices], trans[tilt_indices]
cryodrgn.utils.save_pkl(
(rot, trans), os.path.join(outdir, "filtered_sta_pose.pkl")
)

ctf_mat = cryodrgn.utils.load_pkl(self.ctf_file)[tilt_indices]
cryodrgn.utils.save_pkl(ctf_mat, os.path.join(outdir, "filtered_sta_ctf.pkl"))
cryodrgn.utils.save_pkl(new_indices, os.path.join(outdir, "filtered_ind.pkl"))

args = [
os.path.join(outdir, "filtered_sta_testing_bin8.star"),
"--datadir",
DATA_FOLDER,
"--encode-mode",
"tilt",
"--ntilts",
"5",
"--poses",
os.path.join(outdir, "filtered_sta_pose.pkl"),
"--ctf",
os.path.join(outdir, "filtered_sta_ctf.pkl"),
"--ind",
os.path.join(outdir, "filtered_ind.pkl"),
"--num-epochs",
"5",
"--zdim",
"4",
"-o",
os.path.join(outdir, "filtered"),
"--tdim",
"16",
"--enc-dim",
"16",
"--dec-dim",
"16",
]
train_vae.main(train_vae.add_args(argparse.ArgumentParser()).parse_args(args))

def test_analyze(self, outdir, star_particles, indices_file):
"""Produce standard analyses for a particular epoch."""
args = analyze.add_args(argparse.ArgumentParser()).parse_args(
Expand Down Expand Up @@ -380,6 +443,47 @@ def test_notebooks(self, outdir, nb_lbl, star_particles, indices_file):
ExecutePreprocessor(timeout=600, kernel_name="python3").preprocess(nb_in)
os.chdir(os.path.join("..", ".."))

def test_refiltering(self, outdir, star_particles, indices_file):
"""Use particle index creating during analysis."""
os.chdir(os.path.join(outdir, "analyze.4"))
assert os.path.exists("tmp_ind_selected.pkl"), "Upstream tests have failed!"

with open("tmp_ind_selected.pkl", "rb") as f:
indices = pickle.load(f)

new_indices = indices[:3]
with open("tmp_ind_selected.pkl", "wb") as f:
pickle.dump(new_indices, f)

args = [
star_particles,
"--datadir",
DATA_FOLDER,
"--encode-mode",
"tilt",
"--poses",
self.poses_file,
"--ctf",
self.ctf_file,
"--ind",
"tmp_ind_selected.pkl",
"--num-epochs",
"5",
"--zdim",
"4",
"-o",
outdir,
"--tdim",
"16",
"--enc-dim",
"16",
"--dec-dim",
"16",
]

args = train_vae.add_args(argparse.ArgumentParser()).parse_args(args)
train_vae.main(args)


@pytest.mark.parametrize(
"star_particles",
Expand Down

0 comments on commit d131c5a

Please sign in to comment.