Skip to content

Commit

Permalink
removing deprecated commands and attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-g committed Feb 4, 2025
1 parent 86e1c76 commit beed1cf
Show file tree
Hide file tree
Showing 9 changed files with 33 additions and 201 deletions.
1 change: 0 additions & 1 deletion cryodrgn/command_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def main_commands() -> None:
"train",
"train_nn",
"train_vae",
"view_config",
],
doc_str="Commands installed with cryoDRGN",
)
Expand Down
2 changes: 1 addition & 1 deletion cryodrgn/commands/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ def linear_interpolation(z_0, z_1, n, exclude_last=False):
return z_0[None] * (1.0 - t) + z_1[None] * t


def main(args):
def main(args: argparse.Namespace) -> None:
matplotlib.use("Agg") # non-interactive backend
t0 = dt.now()

Expand Down
15 changes: 4 additions & 11 deletions cryodrgn/commands/analyze_convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from cryodrgn import analysis, fft, utils
from cryodrgn.source import ImageSource
from cryodrgn.mrc import MRCFile
from cryodrgn.mrcfile import write_mrc
import cryodrgn.config

try:
Expand All @@ -41,7 +41,7 @@
logger = logging.getLogger(__name__)


def add_args(parser):
def add_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"workdir", type=os.path.abspath, help="Directory with cryoDRGN results"
)
Expand Down Expand Up @@ -191,8 +191,6 @@ def add_args(parser):
help="Number of voxels over which to apply a soft cosine falling edge from dilated mask boundary",
)

return parser


def plot_loss(logfile, outdir, E):
"""
Expand Down Expand Up @@ -810,7 +808,7 @@ def mask_volume(volpath, outpath, Apix, thresh=None, dilate=3, dist=10):
# used to write out mask separately from masked volume, now apply and save the masked vol to minimize future I/O
# MRCFile.write(outpath, z.astype(np.float32))
vol *= z
MRCFile.write(outpath, vol.astype(np.float32), Apix=Apix)
write_mrc(outpath, vol.astype(np.float32), Apix=Apix)


def mask_volumes(
Expand Down Expand Up @@ -1046,7 +1044,7 @@ def calc_fsc(vol1_path: str, vol2_path: str):
)


def main(args):
def main(args: argparse.Namespace) -> None:
t1 = dt.now()

# Configure paths
Expand Down Expand Up @@ -1245,8 +1243,3 @@ def main(args):
calculate_FSCs(outdir, epochs, labels, img_size, chimerax_colors)

logger.info(f"Finished in {dt.now() - t1}")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
main(add_args(parser).parse_args())
10 changes: 2 additions & 8 deletions cryodrgn/commands/parse_ctf_csparc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
logger = logging.getLogger(__name__)


def add_args(parser):
def add_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument("cs", help="Input cryosparc particles.cs file")
parser.add_argument(
"-o", type=os.path.abspath, required=True, help="Output pkl of CTF parameters"
Expand All @@ -22,10 +22,9 @@ def add_args(parser):
group = parser.add_argument_group("Optionally provide missing image parameters")
group.add_argument("-D", type=int, help="Image size in pixels")
group.add_argument("--Apix", type=float, help="Angstroms per pixel")
return parser


def main(args):
def main(args: argparse.Namespace) -> None:
assert args.cs.endswith(".cs"), "Input file must be a .cs file"
assert args.o.endswith(".pkl"), "Output CTF parameters must be .pkl file"

Expand Down Expand Up @@ -70,8 +69,3 @@ def main(args):
ctf.plot_ctf(int(ctf_params[0, 0]), ctf_params[0, 1], ctf_params[0, 2:])
plt.savefig(args.png)
logger.info(args.png)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
main(add_args(parser).parse_args())
52 changes: 0 additions & 52 deletions cryodrgn/commands/view_config.py

This file was deleted.

65 changes: 3 additions & 62 deletions cryodrgn/trainers/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,6 @@ class BaseConfigurations(ABC):
of this abstract class' children configuration classes contain these parameters
in addition to the ones they themselves define.
This class also defines special behaviour for the `quick_config` class variable,
which is not treated as a data field and instead defines a set of shortcuts used as
values for the data field parameters listed as its keys. These shortcuts each define
a list of fields and values that are used as the new defaults when the shortcut is
used, but can still be overridden by values specified by the user.
Note that unlike regular data classes these config classes must define defaults for
all their parameters to ensure that default engine behaviour is explicitly stated,
with an AssertionError being thrown upon initialization otherwise.
Expand All @@ -51,78 +45,25 @@ class BaseConfigurations(ABC):
correctly and exit immediately without running anything if this
boolean value is set to `True`.
Default is not to run this test.
Attributes
----------
quick_config: A dictionary with keys consisting of special `quick_config` shortcut
parameters; each value is a dictionary of non-quick_config
parameter keys and shortcut values that are used when the
corresponding quick configuration parameter value is used.
"""

# This class variable is not a dataclass field and is instead used to define shortcut
# labels to set values for a number of other fields
quick_config = dict()

# A parameter belongs to this configuration set if and only if it has a type and a
# default value defined here, note that children classes inherit these parameters
verbose: int = 0
outdir: str = os.getcwd()
seed: int = None
test_installation: bool = False

def __init__(self, **config_args: dict[str, Any]) -> None:
"""Setting given config values as attributes; saving values given by user."""
self.given_configs = config_args
def __post_init__(self) -> None:
"""Parsing given configuration parameter values and checking their validity."""
self.outdir = os.path.abspath(self.outdir)

# for configuration values not given by the user we use the defined defaults
for this_field in self.fields():
assert this_field.default is not MISSING, (
f"`{self.__class__.__name__}` class has no default value defined "
f"for parameter `{this_field.name}`!"
)

# set values specified explicitly by the user as attributes of this class
for k, v in self.given_configs.items():
setattr(self, k, v)

self.__post_init__()

def __post_init__(self) -> None:
"""Parsing given configuration parameter values and checking their validity."""
self.outdir = os.path.abspath(self.outdir)

for quick_cfg_k, quick_cfg_dict in self.quick_config.items():
assert quick_cfg_k in self, (
f"Configuration class `{self.__class__.__name__}` has a `quick_config` "
f"entry `{quick_cfg_k}` that is not a valid configuration parameter!"
)
for quick_cfg_label, quick_label_dict in quick_cfg_dict.items():
for quick_cfg_param, quick_cfg_val in quick_label_dict.items():
assert quick_cfg_param in self, (
f"Configuration class `{self.__class__.__name__}` has a "
f"`quick_config` entry `{quick_cfg_label}` under "
f"`{quick_cfg_k}` with a value for `{quick_cfg_param}` which "
f"is not a valid configuration parameter!"
)

if quick_cfg_k in self.given_configs:
quick_cfg_val = getattr(self, quick_cfg_k)
if quick_cfg_val is not None:
if quick_cfg_val not in self.quick_config[quick_cfg_k]:
raise ValueError(
f"Given value `{quick_cfg_val}` is not a valid entry "
f"for quick config shortcut parameter `{quick_cfg_k}`!"
)

# We only use the `quick_config` value if the parameter is not
# also being set explicitly by the user
for param_k, param_val in self.quick_config[quick_cfg_k][
quick_cfg_val
].items():
if param_k not in self.given_configs:
setattr(self, param_k, param_val)

if self.test_installation:
print("Installation was successful!")
sys.exit()
Expand Down
27 changes: 2 additions & 25 deletions cryodrgn/trainers/amortinf_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,6 @@
@dataclass
class AmortizedInferenceConfigurations(ReconstructionModelConfigurations):

# This parameter is not a data class field and is instead used to define shortcut
# labels to set values for a number of other fields
quick_config = {
"capture_setup": {
"spa": {"lazy": True},
"et": {
"subtomo_averaging": True,
"lazy": True,
"shuffler_size": 0,
"num_workers": 0,
"t_extent": 0.0,
"batch_size_known_poses": 8,
"batch_size_sgd": 32,
"n_imgs_pose_search": 150000,
"pose_only_phase": 50000,
"lr_pose_table": 1.0e-5,
},
},
"reconstruction_type": {"homo": {"z_dim": 0}, "het": {"z_dim": 8}},
}

# A parameter belongs to this configuration set if and only if it has a type and a
# default value defined here, note that children classes inherit these parameters
model: str = "amort"
Expand All @@ -63,6 +42,7 @@ class AmortizedInferenceConfigurations(ReconstructionModelConfigurations):
n_imgs_pose_search: int = 500000
epochs_sgd: int = None
pose_only_phase: int = 0
use_gt_trans: bool = False
invert_data: bool = False
# optimizers
pose_table_optim_type: str = "adam"
Expand Down Expand Up @@ -108,9 +88,6 @@ class AmortizedInferenceConfigurations(ReconstructionModelConfigurations):
# quick configs
conf_estimation: str = None

def __init__(self, **config_args: dict[str, Any]) -> None:
super().__init__(**config_args)

def __post_init__(self) -> None:
super().__post_init__()

Expand Down Expand Up @@ -831,7 +808,7 @@ def closure():
batch["indices"],
rot_pred,
trans_pred,
latent_variables_dict["z"],
latent_variables_dict["z"] if "z" in latent_variables_dict else None,
None,
)

Expand Down
13 changes: 0 additions & 13 deletions cryodrgn/trainers/hps_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,6 @@
@dataclass
class HierarchicalPoseSearchConfigurations(ReconstructionModelConfigurations):

# This parameter is not a data class field and is instead used to define shortcut
# labels to set values for a number of other fields
quick_config = {
"capture_setup": {
"spa": {"lazy": True},
"et": {"tilt": True, "dose_per_tilt": 2.97, "angle_per_tilt": 3},
},
"reconstruction_type": {"homo": {"z_dim": 0}, "het": {"z_dim": 8}},
}

# A parameter belongs to this configuration set if and only if it has a type and a
# default value defined here, note that children classes inherit these parameters
model = "hps"
Expand Down Expand Up @@ -78,9 +68,6 @@ class HierarchicalPoseSearchConfigurations(ReconstructionModelConfigurations):
reset_model_every: int = None
reset_optim_every: int = None

def __init__(self, **config_args: dict[str, Any]) -> None:
super().__init__(**config_args)

def __post_init__(self) -> None:
super().__post_init__()
assert self.model == "hps"
Expand Down
Loading

0 comments on commit beed1cf

Please sign in to comment.