Skip to content

Commit

Permalink
assorted fixed in preparation for alpha release
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-g committed Jan 31, 2025
1 parent 0096041 commit 17576fc
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 14 deletions.
3 changes: 0 additions & 3 deletions cryodrgn/commands/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,6 @@ def create_configs(

if model:
configs["model"] = model
elif "model" not in configs:
configs["model"] = "amort"

if dataset:
configs["dataset"] = dataset
if particles:
Expand Down
12 changes: 7 additions & 5 deletions cryodrgn/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,26 +55,28 @@ def add_args(parser: argparse.ArgumentParser) -> None:
)


def main(args: argparse.Namespace, configs: Optional[dict[str, Any]] = None) -> None:
def main(
args: argparse.Namespace, additional_configs: Optional[dict[str, Any]] = None
) -> None:
"""Running the `cryodrgn train` command (see `add_args` above for arguments).
An additional `configs` dictionary of configuration values can also be passed, which
will be appended to (and override) the values in the given config file and any
values specified through the `--cfgs` command-line argument.
"""
if configs is None:
configs = dict()
if additional_configs is None:
additional_configs = dict()

file_configs = SetupHelper(args.config_file, update_existing=False).create_configs(
configs = SetupHelper(args.config_file, update_existing=False).create_configs(
model=args.model,
)

trainer_cls = TRAINER_CLASSES[configs["model"]]
if args.cfgs:
configs = {**configs, **trainer_cls.config_cls.parse_cfg_keys(args.cfgs)}
configs = {**configs, **additional_configs, "outdir": args.outdir}

configs = {**file_configs, **configs, "outdir": args.outdir}
trainer = trainer_cls(configs)
cryodrgn.utils._verbose = False
trainer.train()
Expand Down
11 changes: 9 additions & 2 deletions cryodrgn/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,16 @@ def __getitem__(self, index):
# take the first ntilts
tilt_index = self.particles[ii][0 : self.ntilts]
tilt_indices.append(tilt_index)

tilt_indices = np.concatenate(tilt_indices)
images = self._process(self.src.images(tilt_indices).to(self.device))
return images, tilt_indices, index
r_imgs, f_imgs = self._process(self.src.images(tilt_indices).to(self.device))

return {
"y": f_imgs,
"y_real": r_imgs,
"tilt_indices": tilt_indices,
"indices": index,
}

@classmethod
def parse_particle_tilt(
Expand Down
3 changes: 3 additions & 0 deletions cryodrgn/mrcfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ def make_default_header(
else:
data_dtype = np.dtype("float32") # default to np.float 32 mode

if Apix is None:
Apix = 1.0

if data is not None:
nz, ny, nx = data.shape

Expand Down
6 changes: 3 additions & 3 deletions cryodrgn/trainers/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,8 @@ class ReconstructionModelConfigurations(BaseConfigurations):
volume_optim_type: str = "adam"
pose_sgd_emb_type: str = "quat"
verbose_time: bool = False
angle_per_tilt: int = 3
dose_per_tilt: float = 2.97

# quick configs
capture_setup: str = None
Expand Down Expand Up @@ -510,14 +512,12 @@ def __init__(self, configs: dict[str, Any]) -> None:
ind=self.ind,
ntilts=self.configs.n_tilts,
angle_per_tilt=self.configs.angle_per_tilt,
window_r=self.configs.window_radius_gt_real,
window_r=self.configs.window_r,
datadir=self.configs.datadir,
max_threads=self.configs.max_threads,
dose_per_tilt=self.configs.dose_per_tilt,
device=self.device,
poses_gt_pkl=use_poses,
tilt_axis_angle=self.configs.tilt_axis_angle,
no_trans=self.configs.no_trans,
)
self.particle_count = self.data.Np

Expand Down
9 changes: 8 additions & 1 deletion cryodrgn/trainers/hps_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@

@dataclass
class HierarchicalPoseSearchConfigurations(ReconstructionModelConfigurations):
quick_config = {
"capture_setup": {
"spa": {"lazy": True},
"et": {"tilt": True, "dose_per_tilt": 2.97, "angle_per_tilt": 3},
},
"reconstruction_type": {"homo": {"z_dim": 0}, "het": {"z_dim": 8}},
}

# a parameter belongs to this configuration set if and only if it has a default
# value defined here, note that children classes inherit these from parents
Expand Down Expand Up @@ -338,7 +345,7 @@ def train_batch(self, batch: dict[str, torch.Tensor]) -> tuple:

y = y.to(self.device)
ind_np = ind.cpu().numpy()
B = y.size(0)
B = len(ind)

if self.configs.tilt:
tilt_ind = batch["tilt_indices"]
Expand Down

0 comments on commit 17576fc

Please sign in to comment.