Skip to content

Commit a324526

Browse files
committed
added flexibility to load_from_checkpoint #156
1 parent 400c5eb commit a324526

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

comet/models/__init__.py

+21-7
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,17 @@
1717
from pathlib import Path
1818
from typing import Union
1919

20+
import torch
2021
import yaml
2122
from huggingface_hub import snapshot_download
2223

2324
from .base import CometModel
24-
from .download_utils import download_model_legacy
2525
from .multitask.unified_metric import UnifiedMetric
2626
from .ranking.ranking_metric import RankingMetric
2727
from .regression.referenceless import ReferencelessRegression
2828
from .regression.regression_metric import RegressionMetric
29+
from .download_utils import download_model_legacy
30+
2931

3032
str2model = {
3133
"referenceless_regression_metric": ReferencelessRegression,
@@ -34,10 +36,11 @@
3436
"unified_metric": UnifiedMetric,
3537
}
3638

39+
3740
def download_model(
3841
model: str,
3942
saving_directory: Union[str, Path, None] = None,
40-
local_files_only: bool = False
43+
local_files_only: bool = False,
4144
) -> str:
4245
try:
4346
model_path = snapshot_download(
@@ -52,29 +55,40 @@ def download_model(
5255
checkpoint_path = os.path.join(*[model_path, "checkpoints", "model.ckpt"])
5356
return checkpoint_path
5457

55-
def load_from_checkpoint(checkpoint_path: str) -> CometModel:
58+
59+
def load_from_checkpoint(
60+
checkpoint_path: str, reload_hparams: bool = False, strict: bool = False
61+
) -> CometModel:
5662
"""Loads models from a checkpoint path.
5763
5864
Args:
5965
checkpoint_path (str): Path to a model checkpoint.
60-
66+
reload_hparams (bool): hparams.yaml file located in the parent folder is
67+
only use for deciding the `class_identifier`. By setting this flag
68+
to True all hparams will be reloaded.
69+
strict (bool): Strictly enforce that the keys in checkpoint_path match the
70+
keys returned by this module's state dict. Defaults to False
6171
Return:
6272
COMET model.
6373
"""
6474
checkpoint_path = Path(checkpoint_path)
6575

6676
if not checkpoint_path.is_file():
6777
raise Exception(f"Invalid checkpoint path: {checkpoint_path}")
68-
69-
parent_folder = checkpoint_path.parents[1] # .parent.parent
78+
79+
parent_folder = checkpoint_path.parents[1] # .parent.parent
7080
hparams_file = parent_folder / "hparams.yaml"
7181

7282
if hparams_file.is_file():
7383
with open(hparams_file) as yaml_file:
7484
hparams = yaml.load(yaml_file.read(), Loader=yaml.FullLoader)
7585
model_class = str2model[hparams["class_identifier"]]
7686
model = model_class.load_from_checkpoint(
77-
checkpoint_path, load_pretrained_weights=False
87+
checkpoint_path,
88+
load_pretrained_weights=False,
89+
hparams_file=hparams_file if reload_hparams else None,
90+
map_location=torch.device("cpu"),
91+
strict=strict,
7892
)
7993
return model
8094
else:

0 commit comments

Comments
 (0)