diff --git a/cryodrgn/commands/train.py b/cryodrgn/commands/train.py index 1f1c5680..1cda18bb 100644 --- a/cryodrgn/commands/train.py +++ b/cryodrgn/commands/train.py @@ -31,7 +31,7 @@ def add_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--model", choices={"amort", "hps"}, - help="which model to use for reconstruction", + help="which model to use for reconstruction; default is `amort` (cryoDRGN-AI)", ) parser.add_argument( "--no-analysis", @@ -72,7 +72,7 @@ def main( if additional_configs is not None: configs.update(additional_configs) - trainer = get_model_trainer(configs) + trainer = get_model_trainer(configs, add_cfgs=args.cfgs) cryodrgn.utils._verbose = False trainer.train() diff --git a/cryodrgn/trainers/amortinf_trainer.py b/cryodrgn/trainers/amortinf_trainer.py index 68189266..c1205c21 100644 --- a/cryodrgn/trainers/amortinf_trainer.py +++ b/cryodrgn/trainers/amortinf_trainer.py @@ -33,6 +33,75 @@ @dataclass class AmortizedInferenceConfigurations(ReconstructionModelConfigurations): + """The configurations used by the cryoDRGN v3 model training engine. + + Arguments + --------- + > inherited from `BaseConfigurations`: + verbose An integer specifiying the verbosity level for this engine, with + the default value of 0 generally specifying no/minimum verbosity. + outdir Path to where output produced by the engine will be saved. + seed A non-negative integer used to fix the stochasticity of the random + number generators used by this engine for reproducibility. + The default is to not fix stochasticity and thus use a different + random seed upon each run of the engine. + test_installation Only perform a smoke test that this module has been + installed correctly and exit immediately without running + anything if this boolean value is set to `True`. + Default is not to run this test. + + > inherited from `ReconstructionModelConfigurations`: + model A label for the reconstruction algorithm to be used — must be either + `hps` for cryoDRGN v3 models or `amort` for cryoDRGN-AI models. + z_dim The dimensionality of the latent space of conformations. + Thus z_dim=0 for homogeneous models + and z_dim>0 for hetergeneous models. + num_epochs The total number of epochs to use when training the model, not + including pretraining epoch(s). + + dataset Label for the particle dataset to be used as input for the model. + If used, remaining input parameters can be omitted. + particles Path to the stack of particle images to use as input for the model. + Must be a (.mrcs/.txt/.star/.cs file). + ctf Path to the file storing contrast transfer function parameters + used to process the input particle images. + poses Path to the input particle poses data (.pkl). + datadir Path prefix to particle stack if loading relative paths from + a .star or .cs file. + ind Path to a numpy array saved as a .pkl used to filter + input particles. + + pose_estimation Whether to perform ab-initio reconstruction ("abinit"), + reconstruction using fixed poses ("fixed"), or + reconstruction with SGD refinement of poses ("refine"). + Default is to use fixed poses if poses file is given + and ab-initio otherwise. + + load: Load model from given weights..pkl output file saved from + previous run of the engine. + Can also be given as "latest", in which the latest saved epoch + in the given output directory will be used. + lazy Whether to use lazy loading of data into memory in smaller batches. + Necessary if input dataset is too large to fit into memory. + + batch_size The number of input images to use at a time when updating + the learning algorithm. + + multigpu Whether to use all available GPUs available on this machine. + The default is to use only one GPU. + + log_interval Print a log message every `N` number of training images. + checkpoint Save model results to file every `N` training epochs. + + pe_type Label for the type of positional encoding to use. + pe_dim Number of frequencies to use in the positional encoding + (default: 64). + volume_domain Representation to use in the volume + decoder ("hartley" or "fourier"). + + n_imgs_pose_search The number of images the model needs to see in order to + complete the pose search portion of training. + """ # A parameter belongs to this configuration set if and only if it has a type and a # default value defined here, note that children classes inherit these parameters @@ -94,9 +163,8 @@ def __post_init__(self) -> None: if self.model != "amort": raise ValueError( - f"Mismatched model {self.model} for AmortizedInferenceTrainer!" + f"Mismatched model `{self.model=}` for {self.__class__.__name__}!" ) - if self.pose_estimation is not None: if self.pose_estimation == "refine": self.pose_learning_rate = 1.0e-4 @@ -104,21 +172,12 @@ def __post_init__(self) -> None: if self.conf_estimation == "encoder": self.use_conf_encoder = True - if self.batch_size_sgd is None: - self.batch_size_sgd = self.batch_size - if self.batch_size_known_poses is None: - self.batch_size_known_poses = self.batch_size - if self.batch_size_hps is None: - self.batch_size_hps = self.batch_size - if self.explicit_volume and self.z_dim >= 1: raise ValueError( "Explicit volumes do not support heterogeneous reconstruction." ) if self.dataset is None: - if self.particles is None: - raise ValueError("Dataset wasn't specified: please specify particles!") if self.ctf is None: raise ValueError("Dataset wasn't specified: please specify ctf!") @@ -168,10 +227,6 @@ def __post_init__(self) -> None: "Conformations cannot be initialized when also using an encoder!" ) - if self.pose_estimation == "fixed" and self.poses is None: - raise ValueError( - "Poses must be specified to use ground-truth translations!" - ) if self.pose_estimation == "refine": self.n_imgs_pose_search = 0 if self.poses is None: @@ -205,9 +260,6 @@ def __post_init__(self) -> None: if self.pose_estimation == "fixed": # "poses" include translations self.use_gt_trans = True - if self.poses is None: - raise ValueError("Ground truth poses must be specified!") - if self.no_trans: self.t_extent = 0.0 if self.t_extent == 0.0: @@ -267,7 +319,7 @@ def make_output_mask(self) -> CircularMask: return output_mask - def make_reconstruction_model(self) -> nn.Module: + def make_reconstruction_model(self, weights=None) -> nn.Module: output_mask = self.make_output_mask() if self.configs.z_dim > 0: diff --git a/cryodrgn/trainers/hps_trainer.py b/cryodrgn/trainers/hps_trainer.py index 324de089..37767f39 100644 --- a/cryodrgn/trainers/hps_trainer.py +++ b/cryodrgn/trainers/hps_trainer.py @@ -38,6 +38,75 @@ @dataclass class HierarchicalPoseSearchConfigurations(ReconstructionModelConfigurations): + """The configurations used by the cryoDRGN v3 model training engine. + + Arguments + --------- + > inherited from `BaseConfigurations`: + verbose An integer specifiying the verbosity level for this engine, with + the default value of 0 generally specifying no/minimum verbosity. + outdir Path to where output produced by the engine will be saved. + seed A non-negative integer used to fix the stochasticity of the random + number generators used by this engine for reproducibility. + The default is to not fix stochasticity and thus use a different + random seed upon each run of the engine. + test_installation Only perform a smoke test that this module has been + installed correctly and exit immediately without running + anything if this boolean value is set to `True`. + Default is not to run this test. + + > inherited from `ReconstructionModelConfigurations`: + model A label for the reconstruction algorithm to be used — must be either + `hps` for cryoDRGN v3 models or `amort` for cryoDRGN-AI models. + z_dim The dimensionality of the latent space of conformations. + Thus z_dim=0 for homogeneous models + and z_dim>0 for hetergeneous models. + num_epochs The total number of epochs to use when training the model, not + including pretraining epoch(s). + + dataset Label for the particle dataset to be used as input for the model. + If used, remaining input parameters can be omitted. + particles Path to the stack of particle images to use as input for the model. + Must be a (.mrcs/.txt/.star/.cs file). + ctf Path to the file storing contrast transfer function parameters + used to process the input particle images. + poses Path to the input particle poses data (.pkl). + datadir Path prefix to particle stack if loading relative paths from + a .star or .cs file. + ind Path to a numpy array saved as a .pkl used to filter + input particles. + + pose_estimation Whether to perform ab-initio reconstruction ("abinit"), + reconstruction using fixed poses ("fixed"), or + reconstruction with SGD refinement of poses ("refine"). + Default is to use fixed poses if poses file is given + and ab-initio otherwise. + + load: Load model from given weights..pkl output file saved from + previous run of the engine. + Can also be given as "latest", in which the latest saved epoch + in the given output directory will be used. + lazy Whether to use lazy loading of data into memory in smaller batches. + Necessary if input dataset is too large to fit into memory. + + batch_size The number of input images to use at a time when updating + the learning algorithm. + + multigpu Whether to use all available GPUs available on this machine. + The default is to use only one GPU. + + log_interval Print a log message every `N` number of training images. + checkpoint Save model results to file every `N` training epochs. + + pe_type Label for the type of positional encoding to use. + pe_dim Number of frequencies to use in the positional encoding + (default: 64). + volume_domain Representation to use in the volume + decoder ("hartley" or "fourier"). + + enc_layers The number of hidden layers in the encoder used by the model. + dec_layers The number of hidden layers in the decoder used by the model. + """ # A parameter belongs to this configuration set if and only if it has a type and a # default value defined here, note that children classes inherit these parameters @@ -70,14 +139,11 @@ class HierarchicalPoseSearchConfigurations(ReconstructionModelConfigurations): def __post_init__(self) -> None: super().__post_init__() - assert self.model == "hps" - - if self.dataset is None: - if self.particles is None: - raise ValueError( - "As dataset was not specified, please specify particles!" - ) + if self.model != "hps": + raise ValueError( + f"Mismatched model `{self.model=}` for {self.__class__.__name__}!" + ) if self.beta is not None: if not self.z_dim: raise ValueError("Cannot use beta with homogeneous reconstruction!.") diff --git a/cryodrgn/trainers/reconstruction.py b/cryodrgn/trainers/reconstruction.py index 61824b24..e7b4fce6 100644 --- a/cryodrgn/trainers/reconstruction.py +++ b/cryodrgn/trainers/reconstruction.py @@ -50,7 +50,7 @@ class ReconstructionModelConfigurations(BaseConfigurations): z_dim The dimensionality of the latent space of conformations. Thus z_dim=0 for homogeneous models and z_dim>0 for hetergeneous models. num_epochs The total number of epochs to use when training the model, not including - pretraining epochs. + pretraining epoch(s). dataset Label for the particle dataset to be used as input for the model. If used, remaining input parameters can be omitted. @@ -224,7 +224,7 @@ def __post_init__(self) -> None: raise ValueError( "To specify datasets using a label, first specify" "a .yaml catalogue of datasets using the " - "environment variable $DRGNAI_DATASETS!" + "environment variable $CRYODRGN_DATASETS!" ) # you can also give the dataset as a label in the global dataset list @@ -234,15 +234,17 @@ def __post_init__(self) -> None: for k, v in paths.items(): setattr(self, k, v) - elif self.particles is None: + elif not self.particles: raise ValueError( "Must specify either a dataset label stored in " f"{paths_file} or the paths to a particles and " "ctf settings file!" ) + elif not os.path.isfile(self.particles): + raise ValueError(f"Given particles file `{self.particles}` does not exist!") - if isinstance(self.ind, str) and not os.path.exists(self.ind): - raise ValueError(f"Subset indices file {self.ind} does not exist!") + if isinstance(self.ind, str) and not os.path.isfile(self.ind): + raise ValueError(f"Given subset indices file {self.ind} does not exist!") if self.pose_estimation is None: self.pose_estimation = "fixed" if self.poses else "abinit" @@ -258,6 +260,8 @@ def __post_init__(self) -> None: raise ValueError( "Specify an input file (poses=) if using ground truth poses!" ) + if isinstance(self.poses, str) and not os.path.isfile(self.poses): + raise ValueError(f"Given poses file {self.poses} does not exist!") if self.batch_size_known_poses is None: self.batch_size_known_poses = self.batch_size @@ -411,7 +415,7 @@ def __init__(self, configs: dict[str, Any]) -> None: self.data.voltage = float(ctf_params[0, 4]) self.ctf_params = torch.tensor(ctf_params, device=self.device) - self.apix = self.ctf_params[0, 0] + self.apix = float(self.ctf_params[0, 0]) else: self.ctf_params = None self.apix = None @@ -787,7 +791,16 @@ def pretrain(self) -> None: ) def get_configs(self) -> dict[str, Any]: - """Retrieves all given and inferred configurations for downstream use.""" + """Retrieves all given and inferred configurations for downstream use. + + Note that we need this in addition to the class attribute `configs` which is a + `ReconstructionModelConfigurations` object in order to 1) define a way of saving + these configurations to file in a more structured, hierarchical, human-readable + format where the configurations are sorted and partitioned according to the + facet of the engine they concern and 2) provide a method of retrieving + configuration values that need to be computed from the particle stack (such as + `data.norm`) without having to reload the stack from file. + """ dataset_args = dict( particles=self.configs.particles, @@ -795,12 +808,15 @@ def get_configs(self) -> dict[str, Any]: norm=self.data.norm, invert_data=self.configs.invert_data, ind=self.configs.ind, + n_particles=self.particle_count, + n_images=self.image_count, keepreal=self.configs.use_real, window=self.configs.window, window_r=self.configs.window_r, datadir=self.configs.datadir, ctf=self.configs.ctf, tilt=self.configs.tilt, + apix=self.apix, ) if self.lattice is not None: