diff --git a/cryodrgn/commands/filter.py b/cryodrgn/commands/filter.py index 10481882..f228b616 100644 --- a/cryodrgn/commands/filter.py +++ b/cryodrgn/commands/filter.py @@ -18,6 +18,9 @@ # Choose an epoch yourself; save final selection to `indices.pkl` without prompting $ cryodrgn filter my_outdir --epoch 30 -f +# Choose another epoch; this time choose file name but pre-select directory to save in +$ cryodrgn filter my_outdir --epoch 30 --sel-dir /data/my_indices/ + # If you have done multiple k-means clusterings, you can pick which one to use $ cryodrgn filter my_outdir/ -k 25 @@ -44,7 +47,7 @@ from scipy.spatial import transform from typing import Optional, Sequence -from cryodrgn import analysis, utils +from cryodrgn import analysis, utils, dataset logger = logging.getLogger(__name__) @@ -89,6 +92,7 @@ def add_args(parser: argparse.ArgumentParser) -> None: def main(args: argparse.Namespace) -> None: + """Launching the interactive interface for filtering particles from command-line.""" workdir = args.outdir epoch = args.epoch kmeans = args.kmeans @@ -129,23 +133,40 @@ def main(args: argparse.Namespace) -> None: # Load poses and initial indices for plotting if they have been specified rot, trans = utils.load_pkl(pose_pkl) - ctf_params = utils.load_pkl(train_configs["dataset_args"]["ctf"]) - pre_indices = None if plot_inds is None else utils.load_pkl(plot_inds) - all_indices = np.array(range(ctf_params.shape[0])) + if train_configs["dataset_args"]["ctf"] is not None: + ctf_params = utils.load_pkl(train_configs["dataset_args"]["ctf"]) + else: + ctf_params = None # Load the set of indices used to filter original dataset and apply it to inputs + pre_indices = None if plot_inds is None else utils.load_pkl(plot_inds) + if ctf_params is not None: + all_indices = np.array(range(ctf_params.shape[0])) + else: + all_indices = np.array( + range( + dataset.ImageDataset( + mrcfile=train_configs["dataset_args"]["particles"], lazy=True + ).N + ) + ) + + # Load the set of indices used to filter the original dataset and apply it to inputs if isinstance(train_configs["dataset_args"]["ind"], int): indices = slice(train_configs["dataset_args"]["ind"]) - # We only need to filter the poses if they weren't generated by the model - rot = rot[indices, :, :] - trans = trans[indices, :] - elif train_configs["dataset_args"]["ind"]: + elif train_configs["dataset_args"]["ind"] is not None: indices = utils.load_pkl(train_configs["dataset_args"]["ind"]) - ctf_params = ctf_params[indices, :] + else: + indices = None + + if indices is not None: + ctf_params = ctf_params[indices, :] if ctf_params is not None else None all_indices = all_indices[indices] + # We only need to filter the poses if they weren't generated by the model - rot = rot[indices, :, :] - trans = trans[indices, :] + if "poses" in train_configs["dataset_args"]: + rot = rot[indices, :, :] + trans = trans[indices, :] pc, pca = analysis.run_pca(z) umap = utils.load_pkl(os.path.join(anlzdir, "umap.pkl")) @@ -192,12 +213,14 @@ def main(args: argparse.Namespace) -> None: trans=trans, labels=kmeans_lbls, umap=umap, - df1=ctf_params[:, 2], - df2=ctf_params[:, 3], - dfang=ctf_params[:, 4], - phase=ctf_params[:, 8], znorm=znorm, ) + if ctf_params is not None: + plot_df["df1"] = ctf_params[:, 2] + plot_df["df2"] = ctf_params[:, 3] + plot_df["dfang"] = ctf_params[:, 4] + plot_df["phase"] = ctf_params[:, 8] + # Tilt-series outputs have tilt-level CTFs and poses but particle-level model # results, thus we ignore the former in this case for now else: diff --git a/tests/test_reconstruct.py b/tests/test_reconstruct.py index 571b13a3..41daf435 100644 --- a/tests/test_reconstruct.py +++ b/tests/test_reconstruct.py @@ -128,29 +128,6 @@ def test_analyze(self, tmpdir_factory, particles, poses, ctf, indices, epoch): assert os.path.exists(os.path.join(outdir, f"analyze.{epoch}")) - @pytest.mark.parametrize( - "ctf, epoch", - [ - ("CTF-Test", 3), - ("CTF-Test", None), - pytest.param( - None, - None, - marks=pytest.mark.xfail(raises=NotImplementedError), - ), - ], - indirect=["ctf"], - ) - def test_interactive_filtering( - self, tmpdir_factory, particles, poses, ctf, indices, epoch - ): - """Launch interface for filtering particles using model covariates.""" - outdir = self.get_outdir(tmpdir_factory, particles, indices, poses, ctf) - parser = argparse.ArgumentParser() - filter.add_args(parser) - epoch_args = ["--epoch", str(epoch)] if epoch is not None else list() - filter.main(parser.parse_args([outdir] + epoch_args)) - @pytest.mark.parametrize( "nb_lbl, ctf", [ @@ -186,6 +163,29 @@ def test_notebooks(self, tmpdir_factory, particles, poses, ctf, indices, nb_lbl) os.chdir(orig_cwd) + @pytest.mark.parametrize( + "ctf, epoch", + [ + ("CTF-Test", 3), + ("CTF-Test", None), + pytest.param( + None, + None, + marks=pytest.mark.xfail(raises=NotImplementedError), + ), + ], + indirect=["ctf"], + ) + def test_interactive_filtering( + self, tmpdir_factory, particles, poses, ctf, indices, epoch + ): + """Launch interface for filtering particles using model covariates.""" + outdir = self.get_outdir(tmpdir_factory, particles, indices, poses, ctf) + parser = argparse.ArgumentParser() + filter.add_args(parser) + epoch_args = ["--epoch", str(epoch)] if epoch is not None else list() + filter.main(parser.parse_args([outdir] + epoch_args)) + @pytest.mark.parametrize( "ctf, downsample_dim, flip_vol", [ @@ -553,7 +553,18 @@ def test_notebooks(self, tmpdir_factory, particles, ctf, indices, nb_lbl): ExecutePreprocessor(timeout=600, kernel_name="python3").preprocess(nb_in) os.chdir(orig_cwd) - @pytest.mark.parametrize("epoch", [1, None]) + @pytest.mark.parametrize( + "epoch", + [ + 1, + pytest.param( + None, + marks=pytest.mark.xfail( + raises=ValueError, reason="missing analysis epoch" + ), + ), + ], + ) def test_interactive_filtering( self, tmpdir_factory, particles, ctf, indices, epoch ):