17
17
from pathlib import Path
18
18
from typing import Union
19
19
20
+ import torch
20
21
import yaml
21
22
from huggingface_hub import snapshot_download
22
23
23
24
from .base import CometModel
24
- from .download_utils import download_model_legacy
25
25
from .multitask .unified_metric import UnifiedMetric
26
26
from .ranking .ranking_metric import RankingMetric
27
27
from .regression .referenceless import ReferencelessRegression
28
28
from .regression .regression_metric import RegressionMetric
29
+ from .download_utils import download_model_legacy
30
+
29
31
30
32
str2model = {
31
33
"referenceless_regression_metric" : ReferencelessRegression ,
34
36
"unified_metric" : UnifiedMetric ,
35
37
}
36
38
39
+
37
40
def download_model (
38
41
model : str ,
39
42
saving_directory : Union [str , Path , None ] = None ,
40
- local_files_only : bool = False
43
+ local_files_only : bool = False ,
41
44
) -> str :
42
45
try :
43
46
model_path = snapshot_download (
@@ -52,29 +55,40 @@ def download_model(
52
55
checkpoint_path = os .path .join (* [model_path , "checkpoints" , "model.ckpt" ])
53
56
return checkpoint_path
54
57
55
- def load_from_checkpoint (checkpoint_path : str ) -> CometModel :
58
+
59
+ def load_from_checkpoint (
60
+ checkpoint_path : str , reload_hparams : bool = False , strict : bool = False
61
+ ) -> CometModel :
56
62
"""Loads models from a checkpoint path.
57
63
58
64
Args:
59
65
checkpoint_path (str): Path to a model checkpoint.
60
-
66
+ reload_hparams (bool): hparams.yaml file located in the parent folder is
67
+ only use for deciding the `class_identifier`. By setting this flag
68
+ to True all hparams will be reloaded.
69
+ strict (bool): Strictly enforce that the keys in checkpoint_path match the
70
+ keys returned by this module's state dict. Defaults to False
61
71
Return:
62
72
COMET model.
63
73
"""
64
74
checkpoint_path = Path (checkpoint_path )
65
75
66
76
if not checkpoint_path .is_file ():
67
77
raise Exception (f"Invalid checkpoint path: { checkpoint_path } " )
68
-
69
- parent_folder = checkpoint_path .parents [1 ] # .parent.parent
78
+
79
+ parent_folder = checkpoint_path .parents [1 ] # .parent.parent
70
80
hparams_file = parent_folder / "hparams.yaml"
71
81
72
82
if hparams_file .is_file ():
73
83
with open (hparams_file ) as yaml_file :
74
84
hparams = yaml .load (yaml_file .read (), Loader = yaml .FullLoader )
75
85
model_class = str2model [hparams ["class_identifier" ]]
76
86
model = model_class .load_from_checkpoint (
77
- checkpoint_path , load_pretrained_weights = False
87
+ checkpoint_path ,
88
+ load_pretrained_weights = False ,
89
+ hparams_file = hparams_file if reload_hparams else None ,
90
+ map_location = torch .device ("cpu" ),
91
+ strict = strict ,
78
92
)
79
93
return model
80
94
else :
0 commit comments