Skip to content

Commit

Permalink
Merge pull request #427 from ml-struct-bio/v3.4.3
Browse files Browse the repository at this point in the history
v3.4.3: Making movies, improving filtering interface, and fixing landscape analysis
  • Loading branch information
michal-g authored Dec 21, 2024
2 parents 196365d + ec32932 commit 4ba7550
Show file tree
Hide file tree
Showing 13 changed files with 1,036 additions and 433 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ cryoDRGN installation, training and analysis. A brief quick start is provided be
For any feedback, questions, or bugs, please file a Github issue or start a Github discussion.


### New in Version 3.4.x
* [NEW] `cryodrgn plot_classes` for analysis visualizations colored by a given set of class labels
### Updates in Version 3.4.x
* [NEW] `cryodrgn_utils plot_classes` for analysis visualizations colored by a given set of class labels
* [NEW] `cryodrgn_utils make_movies` for animations of `analyze*` output volumes
* implementing [automatic mixed-precision training](https://pytorch.org/docs/stable/amp.html)
for ab-initio reconstruction for 2-4x speedup
* support for RELION 3.1 .star files with separate optics tables, np.float16 number formats used in RELION .mrcs outputs
Expand All @@ -33,7 +34,7 @@ For any feedback, questions, or bugs, please file a Github issue or start a Gith
* official support for Python 3.11


### New in Version 3.x
### Updates in Version 3.x

The official release of [cryoDRGN-ET](https://www.biorxiv.org/content/10.1101/2023.08.18.553799v1) for heterogeneous subtomogram analysis.

Expand Down
2 changes: 2 additions & 0 deletions cryodrgn/command_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
since automated scanning for command modules is computationally non-trivial.
"""

import argparse
import os
from importlib import import_module
Expand Down Expand Up @@ -122,6 +123,7 @@ def util_commands() -> None:
"fsc",
"gen_mask",
"invert_contrast",
"make_movies",
"phase_flip",
"plot_classes",
"plot_fsc",
Expand Down
12 changes: 9 additions & 3 deletions cryodrgn/commands/abinit_het.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,10 @@ def train(
# We do this in pose-supervised train_vae

if scaler is not None:
amp_mode = torch.cuda.amp.autocast_mode.autocast()
try:
amp_mode = torch.amp.autocast("cuda")
except AttributeError:
amp_mode = torch.cuda.amp.autocast_mode.autocast()
else:
amp_mode = contextlib.nullcontext()

Expand Down Expand Up @@ -898,7 +901,7 @@ def main(args):
)
if in_dim % 8 != 0:
logger.warning(
f"Masked input image dimension {in_dim} is not a mutiple of 8 "
f"Masked input image dimension {in_dim} is not a multiple of 8 "
"-- AMP training speedup is not optimized!"
)

Expand All @@ -907,7 +910,10 @@ def main(args):
model, optim = amp.initialize(model, optim, opt_level="O1")
# mixed precision with pytorch (v1.6+)
except: # noqa: E722
scaler = torch.cuda.amp.grad_scaler.GradScaler()
try:
scaler = torch.amp.GradScaler("cuda")
except AttributeError:
scaler = torch.cuda.amp.grad_scaler.GradScaler()

if args.load == "latest":
args = get_latest(args)
Expand Down
10 changes: 8 additions & 2 deletions cryodrgn/commands/abinit_homo.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,10 @@ def train(
)

if scaler is not None:
amp_mode = torch.cuda.amp.autocast_mode.autocast()
try:
amp_mode = torch.amp.autocast("cuda")
except AttributeError:
amp_mode = torch.cuda.amp.autocast_mode.autocast()
else:
amp_mode = contextlib.nullcontext()

Expand Down Expand Up @@ -676,7 +679,10 @@ def main(args):
model, optim = amp.initialize(model, optim, opt_level="O1")
# mixed precision with pytorch (v1.6+)
except: # noqa: E722
scaler = torch.cuda.amp.grad_scaler.GradScaler()
try:
scaler = torch.amp.GradScaler("cuda")
except AttributeError:
scaler = torch.cuda.amp.grad_scaler.GradScaler()

sorted_poses = []
if args.load:
Expand Down
46 changes: 27 additions & 19 deletions cryodrgn/commands/analyze_landscape.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,47 +348,55 @@ def plot(i, j):

kmeans_labels = utils.load_pkl(os.path.join(outdir, f"kmeans{K}", "labels.pkl"))
kmeans_counts = Counter(kmeans_labels)
for i in range(M):
vol_i = np.where(labels == i)[0]
logger.info(f"State {i}: {len(vol_i)} volumes")
for cluster_i in range(M):
vol_indices = np.where(labels == cluster_i)[0]
logger.info(f"State {cluster_i}: {len(vol_indices)} volumes")
if vol_ind is not None:
vol_i = np.arange(K)[vol_ind][vol_i]
vol_indices = np.arange(K)[vol_ind][vol_indices]

vol_fls = [
os.path.join(kmean_dir, f"vol_{vol_start_index + vol_i:03d}.mrc")
for vol_i in vol_indices
]
vol_i_all = torch.stack(
[torch.Tensor(parse_mrc(vol_fl)[0]) for vol_fl in vol_fls]
)

vol_fl = os.path.join(kmean_dir, f"vol_{vol_start_index+i:03d}.mrc")
vol_i_all = torch.stack([torch.Tensor(parse_mrc(vol_fl)[0]) for i in vol_i])
nparticles = np.array([kmeans_counts[i] for i in vol_i])
nparticles = np.array([kmeans_counts[vol_i] for vol_i in vol_indices])
vol_i_mean = np.average(vol_i_all, axis=0, weights=nparticles)
vol_i_std = (
np.average((vol_i_all - vol_i_mean) ** 2, axis=0, weights=nparticles) ** 0.5
)

write_mrc(
os.path.join(subdir, f"state_{i}_mean.mrc"),
os.path.join(subdir, f"state_{cluster_i}_mean.mrc"),
vol_i_mean.astype(np.float32),
Apix=Apix,
)
write_mrc(
os.path.join(subdir, f"state_{i}_std.mrc"),
os.path.join(subdir, f"state_{cluster_i}_std.mrc"),
vol_i_std.astype(np.float32),
Apix=Apix,
)

os.makedirs(os.path.join(subdir, f"state_{i}"), exist_ok=True)
for v in vol_i:
os.symlink(
os.path.join(kmean_dir, f"vol_{vol_start_index+v:03d}.mrc"),
os.path.join(subdir, f"state_{i}", f"vol_{vol_start_index+v:03d}.mrc"),
)
statedir = os.path.join(subdir, f"state_{cluster_i}")
os.makedirs(statedir, exist_ok=True)
for vol_i in vol_indices:
kmean_fl = os.path.join(kmean_dir, f"vol_{vol_start_index+vol_i:03d}.mrc")
sub_fl = os.path.join(statedir, f"vol_{vol_start_index+vol_i:03d}.mrc")
os.symlink(kmean_fl, sub_fl)

particle_ind = analysis.get_ind_for_cluster(kmeans_labels, vol_i)
logger.info(f"State {i}: {len(particle_ind)} particles")
particle_ind = analysis.get_ind_for_cluster(kmeans_labels, vol_indices)
logger.info(f"State {cluster_i}: {len(particle_ind)} particles")
if particle_ind_orig is not None:
utils.save_pkl(
particle_ind_orig[particle_ind],
os.path.join(subdir, f"state_{i}_particle_ind.pkl"),
os.path.join(subdir, f"state_{cluster_i}_particle_ind.pkl"),
)
else:
utils.save_pkl(
particle_ind, os.path.join(subdir, f"state_{i}_particle_ind.pkl")
particle_ind,
os.path.join(subdir, f"state_{cluster_i}_particle_ind.pkl"),
)

# plot clustering results
Expand Down
Loading

0 comments on commit 4ba7550

Please sign in to comment.