Skip to content

Commit

Permalink
cleaning up training engine documentation and config attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-g committed Feb 20, 2025
1 parent c51a595 commit baa4bf8
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 35 deletions.
4 changes: 2 additions & 2 deletions cryodrgn/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def add_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--model",
choices={"amort", "hps"},
help="which model to use for reconstruction",
help="which model to use for reconstruction; default is `amort` (cryoDRGN-AI)",
)
parser.add_argument(
"--no-analysis",
Expand Down Expand Up @@ -72,7 +72,7 @@ def main(
if additional_configs is not None:
configs.update(additional_configs)

trainer = get_model_trainer(configs)
trainer = get_model_trainer(configs, add_cfgs=args.cfgs)
cryodrgn.utils._verbose = False
trainer.train()

Expand Down
90 changes: 71 additions & 19 deletions cryodrgn/trainers/amortinf_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,75 @@

@dataclass
class AmortizedInferenceConfigurations(ReconstructionModelConfigurations):
"""The configurations used by the cryoDRGN v3 model training engine.
Arguments
---------
> inherited from `BaseConfigurations`:
verbose An integer specifiying the verbosity level for this engine, with
the default value of 0 generally specifying no/minimum verbosity.
outdir Path to where output produced by the engine will be saved.
seed A non-negative integer used to fix the stochasticity of the random
number generators used by this engine for reproducibility.
The default is to not fix stochasticity and thus use a different
random seed upon each run of the engine.
test_installation Only perform a smoke test that this module has been
installed correctly and exit immediately without running
anything if this boolean value is set to `True`.
Default is not to run this test.
> inherited from `ReconstructionModelConfigurations`:
model A label for the reconstruction algorithm to be used — must be either
`hps` for cryoDRGN v3 models or `amort` for cryoDRGN-AI models.
z_dim The dimensionality of the latent space of conformations.
Thus z_dim=0 for homogeneous models
and z_dim>0 for hetergeneous models.
num_epochs The total number of epochs to use when training the model, not
including pretraining epoch(s).
dataset Label for the particle dataset to be used as input for the model.
If used, remaining input parameters can be omitted.
particles Path to the stack of particle images to use as input for the model.
Must be a (.mrcs/.txt/.star/.cs file).
ctf Path to the file storing contrast transfer function parameters
used to process the input particle images.
poses Path to the input particle poses data (.pkl).
datadir Path prefix to particle stack if loading relative paths from
a .star or .cs file.
ind Path to a numpy array saved as a .pkl used to filter
input particles.
pose_estimation Whether to perform ab-initio reconstruction ("abinit"),
reconstruction using fixed poses ("fixed"), or
reconstruction with SGD refinement of poses ("refine").
Default is to use fixed poses if poses file is given
and ab-initio otherwise.
load: Load model from given weights.<epoch>.pkl output file saved from
previous run of the engine.
Can also be given as "latest", in which the latest saved epoch
in the given output directory will be used.
lazy Whether to use lazy loading of data into memory in smaller batches.
Necessary if input dataset is too large to fit into memory.
batch_size The number of input images to use at a time when updating
the learning algorithm.
multigpu Whether to use all available GPUs available on this machine.
The default is to use only one GPU.
log_interval Print a log message every `N` number of training images.
checkpoint Save model results to file every `N` training epochs.
pe_type Label for the type of positional encoding to use.
pe_dim Number of frequencies to use in the positional encoding
(default: 64).
volume_domain Representation to use in the volume
decoder ("hartley" or "fourier").
n_imgs_pose_search The number of images the model needs to see in order to
complete the pose search portion of training.
"""

# A parameter belongs to this configuration set if and only if it has a type and a
# default value defined here, note that children classes inherit these parameters
Expand Down Expand Up @@ -94,31 +163,21 @@ def __post_init__(self) -> None:

if self.model != "amort":
raise ValueError(
f"Mismatched model {self.model} for AmortizedInferenceTrainer!"
f"Mismatched model `{self.model=}` for {self.__class__.__name__}!"
)

if self.pose_estimation is not None:
if self.pose_estimation == "refine":
self.pose_learning_rate = 1.0e-4
if self.conf_estimation is not None:
if self.conf_estimation == "encoder":
self.use_conf_encoder = True

if self.batch_size_sgd is None:
self.batch_size_sgd = self.batch_size
if self.batch_size_known_poses is None:
self.batch_size_known_poses = self.batch_size
if self.batch_size_hps is None:
self.batch_size_hps = self.batch_size

if self.explicit_volume and self.z_dim >= 1:
raise ValueError(
"Explicit volumes do not support heterogeneous reconstruction."
)

if self.dataset is None:
if self.particles is None:
raise ValueError("Dataset wasn't specified: please specify particles!")
if self.ctf is None:
raise ValueError("Dataset wasn't specified: please specify ctf!")

Expand Down Expand Up @@ -168,10 +227,6 @@ def __post_init__(self) -> None:
"Conformations cannot be initialized when also using an encoder!"
)

if self.pose_estimation == "fixed" and self.poses is None:
raise ValueError(
"Poses must be specified to use ground-truth translations!"
)
if self.pose_estimation == "refine":
self.n_imgs_pose_search = 0
if self.poses is None:
Expand Down Expand Up @@ -205,9 +260,6 @@ def __post_init__(self) -> None:
if self.pose_estimation == "fixed":
# "poses" include translations
self.use_gt_trans = True
if self.poses is None:
raise ValueError("Ground truth poses must be specified!")

if self.no_trans:
self.t_extent = 0.0
if self.t_extent == 0.0:
Expand Down Expand Up @@ -267,7 +319,7 @@ def make_output_mask(self) -> CircularMask:

return output_mask

def make_reconstruction_model(self) -> nn.Module:
def make_reconstruction_model(self, weights=None) -> nn.Module:
output_mask = self.make_output_mask()

if self.configs.z_dim > 0:
Expand Down
80 changes: 73 additions & 7 deletions cryodrgn/trainers/hps_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,75 @@

@dataclass
class HierarchicalPoseSearchConfigurations(ReconstructionModelConfigurations):
"""The configurations used by the cryoDRGN v3 model training engine.
Arguments
---------
> inherited from `BaseConfigurations`:
verbose An integer specifiying the verbosity level for this engine, with
the default value of 0 generally specifying no/minimum verbosity.
outdir Path to where output produced by the engine will be saved.
seed A non-negative integer used to fix the stochasticity of the random
number generators used by this engine for reproducibility.
The default is to not fix stochasticity and thus use a different
random seed upon each run of the engine.
test_installation Only perform a smoke test that this module has been
installed correctly and exit immediately without running
anything if this boolean value is set to `True`.
Default is not to run this test.
> inherited from `ReconstructionModelConfigurations`:
model A label for the reconstruction algorithm to be used — must be either
`hps` for cryoDRGN v3 models or `amort` for cryoDRGN-AI models.
z_dim The dimensionality of the latent space of conformations.
Thus z_dim=0 for homogeneous models
and z_dim>0 for hetergeneous models.
num_epochs The total number of epochs to use when training the model, not
including pretraining epoch(s).
dataset Label for the particle dataset to be used as input for the model.
If used, remaining input parameters can be omitted.
particles Path to the stack of particle images to use as input for the model.
Must be a (.mrcs/.txt/.star/.cs file).
ctf Path to the file storing contrast transfer function parameters
used to process the input particle images.
poses Path to the input particle poses data (.pkl).
datadir Path prefix to particle stack if loading relative paths from
a .star or .cs file.
ind Path to a numpy array saved as a .pkl used to filter
input particles.
pose_estimation Whether to perform ab-initio reconstruction ("abinit"),
reconstruction using fixed poses ("fixed"), or
reconstruction with SGD refinement of poses ("refine").
Default is to use fixed poses if poses file is given
and ab-initio otherwise.
load: Load model from given weights.<epoch>.pkl output file saved from
previous run of the engine.
Can also be given as "latest", in which the latest saved epoch
in the given output directory will be used.
lazy Whether to use lazy loading of data into memory in smaller batches.
Necessary if input dataset is too large to fit into memory.
batch_size The number of input images to use at a time when updating
the learning algorithm.
multigpu Whether to use all available GPUs available on this machine.
The default is to use only one GPU.
log_interval Print a log message every `N` number of training images.
checkpoint Save model results to file every `N` training epochs.
pe_type Label for the type of positional encoding to use.
pe_dim Number of frequencies to use in the positional encoding
(default: 64).
volume_domain Representation to use in the volume
decoder ("hartley" or "fourier").
enc_layers The number of hidden layers in the encoder used by the model.
dec_layers The number of hidden layers in the decoder used by the model.
"""

# A parameter belongs to this configuration set if and only if it has a type and a
# default value defined here, note that children classes inherit these parameters
Expand Down Expand Up @@ -70,14 +139,11 @@ class HierarchicalPoseSearchConfigurations(ReconstructionModelConfigurations):

def __post_init__(self) -> None:
super().__post_init__()
assert self.model == "hps"

if self.dataset is None:
if self.particles is None:
raise ValueError(
"As dataset was not specified, please specify particles!"
)

if self.model != "hps":
raise ValueError(
f"Mismatched model `{self.model=}` for {self.__class__.__name__}!"
)
if self.beta is not None:
if not self.z_dim:
raise ValueError("Cannot use beta with homogeneous reconstruction!.")
Expand Down
30 changes: 23 additions & 7 deletions cryodrgn/trainers/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class ReconstructionModelConfigurations(BaseConfigurations):
z_dim The dimensionality of the latent space of conformations.
Thus z_dim=0 for homogeneous models and z_dim>0 for hetergeneous models.
num_epochs The total number of epochs to use when training the model, not including
pretraining epochs.
pretraining epoch(s).
dataset Label for the particle dataset to be used as input for the model.
If used, remaining input parameters can be omitted.
Expand Down Expand Up @@ -224,7 +224,7 @@ def __post_init__(self) -> None:
raise ValueError(
"To specify datasets using a label, first specify"
"a .yaml catalogue of datasets using the "
"environment variable $DRGNAI_DATASETS!"
"environment variable $CRYODRGN_DATASETS!"
)

# you can also give the dataset as a label in the global dataset list
Expand All @@ -234,15 +234,17 @@ def __post_init__(self) -> None:
for k, v in paths.items():
setattr(self, k, v)

elif self.particles is None:
elif not self.particles:
raise ValueError(
"Must specify either a dataset label stored in "
f"{paths_file} or the paths to a particles and "
"ctf settings file!"
)
elif not os.path.isfile(self.particles):
raise ValueError(f"Given particles file `{self.particles}` does not exist!")

if isinstance(self.ind, str) and not os.path.exists(self.ind):
raise ValueError(f"Subset indices file {self.ind} does not exist!")
if isinstance(self.ind, str) and not os.path.isfile(self.ind):
raise ValueError(f"Given subset indices file {self.ind} does not exist!")

if self.pose_estimation is None:
self.pose_estimation = "fixed" if self.poses else "abinit"
Expand All @@ -258,6 +260,8 @@ def __post_init__(self) -> None:
raise ValueError(
"Specify an input file (poses=) if using ground truth poses!"
)
if isinstance(self.poses, str) and not os.path.isfile(self.poses):
raise ValueError(f"Given poses file {self.poses} does not exist!")

if self.batch_size_known_poses is None:
self.batch_size_known_poses = self.batch_size
Expand Down Expand Up @@ -411,7 +415,7 @@ def __init__(self, configs: dict[str, Any]) -> None:
self.data.voltage = float(ctf_params[0, 4])

self.ctf_params = torch.tensor(ctf_params, device=self.device)
self.apix = self.ctf_params[0, 0]
self.apix = float(self.ctf_params[0, 0])
else:
self.ctf_params = None
self.apix = None
Expand Down Expand Up @@ -787,20 +791,32 @@ def pretrain(self) -> None:
)

def get_configs(self) -> dict[str, Any]:
"""Retrieves all given and inferred configurations for downstream use."""
"""Retrieves all given and inferred configurations for downstream use.
Note that we need this in addition to the class attribute `configs` which is a
`ReconstructionModelConfigurations` object in order to 1) define a way of saving
these configurations to file in a more structured, hierarchical, human-readable
format where the configurations are sorted and partitioned according to the
facet of the engine they concern and 2) provide a method of retrieving
configuration values that need to be computed from the particle stack (such as
`data.norm`) without having to reload the stack from file.
"""

dataset_args = dict(
particles=self.configs.particles,
poses=self.configs.poses,
norm=self.data.norm,
invert_data=self.configs.invert_data,
ind=self.configs.ind,
n_particles=self.particle_count,
n_images=self.image_count,
keepreal=self.configs.use_real,
window=self.configs.window,
window_r=self.configs.window_r,
datadir=self.configs.datadir,
ctf=self.configs.ctf,
tilt=self.configs.tilt,
apix=self.apix,
)

if self.lattice is not None:
Expand Down

0 comments on commit baa4bf8

Please sign in to comment.