Skip to content

Commit

Permalink
harmonizing how pretraining is done between v3 and v4 trainers
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-g committed Mar 23, 2024
1 parent 5b0b83e commit c1fdda3
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 72 deletions.
65 changes: 33 additions & 32 deletions cryodrgn/trainers/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import shutil
import sys
import pickle
import time
from collections import OrderedDict
from abc import ABC, abstractmethod
from typing import Any, Optional
from typing import Any
from typing_extensions import Self
import yaml
from datetime import datetime as dt
Expand Down Expand Up @@ -198,7 +199,6 @@ class ModelConfigurations(BaseConfigurations):
"datadir",
"ind",
"log_interval",
"verbose",
"load",
"load_poses",
"initial_conf",
Expand Down Expand Up @@ -233,9 +233,6 @@ class ModelConfigurations(BaseConfigurations):
"t_yshift",
"hidden_layers",
"hidden_dim",
"encode_mode",
"enc_mask",
"use_real",
"pe_type",
"pe_dim",
"volume_domain",
Expand All @@ -244,8 +241,10 @@ class ModelConfigurations(BaseConfigurations):
"base_healpy",
"subtomo_averaging",
"volume_optim_type",
"use_real",
"no_trans",
"amp",
"reset_optim_after_pretrain",
)
default_values = OrderedDict(
{
Expand All @@ -257,7 +256,6 @@ class ModelConfigurations(BaseConfigurations):
"datadir": None,
"ind": None,
"log_interval": 1000,
"verbose": False,
"load": None,
"load_poses": None,
"checkpoint": 1,
Expand Down Expand Up @@ -291,9 +289,6 @@ class ModelConfigurations(BaseConfigurations):
"t_yshift": 0,
"hidden_layers": 3,
"hidden_dim": 256,
"encode_mode": "resid",
"enc_mask": None,
"use_real": False,
"pe_type": "gaussian",
"pe_dim": 64,
"volume_domain": None,
Expand All @@ -302,8 +297,10 @@ class ModelConfigurations(BaseConfigurations):
"base_healpy": 2,
"subtomo_averaging": False,
"volume_optim_type": "adam",
"use_real": False,
"no_trans": False,
"amp": True,
"reset_optim_after_pretrain": False,
}
)

Expand Down Expand Up @@ -484,7 +481,7 @@ def __init__(self, configs: dict[str, Any]) -> None:
self.logger.info(self.volume_model)

# parallelize
if self.volume_model.zdim > 0 and self.configs.multigpu and self.n_prcs > 1:
if self.volume_model.z_dim > 0 and self.configs.multigpu and self.n_prcs > 1:
if self.configs.multigpu and torch.cuda.device_count() > 1:
self.logger.info(f"Using {torch.cuda.device_count()} GPUs!")
self.configs.batch_size *= torch.cuda.device_count()
Expand Down Expand Up @@ -611,9 +608,10 @@ def __init__(self, configs: dict[str, Any]) -> None:
self.configs.write(os.path.join(self.outdir, "train-configs.yaml"))

def train(self) -> None:
self.configs: ModelConfigurations
t0 = dt.now()

# self.pretrain()
self.pretrain()
self.current_epoch = self.start_epoch
self.logger.info("--- Training Starts Now ---")

Expand Down Expand Up @@ -645,44 +643,47 @@ def train(self) -> None:

t_total = dt.now() - t0
self.logger.info(
f"Finished in {t_total} ({t_total / self.num_epochs} per epoch)"
f"Finished in {t_total} ({t_total / self.configs.num_epochs} per epoch)"
)

def pretrain(self):
def pretrain(self) -> None:
"""Pretrain the decoder using random initial poses."""
particles_seen = 0
loss = None
self.configs: ModelConfigurations
end_time = time.time()
self.logger.info(f"Using random poses for {self.configs.pretrain} iterations")

for batch in self.data_iterator:
particles_seen += len(batch[0])
particles_seen = 0
while particles_seen < self.configs.pretrain:
for batch in self.data_iterator:
particles_seen += len(batch[0])

batch = (
(batch[0].to(self.device), None)
if batch[1] is None
else (batch[0].to(self.device), batch[1].to(self.device))
)
loss = self.pretrain_step(batch, end_time=end_time)

batch = (
(batch[0].to(self.device), None)
if self.configs.tilt is None
else (batch[0].to(self.device), batch[1].to(self.device))
)
loss = self.pretrain_step(batch)
if self.configs.verbose_time:
torch.cuda.synchronize()

if particles_seen % self.configs.log_interval == 0:
self.logger.info(
f"[Pretrain Iteration {particles_seen}] loss={loss:4f}"
)
if particles_seen % self.configs.log_interval == 0:
self.logger.info(
f"[Pretrain Iteration {particles_seen}] loss={loss:4f}"
)

if particles_seen > self.configs.pretrain_iter:
break
if particles_seen > self.configs.pretrain:
break

# reset model after pretraining
if self.configs.reset_optim_after_pretrain:
self.logger.info(">> Resetting optim after pretrain")
self.optim = torch.optim.Adam(
self.model.parameters(),
self.volume_model.parameters(),
lr=self.configs.learning_rate,
weight_decay=self.configs.weight_decay,
)

return loss

@abstractmethod
def train_epoch(self):
pass
Expand Down
19 changes: 2 additions & 17 deletions cryodrgn/trainers/amortinf_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pickle
from collections import OrderedDict
import numpy as np
from typing import Any, Optional
from typing import Any
import time

import torch
Expand All @@ -14,7 +14,7 @@
from cryodrgn import ctf, mrc
from cryodrgn.dataset import make_dataloader
from cryodrgn.trainers import summary
from cryodrgn.losses import kl_divergence_conf, l1_regularizer, l2_frequency_bias
from cryodrgn.models.losses import kl_divergence_conf, l1_regularizer, l2_frequency_bias
from cryodrgn.models.amortized_inference import DRGNai, MyDataParallel
from cryodrgn.masking import CircularMask, FrequencyMarchingMask
from cryodrgn.trainers._base import ModelTrainer, ModelConfigurations
Expand Down Expand Up @@ -737,21 +737,6 @@ def train_epoch(self):
if hasattr(self.model.output_mask, "update_epoch") and self.use_point_estimates:
self.model.output_mask.update_epoch(self.configs.n_frequencies_per_epoch)

def pretrain(self):
end_time = time.time()

for batch_idx, (batch, tilt_ind, ind) in enumerate(self.data_iterator):
self.batch_idx = batch_idx

self.train_step(batch, tilt_ind, ind, end_time=end_time)
if self.configs.verbose_time:
torch.cuda.synchronize()

end_time = time.time()

if self.current_epoch_particles_count > self.n_particles_pretrain:
break

def pretrain_step(self, batch, **pretrain_kwargs):
self.train_step(batch, **pretrain_kwargs)

Expand Down
47 changes: 24 additions & 23 deletions cryodrgn/trainers/hps_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
import torch.nn.functional as F

import cryodrgn.config
from cryodrgn import ctf, dataset, lie_tools
from cryodrgn.losses import EquivarianceLoss
from cryodrgn import ctf, dataset
from cryodrgn.models import lie_tools
from cryodrgn.models.losses import EquivarianceLoss
from cryodrgn.models.variational_autoencoder import unparallelize, HetOnlyVAE
from cryodrgn.models.neural_nets import get_decoder
from cryodrgn.pose_search import PoseSearch
from cryodrgn.models.pose_search import PoseSearch
from cryodrgn.trainers._base import ModelTrainer, ModelConfigurations


Expand All @@ -28,7 +29,6 @@ class HierarchicalPoseSearchConfigurations(ModelConfigurations):
"equivariance",
"equivariance_start",
"equivariance_stop",
"data_norm",
"l_ramp_epochs",
"l_ramp_model",
"reset_model_every",
Expand All @@ -42,7 +42,6 @@ class HierarchicalPoseSearchConfigurations(ModelConfigurations):
"enc_dim",
"encode_mode",
"enc_mask",
"use_real",
"dec_layers",
"dec_dim",
)
Expand All @@ -54,12 +53,10 @@ class HierarchicalPoseSearchConfigurations(ModelConfigurations):
"equivariance": None,
"equivariance_start": 100000,
"equivariance_stop": 200000,
"data_norm": None,
"l_ramp_epochs": 0,
"l_ramp_model": 0,
"reset_model_every": None,
"reset_optim_every": None,
"reset_optim_after_pretrain": None,
"grid_niter": 4,
"ps_freq": 5,
"n_kept_poses": 8,
Expand All @@ -68,7 +65,6 @@ class HierarchicalPoseSearchConfigurations(ModelConfigurations):
"enc_dim": 1024,
"encode_mode": "resid",
"enc_mask": None,
"use_real": False,
"dec_layers": 3,
"dec_dim": 1024,
}
Expand Down Expand Up @@ -395,13 +391,18 @@ def train_epoch(self):
self.ctf_params[ind] if self.ctf_params is not None else None
)

if batch_poses:
rot, trans = batch_poses
else:
rot, trans = None, None

# train the model
losses, new_poses, new_base_poses = self.train_step(
batch,
tilt_ind,
equivariance_tuple,
rot=batch_poses[0],
trans=batch_poses[1],
rot=rot,
trans=trans,
ctf_params=ctf_param,
)

Expand All @@ -410,7 +411,7 @@ def train_epoch(self):
self.pose_optimizer.step()

all_poses.append((ind.cpu().numpy(), new_poses))
if new_base_poses:
if new_base_poses is not None:
base_poses.append((ind_np, new_base_poses))

# logging
Expand Down Expand Up @@ -612,22 +613,22 @@ def preprocess_input(self, y, trans):
).view(y.size(0), self.resolution, self.resolution)

def pretrain_step(self, batch):
if self.z_dim > 0:
if self.configs.z_dim > 0:
y, yt = batch
use_tilt = yt is not None
B = y.size(0)

self.model.train()
self.optim.zero_grad()
self.volume_model.train()
self.volume_optimizer.zero_grad()

rot = lie_tools.random_SO3(B, device=y.device)
z = torch.randn((B, self.z_dim), device=y.device)
z = torch.randn((B, self.configs.z_dim), device=y.device)

# reconstruct circle of pixels instead of whole image
mask = self.lattice.get_circular_mask(self.lattice.D // 2)

def gen_slice(R):
_model = unparallelize(self.model)
_model = unparallelize(self.volume_model)
assert isinstance(_model, HetOnlyVAE)
return _model.decode(self.lattice.coords[mask] @ R, z).view(B, -1)

Expand All @@ -641,34 +642,34 @@ def gen_slice(R):
gen_loss = F.mse_loss(gen_slice(rot), y)

gen_loss.backward()
self.optim.step()
self.volume_optimizer.step()

return gen_loss.item()

else:
y, yt = batch
B = y.size(0)
self.model.train()
self.optim.zero_grad()
self.volume_model.train()
self.volume_optimizer.zero_grad()

mask = self.lattice.get_circular_mask(self.lattice.D // 2)

def gen_slice(R):
slice_ = self.model(self.lattice.coords[mask] @ R)
slice_ = self.volume_model(self.lattice.coords[mask] @ R)
return slice_.view(B, -1)

rot = lie_tools.random_SO3(B, device=y.device)

y = y.view(B, -1)[:, mask]
if self.tilt is not None:
if self.configs.tilt is not None:
yt = yt.view(B, -1)[:, mask]
loss = 0.5 * F.mse_loss(gen_slice(rot), y) + 0.5 * F.mse_loss(
gen_slice(self.tilt @ rot), yt
gen_slice(self.configs.tilt @ rot), yt
)
else:
loss = F.mse_loss(gen_slice(rot), y)
loss.backward()
self.optim.step()
self.volume_optimizer.step()

return loss.item()

Expand Down

0 comments on commit c1fdda3

Please sign in to comment.