Skip to content

Commit

Permalink
separating loading model from configurations and loading trainer from…
Browse files Browse the repository at this point in the history
… configurations
  • Loading branch information
michal-g committed Feb 20, 2025
1 parent c1d9b6f commit c51a595
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 43 deletions.
27 changes: 13 additions & 14 deletions cryodrgn/commands/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from datetime import datetime as dt
import logging
import nbformat
from typing import Any

import numpy as np
import torch
Expand All @@ -29,7 +30,7 @@
from cryodrgn import _ROOT
import cryodrgn.analysis
import cryodrgn.utils
from cryodrgn.models.utils import get_model_trainer
from cryodrgn.models.utils import get_model_configurations
from cryodrgn.commands.eval_vol import VolumeEvaluator

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -140,13 +141,13 @@ def __init__(
self.cfg_file = os.path.join(self.traindir, "train-configs.yaml")
if os.path.exists(self.cfg_file):
with open(self.cfg_file, "r") as f:
cfg_dict: dict[str, dict[str, str]] = yaml.safe_load(f)
cfg_dict: dict[str, dict[str, Any]] = yaml.safe_load(f)
else:
raise FileNotFoundError(
f"Cannot find training configurations file `{self.cfg_file}` "
f"— has this model been trained yet?"
)
self.trainer = get_model_trainer(cfg_dict)
self.train_configs = get_model_configurations(cfg_dict)

log_fl = os.path.join(self.traindir, "training.log")
if os.path.exists(log_fl):
Expand All @@ -170,8 +171,8 @@ def __init__(

# find A/px from CTF if not given
else:
if self.trainer.configs.ctf:
ctf_params = cryodrgn.utils.load_pkl(self.trainer.configs.ctf)
if self.train_configs.ctf:
ctf_params = cryodrgn.utils.load_pkl(self.train_configs.ctf)
orig_apixs = set(ctf_params[:, 1])

# TODO: add support for multiple optics groups
Expand All @@ -186,17 +187,15 @@ def __init__(
orig_apix = tuple(orig_apixs)[0]
orig_sizes = set(ctf_params[:, 0])
orig_size = tuple(orig_sizes)[0]

if len(orig_sizes) > 1:
logger.info(
f"Cannot find unique original box size in CTF "
f"parameters, defaulting to first found: {orig_size}"
)

cur_size = self.trainer.lattice.D
cur_size = cfg_dict["lattice_args"]["D"] - 1
self.apix = round(orig_apix * orig_size / cur_size, 6)
logger.info(f"using A/px={self.apix} as per CTF parameters")

else:
self.apix = 1.0
logger.info(
Expand All @@ -219,7 +218,7 @@ def __init__(
f"to epoch {self.epoch} yet?"
)

if self.trainer.configs.z_dim > 0:
if self.train_configs.z_dim > 0:
self.z = cryodrgn.utils.load_pkl(
os.path.join(self.traindir, f"conf.{self.epoch}.pkl")
)
Expand All @@ -235,7 +234,7 @@ def __init__(
logger.info(f"Saving results to {self.outdir}")

# We will generate volumes unless told not to or if using a homogeneous model
if skip_vol or self.trainer.configs.z_dim == 0:
if skip_vol or self.train_configs.z_dim == 0:
self.volume_generator = None
else:
cfgs = cryodrgn.utils.load_yaml(self.cfg_file)
Expand Down Expand Up @@ -276,19 +275,19 @@ def generate_volumes(
logger.info("Skipping volume generation...")

def analyze(self) -> None:
if self.trainer.configs.z_dim == 0:
if self.train_configs.z_dim == 0:
logger.warning("No analyses available for homogeneous reconstruction!")
return

elif self.trainer.configs.z_dim == 1:
elif self.train_configs.z_dim == 1:
self.analyze_z1()
else:
self.analyze_zN()

# create Jupyter notebooks for data analysis and visualization by
# copying them over from the template directory
ipynbs = ["cryoDRGN_figures"]
if self.trainer.configs.tilt:
if self.train_configs.tilt:
ipynbs += ["ET-viz"]
else:
ipynbs += ["cryoDRGN_viz", "analysis"]
Expand Down Expand Up @@ -353,7 +352,7 @@ def analyze(self) -> None:
if self.direct_traversal_txt is not None:
dir_traversal_vertices_ind = np.loadtxt(self.direct_traversal_txt)
travdir = os.path.join(self.outdir, "direct_traversal")
z_values = np.zeros((0, self.trainer.configs.z_dim))
z_values = np.zeros((0, self.train_configs.z_dim))

for i, ind in enumerate(dir_traversal_vertices_ind[:-1]):
z_0 = self.z[int(int)]
Expand Down
56 changes: 28 additions & 28 deletions cryodrgn/commands/eval_vol.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import numpy as np
import torch
import cryodrgn.config
from cryodrgn.models.utils import get_model_trainer
from cryodrgn.lattice import Lattice
from cryodrgn.source import write_mrc
from cryodrgn.models.utils import get_model

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -196,37 +196,41 @@ def __init__(
cfg_data["weights"] = weights
pprint.pprint(cfg_data)

self.trainer = get_model_trainer(cfg_data)
apix = apix if apix is not None else self.trainer.apix
orig_d = self.trainer.lattice.D # image size + 1
self.z_dim = self.trainer.configs.z_dim
self.norm = self.trainer.data.norm
if apix is not None:
self.apix = apix
elif "apix" in cfg_data["dataset_args"]:
self.apix = cfg_data["dataset_args"]["apix"]
else:
self.apix = 1.0

model = get_model(cfg_data)
if downsample:
if downsample % 2 != 0:
raise ValueError("Boxsize must be even")
if downsample > orig_d - 1:
if downsample > model.lattice.D - 1:
raise ValueError(
"Downsampling size must be smaller than original box size"
)

self.coords = self.trainer.lattice.get_downsample_coords(downsample + 1)
self.coords = model.lattice.get_downsample_coords(downsample + 1)
self.D = downsample + 1
self.extent = self.trainer.lattice.extent * (downsample / (orig_d - 1))
self.extent = model.lattice.extent * (downsample / (model.lattice.D - 1))
self.lattice = Lattice(
self.D, extent=self.trainer.lattice.extent, device=self.device
self.D, extent=model.lattice.extent, device=self.device
)
else:
self.lattice = self.trainer.lattice
self.coords = self.trainer.lattice.coords
self.D = self.trainer.lattice.D
self.extent = self.trainer.lattice.extent
self.lattice = model.lattice
self.coords = model.lattice.coords
self.D = model.lattice.D
self.extent = model.lattice.extent

self.verbose = verbose
self.apix = apix
self.flip = flip
self.invert = invert
self.trainer.reconstruction_model.eval()
self.model = model
self.model.eval()
self.norm = cfg_data["dataset_args"]["norm"]

def transform_volume(self, vol):
if self.flip:
Expand All @@ -238,16 +242,14 @@ def transform_volume(self, vol):

def evaluate_volume(self, z):
return self.transform_volume(
self.trainer.reconstruction_model.eval_volume(
self.model.eval_volume(
lattice=self.lattice,
coords=self.coords,
resolution=self.D,
extent=self.extent,
norm=self.norm,
zval=z,
radius=self.trainer.mask_dimensions
if hasattr(self.trainer, "mask_dimensions")
else None,
radius=None,
)
)

Expand Down Expand Up @@ -326,23 +328,21 @@ def main(args: argparse.Namespace) -> None:
"and z-end of equal length!"
)

# parse user inputs for location(s) in the latent space
# Parse location(s) specified by the user in the latent space in various formats
if args.zfile:
z_vals = np.loadtxt(args.zfile).reshape(-1, evaluator.z_dim)

z_vals = np.loadtxt(args.zfile).reshape(-1, evaluator.model.z_dim)
elif args.z_start:
z_start = np.array(args.z_start)
z_end = np.array(args.z_end)

z_start, z_end = np.array(args.z_start), np.array(args.z_end)
z_dim = cfg["model_args"]["z_dim"]
z_vals = np.repeat(
np.arange(args.volume_count, dtype=np.float32), evaluator.z_dim
).reshape((args.volume_count, evaluator.z_dim))
np.arange(args.volume_count, dtype=np.float32), z_dim
).reshape((args.volume_count, z_dim))
z_vals *= (z_end - z_start) / (args.volume_count - 1) # type: ignore
z_vals += z_start

else:
z_vals = np.array(args.z_val)

# Evaluate the volumes at these locations and save them to file
if len(z_vals):
evaluator.produce_volumes(
z_vals, args.output, args.prefix, args.vol_start_index
Expand Down
117 changes: 116 additions & 1 deletion cryodrgn/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Utilities shared across all types of models."""

from typing import Any, Optional
from typing import Any, Optional, Union
import torch
from torch import nn
from cryodrgn.trainers.reconstruction import (
ReconstructionModelTrainer,
ReconstructionModelConfigurations,
Expand All @@ -13,6 +15,11 @@
HierarchicalPoseSearchTrainer,
HierarchicalPoseSearchConfigurations,
)
from cryodrgn.models.amortized_inference import DRGNai
from cryodrgn.models.variational_autoencoder import HetOnlyVAE
from cryodrgn.models.neural_nets import get_decoder
from cryodrgn.lattice import Lattice
from cryodrgn.masking import CircularMask, FrequencyMarchingMask


def update_configs(cfg: dict[str, Any]) -> dict[str, Any]:
Expand Down Expand Up @@ -82,3 +89,111 @@ def get_model_configurations(
cfg.update(configs_cls.parse_cfg_keys(add_cfgs))

return configs_cls(**cfg)


def get_model(
cfg: dict[str, Any],
add_cfgs: Optional[list[str]] = None,
weights=None,
device=None,
) -> Union[HetOnlyVAE, DRGNai]:
configs = get_model_configurations(cfg, add_cfgs)
lattice = Lattice(
cfg["lattice_args"]["D"], extent=cfg["lattice_args"]["extent"], device=device
)

if isinstance(configs, AmortizedInferenceConfigurations):
if configs.output_mask == "circ":
radius = configs.max_freq or lattice.D // 2
output_mask = CircularMask(lattice, radius)

elif configs.output_mask == "frequency_marching":
output_mask = FrequencyMarchingMask(
lattice,
radius=configs.l_start_fm,
radius_max=lattice.D // 2,
add_one_every=configs.add_one_frequency_every,
)
else:
raise NotImplementedError

if "particle_count" in cfg["dataset_args"]:
particle_count = cfg["dataset_args"]["particle_count"]
if "image_count" not in cfg["dataset_args"]:
image_count = particle_count
else:
image_count = cfg["dataset_args"]["image_count"]
else:
trainer = get_model_trainer(cfg, add_cfgs)
particle_count, image_count = trainer.particle_count, trainer.image_count

model = DRGNai(
lattice=lattice,
output_mask=output_mask,
n_particles_dataset=particle_count,
n_tilts_dataset=image_count,
cnn_params=cfg["model_args"]["cnn_params"],
conf_regressor_params=cfg["model_args"]["conf_regressor_params"],
hypervolume_params=cfg["model_args"]["hypervolume_params"],
resolution_encoder=configs.resolution_encoder,
no_trans=configs.no_trans,
use_gt_poses=configs.pose_estimation == "fixed",
use_gt_trans=configs.use_gt_trans,
will_use_point_estimates=False,
ps_params=cfg["model_args"]["ps_params"],
verbose_time=configs.verbose_time,
pretrain_with_gt_poses=configs.pretrain_with_gt_poses,
n_tilts_pose_search=configs.n_tilts_pose_search,
)

elif configs.model == "hps":
activation = {"relu": nn.ReLU, "leaky_relu": nn.LeakyReLU}[configs.activation]
if configs.z_dim > 0:
if (
cfg["model_args"]["enc_mask"] is not None
and cfg["model_args"]["enc_mask"] > 0
):
enc_mask = lattice.get_circular_mask(cfg["model_args"]["enc_mask"])
in_dim = int(enc_mask.sum())
else:
enc_mask = None
in_dim = lattice.D**2

model = HetOnlyVAE(
lattice=lattice,
qlayers=cfg["model_args"]["enc_layers"],
qdim=cfg["model_args"]["enc_dim"],
players=cfg["model_args"]["dec_layers"],
pdim=cfg["model_args"]["dec_dim"],
in_dim=in_dim,
z_dim=configs.z_dim,
encode_mode=cfg["model_args"]["encode_mode"],
enc_mask=enc_mask,
enc_type=configs.pe_type,
enc_dim=configs.pe_dim,
domain=configs.volume_domain,
activation=activation,
feat_sigma=cfg["model_args"]["feat_sigma"],
tilt_params=cfg["model_args"].get("tilt_params", {}),
)
else:
model = get_decoder(
in_dim=3,
D=lattice.D,
layers=cfg["model_args"]["qlayers"],
dim=cfg["model_args"]["qlayers"],
domain=configs.volume_domain,
enc_type=configs.pe_type,
enc_dim=configs.pe_dim,
activation=activation,
feat_sigma=configs.feat_sigma,
)
if weights is not None:
ckpt = torch.load(weights, device=device)
model.load_state_dict(ckpt["model_state_dict"])
if device is not None:
model.to(device)
else:
raise ValueError(f"Unrecognized model `{configs.model}` specified in config!")

return model

0 comments on commit c51a595

Please sign in to comment.