From 4c6fffc7664588db77b26c07b996ded4d94b0356 Mon Sep 17 00:00:00 2001 From: Joel Greer Date: Tue, 17 Sep 2024 20:44:37 +0100 Subject: [PATCH 01/24] Added simple save button to cryodrgn filter utility --- cryodrgn/commands/filter.py | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/cryodrgn/commands/filter.py b/cryodrgn/commands/filter.py index 4edd5f32..c03ce0d5 100644 --- a/cryodrgn/commands/filter.py +++ b/cryodrgn/commands/filter.py @@ -38,8 +38,9 @@ import seaborn as sns from matplotlib import colors from matplotlib.backend_bases import Event, MouseButton -from matplotlib.widgets import LassoSelector, RadioButtons +from matplotlib.widgets import LassoSelector, RadioButtons, Button from matplotlib.path import Path as PlotPath +from matplotlib.gridspec import GridSpec from scipy.spatial import transform from typing import Optional, Sequence @@ -206,7 +207,7 @@ def main(args: argparse.Namespace) -> None: ) selector = SelectFromScatter(plot_df, pre_indices) - input("Press Enter after making your selection...") + # input("Press Enter after making your selection...") selected_indices = [all_indices[i] for i in selector.indices] plt.close() # Close the figure to avoid interference with other plots @@ -272,7 +273,7 @@ def __init__( self.scatter = None self.fig = plt.figure(constrained_layout=True) - gs = self.fig.add_gridspec(2, 3, width_ratios=[1, 7, 1]) + gs = self.gridspec() self.main_ax = self.fig.add_subplot(gs[:, 1]) self.select_cols = [ @@ -295,6 +296,11 @@ def __init__( self.menu_y = RadioButtons(rax, labels=self.select_cols, active=1) self.menu_y.on_clicked(self.update_yaxis) + # add save button only when selection is made + self.btn_loc = self.fig.add_subplot(gs[2, 0]) + self.sv_btn = Button(self.btn_loc, "Save Selection!", color="darkgreen") + self.btn_loc.set_visible(False) + cax = self.fig.add_subplot(gs[:, 2]) cax.axis("off") cax.set_title("choose\ncolors", size=13) @@ -314,6 +320,9 @@ def __init__( self.plot() + def gridspec(self) -> GridSpec: + return self.fig.add_gridspec(3, 3, width_ratios=[1, 7, 1], height_ratios=[5, 5, 1]) + def plot(self) -> None: self.main_ax.clear() pnt_colors = ["gray" for _ in range(self.data_table.shape[0])] @@ -322,6 +331,14 @@ def plot(self) -> None: for idx in self.indices: pnt_colors[idx] = "goldenrod" + # with selection, set save button visible + self.btn_loc.set_visible(True) + self.sv_btn.on_clicked(self.save_click) + + elif ~len(self.indices): + # remove save button if no selection is made + self.btn_loc.set_visible(False) + elif self.color_col != "None": clr_vals = self.data_table[self.color_col] @@ -367,7 +384,7 @@ def use_norm(x): va="bottom", transform=self.main_ax.transAxes, ) - plt.show(block=False) + plt.show() plt.draw() def update_xaxis(self, xlbl: str) -> None: @@ -446,3 +463,9 @@ def on_release(self, event: Event) -> None: self.handl_id = self.fig.canvas.mpl_connect( "motion_notify_event", self.hover_points ) + + def save_click(self, event: Event) -> None: + """When the save button is clicked, we close display.""" + if hasattr(event, "button") and event.button is MouseButton.LEFT: + # close the plt so we can move onto saving it + plt.close("all") From 979048348c01a49e63e343c4e2b50f90ec7706a0 Mon Sep 17 00:00:00 2001 From: alkinkaz Date: Wed, 16 Oct 2024 05:02:22 -0400 Subject: [PATCH 02/24] fix: lscape full now evals vol in cuda --- cryodrgn/commands/analyze_landscape_full.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cryodrgn/commands/analyze_landscape_full.py b/cryodrgn/commands/analyze_landscape_full.py index c107ba65..713973f2 100644 --- a/cryodrgn/commands/analyze_landscape_full.py +++ b/cryodrgn/commands/analyze_landscape_full.py @@ -9,6 +9,7 @@ $ cryodrgn analyze_landscape_full 005_train-vae/ 39 -N 4000 -d 256 """ + import argparse import os import os.path @@ -210,9 +211,11 @@ def generate_and_map_volumes( pca = utils.load_pkl(pca_obj_pkl) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # Load model weights logger.info("Loading weights from {}".format(weights)) - model, lattice = HetOnlyVAE.load(cfg, weights) + model, lattice = HetOnlyVAE.load(cfg, weights, device) model.eval() # Set z From 44cca659eb2b3cb8a4650d5fbd10b91e67585375 Mon Sep 17 00:00:00 2001 From: alkinkaz Date: Fri, 25 Oct 2024 15:56:44 -0400 Subject: [PATCH 03/24] fix: mrcfile in landscape template notebook --- cryodrgn/templates/cryoDRGN_analyze_landscape_template.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cryodrgn/templates/cryoDRGN_analyze_landscape_template.ipynb b/cryodrgn/templates/cryoDRGN_analyze_landscape_template.ipynb index 95981184..35dc12ed 100644 --- a/cryodrgn/templates/cryoDRGN_analyze_landscape_template.ipynb +++ b/cryodrgn/templates/cryoDRGN_analyze_landscape_template.ipynb @@ -21,7 +21,7 @@ "import subprocess\n", "import os, sys\n", "\n", - "from cryodrgn import mrc\n", + "from cryodrgn import mrcfile as mrc\n", "from cryodrgn import analysis\n", "from cryodrgn import utils\n", "from cryodrgn import dataset\n", From 69ed6155611bdf060de4ab86c5d48abfab2a2f95 Mon Sep 17 00:00:00 2001 From: alkinkaz Date: Mon, 28 Oct 2024 17:14:42 -0400 Subject: [PATCH 04/24] add: volpca grid plotting to lscape template --- .../cryoDRGN_analyze_landscape_template.ipynb | 100 +++++++++++++++++- 1 file changed, 99 insertions(+), 1 deletion(-) diff --git a/cryodrgn/templates/cryoDRGN_analyze_landscape_template.ipynb b/cryodrgn/templates/cryoDRGN_analyze_landscape_template.ipynb index 35dc12ed..793c0d6f 100644 --- a/cryodrgn/templates/cryoDRGN_analyze_landscape_template.ipynb +++ b/cryodrgn/templates/cryoDRGN_analyze_landscape_template.ipynb @@ -38,7 +38,11 @@ "\n", "from sklearn.decomposition import PCA\n", "from sklearn.cluster import AgglomerativeClustering\n", - "from scipy.spatial.distance import cdist" + "from scipy.spatial.distance import cdist\n", + "\n", + "import matplotlib.ticker as ticker\n", + "import matplotlib.patches as mpatches\n", + "import matplotlib.gridspec as gridspec" ] }, { @@ -113,6 +117,7 @@ "outputs": [], "source": [ "vol_pc = utils.load_pkl(f'{landscape_dir}/vol_pca_{K}.pkl')\n", + "vol_pca = utils.load_pkl(f'{landscape_dir}/vol_pca_obj.pkl')\n", "vol_pc_all = utils.load_pkl(f'{landscape_full_dir}/vol_pca_all.pkl')" ] }, @@ -288,6 +293,52 @@ " plt.savefig('volpca_landscape_energy.pdf')" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Grid plot landscape -- energy scale\n", + "\n", + "# Set up the triangular grid layout\n", + "n_pcs = 5 # CHANGE ME IF NEEDED\n", + "fig = plt.figure(figsize=(15, 15))\n", + "gs = gridspec.GridSpec(n_pcs-1, n_pcs-1, wspace=0, hspace=0)\n", + "\n", + "# Define the color map and color bar axis\n", + "cmap = 'jet_r'\n", + "norm = plt.Normalize(vmin=0, vmax=5)\n", + "cbar_ax = fig.add_axes([0.92, 0.25, 0.02, 0.5]) # Adjust position as needed\n", + "\n", + "# Loop over each subplot location in the triangular grid\n", + "for i in range(1, n_pcs):\n", + " for j in range(i): \n", + " \n", + " ax = fig.add_subplot(gs[i-1, j])\n", + "\n", + " # Plot hexbin with color map and bins as log scale\n", + " hb = ax.hexbin(vol_pc_all[:, j], vol_pc_all[:, i], gridsize=50, cmap=cmap, bins='log', mincnt=1)\n", + " \n", + " # Only set labels for leftmost and bottom plots\n", + " if j == 0:\n", + " ax.set_ylabel(f'Volume PC{i+1} (EV: {vol_pca.explained_variance_ratio_[i]:.0%})',\n", + " fontsize=14, fontweight='bold')\n", + "\n", + " if i == n_pcs-1:\n", + " ax.set_xlabel(f'Volume PC{j+1} (EV: {vol_pca.explained_variance_ratio_[j]:.0%})',\n", + " fontsize=14, fontweight='bold')\n", + " \n", + " # Exact values are not needed\n", + " ax.set_yticks([])\n", + " ax.set_xticks([])\n", + "\n", + "plt.colorbar(hb, cax=cbar_ax, label='Log Density')\n", + "\n", + "if save_pdf:\n", + " plt.savefig(f'volpca_grid{n_pcs}_landscape_energy.pdf')" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -320,6 +371,53 @@ " plt.savefig('volpca_clusters_all.pdf')" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Set up the triangular grid layout\n", + "n_pcs = 5 # CHANGE ME if needed\n", + "fig = plt.figure(figsize=(15, 15))\n", + "gs = gridspec.GridSpec(n_pcs-1, n_pcs-1, wspace=0, hspace=0)\n", + "\n", + "# Define the color map for cluster labels\n", + "cmap = 'tab20'\n", + "\n", + "# Loop over each subplot location in the triangular grid\n", + "for i in range(1, n_pcs):\n", + " for j in range(i): \n", + " ax = fig.add_subplot(gs[i-1, j])\n", + "\n", + " # Plot background scatter with light gray points\n", + " ax.scatter(vol_pc_all[:, j], vol_pc_all[:, i], color='lightgrey', s=1, alpha=0.1, rasterized=True)\n", + " \n", + " # Overlay labeled scatter points with color coding\n", + " sc = ax.scatter(vol_pc[:, j], vol_pc[:, i], c=labels, cmap=cmap, s=25, edgecolor='white', linewidths=0.25)\n", + "\n", + " # Only set labels for leftmost and bottom plots\n", + " if j == 0:\n", + " ax.set_ylabel(f'Volume PC{i+1} (EV: {vol_pca.explained_variance_ratio_[i]:.0%})',\n", + " fontsize=14, fontweight='bold')\n", + " if i == n_pcs-1:\n", + " ax.set_xlabel(f'Volume PC{j+1} (EV: {vol_pca.explained_variance_ratio_[j]:.0%})',\n", + " fontsize=14, fontweight='bold')\n", + "\n", + " # Remove ticks for cleaner look\n", + " ax.xaxis.set_major_locator(ticker.NullLocator())\n", + " ax.yaxis.set_major_locator(ticker.NullLocator())\n", + "\n", + "# Create a legend outside the grid\n", + "unique_labels = np.unique(labels)\n", + "colors = [sc.cmap(sc.norm(label)) for label in unique_labels]\n", + "patches = [mpatches.Patch(color=colors[k], label=f'Cluster {unique_labels[k]}') for k in range(len(unique_labels))]\n", + "fig.legend(handles=patches, fontsize=20)\n", + "\n", + "if save_pdf:\n", + " plt.savefig(f'volpca_grid{n_pcs}_clusters_all.pdf')" + ] + }, { "cell_type": "markdown", "metadata": {}, From cafb13d9858c195ebc1a9a45805960b8c7d7ca1b Mon Sep 17 00:00:00 2001 From: alkinkaz Date: Mon, 4 Nov 2024 20:06:26 -0500 Subject: [PATCH 05/24] add: volPCA scree plot --- .../cryoDRGN_analyze_landscape_template.ipynb | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/cryodrgn/templates/cryoDRGN_analyze_landscape_template.ipynb b/cryodrgn/templates/cryoDRGN_analyze_landscape_template.ipynb index 793c0d6f..f4b9dc1a 100644 --- a/cryodrgn/templates/cryoDRGN_analyze_landscape_template.ipynb +++ b/cryodrgn/templates/cryoDRGN_analyze_landscape_template.ipynb @@ -239,6 +239,45 @@ "save_pdf = False" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Scree plot for volume PCA" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "explained_variance_ratio = vol_pca.explained_variance_ratio_\n", + "cumulative_variance = np.cumsum(explained_variance_ratio)\n", + "\n", + "# percent\n", + "explained_variance_ratio_percent = explained_variance_ratio * 100\n", + "cumulative_variance_percent = cumulative_variance * 100\n", + "\n", + "# plot EV\n", + "plt.figure(figsize=(10, 6))\n", + "plt.plot(range(1, len(explained_variance_ratio) + 1), explained_variance_ratio_percent, marker='o', linestyle='--')\n", + "plt.xlabel('Principal Component')\n", + "plt.ylabel('Explained Variance Ratio (%)')\n", + "plt.title('Scree Plot for Volume PCA')\n", + "\n", + "# x-axis ticks\n", + "plt.xticks(range(1, len(explained_variance_ratio) + 1))\n", + "\n", + "# Plot cumulative EV\n", + "ax2 = plt.gca().twinx()\n", + "ax2.plot(range(1, len(cumulative_variance_percent) + 1), cumulative_variance_percent, marker='o', color='gray', linestyle='-')\n", + "ax2.set_ylabel('Cumulative Explained Variance (%)')\n", + "\n", + "if save_pdf:\n", + " plt.savefig('volpca_scree.pdf')" + ] + }, { "cell_type": "markdown", "metadata": {}, From acde19def090027904396335c034aa82a4292057 Mon Sep 17 00:00:00 2001 From: Michal Grzadkowski Date: Tue, 5 Nov 2024 12:47:27 -0500 Subject: [PATCH 06/24] fixing use of interactive filtering with filtered ab-initio reconstruction outputs --- cryodrgn/commands/filter.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/cryodrgn/commands/filter.py b/cryodrgn/commands/filter.py index 751029b7..6980a56d 100644 --- a/cryodrgn/commands/filter.py +++ b/cryodrgn/commands/filter.py @@ -115,38 +115,35 @@ def main(args: argparse.Namespace) -> None: ) z = utils.load_pkl(os.path.join(workdir, f"z.{epoch}.pkl")) - # Load poses either from input file or from reconstruction results if ab-initio + # Get poses either from input file or from reconstruction results if ab-initio if "poses" in train_configs["dataset_args"]: pose_pkl = train_configs["dataset_args"]["poses"] else: pose_pkl = os.path.join(workdir, f"pose.{epoch}.pkl") + # Load poses and initial indices for plotting if they have been specified rot, trans = utils.load_pkl(pose_pkl) ctf_params = utils.load_pkl(train_configs["dataset_args"]["ctf"]) + pre_indices = None if plot_inds is None else utils.load_pkl(plot_inds) all_indices = np.array(range(ctf_params.shape[0])) # Load the set of indices used to filter original dataset and apply it to inputs if isinstance(train_configs["dataset_args"]["ind"], int): - ctf_params = ctf_params[: train_configs["dataset_args"]["ind"], :] - all_indices = all_indices[: train_configs["dataset_args"]["ind"]] + indices = slice(train_configs["dataset_args"]["ind"]) + else: + indices = utils.load_pkl(train_configs["dataset_args"]["ind"]) + + ctf_params = ctf_params[indices, :] + all_indices = all_indices[indices] - elif isinstance(train_configs["dataset_args"]["ind"], str): - inds = utils.load_pkl(train_configs["dataset_args"]["ind"]) - ctf_params = ctf_params[inds, :] - rot = rot[inds, :, :] - trans = trans[inds, :] - all_indices = all_indices[inds] + # We only need to filter the poses if they weren't generated by the model + if "poses" in train_configs["dataset_args"]: + rot = rot[indices, :, :] + trans = trans[indices, :] pc, pca = analysis.run_pca(z) umap = utils.load_pkl(os.path.join(anlzdir, "umap.pkl")) - # Load preselected indices for initial plotting if they have been specified - if plot_inds: - with open(plot_inds, "rb") as file: - pre_indices = pickle.load(file) - else: - pre_indices = None - if kmeans == -1: kmeans_dirs = [ d From 087b72f840565c6bd78697dc2f14a663f118f6f7 Mon Sep 17 00:00:00 2001 From: Michal Grzadkowski Date: Tue, 5 Nov 2024 13:55:12 -0500 Subject: [PATCH 07/24] better input error messages for write_star --- cryodrgn/commands_utils/write_star.py | 33 +++++++++++++++++---------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/cryodrgn/commands_utils/write_star.py b/cryodrgn/commands_utils/write_star.py index 7d105bff..38b28049 100644 --- a/cryodrgn/commands_utils/write_star.py +++ b/cryodrgn/commands_utils/write_star.py @@ -54,7 +54,7 @@ def add_args(parser: argparse.ArgumentParser) -> None: parser.add_argument("particles", help="Input particles (.mrcs, .txt, .star)") parser.add_argument( - "-o", type=os.path.abspath, required=True, help="Output .star file" + "-o", "--outfile", type=os.path.abspath, required=True, help="Output .star file" ) parser.add_argument("--ctf", help="Input ctf.pkl") @@ -81,14 +81,11 @@ def add_args(parser: argparse.ArgumentParser) -> None: def main(args: argparse.Namespace) -> None: - assert args.o.endswith(".star"), "Output file must be .star file" - input_ext = os.path.splitext(args.particles)[-1] - assert input_ext in ( - ".mrcs", - ".txt", - ".star", - ), "Input file must be .mrcs/.txt/.star" + if input_ext not in {".mrcs", ".txt", ".star"}: + raise ValueError(f"Input `{args.particles}` not a .mrcs/.txt/.star file!") + if not args.outfile.endswith(".star"): + raise ValueError(f"Output `{args.outfile}` not .star file!") # Either accept an input star file, or an input .mrcs/.txt with CTF .pkl # and an optional pose .pkl file(s) @@ -101,8 +98,8 @@ def main(args: argparse.Namespace) -> None: ) if args.ctf is not None: raise ValueError( - "--ctf cannot be specified when input is a starfile " - "(ctf information are obtained from starfile)" + "--ctf cannot be specified when input is a .star file " + "(CTF information is obtained from .star file)" ) else: if not args.ctf: @@ -123,8 +120,13 @@ def main(args: argparse.Namespace) -> None: f"{len(particles)} != {len(ctf)}, " f"Number of particles != number of CTF parameters" ) + if args.poses: poses = utils.load_pkl(args.poses) + if not isinstance(poses, tuple) or len(poses) != 2: + raise ValueError( + f"Unrecognized pose format found in given file `{args.poses}`!" + ) if len(particles) != len(poses[0]): raise ValueError( f"{len(particles)} != {len(poses)}, " @@ -132,15 +134,22 @@ def main(args: argparse.Namespace) -> None: ) # load the particle filter if given and apply it to the CTF and poses data - ind = np.arange(particles.n) if args.ind: ind = utils.load_pkl(args.ind) + if np.array(ind).ndim != 1: + raise ValueError( + f"Unrecognized indices format found in given file `{args.ind}`!" + ) + logger.info(f"Filtering to {len(ind)} particles") if ctf is not None: ctf = ctf[ind] if poses is not None: poses = (poses[0][ind], poses[1][ind]) + else: + ind = np.arange(particles.n) + # When the input is already a .star file, we just filter the data table directly if input_ext == ".star": assert isinstance(particles, StarfileSource) @@ -213,4 +222,4 @@ def main(args: argparse.Namespace) -> None: df = pd.DataFrame(data=data) - write_star(args.o, data=df, data_optics=optics) + write_star(args.outfile, data=df, data_optics=optics) From d1f2d486335c368f6453bd5c2a6da730100c4aac Mon Sep 17 00:00:00 2001 From: alkinkaz Date: Fri, 8 Nov 2024 13:34:01 -0500 Subject: [PATCH 08/24] add: make_movies init --- cryodrgn/command_line.py | 2 + cryodrgn/commands_utils/make_movies.py | 247 +++++++++++++++++++++++++ 2 files changed, 249 insertions(+) create mode 100644 cryodrgn/commands_utils/make_movies.py diff --git a/cryodrgn/command_line.py b/cryodrgn/command_line.py index 7448a73b..b3f44acc 100644 --- a/cryodrgn/command_line.py +++ b/cryodrgn/command_line.py @@ -11,6 +11,7 @@ since automated scanning for command modules is computationally non-trivial. """ + import argparse import os from importlib import import_module @@ -122,6 +123,7 @@ def util_commands() -> None: "fsc", "gen_mask", "invert_contrast", + "make_movies", "phase_flip", "plot_classes", "plot_fsc", diff --git a/cryodrgn/commands_utils/make_movies.py b/cryodrgn/commands_utils/make_movies.py new file mode 100644 index 00000000..f84ba0e0 --- /dev/null +++ b/cryodrgn/commands_utils/make_movies.py @@ -0,0 +1,247 @@ +"""Make mp4 movies of .mrc's produced by analyze* commands + +Example usage +------------- +# Latent k-means and PCA movies +$ cryodrgn_utils make_movies spr_runs/07/out 19 latent --iso=31 --camera="-0.03377,-0.97371,0.22528,545.75,0.89245,-0.13085,-0.43175,87.21,0.44988,0.18647,0.87341,1039" + +# Volume PCA movies +$ cryodrgn_utils make_movies spr_runs/07/out 19 volume --iso=210 --camera="0.12868,-0.9576,0.25778,95.4,-0.85972,-0.23728,-0.45231,15.356,0.4943,-0.16341,-0.8538,-33.755" + +""" + +import argparse +import os +from datetime import datetime as dt +import logging +import subprocess +from pathlib import Path + +logger = logging.getLogger(__name__) + +generate_movie_prologue = lambda w, h: [ + "set bgColor white", + "graphics silhouettes true", + f"windowsize {w} {h}", +] + + +def generate_movie_epilogue(cam_matrix, iso_threshold, num_vols, directory, all_cornfl): + + l = [ + f"view matrix camera {cam_matrix}", + f"vol all level {iso_threshold}", + "surface dust all size 10", + "lighting soft", + "mov record", + "mseries all pause 1 step 1", + f"wait {num_vols}", + f"mov encode {directory}/movie.mp4 framerate 3", + "exit", + ] + + if all_cornfl: + l = ["vol all color cornfl"] + l + + return l + + +def add_args(parser: argparse.ArgumentParser) -> None: + """The command-line arguments for use with `cryodrgn_utils make_movies`.""" + + parser.add_argument( + "workdir", type=os.path.abspath, help="Directory with cryoDRGN results" + ) + parser.add_argument( + "epoch", + type=str, + help="Epoch number N to analyze " + "(0-based indexing, corresponding to z.N.pkl, weights.N.pkl)", + ) + parser.add_argument( + "type", + type=str, + help="Analysis type to generate movies for ('latent' or 'volume')", + ) + + parser.add_argument( + "--iso", + required=True, + type=str, + help="Isosurface threshold for the movies", + ) + parser.add_argument( + "--camera", + required=True, + type=str, + help="Camera matrix string for the movies", + ) + + parser.add_argument( + "--width", + type=int, + help="Video width in pixels (default: 600)", + ) + parser.add_argument( + "--height", + type=int, + help="Video height in pixels (default: 800)", + ) + + parser.add_argument( + "--analysis-dir", + type=os.path.abspath, + help="Latent space analysis directory (default: [workdir]/analyze.[epoch])", + ) + + parser.add_argument( + "--landscape-dir", + type=os.path.abspath, + help="Landscape analysis directory (default: [workdir]/landscape.[epoch])", + ) + + +def check_chimerax_installation() -> bool: + """Checks if ChimeraX is installed.""" + + command = "chimerax" + try: + subprocess.run( + [command, "--version"], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + logger.info(f"{command} is installed.") + return True + except FileNotFoundError: + logger.info(f"{command} is not installed, aborting.") + return False + except subprocess.CalledProcessError: + logger.info( + f"{command} is installed, but there was an issue executing it. Aborting." + ) + return False + + +def find_subdirs(directory: str, keyword: str) -> list[Path]: + directory_path = Path(directory) + subdirs = [p for p in directory_path.rglob(f"{keyword}*") if p.is_dir()] + + values = [p.name.split(keyword)[-1] for p in subdirs if p.name.startswith(keyword)] + logger.info(f"{len(subdirs)} {keyword} directories were found: {', '.join(values)}") + + return subdirs + + +def get_vols(directory: Path) -> list[str]: + vol_files = sorted(directory.glob("*.mrc")) + logger.info(f"{len(vol_files)} volumes were found in {directory}") + return [f"open {vol_file}" for vol_file in vol_files] + + +def record_movie( + dir_list: list[Path], + iso: str, + cam_matrix: str, + prologue: list[str], + all_cornfl: bool, +): + """Movie recording subroutine""" + + for directory in dir_list: + + vols = get_vols(directory) + + epilogue = generate_movie_epilogue( + cam_matrix, iso, len(vols), directory, all_cornfl + ) + + movie_commands = prologue + vols + epilogue + movie_script = "\n".join(movie_commands) + + script_path = f"{directory}/temp.cxc" + + with open(script_path, "w") as file: + file.write(movie_script) + + logger.info(f"Running chimerax subprocess for movie making.") + subprocess.run( + ["chimerax", "--offscreen", script_path], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + logger.info(f"Movie saved in {directory}/movie.mp4") + os.remove(script_path) + + +def latent_movies( + analysis_dir: str, iso: str, cam_matrix: str, width: int, height: int +): + """Record the movies for latent space analysis""" + + prologue = generate_movie_prologue(width, height) + + # 1) kmeans movies + kmeans_dirs = find_subdirs(analysis_dir, "kmeans") + record_movie(kmeans_dirs, iso, cam_matrix, prologue, False) + + # 2) pc movies + pc_dirs = find_subdirs(analysis_dir, "pc") + record_movie(pc_dirs, iso, cam_matrix, prologue, True) + + return + + +def landscape_movies( + landscape_dir: str, iso: str, cam_matrix: str, width: int, height: int +): + """Record the movies for volume space landscape analysis""" + + prologue = generate_movie_prologue(width, height) + + vol_pc_dirs = find_subdirs(f"{landscape_dir}/vol_pcs", "pc") + record_movie(vol_pc_dirs, iso, cam_matrix, prologue, True) + + return + + +def main(args: argparse.Namespace) -> None: + t1 = dt.now() + logger.info(args) + + # parsing args + E = args.epoch + workdir = args.workdir + analysis_type = args.type + cam_matrix = args.camera + iso = args.iso + width = 600 if args.width is None else args.width + height = 800 if args.height is None else args.height + + # checking chimerax + if not check_chimerax_installation(): + return + + if analysis_type != "latent" and analysis_type != "volume": + logger.info("Analysis type unrecognized, aborting.") + return + + analysis_dir = ( + f"{workdir}/analyze.{E}" if args.analysis_dir is None else args.analysis_dir + ) + landscape_dir = ( + f"{workdir}/landscape.{E}" if args.landscape_dir is None else args.landscape_dir + ) + + if analysis_type == "latent": + logger.info(f"Working in {analysis_dir}") + latent_movies(analysis_dir, iso, cam_matrix, width, height) + else: + logger.info(f"Working in {landscape_dir}") + landscape_movies(landscape_dir, iso, cam_matrix, width, height) + + td = dt.now() - t1 + logger.info(f"Finished in {td}") + return From eb591ed54fc7f24e0d72ae0050de05a10547523a Mon Sep 17 00:00:00 2001 From: alkinkaz Date: Thu, 14 Nov 2024 17:06:42 -0500 Subject: [PATCH 09/24] fix: FutureWarning about amp --- cryodrgn/commands/train_vae.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cryodrgn/commands/train_vae.py b/cryodrgn/commands/train_vae.py index 86b670cc..b557e3ae 100755 --- a/cryodrgn/commands/train_vae.py +++ b/cryodrgn/commands/train_vae.py @@ -17,6 +17,7 @@ --zdim 8 --num-epochs 50 --beta .025 """ + import argparse import os import pickle @@ -376,7 +377,7 @@ def train_batch( y = preprocess_input(y, lattice, trans) # Cast operations to mixed precision if using torch.cuda.amp.GradScaler() if scaler is not None: - with torch.cuda.amp.autocast_mode.autocast(): + with torch.amp.autocast("cuda"): z_mu, z_logvar, z, y_recon, mask = run_batch( model, lattice, y, rot, ntilts, ctf_params, yr ) From b895d76afdfa9beec8ea969e923ddeb5d02847fc Mon Sep 17 00:00:00 2001 From: alkinkaz Date: Thu, 14 Nov 2024 17:49:57 -0500 Subject: [PATCH 10/24] add: custom movie naming --- cryodrgn/commands_utils/make_movies.py | 33 ++++++++++++++++---------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/cryodrgn/commands_utils/make_movies.py b/cryodrgn/commands_utils/make_movies.py index f84ba0e0..88b1cd71 100644 --- a/cryodrgn/commands_utils/make_movies.py +++ b/cryodrgn/commands_utils/make_movies.py @@ -6,7 +6,7 @@ $ cryodrgn_utils make_movies spr_runs/07/out 19 latent --iso=31 --camera="-0.03377,-0.97371,0.22528,545.75,0.89245,-0.13085,-0.43175,87.21,0.44988,0.18647,0.87341,1039" # Volume PCA movies -$ cryodrgn_utils make_movies spr_runs/07/out 19 volume --iso=210 --camera="0.12868,-0.9576,0.25778,95.4,-0.85972,-0.23728,-0.45231,15.356,0.4943,-0.16341,-0.8538,-33.755" +$ cryodrgn_utils make_movies spr_runs/07/out 19 volume --name=front --iso=210 --camera="0.12868,-0.9576,0.25778,95.4,-0.85972,-0.23728,-0.45231,15.356,0.4943,-0.16341,-0.8538,-33.755" """ @@ -26,7 +26,9 @@ ] -def generate_movie_epilogue(cam_matrix, iso_threshold, num_vols, directory, all_cornfl): +def generate_movie_epilogue( + cam_matrix, iso_threshold, num_vols, directory, name, all_cornfl +): l = [ f"view matrix camera {cam_matrix}", @@ -36,7 +38,7 @@ def generate_movie_epilogue(cam_matrix, iso_threshold, num_vols, directory, all_ "mov record", "mseries all pause 1 step 1", f"wait {num_vols}", - f"mov encode {directory}/movie.mp4 framerate 3", + f"mov encode {directory}/{name}.mp4 framerate 3", "exit", ] @@ -77,6 +79,11 @@ def add_args(parser: argparse.ArgumentParser) -> None: help="Camera matrix string for the movies", ) + parser.add_argument( + "--name", + type=str, + help="Video name (default: 'movie')", + ) parser.add_argument( "--width", type=int, @@ -144,6 +151,7 @@ def record_movie( dir_list: list[Path], iso: str, cam_matrix: str, + name: str, prologue: list[str], all_cornfl: bool, ): @@ -154,7 +162,7 @@ def record_movie( vols = get_vols(directory) epilogue = generate_movie_epilogue( - cam_matrix, iso, len(vols), directory, all_cornfl + cam_matrix, iso, len(vols), directory, name, all_cornfl ) movie_commands = prologue + vols + epilogue @@ -172,12 +180,12 @@ def record_movie( stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) - logger.info(f"Movie saved in {directory}/movie.mp4") + logger.info(f"Movie saved in {directory}/{name}.mp4") os.remove(script_path) def latent_movies( - analysis_dir: str, iso: str, cam_matrix: str, width: int, height: int + analysis_dir: str, iso: str, cam_matrix: str, name: str, width: int, height: int ): """Record the movies for latent space analysis""" @@ -185,24 +193,24 @@ def latent_movies( # 1) kmeans movies kmeans_dirs = find_subdirs(analysis_dir, "kmeans") - record_movie(kmeans_dirs, iso, cam_matrix, prologue, False) + record_movie(kmeans_dirs, iso, cam_matrix, name, prologue, False) # 2) pc movies pc_dirs = find_subdirs(analysis_dir, "pc") - record_movie(pc_dirs, iso, cam_matrix, prologue, True) + record_movie(pc_dirs, iso, cam_matrix, name, prologue, True) return def landscape_movies( - landscape_dir: str, iso: str, cam_matrix: str, width: int, height: int + landscape_dir: str, iso: str, cam_matrix: str, name: str, width: int, height: int ): """Record the movies for volume space landscape analysis""" prologue = generate_movie_prologue(width, height) vol_pc_dirs = find_subdirs(f"{landscape_dir}/vol_pcs", "pc") - record_movie(vol_pc_dirs, iso, cam_matrix, prologue, True) + record_movie(vol_pc_dirs, iso, cam_matrix, name, prologue, True) return @@ -218,6 +226,7 @@ def main(args: argparse.Namespace) -> None: cam_matrix = args.camera iso = args.iso width = 600 if args.width is None else args.width + name = "movie" if args.name is None else args.name height = 800 if args.height is None else args.height # checking chimerax @@ -237,10 +246,10 @@ def main(args: argparse.Namespace) -> None: if analysis_type == "latent": logger.info(f"Working in {analysis_dir}") - latent_movies(analysis_dir, iso, cam_matrix, width, height) + latent_movies(analysis_dir, iso, cam_matrix, name, width, height) else: logger.info(f"Working in {landscape_dir}") - landscape_movies(landscape_dir, iso, cam_matrix, width, height) + landscape_movies(landscape_dir, iso, cam_matrix, name, width, height) td = dt.now() - t1 logger.info(f"Finished in {td}") From 1d0217b5e994514bcf5cc93b1dfb7bb3696efee5 Mon Sep 17 00:00:00 2001 From: alkinkaz Date: Sun, 17 Nov 2024 17:09:22 -0500 Subject: [PATCH 11/24] fix: cluster means are now correct --- cryodrgn/commands/analyze_landscape.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cryodrgn/commands/analyze_landscape.py b/cryodrgn/commands/analyze_landscape.py index e344c40a..a315c2cc 100644 --- a/cryodrgn/commands/analyze_landscape.py +++ b/cryodrgn/commands/analyze_landscape.py @@ -9,6 +9,7 @@ $ cryodrgn analyze_landscape 005_train-vae/ 39 -N 5000 -d 256 """ + import argparse import os import shutil @@ -354,8 +355,8 @@ def plot(i, j): if vol_ind is not None: vol_i = np.arange(K)[vol_ind][vol_i] - vol_fl = os.path.join(kmean_dir, f"vol_{vol_start_index+i:03d}.mrc") - vol_i_all = torch.stack([torch.Tensor(parse_mrc(vol_fl)[0]) for i in vol_i]) + vol_fl = lambda j: os.path.join(kmean_dir, f"vol_{vol_start_index+j:03d}.mrc") + vol_i_all = torch.stack([torch.Tensor(parse_mrc(vol_fl(j))[0]) for j in vol_i]) nparticles = np.array([kmeans_counts[i] for i in vol_i]) vol_i_mean = np.average(vol_i_all, axis=0, weights=nparticles) vol_i_std = ( From acf640ccdda5c2890be1c4cec515d845357807b0 Mon Sep 17 00:00:00 2001 From: alkinkaz Date: Mon, 18 Nov 2024 13:46:47 -0500 Subject: [PATCH 12/24] add: make_movies iso now optional --- cryodrgn/commands_utils/make_movies.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/cryodrgn/commands_utils/make_movies.py b/cryodrgn/commands_utils/make_movies.py index 88b1cd71..39913261 100644 --- a/cryodrgn/commands_utils/make_movies.py +++ b/cryodrgn/commands_utils/make_movies.py @@ -30,9 +30,16 @@ def generate_movie_epilogue( cam_matrix, iso_threshold, num_vols, directory, name, all_cornfl ): - l = [ - f"view matrix camera {cam_matrix}", - f"vol all level {iso_threshold}", + l = [] + if all_cornfl: + l = ["vol all color cornfl"] + + l = l + [f"view matrix camera {cam_matrix}"] + + if iso_threshold: + l = l + [f"vol all level {iso_threshold}"] + + l = l + [ "surface dust all size 10", "lighting soft", "mov record", @@ -42,9 +49,6 @@ def generate_movie_epilogue( "exit", ] - if all_cornfl: - l = ["vol all color cornfl"] + l - return l @@ -67,18 +71,16 @@ def add_args(parser: argparse.ArgumentParser) -> None: ) parser.add_argument( - "--iso", + "--camera", required=True, type=str, - help="Isosurface threshold for the movies", + help="Camera matrix string for the movies", ) parser.add_argument( - "--camera", - required=True, + "--iso", type=str, - help="Camera matrix string for the movies", + help="Isosurface threshold for the movies (default: ChimeraX default level)", ) - parser.add_argument( "--name", type=str, From 90a117835483ace3e67f58ff86da052de1618502 Mon Sep 17 00:00:00 2001 From: alkinkaz Date: Mon, 18 Nov 2024 15:07:17 -0500 Subject: [PATCH 13/24] add: make_movies for volume space clustering --- cryodrgn/commands_utils/make_movies.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/cryodrgn/commands_utils/make_movies.py b/cryodrgn/commands_utils/make_movies.py index 39913261..14212eb9 100644 --- a/cryodrgn/commands_utils/make_movies.py +++ b/cryodrgn/commands_utils/make_movies.py @@ -143,9 +143,10 @@ def find_subdirs(directory: str, keyword: str) -> list[Path]: return subdirs -def get_vols(directory: Path) -> list[str]: - vol_files = sorted(directory.glob("*.mrc")) - logger.info(f"{len(vol_files)} volumes were found in {directory}") +def get_vols(directory: Path, postfix_regex: str = "") -> list[str]: + """Postfix regex string enables us to find the state_m_mean.mrc volumes""" + vol_files = sorted(directory.glob(f"*{postfix_regex}.mrc")) + logger.info(f"{len(vol_files)} {postfix_regex} volumes were found in {directory}") return [f"open {vol_file}" for vol_file in vol_files] @@ -155,13 +156,14 @@ def record_movie( cam_matrix: str, name: str, prologue: list[str], - all_cornfl: bool, + all_cornfl: bool = False, + vol_postfix_regex: str = "", ): """Movie recording subroutine""" for directory in dir_list: - vols = get_vols(directory) + vols = get_vols(directory, vol_postfix_regex) epilogue = generate_movie_epilogue( cam_matrix, iso, len(vols), directory, name, all_cornfl @@ -211,6 +213,19 @@ def landscape_movies( prologue = generate_movie_prologue(width, height) + # 1) clustering movies + clustering_dirs = find_subdirs(f"{landscape_dir}", "clustering") + record_movie( + clustering_dirs, + iso, + cam_matrix, + name, + prologue, + False, + vol_postfix_regex="mean", + ) + + # 2) pc movies vol_pc_dirs = find_subdirs(f"{landscape_dir}/vol_pcs", "pc") record_movie(vol_pc_dirs, iso, cam_matrix, name, prologue, True) From 6a1e93e5251e2e2b0e0320af8e7da8340a6e52c7 Mon Sep 17 00:00:00 2001 From: Michal Grzadkowski Date: Mon, 18 Nov 2024 18:25:44 -0500 Subject: [PATCH 14/24] fixing k-means cluster mean volume calculation in analyze_landscape --- cryodrgn/commands/analyze_landscape.py | 46 +++++++++++++++----------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/cryodrgn/commands/analyze_landscape.py b/cryodrgn/commands/analyze_landscape.py index e344c40a..562c088e 100644 --- a/cryodrgn/commands/analyze_landscape.py +++ b/cryodrgn/commands/analyze_landscape.py @@ -348,47 +348,55 @@ def plot(i, j): kmeans_labels = utils.load_pkl(os.path.join(outdir, f"kmeans{K}", "labels.pkl")) kmeans_counts = Counter(kmeans_labels) - for i in range(M): - vol_i = np.where(labels == i)[0] - logger.info(f"State {i}: {len(vol_i)} volumes") + for cluster_i in range(M): + vol_indices = np.where(labels == cluster_i)[0] + logger.info(f"State {cluster_i}: {len(vol_indices)} volumes") if vol_ind is not None: - vol_i = np.arange(K)[vol_ind][vol_i] + vol_indices = np.arange(K)[vol_ind][vol_indices] + + vol_fls = [ + os.path.join(kmean_dir, f"vol_{vol_start_index + vol_i:03d}.mrc") + for vol_i in vol_indices + ] + vol_i_all = torch.stack( + [torch.Tensor(parse_mrc(vol_fl)[0]) for vol_fl in vol_fls] + ) - vol_fl = os.path.join(kmean_dir, f"vol_{vol_start_index+i:03d}.mrc") - vol_i_all = torch.stack([torch.Tensor(parse_mrc(vol_fl)[0]) for i in vol_i]) - nparticles = np.array([kmeans_counts[i] for i in vol_i]) + nparticles = np.array([kmeans_counts[vol_i] for vol_i in vol_indices]) vol_i_mean = np.average(vol_i_all, axis=0, weights=nparticles) vol_i_std = ( np.average((vol_i_all - vol_i_mean) ** 2, axis=0, weights=nparticles) ** 0.5 ) + write_mrc( - os.path.join(subdir, f"state_{i}_mean.mrc"), + os.path.join(subdir, f"state_{cluster_i}_mean.mrc"), vol_i_mean.astype(np.float32), Apix=Apix, ) write_mrc( - os.path.join(subdir, f"state_{i}_std.mrc"), + os.path.join(subdir, f"state_{cluster_i}_std.mrc"), vol_i_std.astype(np.float32), Apix=Apix, ) - os.makedirs(os.path.join(subdir, f"state_{i}"), exist_ok=True) - for v in vol_i: - os.symlink( - os.path.join(kmean_dir, f"vol_{vol_start_index+v:03d}.mrc"), - os.path.join(subdir, f"state_{i}", f"vol_{vol_start_index+v:03d}.mrc"), - ) + statedir = os.path.join(subdir, f"state_{cluster_i}") + os.makedirs(statedir, exist_ok=True) + for vol_i in vol_indices: + kmean_fl = os.path.join(kmean_dir, f"vol_{vol_start_index+vol_i:03d}.mrc") + sub_fl = os.path.join(statedir, f"vol_{vol_start_index+vol_i:03d}.mrc") + os.symlink(kmean_fl, sub_fl) - particle_ind = analysis.get_ind_for_cluster(kmeans_labels, vol_i) - logger.info(f"State {i}: {len(particle_ind)} particles") + particle_ind = analysis.get_ind_for_cluster(kmeans_labels, vol_indices) + logger.info(f"State {cluster_i}: {len(particle_ind)} particles") if particle_ind_orig is not None: utils.save_pkl( particle_ind_orig[particle_ind], - os.path.join(subdir, f"state_{i}_particle_ind.pkl"), + os.path.join(subdir, f"state_{cluster_i}_particle_ind.pkl"), ) else: utils.save_pkl( - particle_ind, os.path.join(subdir, f"state_{i}_particle_ind.pkl") + particle_ind, + os.path.join(subdir, f"state_{cluster_i}_particle_ind.pkl"), ) # plot clustering results From c09d065374dae227ebe97ad53a5884992ac613c3 Mon Sep 17 00:00:00 2001 From: alkinkaz Date: Mon, 18 Nov 2024 23:56:38 -0500 Subject: [PATCH 15/24] fix: only find top subdirectories, not recursively --- cryodrgn/commands_utils/make_movies.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cryodrgn/commands_utils/make_movies.py b/cryodrgn/commands_utils/make_movies.py index 14212eb9..4b0c538e 100644 --- a/cryodrgn/commands_utils/make_movies.py +++ b/cryodrgn/commands_utils/make_movies.py @@ -135,9 +135,11 @@ def check_chimerax_installation() -> bool: def find_subdirs(directory: str, keyword: str) -> list[Path]: directory_path = Path(directory) - subdirs = [p for p in directory_path.rglob(f"{keyword}*") if p.is_dir()] + subdirs = [ + p for p in directory_path.iterdir() if p.is_dir() and p.name.startswith(keyword) + ] - values = [p.name.split(keyword)[-1] for p in subdirs if p.name.startswith(keyword)] + values = [p.name.split(keyword)[-1] for p in subdirs] logger.info(f"{len(subdirs)} {keyword} directories were found: {', '.join(values)}") return subdirs From ccd179eade886070be30f354239fdb27b0a5fd88 Mon Sep 17 00:00:00 2001 From: Joel Greer Date: Thu, 21 Nov 2024 21:20:56 +0000 Subject: [PATCH 16/24] added --sel-dir for filter output dir --- cryodrgn/commands/filter.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/cryodrgn/commands/filter.py b/cryodrgn/commands/filter.py index c03ce0d5..69a0d8c3 100644 --- a/cryodrgn/commands/filter.py +++ b/cryodrgn/commands/filter.py @@ -81,6 +81,11 @@ def add_args(parser: argparse.ArgumentParser) -> None: help="path to a file containing previously selected indices " "that will be plotted at the beginning", ) + parser.add_argument( + "--sel-dir", + type=str, + help="directory to save the particle selection into", + ) def main(args: argparse.Namespace) -> None: @@ -88,6 +93,7 @@ def main(args: argparse.Namespace) -> None: epoch = args.epoch kmeans = args.kmeans plot_inds = args.plot_inds + sel_dir = args.sel_dir train_configs_file = os.path.join(workdir, "config.yaml") if not os.path.exists(train_configs_file): @@ -236,6 +242,8 @@ def main(args: argparse.Namespace) -> None: if save_option == "yes": if args.force: filename = "indices" + if args.sel_dir: + filename = os.path.join(args.sel_dir, filename) else: filename = input( "Enter filename to save selection (absolute, without extension): " From ee9a762f8f5691ad4e3aa982c96b02a40a2765ff Mon Sep 17 00:00:00 2001 From: alkinkaz Date: Fri, 22 Nov 2024 17:50:08 -0500 Subject: [PATCH 17/24] add: frame rate as an argument --- cryodrgn/commands_utils/make_movies.py | 42 ++++++++++++++++++++------ 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/cryodrgn/commands_utils/make_movies.py b/cryodrgn/commands_utils/make_movies.py index 4b0c538e..ac38bc54 100644 --- a/cryodrgn/commands_utils/make_movies.py +++ b/cryodrgn/commands_utils/make_movies.py @@ -27,7 +27,7 @@ def generate_movie_epilogue( - cam_matrix, iso_threshold, num_vols, directory, name, all_cornfl + cam_matrix, iso_threshold, num_vols, directory, name, all_cornfl, framerate ): l = [] @@ -45,7 +45,7 @@ def generate_movie_epilogue( "mov record", "mseries all pause 1 step 1", f"wait {num_vols}", - f"mov encode {directory}/{name}.mp4 framerate 3", + f"mov encode {directory}/{name}.mp4 framerate {framerate}", "exit", ] @@ -81,6 +81,11 @@ def add_args(parser: argparse.ArgumentParser) -> None: type=str, help="Isosurface threshold for the movies (default: ChimeraX default level)", ) + parser.add_argument( + "--frame", + type=int, + help="Frame rate (fps) for the movies (default: 3)", + ) parser.add_argument( "--name", type=str, @@ -158,6 +163,7 @@ def record_movie( cam_matrix: str, name: str, prologue: list[str], + frame_rate: int, all_cornfl: bool = False, vol_postfix_regex: str = "", ): @@ -168,7 +174,7 @@ def record_movie( vols = get_vols(directory, vol_postfix_regex) epilogue = generate_movie_epilogue( - cam_matrix, iso, len(vols), directory, name, all_cornfl + cam_matrix, iso, len(vols), directory, name, all_cornfl, frame_rate ) movie_commands = prologue + vols + epilogue @@ -191,7 +197,13 @@ def record_movie( def latent_movies( - analysis_dir: str, iso: str, cam_matrix: str, name: str, width: int, height: int + analysis_dir: str, + iso: str, + cam_matrix: str, + name: str, + width: int, + height: int, + frame_rate: int, ): """Record the movies for latent space analysis""" @@ -199,17 +211,23 @@ def latent_movies( # 1) kmeans movies kmeans_dirs = find_subdirs(analysis_dir, "kmeans") - record_movie(kmeans_dirs, iso, cam_matrix, name, prologue, False) + record_movie(kmeans_dirs, iso, cam_matrix, name, prologue, frame_rate, False) # 2) pc movies pc_dirs = find_subdirs(analysis_dir, "pc") - record_movie(pc_dirs, iso, cam_matrix, name, prologue, True) + record_movie(pc_dirs, iso, cam_matrix, name, prologue, frame_rate, True) return def landscape_movies( - landscape_dir: str, iso: str, cam_matrix: str, name: str, width: int, height: int + landscape_dir: str, + iso: str, + cam_matrix: str, + name: str, + width: int, + height: int, + frame_rate: int, ): """Record the movies for volume space landscape analysis""" @@ -223,13 +241,14 @@ def landscape_movies( cam_matrix, name, prologue, + frame_rate, False, vol_postfix_regex="mean", ) # 2) pc movies vol_pc_dirs = find_subdirs(f"{landscape_dir}/vol_pcs", "pc") - record_movie(vol_pc_dirs, iso, cam_matrix, name, prologue, True) + record_movie(vol_pc_dirs, iso, cam_matrix, name, prologue, frame_rate, True) return @@ -247,6 +266,7 @@ def main(args: argparse.Namespace) -> None: width = 600 if args.width is None else args.width name = "movie" if args.name is None else args.name height = 800 if args.height is None else args.height + frame_rate = 3 if args.frame is None else args.frame # checking chimerax if not check_chimerax_installation(): @@ -265,10 +285,12 @@ def main(args: argparse.Namespace) -> None: if analysis_type == "latent": logger.info(f"Working in {analysis_dir}") - latent_movies(analysis_dir, iso, cam_matrix, name, width, height) + latent_movies(analysis_dir, iso, cam_matrix, name, width, height, frame_rate) else: logger.info(f"Working in {landscape_dir}") - landscape_movies(landscape_dir, iso, cam_matrix, name, width, height) + landscape_movies( + landscape_dir, iso, cam_matrix, name, width, height, frame_rate + ) td = dt.now() - t1 logger.info(f"Finished in {td}") From 99898273ff402848319c9cd2c40cf49e2703a131 Mon Sep 17 00:00:00 2001 From: Michal Grzadkowski Date: Mon, 16 Dec 2024 18:47:22 -0500 Subject: [PATCH 18/24] addressing torch.amp.autocast FutureWarnings --- cryodrgn/commands/abinit_het.py | 5 ++++- cryodrgn/commands/abinit_homo.py | 5 ++++- cryodrgn/commands/train_nn.py | 6 +++++- cryodrgn/commands/train_vae.py | 5 ++++- 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/cryodrgn/commands/abinit_het.py b/cryodrgn/commands/abinit_het.py index a70afb76..885b0473 100644 --- a/cryodrgn/commands/abinit_het.py +++ b/cryodrgn/commands/abinit_het.py @@ -485,7 +485,10 @@ def train( # We do this in pose-supervised train_vae if scaler is not None: - amp_mode = torch.cuda.amp.autocast_mode.autocast() + try: + amp_mode = torch.amp.autocast("cuda") + except AttributeError: + amp_mode = torch.cuda.amp.autocast_mode.autocast() else: amp_mode = contextlib.nullcontext() diff --git a/cryodrgn/commands/abinit_homo.py b/cryodrgn/commands/abinit_homo.py index 10010532..404afd75 100644 --- a/cryodrgn/commands/abinit_homo.py +++ b/cryodrgn/commands/abinit_homo.py @@ -416,7 +416,10 @@ def train( ) if scaler is not None: - amp_mode = torch.cuda.amp.autocast_mode.autocast() + try: + amp_mode = torch.amp.autocast("cuda") + except AttributeError: + amp_mode = torch.cuda.amp.autocast_mode.autocast() else: amp_mode = contextlib.nullcontext() diff --git a/cryodrgn/commands/train_nn.py b/cryodrgn/commands/train_nn.py index 3baa789d..a76cecf3 100644 --- a/cryodrgn/commands/train_nn.py +++ b/cryodrgn/commands/train_nn.py @@ -302,7 +302,11 @@ def run_model(y): # Cast operations to mixed precision if using torch.cuda.amp.GradScaler() if scaler is not None: - with torch.cuda.amp.autocast_mode.autocast(): + try: + amp_mode = torch.amp.autocast("cuda") + except AttributeError: + amp_mode = torch.cuda.amp.autocast_mode.autocast() + with amp_mode: loss = run_model(y) else: loss = run_model(y) diff --git a/cryodrgn/commands/train_vae.py b/cryodrgn/commands/train_vae.py index 1abdfea8..deb2ceac 100755 --- a/cryodrgn/commands/train_vae.py +++ b/cryodrgn/commands/train_vae.py @@ -378,7 +378,10 @@ def train_batch( # Cast operations to mixed precision if using torch.cuda.amp.GradScaler() if scaler is not None: - amp_mode = torch.cuda.amp.autocast_mode.autocast() + try: + amp_mode = torch.amp.autocast("cuda") + except AttributeError: + amp_mode = torch.cuda.amp.autocast_mode.autocast() else: amp_mode = contextlib.nullcontext() From 25f1566f1d7f2de81d6f1542db6d1e77b619dcdf Mon Sep 17 00:00:00 2001 From: Michal Grzadkowski Date: Mon, 16 Dec 2024 21:42:21 -0500 Subject: [PATCH 19/24] addressing torch.amp.GradScaler FutureWarnings --- cryodrgn/commands/abinit_het.py | 5 ++++- cryodrgn/commands/abinit_homo.py | 5 ++++- cryodrgn/commands/train_nn.py | 5 ++++- cryodrgn/commands/train_vae.py | 5 ++++- 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/cryodrgn/commands/abinit_het.py b/cryodrgn/commands/abinit_het.py index 885b0473..f165d550 100644 --- a/cryodrgn/commands/abinit_het.py +++ b/cryodrgn/commands/abinit_het.py @@ -910,7 +910,10 @@ def main(args): model, optim = amp.initialize(model, optim, opt_level="O1") # mixed precision with pytorch (v1.6+) except: # noqa: E722 - scaler = torch.cuda.amp.grad_scaler.GradScaler() + try: + scaler = torch.amp.GradScaler("cuda") + except AttributeError: + scaler = torch.cuda.amp.grad_scaler.GradScaler() if args.load == "latest": args = get_latest(args) diff --git a/cryodrgn/commands/abinit_homo.py b/cryodrgn/commands/abinit_homo.py index 404afd75..efcf32ec 100644 --- a/cryodrgn/commands/abinit_homo.py +++ b/cryodrgn/commands/abinit_homo.py @@ -679,7 +679,10 @@ def main(args): model, optim = amp.initialize(model, optim, opt_level="O1") # mixed precision with pytorch (v1.6+) except: # noqa: E722 - scaler = torch.cuda.amp.grad_scaler.GradScaler() + try: + scaler = torch.amp.GradScaler("cuda") + except AttributeError: + scaler = torch.cuda.amp.grad_scaler.GradScaler() sorted_poses = [] if args.load: diff --git a/cryodrgn/commands/train_nn.py b/cryodrgn/commands/train_nn.py index a76cecf3..65b7bcee 100644 --- a/cryodrgn/commands/train_nn.py +++ b/cryodrgn/commands/train_nn.py @@ -513,7 +513,10 @@ def main(args: argparse.Namespace) -> None: model, optim = amp.initialize(model, optim, opt_level="O1") # Mixed precision with pytorch (v1.6+) except: # noqa: E722 - scaler = torch.cuda.amp.grad_scaler.GradScaler() + try: + scaler = torch.amp.GradScaler("cuda") + except AttributeError: + scaler = torch.cuda.amp.grad_scaler.GradScaler() # parallelize if args.multigpu and torch.cuda.device_count() > 1: diff --git a/cryodrgn/commands/train_vae.py b/cryodrgn/commands/train_vae.py index deb2ceac..75535127 100755 --- a/cryodrgn/commands/train_vae.py +++ b/cryodrgn/commands/train_vae.py @@ -864,7 +864,10 @@ def main(args: argparse.Namespace) -> None: model, optim = amp.initialize(model, optim, opt_level="O1") # mixed precision with pytorch (v1.6+) except: # noqa: E722 - scaler = torch.cuda.amp.grad_scaler.GradScaler() + try: + scaler = torch.amp.GradScaler("cuda") + except AttributeError: + scaler = torch.cuda.amp.grad_scaler.GradScaler() # restart from checkpoint if args.load: From 6bf91777d8977feb68054f42080bccd17ff9fe73 Mon Sep 17 00:00:00 2001 From: Michal Grzadkowski Date: Tue, 17 Dec 2024 12:27:53 -0500 Subject: [PATCH 20/24] replacing command-line dialogue with button interface in interactive filtering --- cryodrgn/commands/filter.py | 97 +++++++++++++++++++++---------------- 1 file changed, 56 insertions(+), 41 deletions(-) diff --git a/cryodrgn/commands/filter.py b/cryodrgn/commands/filter.py index 28cd61d3..640b4c68 100644 --- a/cryodrgn/commands/filter.py +++ b/cryodrgn/commands/filter.py @@ -37,7 +37,7 @@ import matplotlib.pyplot as plt import seaborn as sns from matplotlib import colors -from matplotlib.backend_bases import Event, MouseButton, MouseEvent +from matplotlib.backend_bases import MouseEvent, MouseButton from matplotlib.widgets import LassoSelector, RadioButtons, Button from matplotlib.path import Path as PlotPath from matplotlib.gridspec import GridSpec @@ -93,7 +93,6 @@ def main(args: argparse.Namespace) -> None: epoch = args.epoch kmeans = args.kmeans plot_inds = args.plot_inds - sel_dir = args.sel_dir train_configs_file = os.path.join(workdir, "config.yaml") if not os.path.exists(train_configs_file): @@ -209,7 +208,6 @@ def main(args: argparse.Namespace) -> None: # Launch the plot and the interactive command-line prompt; once points are selected, # close the figure to avoid interference with other plots selector = SelectFromScatter(plot_df, pre_indices) - # input("Press Enter after making your selection...") selected_indices = [all_indices[i] for i in selector.indices] plt.close() @@ -227,42 +225,33 @@ def main(args: argparse.Namespace) -> None: ) if args.force: - save_option = "yes" + filename = "indices" else: - save_option = ( - input("Do you want to save the selection to file? (yes/no): ") - .strip() - .lower() - ) - - if save_option == "yes": - if args.force: - filename = "indices" - if args.sel_dir: - filename = os.path.join(args.sel_dir, filename) + if args.sel_dir: + sel_msg = f"Enter filename to save selection under {args.sel_dir} " else: - filename = input( - "Enter filename to save selection (absolute, without extension): " - ).strip() + sel_msg = "Enter filename to save selection " + filename = input(sel_msg + "(absolute, without extension):").strip() + if args.sel_dir: + filename = os.path.join(args.sel_dir, filename) - # Saving the selected indices - if filename: - selected_full_path = filename + ".pkl" + # Saving the selected indices + if filename: + selected_full_path = filename + ".pkl" - with open(selected_full_path, "wb") as file: - pickle.dump(np.array(selected_indices, dtype=int), file) - print(f"Selection saved to {selected_full_path}") + with open(selected_full_path, "wb") as file: + pickle.dump(np.array(selected_indices, dtype=int), file) + print(f"Selection saved to `{selected_full_path}`") - # Saving the inverse selection - inverse_filename = filename + "_inverse.pkl" - inverse_indices = np.setdiff1d(all_indices, selected_indices) + # Saving the inverse selection + inverse_filename = filename + "_inverse.pkl" + inverse_indices = np.setdiff1d(all_indices, selected_indices) - with open(inverse_filename, "wb") as file: - pickle.dump(np.array(inverse_indices, dtype=int), file) + with open(inverse_filename, "wb") as file: + pickle.dump(np.array(inverse_indices, dtype=int), file) + + print(f"Inverse selection saved to `{inverse_filename}`") - print(f"Inverse selection saved to {inverse_filename}") - else: - print("Exiting without saving selection.") else: print("Exiting without having made a selection.") @@ -306,9 +295,26 @@ def __init__( self.menu_y.on_clicked(self.update_yaxis) # add save button only when selection is made - self.btn_loc = self.fig.add_subplot(gs[2, 0]) - self.sv_btn = Button(self.btn_loc, "Save Selection!", color="darkgreen") - self.btn_loc.set_visible(False) + self.save_ax = self.fig.add_subplot(gs[2, 0]) + self.exit_ax = self.fig.add_subplot(gs[3, 0]) + self.save_btn = Button( + self.save_ax, + "Save Selection", + color="#164316", + hovercolor="#01BC01", + ) + self.exit_btn = Button( + self.exit_ax, + "Exit Without Saving", + color="#601515", + hovercolor="#BA0B0B", + ) + self.save_btn.label.set_color("white") + self.exit_btn.label.set_color("white") + self.save_btn.on_clicked(self.save_click) + self.exit_btn.on_clicked(self.exit_click) + self.save_ax.set_visible(False) + self.exit_ax.set_visible(True) cax = self.fig.add_subplot(gs[:, 2]) cax.axis("off") @@ -331,7 +337,10 @@ def __init__( self.plot() def gridspec(self) -> GridSpec: - return self.fig.add_gridspec(3, 3, width_ratios=[1, 7, 1], height_ratios=[5, 5, 1]) + """Defines the layout of the plots and menus in the interactive interface.""" + return self.fig.add_gridspec( + 4, 3, width_ratios=[1, 7, 1], height_ratios=[7, 7, 1, 1] + ) def plot(self) -> None: """Redraw the plot using the current plot info upon e.g. input from user.""" @@ -343,12 +352,11 @@ def plot(self) -> None: pnt_colors[idx] = "goldenrod" # with selection, set save button visible - self.btn_loc.set_visible(True) - self.sv_btn.on_clicked(self.save_click) - + self.save_ax.set_visible(True) + elif ~len(self.indices): # remove save button if no selection is made - self.btn_loc.set_visible(False) + self.save_ax.set_visible(False) elif self.color_col != "None": clr_vals = self.data_table[self.color_col] @@ -476,8 +484,15 @@ def on_release(self, event: MouseEvent) -> None: "motion_notify_event", self.hover_points ) - def save_click(self, event: Event) -> None: + def save_click(self, event: MouseEvent) -> None: """When the save button is clicked, we close display.""" if hasattr(event, "button") and event.button is MouseButton.LEFT: # close the plt so we can move onto saving it plt.close("all") + + def exit_click(self, event: MouseEvent) -> None: + """When the exit button is clicked, we clear the selection and close display.""" + if hasattr(event, "button") and event.button is MouseButton.LEFT: + # close the plt so we can move onto saving it + self.indices = list() + plt.close("all") From 7d9c79a07ac0183a3ff60f3eda14ea778d1f4b40 Mon Sep 17 00:00:00 2001 From: Joel Greer Date: Tue, 17 Dec 2024 21:03:21 +0000 Subject: [PATCH 21/24] updated filter.py for cases without existing list of filtered indices --- cryodrgn/commands/filter.py | 128 ++++++++++++++++++++---------------- 1 file changed, 70 insertions(+), 58 deletions(-) diff --git a/cryodrgn/commands/filter.py b/cryodrgn/commands/filter.py index 6e2de563..10481882 100644 --- a/cryodrgn/commands/filter.py +++ b/cryodrgn/commands/filter.py @@ -37,7 +37,7 @@ import matplotlib.pyplot as plt import seaborn as sns from matplotlib import colors -from matplotlib.backend_bases import Event, MouseButton, MouseEvent +from matplotlib.backend_bases import MouseEvent, MouseButton from matplotlib.widgets import LassoSelector, RadioButtons, Button from matplotlib.path import Path as PlotPath from matplotlib.gridspec import GridSpec @@ -93,7 +93,6 @@ def main(args: argparse.Namespace) -> None: epoch = args.epoch kmeans = args.kmeans plot_inds = args.plot_inds - sel_dir = args.sel_dir train_configs_file = os.path.join(workdir, "config.yaml") if not os.path.exists(train_configs_file): @@ -122,38 +121,35 @@ def main(args: argparse.Namespace) -> None: ) z = utils.load_pkl(os.path.join(workdir, f"z.{epoch}.pkl")) - # Load poses either from input file or from reconstruction results if ab-initio + # Get poses either from input file or from reconstruction results if ab-initio if "poses" in train_configs["dataset_args"]: pose_pkl = train_configs["dataset_args"]["poses"] else: pose_pkl = os.path.join(workdir, f"pose.{epoch}.pkl") + # Load poses and initial indices for plotting if they have been specified rot, trans = utils.load_pkl(pose_pkl) ctf_params = utils.load_pkl(train_configs["dataset_args"]["ctf"]) + pre_indices = None if plot_inds is None else utils.load_pkl(plot_inds) all_indices = np.array(range(ctf_params.shape[0])) # Load the set of indices used to filter original dataset and apply it to inputs if isinstance(train_configs["dataset_args"]["ind"], int): - ctf_params = ctf_params[: train_configs["dataset_args"]["ind"], :] - all_indices = all_indices[: train_configs["dataset_args"]["ind"]] - - elif isinstance(train_configs["dataset_args"]["ind"], str): - inds = utils.load_pkl(train_configs["dataset_args"]["ind"]) - ctf_params = ctf_params[inds, :] - rot = rot[inds, :, :] - trans = trans[inds, :] - all_indices = all_indices[inds] + indices = slice(train_configs["dataset_args"]["ind"]) + # We only need to filter the poses if they weren't generated by the model + rot = rot[indices, :, :] + trans = trans[indices, :] + elif train_configs["dataset_args"]["ind"]: + indices = utils.load_pkl(train_configs["dataset_args"]["ind"]) + ctf_params = ctf_params[indices, :] + all_indices = all_indices[indices] + # We only need to filter the poses if they weren't generated by the model + rot = rot[indices, :, :] + trans = trans[indices, :] pc, pca = analysis.run_pca(z) umap = utils.load_pkl(os.path.join(anlzdir, "umap.pkl")) - # Load preselected indices for initial plotting if they have been specified - if plot_inds: - with open(plot_inds, "rb") as file: - pre_indices = pickle.load(file) - else: - pre_indices = None - if kmeans == -1: kmeans_dirs = [ d @@ -212,7 +208,6 @@ def main(args: argparse.Namespace) -> None: # Launch the plot and the interactive command-line prompt; once points are selected, # close the figure to avoid interference with other plots selector = SelectFromScatter(plot_df, pre_indices) - # input("Press Enter after making your selection...") selected_indices = [all_indices[i] for i in selector.indices] plt.close() @@ -230,42 +225,33 @@ def main(args: argparse.Namespace) -> None: ) if args.force: - save_option = "yes" + filename = "indices" else: - save_option = ( - input("Do you want to save the selection to file? (yes/no): ") - .strip() - .lower() - ) - - if save_option == "yes": - if args.force: - filename = "indices" - if args.sel_dir: - filename = os.path.join(args.sel_dir, filename) + if args.sel_dir: + sel_msg = f"Enter filename to save selection under {args.sel_dir} " else: - filename = input( - "Enter filename to save selection (absolute, without extension): " - ).strip() + sel_msg = "Enter filename to save selection " + filename = input(sel_msg + "(absolute, without extension):").strip() + if args.sel_dir: + filename = os.path.join(args.sel_dir, filename) - # Saving the selected indices - if filename: - selected_full_path = filename + ".pkl" + # Saving the selected indices + if filename: + selected_full_path = filename + ".pkl" - with open(selected_full_path, "wb") as file: - pickle.dump(np.array(selected_indices, dtype=int), file) - print(f"Selection saved to {selected_full_path}") + with open(selected_full_path, "wb") as file: + pickle.dump(np.array(selected_indices, dtype=int), file) + print(f"Selection saved to `{selected_full_path}`") - # Saving the inverse selection - inverse_filename = filename + "_inverse.pkl" - inverse_indices = np.setdiff1d(all_indices, selected_indices) + # Saving the inverse selection + inverse_filename = filename + "_inverse.pkl" + inverse_indices = np.setdiff1d(all_indices, selected_indices) - with open(inverse_filename, "wb") as file: - pickle.dump(np.array(inverse_indices, dtype=int), file) + with open(inverse_filename, "wb") as file: + pickle.dump(np.array(inverse_indices, dtype=int), file) + + print(f"Inverse selection saved to `{inverse_filename}`") - print(f"Inverse selection saved to {inverse_filename}") - else: - print("Exiting without saving selection.") else: print("Exiting without having made a selection.") @@ -309,9 +295,26 @@ def __init__( self.menu_y.on_clicked(self.update_yaxis) # add save button only when selection is made - self.btn_loc = self.fig.add_subplot(gs[2, 0]) - self.sv_btn = Button(self.btn_loc, "Save Selection!", color="darkgreen") - self.btn_loc.set_visible(False) + self.save_ax = self.fig.add_subplot(gs[2, 0]) + self.exit_ax = self.fig.add_subplot(gs[3, 0]) + self.save_btn = Button( + self.save_ax, + "Save Selection", + color="#164316", + hovercolor="#01BC01", + ) + self.exit_btn = Button( + self.exit_ax, + "Exit Without Saving", + color="#601515", + hovercolor="#BA0B0B", + ) + self.save_btn.label.set_color("white") + self.exit_btn.label.set_color("white") + self.save_btn.on_clicked(self.save_click) + self.exit_btn.on_clicked(self.exit_click) + self.save_ax.set_visible(False) + self.exit_ax.set_visible(True) cax = self.fig.add_subplot(gs[:, 2]) cax.axis("off") @@ -334,7 +337,10 @@ def __init__( self.plot() def gridspec(self) -> GridSpec: - return self.fig.add_gridspec(3, 3, width_ratios=[1, 7, 1], height_ratios=[5, 5, 1]) + """Defines the layout of the plots and menus in the interactive interface.""" + return self.fig.add_gridspec( + 4, 3, width_ratios=[1, 7, 1], height_ratios=[7, 7, 1, 1] + ) def plot(self) -> None: """Redraw the plot using the current plot info upon e.g. input from user.""" @@ -346,12 +352,11 @@ def plot(self) -> None: pnt_colors[idx] = "goldenrod" # with selection, set save button visible - self.btn_loc.set_visible(True) - self.sv_btn.on_clicked(self.save_click) - + self.save_ax.set_visible(True) + elif ~len(self.indices): # remove save button if no selection is made - self.btn_loc.set_visible(False) + self.save_ax.set_visible(False) elif self.color_col != "None": clr_vals = self.data_table[self.color_col] @@ -479,8 +484,15 @@ def on_release(self, event: MouseEvent) -> None: "motion_notify_event", self.hover_points ) - def save_click(self, event: Event) -> None: + def save_click(self, event: MouseEvent) -> None: """When the save button is clicked, we close display.""" if hasattr(event, "button") and event.button is MouseButton.LEFT: # close the plt so we can move onto saving it plt.close("all") + + def exit_click(self, event: MouseEvent) -> None: + """When the exit button is clicked, we clear the selection and close display.""" + if hasattr(event, "button") and event.button is MouseButton.LEFT: + # close the plt so we can move onto saving it + self.indices = list() + plt.close("all") From 94872fa88102affac57f8de51a5d3e3abd62ce58 Mon Sep 17 00:00:00 2001 From: Michal Grzadkowski Date: Thu, 19 Dec 2024 10:50:49 -0500 Subject: [PATCH 22/24] adding tests of interactive filtering; splitting reconstruction tests according to SPA/ET --- tests/test_reconstruct.py | 367 ++++----------------------------- tests/test_reconstruct_tilt.py | 346 +++++++++++++++++++++++++++++++ 2 files changed, 382 insertions(+), 331 deletions(-) create mode 100644 tests/test_reconstruct_tilt.py diff --git a/tests/test_reconstruct.py b/tests/test_reconstruct.py index f78d7f0b..571b13a3 100644 --- a/tests/test_reconstruct.py +++ b/tests/test_reconstruct.py @@ -4,7 +4,6 @@ import argparse import os.path import shutil -import pickle import random import nbformat from nbclient.exceptions import CellExecutionError @@ -15,19 +14,17 @@ analyze, analyze_landscape, analyze_landscape_full, - backproject_voxel, direct_traversal, eval_images, eval_vol, + filter, graph_traversal, train_nn, train_vae, - abinit_homo, abinit_het, ) -from cryodrgn.commands_utils import clean, filter_star, plot_classes +from cryodrgn.commands_utils import clean, plot_classes from cryodrgn.source import ImageSource -from cryodrgn.dataset import TiltSeriesData from cryodrgn import utils @@ -131,6 +128,29 @@ def test_analyze(self, tmpdir_factory, particles, poses, ctf, indices, epoch): assert os.path.exists(os.path.join(outdir, f"analyze.{epoch}")) + @pytest.mark.parametrize( + "ctf, epoch", + [ + ("CTF-Test", 3), + ("CTF-Test", None), + pytest.param( + None, + None, + marks=pytest.mark.xfail(raises=NotImplementedError), + ), + ], + indirect=["ctf"], + ) + def test_interactive_filtering( + self, tmpdir_factory, particles, poses, ctf, indices, epoch + ): + """Launch interface for filtering particles using model covariates.""" + outdir = self.get_outdir(tmpdir_factory, particles, indices, poses, ctf) + parser = argparse.ArgumentParser() + filter.add_args(parser) + epoch_args = ["--epoch", str(epoch)] if epoch is not None else list() + filter.main(parser.parse_args([outdir] + epoch_args)) + @pytest.mark.parametrize( "nb_lbl, ctf", [ @@ -533,6 +553,17 @@ def test_notebooks(self, tmpdir_factory, particles, ctf, indices, nb_lbl): ExecutePreprocessor(timeout=600, kernel_name="python3").preprocess(nb_in) os.chdir(orig_cwd) + @pytest.mark.parametrize("epoch", [1, None]) + def test_interactive_filtering( + self, tmpdir_factory, particles, ctf, indices, epoch + ): + """Launch interface for filtering particles using model covariates.""" + outdir = self.get_outdir(tmpdir_factory, particles, indices, ctf) + parser = argparse.ArgumentParser() + filter.add_args(parser) + epoch_args = ["--epoch", str(epoch)] if epoch is not None else list() + filter.main(parser.parse_args([outdir] + epoch_args)) + @pytest.mark.parametrize("epoch", [1, 2]) def test_graph_traversal(self, tmpdir_factory, particles, ctf, indices, epoch): outdir = self.get_outdir(tmpdir_factory, particles, indices, ctf) @@ -649,329 +680,3 @@ def test_train_model(self, tmpdir, particles, indices, poses, ctf, datadir): args = train_nn.add_args(argparse.ArgumentParser()).parse_args(args) train_nn.main(args) - - -@pytest.mark.parametrize("particles", ["tilts.star"], indirect=True) -@pytest.mark.parametrize("indices", ["just-4", "just-5"], indirect=True) -@pytest.mark.parametrize("poses", ["tilt-poses"], indirect=True) -@pytest.mark.parametrize("ctf", ["CTF-Tilt"], indirect=True) -@pytest.mark.parametrize("datadir", ["default-datadir"], indirect=True) -class TestTiltFixedHetero: - """Run heterogeneous reconstruction using tilt series from a .star file and poses. - - We use two sets of indices, one that produces a tilt series with all particles - having the same number of tilts and another that produces a ragged tilt-series. - """ - - def get_outdir(self, tmpdir_factory, particles, poses, ctf, indices, datadir): - dirname = os.path.join( - "TiltFixedHetero", - particles.label, - poses.label, - ctf.label, - indices.label, - datadir.label, - ) - odir = os.path.join(tmpdir_factory.getbasetemp(), dirname) - os.makedirs(odir, exist_ok=True) - - return odir - - def test_train_model(self, tmpdir_factory, particles, indices, poses, ctf, datadir): - outdir = self.get_outdir( - tmpdir_factory, particles, poses, ctf, indices, datadir - ) - args = [ - particles.path, - "--datadir", - datadir.path, - "--encode-mode", - "tilt", - "--poses", - poses.path, - "--ctf", - ctf.path, - "--num-epochs", - "5", - "--zdim", - "4", - "-o", - outdir, - "--tdim", - "16", - "--enc-dim", - "16", - "--dec-dim", - "16", - ] - if indices.path is not None: - args += ["--ind", indices.path] - - args = train_vae.add_args(argparse.ArgumentParser()).parse_args(args) - train_vae.main(args) - - def test_filter_command( - self, tmpdir_factory, particles, indices, poses, ctf, datadir - ): - outdir = self.get_outdir( - tmpdir_factory, particles, poses, ctf, indices, datadir - ) - - # filter the tilt-series particles - args = [ - particles.path, - "--et", - "-o", - os.path.join(outdir, "filtered_sta_testing_bin8.star"), - ] - if indices.path is not None: - args += ["--ind", indices.path] - parser = argparse.ArgumentParser() - filter_star.add_args(parser) - filter_star.main(parser.parse_args(args)) - - # need to filter poses and CTFs manually due to tilt indices - pt, tp = TiltSeriesData.parse_particle_tilt(particles.path) - ind = utils.load_pkl(indices.path) - new_ind = ind[:3] - tilt_ind = TiltSeriesData.particles_to_tilts(pt, ind) - - rot, trans = utils.load_pkl(poses.path) - rot, trans = rot[tilt_ind], trans[tilt_ind] - utils.save_pkl((rot, trans), os.path.join(outdir, "filtered_sta_pose.pkl")) - ctf_mat = utils.load_pkl(ctf.path)[tilt_ind] - utils.save_pkl(ctf_mat, os.path.join(outdir, "filtered_sta_ctf.pkl")) - utils.save_pkl(new_ind, os.path.join(outdir, "filtered_ind.pkl")) - args = [ - os.path.join(outdir, "filtered_sta_testing_bin8.star"), - "--datadir", - datadir.path, - "--encode-mode", - "tilt", - "--ntilts", - "5", - "--poses", - os.path.join(outdir, "filtered_sta_pose.pkl"), - "--ctf", - os.path.join(outdir, "filtered_sta_ctf.pkl"), - "--ind", - os.path.join(outdir, "filtered_ind.pkl"), - "--num-epochs", - "5", - "--zdim", - "4", - "-o", - os.path.join(outdir, "filtered"), - "--tdim", - "16", - "--enc-dim", - "16", - "--dec-dim", - "16", - ] - train_vae.main(train_vae.add_args(argparse.ArgumentParser()).parse_args(args)) - - def test_analyze(self, tmpdir_factory, particles, indices, poses, ctf, datadir): - """Produce standard analyses for a particular epoch.""" - outdir = self.get_outdir( - tmpdir_factory, particles, poses, ctf, indices, datadir - ) - - parser = argparse.ArgumentParser() - analyze.add_args(parser) - analyze.main( - parser.parse_args( - [ - outdir, - "4", # Epoch number to analyze - 0-indexed - "--pc", - "3", # Number of principal component traversals to generate - "--ksample", - "2", # Number of kmeans samples to generate - "--vol-start-index", - "1", - ] - ) - ) - assert os.path.exists(os.path.join(outdir, "analyze.4")) - - @pytest.mark.parametrize( - "new_indices_file", - [None, "filtered_ind.pkl"], - ids=("no-new-ind", "new-ind"), - ) - def test_backproject( - self, tmpdir_factory, particles, indices, poses, ctf, datadir, new_indices_file - ): - """Run backprojection using the given particles.""" - outdir = self.get_outdir( - tmpdir_factory, particles, poses, ctf, indices, datadir - ) - args = [ - os.path.join(outdir, "filtered_sta_testing_bin8.star"), - "--datadir", - datadir.path, - "--tilt", - "--poses", - os.path.join(outdir, "filtered_sta_pose.pkl"), - "--ctf", - os.path.join(outdir, "filtered_sta_ctf.pkl"), - "-o", - os.path.join(outdir, "filtered"), - "-d", - "2.93", - "--no-half-maps", - ] - if new_indices_file is not None: - args += ["--ind", os.path.join(outdir, new_indices_file)] - - parser = argparse.ArgumentParser() - backproject_voxel.add_args(parser) - backproject_voxel.main(parser.parse_args(args)) - assert os.path.exists(os.path.join(outdir, "filtered", "backproject.mrc")) - shutil.rmtree(os.path.join(outdir, "filtered")) - - @pytest.mark.parametrize("nb_lbl", ["cryoDRGN_figures", "cryoDRGN_ET_viz"]) - def test_notebooks( - self, tmpdir_factory, particles, indices, poses, ctf, datadir, nb_lbl - ): - """Execute the demonstration Jupyter notebooks produced by analysis.""" - outdir = self.get_outdir( - tmpdir_factory, particles, poses, ctf, indices, datadir - ) - orig_cwd = os.path.abspath(os.getcwd()) - os.chdir(os.path.join(outdir, "analyze.4")) - assert os.path.exists(f"{nb_lbl}.ipynb"), "Upstream tests have failed!" - - with open(f"{nb_lbl}.ipynb") as ff: - nb_in = nbformat.read(ff, nbformat.NO_CONVERT) - - ExecutePreprocessor(timeout=600, kernel_name="python3").preprocess(nb_in) - os.chdir(orig_cwd) - - def test_refiltering(self, tmpdir_factory, particles, indices, poses, ctf, datadir): - """Use particle index creating during analysis.""" - outdir = self.get_outdir( - tmpdir_factory, particles, poses, ctf, indices, datadir - ) - orig_cwd = os.path.abspath(os.getcwd()) - os.chdir(os.path.join(outdir, "analyze.4")) - assert os.path.exists("tmp_ind_selected.pkl"), "Upstream tests have failed!" - - with open("tmp_ind_selected.pkl", "rb") as f: - indices = pickle.load(f) - - new_indices = indices[:3] - with open("tmp_ind_selected.pkl", "wb") as f: - pickle.dump(new_indices, f) - - args = [ - particles.path, - "--datadir", - datadir.path, - "--encode-mode", - "tilt", - "--poses", - poses.path, - "--ctf", - ctf.path, - "--ind", - "tmp_ind_selected.pkl", - "--num-epochs", - "5", - "--zdim", - "4", - "-o", - outdir, - "--tdim", - "16", - "--enc-dim", - "16", - "--dec-dim", - "16", - ] - - args = train_vae.add_args(argparse.ArgumentParser()).parse_args(args) - train_vae.main(args) - os.chdir(orig_cwd) - - shutil.rmtree(outdir) - - -@pytest.mark.parametrize("particles", ["tilts.star"], indirect=True) -@pytest.mark.parametrize("indices", [None, "just-4"], indirect=True) -@pytest.mark.parametrize("ctf", ["CTF-Tilt"], indirect=True) -@pytest.mark.parametrize("datadir", ["default-datadir"], indirect=True) -class TestTiltAbinitHomo: - def test_train_model(self, tmpdir, particles, indices, ctf, datadir): - args = [ - particles.path, - "--datadir", - datadir.path, - "--ctf", - ctf.path, - "-o", - str(tmpdir), - "--dim", - "4", - "--layers", - "2", - "--t-extent", - "4.0", - "--t-ngrid", - "2", - "--pretrain=1", - "--num-epochs", - "3", - "--ps-freq", - "2", - ] - if indices.path is not None: - args += ["--ind", indices.path] - - args = abinit_homo.add_args(argparse.ArgumentParser()).parse_args(args) - abinit_homo.main(args) - - -@pytest.mark.parametrize("particles", ["tilts.star"], indirect=True) -@pytest.mark.parametrize("indices", [None, "just-4"], indirect=True) -@pytest.mark.parametrize("ctf", ["CTF-Tilt"], indirect=True) -@pytest.mark.parametrize("datadir", ["default-datadir"], indirect=True) -class TestTiltAbinitHetero: - def test_train_model(self, tmpdir, particles, indices, ctf, datadir): - args = [ - particles.path, - "--datadir", - datadir.path, - "--ctf", - ctf.path, - "--zdim", - "8", - "-o", - str(tmpdir), - "--enc-dim", - "4", - "--enc-layers", - "2", - "--dec-dim", - "4", - "--dec-layers", - "2", - "--pe-dim", - "4", - "--enc-only", - "--t-extent", - "4.0", - "--t-ngrid", - "2", - "--pretrain=1", - "--num-epochs", - "3", - "--ps-freq", - "2", - ] - if indices.path is not None: - args += ["--ind", indices.path] - - args = abinit_het.add_args(argparse.ArgumentParser()).parse_args(args) - abinit_het.main(args) diff --git a/tests/test_reconstruct_tilt.py b/tests/test_reconstruct_tilt.py new file mode 100644 index 00000000..c19242f2 --- /dev/null +++ b/tests/test_reconstruct_tilt.py @@ -0,0 +1,346 @@ +"""Running an experiment of training followed by downstream analyses.""" + +import pytest +import argparse +import os.path +import shutil +import pickle +import nbformat +from nbconvert.preprocessors import ExecutePreprocessor + +from cryodrgn.commands import ( + analyze, + backproject_voxel, + train_vae, + abinit_homo, + abinit_het, +) +from cryodrgn.commands_utils import filter_star +from cryodrgn.dataset import TiltSeriesData +from cryodrgn import utils + + +@pytest.mark.parametrize("particles", ["tilts.star"], indirect=True) +@pytest.mark.parametrize("indices", ["just-4", "just-5"], indirect=True) +@pytest.mark.parametrize("poses", ["tilt-poses"], indirect=True) +@pytest.mark.parametrize("ctf", ["CTF-Tilt"], indirect=True) +@pytest.mark.parametrize("datadir", ["default-datadir"], indirect=True) +class TestTiltFixedHetero: + """Run heterogeneous reconstruction using tilt series from a .star file and poses. + + We use two sets of indices, one that produces a tilt series with all particles + having the same number of tilts and another that produces a ragged tilt-series. + """ + + def get_outdir(self, tmpdir_factory, particles, poses, ctf, indices, datadir): + dirname = os.path.join( + "TiltFixedHetero", + particles.label, + poses.label, + ctf.label, + indices.label, + datadir.label, + ) + odir = os.path.join(tmpdir_factory.getbasetemp(), dirname) + os.makedirs(odir, exist_ok=True) + + return odir + + def test_train_model(self, tmpdir_factory, particles, indices, poses, ctf, datadir): + outdir = self.get_outdir( + tmpdir_factory, particles, poses, ctf, indices, datadir + ) + args = [ + particles.path, + "--datadir", + datadir.path, + "--encode-mode", + "tilt", + "--poses", + poses.path, + "--ctf", + ctf.path, + "--num-epochs", + "5", + "--zdim", + "4", + "-o", + outdir, + "--tdim", + "16", + "--enc-dim", + "16", + "--dec-dim", + "16", + ] + if indices.path is not None: + args += ["--ind", indices.path] + + args = train_vae.add_args(argparse.ArgumentParser()).parse_args(args) + train_vae.main(args) + + def test_filter_command( + self, tmpdir_factory, particles, indices, poses, ctf, datadir + ): + outdir = self.get_outdir( + tmpdir_factory, particles, poses, ctf, indices, datadir + ) + + # filter the tilt-series particles + args = [ + particles.path, + "--et", + "-o", + os.path.join(outdir, "filtered_sta_testing_bin8.star"), + ] + if indices.path is not None: + args += ["--ind", indices.path] + parser = argparse.ArgumentParser() + filter_star.add_args(parser) + filter_star.main(parser.parse_args(args)) + + # need to filter poses and CTFs manually due to tilt indices + pt, tp = TiltSeriesData.parse_particle_tilt(particles.path) + ind = utils.load_pkl(indices.path) + new_ind = ind[:3] + tilt_ind = TiltSeriesData.particles_to_tilts(pt, ind) + + rot, trans = utils.load_pkl(poses.path) + rot, trans = rot[tilt_ind], trans[tilt_ind] + utils.save_pkl((rot, trans), os.path.join(outdir, "filtered_sta_pose.pkl")) + ctf_mat = utils.load_pkl(ctf.path)[tilt_ind] + utils.save_pkl(ctf_mat, os.path.join(outdir, "filtered_sta_ctf.pkl")) + utils.save_pkl(new_ind, os.path.join(outdir, "filtered_ind.pkl")) + args = [ + os.path.join(outdir, "filtered_sta_testing_bin8.star"), + "--datadir", + datadir.path, + "--encode-mode", + "tilt", + "--ntilts", + "5", + "--poses", + os.path.join(outdir, "filtered_sta_pose.pkl"), + "--ctf", + os.path.join(outdir, "filtered_sta_ctf.pkl"), + "--ind", + os.path.join(outdir, "filtered_ind.pkl"), + "--num-epochs", + "5", + "--zdim", + "4", + "-o", + os.path.join(outdir, "filtered"), + "--tdim", + "16", + "--enc-dim", + "16", + "--dec-dim", + "16", + ] + train_vae.main(train_vae.add_args(argparse.ArgumentParser()).parse_args(args)) + + def test_analyze(self, tmpdir_factory, particles, indices, poses, ctf, datadir): + """Produce standard analyses for a particular epoch.""" + outdir = self.get_outdir( + tmpdir_factory, particles, poses, ctf, indices, datadir + ) + + parser = argparse.ArgumentParser() + analyze.add_args(parser) + analyze.main( + parser.parse_args( + [ + outdir, + "4", # Epoch number to analyze - 0-indexed + "--pc", + "3", # Number of principal component traversals to generate + "--ksample", + "2", # Number of kmeans samples to generate + "--vol-start-index", + "1", + ] + ) + ) + assert os.path.exists(os.path.join(outdir, "analyze.4")) + + @pytest.mark.parametrize( + "new_indices_file", + [None, "filtered_ind.pkl"], + ids=("no-new-ind", "new-ind"), + ) + def test_backproject( + self, tmpdir_factory, particles, indices, poses, ctf, datadir, new_indices_file + ): + """Run backprojection using the given particles.""" + outdir = self.get_outdir( + tmpdir_factory, particles, poses, ctf, indices, datadir + ) + args = [ + os.path.join(outdir, "filtered_sta_testing_bin8.star"), + "--datadir", + datadir.path, + "--tilt", + "--poses", + os.path.join(outdir, "filtered_sta_pose.pkl"), + "--ctf", + os.path.join(outdir, "filtered_sta_ctf.pkl"), + "-o", + os.path.join(outdir, "filtered"), + "-d", + "2.93", + "--no-half-maps", + ] + if new_indices_file is not None: + args += ["--ind", os.path.join(outdir, new_indices_file)] + + parser = argparse.ArgumentParser() + backproject_voxel.add_args(parser) + backproject_voxel.main(parser.parse_args(args)) + assert os.path.exists(os.path.join(outdir, "filtered", "backproject.mrc")) + shutil.rmtree(os.path.join(outdir, "filtered")) + + @pytest.mark.parametrize("nb_lbl", ["cryoDRGN_figures", "cryoDRGN_ET_viz"]) + def test_notebooks( + self, tmpdir_factory, particles, indices, poses, ctf, datadir, nb_lbl + ): + """Execute the demonstration Jupyter notebooks produced by analysis.""" + outdir = self.get_outdir( + tmpdir_factory, particles, poses, ctf, indices, datadir + ) + orig_cwd = os.path.abspath(os.getcwd()) + os.chdir(os.path.join(outdir, "analyze.4")) + assert os.path.exists(f"{nb_lbl}.ipynb"), "Upstream tests have failed!" + + with open(f"{nb_lbl}.ipynb") as ff: + nb_in = nbformat.read(ff, nbformat.NO_CONVERT) + + ExecutePreprocessor(timeout=600, kernel_name="python3").preprocess(nb_in) + os.chdir(orig_cwd) + + def test_refiltering(self, tmpdir_factory, particles, indices, poses, ctf, datadir): + """Use particle index creating during analysis.""" + outdir = self.get_outdir( + tmpdir_factory, particles, poses, ctf, indices, datadir + ) + orig_cwd = os.path.abspath(os.getcwd()) + os.chdir(os.path.join(outdir, "analyze.4")) + assert os.path.exists("tmp_ind_selected.pkl"), "Upstream tests have failed!" + + with open("tmp_ind_selected.pkl", "rb") as f: + indices = pickle.load(f) + + new_indices = indices[:3] + with open("tmp_ind_selected.pkl", "wb") as f: + pickle.dump(new_indices, f) + + args = [ + particles.path, + "--datadir", + datadir.path, + "--encode-mode", + "tilt", + "--poses", + poses.path, + "--ctf", + ctf.path, + "--ind", + "tmp_ind_selected.pkl", + "--num-epochs", + "5", + "--zdim", + "4", + "-o", + outdir, + "--tdim", + "16", + "--enc-dim", + "16", + "--dec-dim", + "16", + ] + + args = train_vae.add_args(argparse.ArgumentParser()).parse_args(args) + train_vae.main(args) + os.chdir(orig_cwd) + + shutil.rmtree(outdir) + + +@pytest.mark.parametrize("particles", ["tilts.star"], indirect=True) +@pytest.mark.parametrize("indices", [None, "just-4"], indirect=True) +@pytest.mark.parametrize("ctf", ["CTF-Tilt"], indirect=True) +@pytest.mark.parametrize("datadir", ["default-datadir"], indirect=True) +class TestTiltAbinitHomo: + def test_train_model(self, tmpdir, particles, indices, ctf, datadir): + args = [ + particles.path, + "--datadir", + datadir.path, + "--ctf", + ctf.path, + "-o", + str(tmpdir), + "--dim", + "4", + "--layers", + "2", + "--t-extent", + "4.0", + "--t-ngrid", + "2", + "--pretrain=1", + "--num-epochs", + "3", + "--ps-freq", + "2", + ] + if indices.path is not None: + args += ["--ind", indices.path] + + args = abinit_homo.add_args(argparse.ArgumentParser()).parse_args(args) + abinit_homo.main(args) + + +@pytest.mark.parametrize("particles", ["tilts.star"], indirect=True) +@pytest.mark.parametrize("indices", [None, "just-4"], indirect=True) +@pytest.mark.parametrize("ctf", ["CTF-Tilt"], indirect=True) +@pytest.mark.parametrize("datadir", ["default-datadir"], indirect=True) +class TestTiltAbinitHetero: + def test_train_model(self, tmpdir, particles, indices, ctf, datadir): + args = [ + particles.path, + "--datadir", + datadir.path, + "--ctf", + ctf.path, + "--zdim", + "8", + "-o", + str(tmpdir), + "--enc-dim", + "4", + "--enc-layers", + "2", + "--dec-dim", + "4", + "--dec-layers", + "2", + "--pe-dim", + "4", + "--enc-only", + "--t-extent", + "4.0", + "--t-ngrid", + "2", + "--pretrain=1", + "--num-epochs", + "3", + "--ps-freq", + "2", + ] + if indices.path is not None: + args += ["--ind", indices.path] + + args = abinit_het.add_args(argparse.ArgumentParser()).parse_args(args) + abinit_het.main(args) From 593944fb5f225ad76c0bb35e28d0d9f50221b7f4 Mon Sep 17 00:00:00 2001 From: Michal Grzadkowski Date: Thu, 19 Dec 2024 12:51:42 -0500 Subject: [PATCH 23/24] better handling of different missing inputs in interactive filtering --- cryodrgn/commands/filter.py | 53 +++++++++++++++++++++++---------- tests/test_reconstruct.py | 59 ++++++++++++++++++++++--------------- 2 files changed, 73 insertions(+), 39 deletions(-) diff --git a/cryodrgn/commands/filter.py b/cryodrgn/commands/filter.py index 10481882..f228b616 100644 --- a/cryodrgn/commands/filter.py +++ b/cryodrgn/commands/filter.py @@ -18,6 +18,9 @@ # Choose an epoch yourself; save final selection to `indices.pkl` without prompting $ cryodrgn filter my_outdir --epoch 30 -f +# Choose another epoch; this time choose file name but pre-select directory to save in +$ cryodrgn filter my_outdir --epoch 30 --sel-dir /data/my_indices/ + # If you have done multiple k-means clusterings, you can pick which one to use $ cryodrgn filter my_outdir/ -k 25 @@ -44,7 +47,7 @@ from scipy.spatial import transform from typing import Optional, Sequence -from cryodrgn import analysis, utils +from cryodrgn import analysis, utils, dataset logger = logging.getLogger(__name__) @@ -89,6 +92,7 @@ def add_args(parser: argparse.ArgumentParser) -> None: def main(args: argparse.Namespace) -> None: + """Launching the interactive interface for filtering particles from command-line.""" workdir = args.outdir epoch = args.epoch kmeans = args.kmeans @@ -129,23 +133,40 @@ def main(args: argparse.Namespace) -> None: # Load poses and initial indices for plotting if they have been specified rot, trans = utils.load_pkl(pose_pkl) - ctf_params = utils.load_pkl(train_configs["dataset_args"]["ctf"]) - pre_indices = None if plot_inds is None else utils.load_pkl(plot_inds) - all_indices = np.array(range(ctf_params.shape[0])) + if train_configs["dataset_args"]["ctf"] is not None: + ctf_params = utils.load_pkl(train_configs["dataset_args"]["ctf"]) + else: + ctf_params = None # Load the set of indices used to filter original dataset and apply it to inputs + pre_indices = None if plot_inds is None else utils.load_pkl(plot_inds) + if ctf_params is not None: + all_indices = np.array(range(ctf_params.shape[0])) + else: + all_indices = np.array( + range( + dataset.ImageDataset( + mrcfile=train_configs["dataset_args"]["particles"], lazy=True + ).N + ) + ) + + # Load the set of indices used to filter the original dataset and apply it to inputs if isinstance(train_configs["dataset_args"]["ind"], int): indices = slice(train_configs["dataset_args"]["ind"]) - # We only need to filter the poses if they weren't generated by the model - rot = rot[indices, :, :] - trans = trans[indices, :] - elif train_configs["dataset_args"]["ind"]: + elif train_configs["dataset_args"]["ind"] is not None: indices = utils.load_pkl(train_configs["dataset_args"]["ind"]) - ctf_params = ctf_params[indices, :] + else: + indices = None + + if indices is not None: + ctf_params = ctf_params[indices, :] if ctf_params is not None else None all_indices = all_indices[indices] + # We only need to filter the poses if they weren't generated by the model - rot = rot[indices, :, :] - trans = trans[indices, :] + if "poses" in train_configs["dataset_args"]: + rot = rot[indices, :, :] + trans = trans[indices, :] pc, pca = analysis.run_pca(z) umap = utils.load_pkl(os.path.join(anlzdir, "umap.pkl")) @@ -192,12 +213,14 @@ def main(args: argparse.Namespace) -> None: trans=trans, labels=kmeans_lbls, umap=umap, - df1=ctf_params[:, 2], - df2=ctf_params[:, 3], - dfang=ctf_params[:, 4], - phase=ctf_params[:, 8], znorm=znorm, ) + if ctf_params is not None: + plot_df["df1"] = ctf_params[:, 2] + plot_df["df2"] = ctf_params[:, 3] + plot_df["dfang"] = ctf_params[:, 4] + plot_df["phase"] = ctf_params[:, 8] + # Tilt-series outputs have tilt-level CTFs and poses but particle-level model # results, thus we ignore the former in this case for now else: diff --git a/tests/test_reconstruct.py b/tests/test_reconstruct.py index 571b13a3..41daf435 100644 --- a/tests/test_reconstruct.py +++ b/tests/test_reconstruct.py @@ -128,29 +128,6 @@ def test_analyze(self, tmpdir_factory, particles, poses, ctf, indices, epoch): assert os.path.exists(os.path.join(outdir, f"analyze.{epoch}")) - @pytest.mark.parametrize( - "ctf, epoch", - [ - ("CTF-Test", 3), - ("CTF-Test", None), - pytest.param( - None, - None, - marks=pytest.mark.xfail(raises=NotImplementedError), - ), - ], - indirect=["ctf"], - ) - def test_interactive_filtering( - self, tmpdir_factory, particles, poses, ctf, indices, epoch - ): - """Launch interface for filtering particles using model covariates.""" - outdir = self.get_outdir(tmpdir_factory, particles, indices, poses, ctf) - parser = argparse.ArgumentParser() - filter.add_args(parser) - epoch_args = ["--epoch", str(epoch)] if epoch is not None else list() - filter.main(parser.parse_args([outdir] + epoch_args)) - @pytest.mark.parametrize( "nb_lbl, ctf", [ @@ -186,6 +163,29 @@ def test_notebooks(self, tmpdir_factory, particles, poses, ctf, indices, nb_lbl) os.chdir(orig_cwd) + @pytest.mark.parametrize( + "ctf, epoch", + [ + ("CTF-Test", 3), + ("CTF-Test", None), + pytest.param( + None, + None, + marks=pytest.mark.xfail(raises=NotImplementedError), + ), + ], + indirect=["ctf"], + ) + def test_interactive_filtering( + self, tmpdir_factory, particles, poses, ctf, indices, epoch + ): + """Launch interface for filtering particles using model covariates.""" + outdir = self.get_outdir(tmpdir_factory, particles, indices, poses, ctf) + parser = argparse.ArgumentParser() + filter.add_args(parser) + epoch_args = ["--epoch", str(epoch)] if epoch is not None else list() + filter.main(parser.parse_args([outdir] + epoch_args)) + @pytest.mark.parametrize( "ctf, downsample_dim, flip_vol", [ @@ -553,7 +553,18 @@ def test_notebooks(self, tmpdir_factory, particles, ctf, indices, nb_lbl): ExecutePreprocessor(timeout=600, kernel_name="python3").preprocess(nb_in) os.chdir(orig_cwd) - @pytest.mark.parametrize("epoch", [1, None]) + @pytest.mark.parametrize( + "epoch", + [ + 1, + pytest.param( + None, + marks=pytest.mark.xfail( + raises=ValueError, reason="missing analysis epoch" + ), + ), + ], + ) def test_interactive_filtering( self, tmpdir_factory, particles, ctf, indices, epoch ): From ec3293247573603c549482a787e5f54ac6e21621 Mon Sep 17 00:00:00 2001 From: Michal Grzadkowski Date: Thu, 19 Dec 2024 19:55:55 -0500 Subject: [PATCH 24/24] fixing logging typos and updating documentation --- README.md | 7 ++++--- cryodrgn/commands/abinit_het.py | 2 +- cryodrgn/commands/train_vae.py | 2 +- cryodrgn/commands_utils/make_movies.py | 6 ++++-- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 95f6bd7f..7d988fe7 100644 --- a/README.md +++ b/README.md @@ -19,8 +19,9 @@ cryoDRGN installation, training and analysis. A brief quick start is provided be For any feedback, questions, or bugs, please file a Github issue or start a Github discussion. -### New in Version 3.4.x -* [NEW] `cryodrgn plot_classes` for analysis visualizations colored by a given set of class labels +### Updates in Version 3.4.x +* [NEW] `cryodrgn_utils plot_classes` for analysis visualizations colored by a given set of class labels +* [NEW] `cryodrgn_utils make_movies` for animations of `analyze*` output volumes * implementing [automatic mixed-precision training](https://pytorch.org/docs/stable/amp.html) for ab-initio reconstruction for 2-4x speedup * support for RELION 3.1 .star files with separate optics tables, np.float16 number formats used in RELION .mrcs outputs @@ -33,7 +34,7 @@ For any feedback, questions, or bugs, please file a Github issue or start a Gith * official support for Python 3.11 -### New in Version 3.x +### Updates in Version 3.x The official release of [cryoDRGN-ET](https://www.biorxiv.org/content/10.1101/2023.08.18.553799v1) for heterogeneous subtomogram analysis. diff --git a/cryodrgn/commands/abinit_het.py b/cryodrgn/commands/abinit_het.py index f165d550..4d9e9c97 100644 --- a/cryodrgn/commands/abinit_het.py +++ b/cryodrgn/commands/abinit_het.py @@ -901,7 +901,7 @@ def main(args): ) if in_dim % 8 != 0: logger.warning( - f"Masked input image dimension {in_dim} is not a mutiple of 8 " + f"Masked input image dimension {in_dim} is not a multiple of 8 " "-- AMP training speedup is not optimized!" ) diff --git a/cryodrgn/commands/train_vae.py b/cryodrgn/commands/train_vae.py index 75535127..88f6506d 100755 --- a/cryodrgn/commands/train_vae.py +++ b/cryodrgn/commands/train_vae.py @@ -855,7 +855,7 @@ def main(args: argparse.Namespace) -> None: ) if in_dim % 8 != 0: logger.warning( - f"Masked input image dimension {in_dim} is not a mutiple of 8 " + f"Masked input image dimension {in_dim} is not a multiple of 8 " "-- AMP training speedup is not optimized!" ) diff --git a/cryodrgn/commands_utils/make_movies.py b/cryodrgn/commands_utils/make_movies.py index 34e0fc57..ae64b515 100644 --- a/cryodrgn/commands_utils/make_movies.py +++ b/cryodrgn/commands_utils/make_movies.py @@ -1,4 +1,7 @@ -"""Make mp4 movies of .mrc's produced by analyze* commands +"""Make MP4 movies of .mrc volumes produced by cryodrgn analyze* commands. + +You must install ChimeraX under the alias `chimerax` before running this command, see: +https://www.cgl.ucsf.edu/chimerax/download.html Example usage ------------- @@ -9,7 +12,6 @@ $ cryodrgn_utils make_movies spr_runs/07/out 19 volume --name=front --iso=210 --camera="0.12868,-0.9576,0.25778,95.4,-0.85972,-0.23728,-0.45231,15.356,0.4943,-0.16341,-0.8538,-33.755" """ - import argparse import os from datetime import datetime as dt