diff --git a/cryodrgn/commands/analyze.py b/cryodrgn/commands/analyze.py index c0df0881..e34904c0 100644 --- a/cryodrgn/commands/analyze.py +++ b/cryodrgn/commands/analyze.py @@ -19,6 +19,7 @@ from datetime import datetime as dt import logging import nbformat +from typing import Any import numpy as np import torch @@ -29,7 +30,7 @@ from cryodrgn import _ROOT import cryodrgn.analysis import cryodrgn.utils -from cryodrgn.models.utils import get_model_trainer +from cryodrgn.models.utils import get_model_configurations from cryodrgn.commands.eval_vol import VolumeEvaluator logger = logging.getLogger(__name__) @@ -140,13 +141,13 @@ def __init__( self.cfg_file = os.path.join(self.traindir, "train-configs.yaml") if os.path.exists(self.cfg_file): with open(self.cfg_file, "r") as f: - cfg_dict: dict[str, dict[str, str]] = yaml.safe_load(f) + cfg_dict: dict[str, dict[str, Any]] = yaml.safe_load(f) else: raise FileNotFoundError( f"Cannot find training configurations file `{self.cfg_file}` " f"— has this model been trained yet?" ) - self.trainer = get_model_trainer(cfg_dict) + self.train_configs = get_model_configurations(cfg_dict) log_fl = os.path.join(self.traindir, "training.log") if os.path.exists(log_fl): @@ -170,8 +171,8 @@ def __init__( # find A/px from CTF if not given else: - if self.trainer.configs.ctf: - ctf_params = cryodrgn.utils.load_pkl(self.trainer.configs.ctf) + if self.train_configs.ctf: + ctf_params = cryodrgn.utils.load_pkl(self.train_configs.ctf) orig_apixs = set(ctf_params[:, 1]) # TODO: add support for multiple optics groups @@ -186,17 +187,15 @@ def __init__( orig_apix = tuple(orig_apixs)[0] orig_sizes = set(ctf_params[:, 0]) orig_size = tuple(orig_sizes)[0] - if len(orig_sizes) > 1: logger.info( f"Cannot find unique original box size in CTF " f"parameters, defaulting to first found: {orig_size}" ) - cur_size = self.trainer.lattice.D + cur_size = cfg_dict["lattice_args"]["D"] - 1 self.apix = round(orig_apix * orig_size / cur_size, 6) logger.info(f"using A/px={self.apix} as per CTF parameters") - else: self.apix = 1.0 logger.info( @@ -219,7 +218,7 @@ def __init__( f"to epoch {self.epoch} yet?" ) - if self.trainer.configs.z_dim > 0: + if self.train_configs.z_dim > 0: self.z = cryodrgn.utils.load_pkl( os.path.join(self.traindir, f"conf.{self.epoch}.pkl") ) @@ -235,7 +234,7 @@ def __init__( logger.info(f"Saving results to {self.outdir}") # We will generate volumes unless told not to or if using a homogeneous model - if skip_vol or self.trainer.configs.z_dim == 0: + if skip_vol or self.train_configs.z_dim == 0: self.volume_generator = None else: cfgs = cryodrgn.utils.load_yaml(self.cfg_file) @@ -276,11 +275,11 @@ def generate_volumes( logger.info("Skipping volume generation...") def analyze(self) -> None: - if self.trainer.configs.z_dim == 0: + if self.train_configs.z_dim == 0: logger.warning("No analyses available for homogeneous reconstruction!") return - elif self.trainer.configs.z_dim == 1: + elif self.train_configs.z_dim == 1: self.analyze_z1() else: self.analyze_zN() @@ -288,7 +287,7 @@ def analyze(self) -> None: # create Jupyter notebooks for data analysis and visualization by # copying them over from the template directory ipynbs = ["cryoDRGN_figures"] - if self.trainer.configs.tilt: + if self.train_configs.tilt: ipynbs += ["ET-viz"] else: ipynbs += ["cryoDRGN_viz", "analysis"] @@ -353,7 +352,7 @@ def analyze(self) -> None: if self.direct_traversal_txt is not None: dir_traversal_vertices_ind = np.loadtxt(self.direct_traversal_txt) travdir = os.path.join(self.outdir, "direct_traversal") - z_values = np.zeros((0, self.trainer.configs.z_dim)) + z_values = np.zeros((0, self.train_configs.z_dim)) for i, ind in enumerate(dir_traversal_vertices_ind[:-1]): z_0 = self.z[int(int)] diff --git a/cryodrgn/commands/eval_vol.py b/cryodrgn/commands/eval_vol.py index 45a624b8..d38ca5cb 100644 --- a/cryodrgn/commands/eval_vol.py +++ b/cryodrgn/commands/eval_vol.py @@ -21,9 +21,9 @@ import numpy as np import torch import cryodrgn.config -from cryodrgn.models.utils import get_model_trainer from cryodrgn.lattice import Lattice from cryodrgn.source import write_mrc +from cryodrgn.models.utils import get_model logger = logging.getLogger(__name__) @@ -196,37 +196,41 @@ def __init__( cfg_data["weights"] = weights pprint.pprint(cfg_data) - self.trainer = get_model_trainer(cfg_data) - apix = apix if apix is not None else self.trainer.apix - orig_d = self.trainer.lattice.D # image size + 1 - self.z_dim = self.trainer.configs.z_dim - self.norm = self.trainer.data.norm + if apix is not None: + self.apix = apix + elif "apix" in cfg_data["dataset_args"]: + self.apix = cfg_data["dataset_args"]["apix"] + else: + self.apix = 1.0 + model = get_model(cfg_data) if downsample: if downsample % 2 != 0: raise ValueError("Boxsize must be even") - if downsample > orig_d - 1: + if downsample > model.lattice.D - 1: raise ValueError( "Downsampling size must be smaller than original box size" ) - self.coords = self.trainer.lattice.get_downsample_coords(downsample + 1) + self.coords = model.lattice.get_downsample_coords(downsample + 1) self.D = downsample + 1 - self.extent = self.trainer.lattice.extent * (downsample / (orig_d - 1)) + self.extent = model.lattice.extent * (downsample / (model.lattice.D - 1)) self.lattice = Lattice( - self.D, extent=self.trainer.lattice.extent, device=self.device + self.D, extent=model.lattice.extent, device=self.device ) else: - self.lattice = self.trainer.lattice - self.coords = self.trainer.lattice.coords - self.D = self.trainer.lattice.D - self.extent = self.trainer.lattice.extent + self.lattice = model.lattice + self.coords = model.lattice.coords + self.D = model.lattice.D + self.extent = model.lattice.extent self.verbose = verbose self.apix = apix self.flip = flip self.invert = invert - self.trainer.reconstruction_model.eval() + self.model = model + self.model.eval() + self.norm = cfg_data["dataset_args"]["norm"] def transform_volume(self, vol): if self.flip: @@ -238,16 +242,14 @@ def transform_volume(self, vol): def evaluate_volume(self, z): return self.transform_volume( - self.trainer.reconstruction_model.eval_volume( + self.model.eval_volume( lattice=self.lattice, coords=self.coords, resolution=self.D, extent=self.extent, norm=self.norm, zval=z, - radius=self.trainer.mask_dimensions - if hasattr(self.trainer, "mask_dimensions") - else None, + radius=None, ) ) @@ -326,23 +328,21 @@ def main(args: argparse.Namespace) -> None: "and z-end of equal length!" ) - # parse user inputs for location(s) in the latent space + # Parse location(s) specified by the user in the latent space in various formats if args.zfile: - z_vals = np.loadtxt(args.zfile).reshape(-1, evaluator.z_dim) - + z_vals = np.loadtxt(args.zfile).reshape(-1, evaluator.model.z_dim) elif args.z_start: - z_start = np.array(args.z_start) - z_end = np.array(args.z_end) - + z_start, z_end = np.array(args.z_start), np.array(args.z_end) + z_dim = cfg["model_args"]["z_dim"] z_vals = np.repeat( - np.arange(args.volume_count, dtype=np.float32), evaluator.z_dim - ).reshape((args.volume_count, evaluator.z_dim)) + np.arange(args.volume_count, dtype=np.float32), z_dim + ).reshape((args.volume_count, z_dim)) z_vals *= (z_end - z_start) / (args.volume_count - 1) # type: ignore z_vals += z_start - else: z_vals = np.array(args.z_val) + # Evaluate the volumes at these locations and save them to file if len(z_vals): evaluator.produce_volumes( z_vals, args.output, args.prefix, args.vol_start_index diff --git a/cryodrgn/models/utils.py b/cryodrgn/models/utils.py index d1589a69..5b246db9 100644 --- a/cryodrgn/models/utils.py +++ b/cryodrgn/models/utils.py @@ -1,6 +1,8 @@ """Utilities shared across all types of models.""" -from typing import Any, Optional +from typing import Any, Optional, Union +import torch +from torch import nn from cryodrgn.trainers.reconstruction import ( ReconstructionModelTrainer, ReconstructionModelConfigurations, @@ -13,6 +15,11 @@ HierarchicalPoseSearchTrainer, HierarchicalPoseSearchConfigurations, ) +from cryodrgn.models.amortized_inference import DRGNai +from cryodrgn.models.variational_autoencoder import HetOnlyVAE +from cryodrgn.models.neural_nets import get_decoder +from cryodrgn.lattice import Lattice +from cryodrgn.masking import CircularMask, FrequencyMarchingMask def update_configs(cfg: dict[str, Any]) -> dict[str, Any]: @@ -82,3 +89,111 @@ def get_model_configurations( cfg.update(configs_cls.parse_cfg_keys(add_cfgs)) return configs_cls(**cfg) + + +def get_model( + cfg: dict[str, Any], + add_cfgs: Optional[list[str]] = None, + weights=None, + device=None, +) -> Union[HetOnlyVAE, DRGNai]: + configs = get_model_configurations(cfg, add_cfgs) + lattice = Lattice( + cfg["lattice_args"]["D"], extent=cfg["lattice_args"]["extent"], device=device + ) + + if isinstance(configs, AmortizedInferenceConfigurations): + if configs.output_mask == "circ": + radius = configs.max_freq or lattice.D // 2 + output_mask = CircularMask(lattice, radius) + + elif configs.output_mask == "frequency_marching": + output_mask = FrequencyMarchingMask( + lattice, + radius=configs.l_start_fm, + radius_max=lattice.D // 2, + add_one_every=configs.add_one_frequency_every, + ) + else: + raise NotImplementedError + + if "particle_count" in cfg["dataset_args"]: + particle_count = cfg["dataset_args"]["particle_count"] + if "image_count" not in cfg["dataset_args"]: + image_count = particle_count + else: + image_count = cfg["dataset_args"]["image_count"] + else: + trainer = get_model_trainer(cfg, add_cfgs) + particle_count, image_count = trainer.particle_count, trainer.image_count + + model = DRGNai( + lattice=lattice, + output_mask=output_mask, + n_particles_dataset=particle_count, + n_tilts_dataset=image_count, + cnn_params=cfg["model_args"]["cnn_params"], + conf_regressor_params=cfg["model_args"]["conf_regressor_params"], + hypervolume_params=cfg["model_args"]["hypervolume_params"], + resolution_encoder=configs.resolution_encoder, + no_trans=configs.no_trans, + use_gt_poses=configs.pose_estimation == "fixed", + use_gt_trans=configs.use_gt_trans, + will_use_point_estimates=False, + ps_params=cfg["model_args"]["ps_params"], + verbose_time=configs.verbose_time, + pretrain_with_gt_poses=configs.pretrain_with_gt_poses, + n_tilts_pose_search=configs.n_tilts_pose_search, + ) + + elif configs.model == "hps": + activation = {"relu": nn.ReLU, "leaky_relu": nn.LeakyReLU}[configs.activation] + if configs.z_dim > 0: + if ( + cfg["model_args"]["enc_mask"] is not None + and cfg["model_args"]["enc_mask"] > 0 + ): + enc_mask = lattice.get_circular_mask(cfg["model_args"]["enc_mask"]) + in_dim = int(enc_mask.sum()) + else: + enc_mask = None + in_dim = lattice.D**2 + + model = HetOnlyVAE( + lattice=lattice, + qlayers=cfg["model_args"]["enc_layers"], + qdim=cfg["model_args"]["enc_dim"], + players=cfg["model_args"]["dec_layers"], + pdim=cfg["model_args"]["dec_dim"], + in_dim=in_dim, + z_dim=configs.z_dim, + encode_mode=cfg["model_args"]["encode_mode"], + enc_mask=enc_mask, + enc_type=configs.pe_type, + enc_dim=configs.pe_dim, + domain=configs.volume_domain, + activation=activation, + feat_sigma=cfg["model_args"]["feat_sigma"], + tilt_params=cfg["model_args"].get("tilt_params", {}), + ) + else: + model = get_decoder( + in_dim=3, + D=lattice.D, + layers=cfg["model_args"]["qlayers"], + dim=cfg["model_args"]["qlayers"], + domain=configs.volume_domain, + enc_type=configs.pe_type, + enc_dim=configs.pe_dim, + activation=activation, + feat_sigma=configs.feat_sigma, + ) + if weights is not None: + ckpt = torch.load(weights, device=device) + model.load_state_dict(ckpt["model_state_dict"]) + if device is not None: + model.to(device) + else: + raise ValueError(f"Unrecognized model `{configs.model}` specified in config!") + + return model