Skip to content

Commit

Permalink
better handling of different missing inputs in interactive filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-g committed Dec 19, 2024
1 parent 94872fa commit 593944f
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 39 deletions.
53 changes: 38 additions & 15 deletions cryodrgn/commands/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
# Choose an epoch yourself; save final selection to `indices.pkl` without prompting
$ cryodrgn filter my_outdir --epoch 30 -f
# Choose another epoch; this time choose file name but pre-select directory to save in
$ cryodrgn filter my_outdir --epoch 30 --sel-dir /data/my_indices/
# If you have done multiple k-means clusterings, you can pick which one to use
$ cryodrgn filter my_outdir/ -k 25
Expand All @@ -44,7 +47,7 @@
from scipy.spatial import transform
from typing import Optional, Sequence

from cryodrgn import analysis, utils
from cryodrgn import analysis, utils, dataset

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -89,6 +92,7 @@ def add_args(parser: argparse.ArgumentParser) -> None:


def main(args: argparse.Namespace) -> None:
"""Launching the interactive interface for filtering particles from command-line."""
workdir = args.outdir
epoch = args.epoch
kmeans = args.kmeans
Expand Down Expand Up @@ -129,23 +133,40 @@ def main(args: argparse.Namespace) -> None:

# Load poses and initial indices for plotting if they have been specified
rot, trans = utils.load_pkl(pose_pkl)
ctf_params = utils.load_pkl(train_configs["dataset_args"]["ctf"])
pre_indices = None if plot_inds is None else utils.load_pkl(plot_inds)
all_indices = np.array(range(ctf_params.shape[0]))
if train_configs["dataset_args"]["ctf"] is not None:
ctf_params = utils.load_pkl(train_configs["dataset_args"]["ctf"])
else:
ctf_params = None

# Load the set of indices used to filter original dataset and apply it to inputs
pre_indices = None if plot_inds is None else utils.load_pkl(plot_inds)
if ctf_params is not None:
all_indices = np.array(range(ctf_params.shape[0]))
else:
all_indices = np.array(
range(
dataset.ImageDataset(
mrcfile=train_configs["dataset_args"]["particles"], lazy=True
).N
)
)

# Load the set of indices used to filter the original dataset and apply it to inputs
if isinstance(train_configs["dataset_args"]["ind"], int):
indices = slice(train_configs["dataset_args"]["ind"])
# We only need to filter the poses if they weren't generated by the model
rot = rot[indices, :, :]
trans = trans[indices, :]
elif train_configs["dataset_args"]["ind"]:
elif train_configs["dataset_args"]["ind"] is not None:
indices = utils.load_pkl(train_configs["dataset_args"]["ind"])
ctf_params = ctf_params[indices, :]
else:
indices = None

if indices is not None:
ctf_params = ctf_params[indices, :] if ctf_params is not None else None
all_indices = all_indices[indices]

# We only need to filter the poses if they weren't generated by the model
rot = rot[indices, :, :]
trans = trans[indices, :]
if "poses" in train_configs["dataset_args"]:
rot = rot[indices, :, :]
trans = trans[indices, :]

pc, pca = analysis.run_pca(z)
umap = utils.load_pkl(os.path.join(anlzdir, "umap.pkl"))
Expand Down Expand Up @@ -192,12 +213,14 @@ def main(args: argparse.Namespace) -> None:
trans=trans,
labels=kmeans_lbls,
umap=umap,
df1=ctf_params[:, 2],
df2=ctf_params[:, 3],
dfang=ctf_params[:, 4],
phase=ctf_params[:, 8],
znorm=znorm,
)
if ctf_params is not None:
plot_df["df1"] = ctf_params[:, 2]
plot_df["df2"] = ctf_params[:, 3]
plot_df["dfang"] = ctf_params[:, 4]
plot_df["phase"] = ctf_params[:, 8]

# Tilt-series outputs have tilt-level CTFs and poses but particle-level model
# results, thus we ignore the former in this case for now
else:
Expand Down
59 changes: 35 additions & 24 deletions tests/test_reconstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,29 +128,6 @@ def test_analyze(self, tmpdir_factory, particles, poses, ctf, indices, epoch):

assert os.path.exists(os.path.join(outdir, f"analyze.{epoch}"))

@pytest.mark.parametrize(
"ctf, epoch",
[
("CTF-Test", 3),
("CTF-Test", None),
pytest.param(
None,
None,
marks=pytest.mark.xfail(raises=NotImplementedError),
),
],
indirect=["ctf"],
)
def test_interactive_filtering(
self, tmpdir_factory, particles, poses, ctf, indices, epoch
):
"""Launch interface for filtering particles using model covariates."""
outdir = self.get_outdir(tmpdir_factory, particles, indices, poses, ctf)
parser = argparse.ArgumentParser()
filter.add_args(parser)
epoch_args = ["--epoch", str(epoch)] if epoch is not None else list()
filter.main(parser.parse_args([outdir] + epoch_args))

@pytest.mark.parametrize(
"nb_lbl, ctf",
[
Expand Down Expand Up @@ -186,6 +163,29 @@ def test_notebooks(self, tmpdir_factory, particles, poses, ctf, indices, nb_lbl)

os.chdir(orig_cwd)

@pytest.mark.parametrize(
"ctf, epoch",
[
("CTF-Test", 3),
("CTF-Test", None),
pytest.param(
None,
None,
marks=pytest.mark.xfail(raises=NotImplementedError),
),
],
indirect=["ctf"],
)
def test_interactive_filtering(
self, tmpdir_factory, particles, poses, ctf, indices, epoch
):
"""Launch interface for filtering particles using model covariates."""
outdir = self.get_outdir(tmpdir_factory, particles, indices, poses, ctf)
parser = argparse.ArgumentParser()
filter.add_args(parser)
epoch_args = ["--epoch", str(epoch)] if epoch is not None else list()
filter.main(parser.parse_args([outdir] + epoch_args))

@pytest.mark.parametrize(
"ctf, downsample_dim, flip_vol",
[
Expand Down Expand Up @@ -553,7 +553,18 @@ def test_notebooks(self, tmpdir_factory, particles, ctf, indices, nb_lbl):
ExecutePreprocessor(timeout=600, kernel_name="python3").preprocess(nb_in)
os.chdir(orig_cwd)

@pytest.mark.parametrize("epoch", [1, None])
@pytest.mark.parametrize(
"epoch",
[
1,
pytest.param(
None,
marks=pytest.mark.xfail(
raises=ValueError, reason="missing analysis epoch"
),
),
],
)
def test_interactive_filtering(
self, tmpdir_factory, particles, ctf, indices, epoch
):
Expand Down

0 comments on commit 593944f

Please sign in to comment.