Skip to content

Commit cbba398

Browse files
committed
feat: 🧑‍💻 add **rough** score range for each metric
1 parent b57dede commit cbba398

9 files changed

+130
-34
lines changed

pyiqa/api_helpers.py

+6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import fnmatch
22
import re
33
from pyiqa.default_model_configs import DEFAULT_CONFIGS
4+
from pyiqa.dataset_info import DATASET_INFO
45

56
from pyiqa.utils import get_root_logger
67
from pyiqa.models.inference_model import InferenceModel
@@ -49,3 +50,8 @@ def list_models(metric_mode=None, filter='', exclude_filters=''):
4950
if len(exclude_models):
5051
models = set(models).difference(exclude_models)
5152
return list(sorted(models, key=_natural_key))
53+
54+
55+
def get_dataset_info(dataset_name):
56+
assert dataset_name in DATASET_INFO.keys(), f'Dataset {dataset_name} not implemented yet.'
57+
return DATASET_INFO[dataset_name]

pyiqa/archs/ahiq_arch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def __init__(
198198
)
199199
elif pretrained:
200200
weight_path = load_file_from_url(default_model_urls["pipal"])
201-
checkpoint = torch.load(weight_path)
201+
checkpoint = torch.load(weight_path, map_location='cpu', weights_only=False)
202202
self.regressor.load_state_dict(checkpoint["regressor_model_state_dict"])
203203
self.deform_net.load_state_dict(checkpoint["deform_net_model_state_dict"])
204204

pyiqa/archs/arch_util.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def load_pretrained_network(net, model_path, strict=True, weight_keys=None):
163163
if model_path.startswith("https://") or model_path.startswith("http://"):
164164
model_path = load_file_from_url(model_path)
165165
print(f"Loading pretrained model {net.__class__.__name__} from {model_path}")
166-
state_dict = torch.load(model_path, map_location=torch.device("cpu"))
166+
state_dict = torch.load(model_path, map_location=torch.device("cpu"), weights_only=False)
167167
if weight_keys is not None:
168168
state_dict = state_dict[weight_keys]
169169
state_dict = clean_state_dict(state_dict)

pyiqa/archs/brisque_arch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def brisque(x: torch.Tensor,
6161
scaled_features = scale_features(features)
6262

6363
if pretrained_model_path:
64-
sv_coef, sv = torch.load(pretrained_model_path)
64+
sv_coef, sv = torch.load(pretrained_model_path, weights_only=False)
6565
sv_coef = sv_coef.to(x)
6666
sv = sv.to(x)
6767

pyiqa/archs/clipiqa_arch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def __init__(self,
138138

139139
if pretrained and 'clipiqa+' in model_type:
140140
if model_type == 'clipiqa+' and backbone == 'RN50':
141-
self.prompt_learner.ctx.data = torch.load(load_file_from_url(default_model_urls['clipiqa+']))
141+
self.prompt_learner.ctx.data = torch.load(load_file_from_url(default_model_urls['clipiqa+']), weights_only=False)
142142
elif model_type in default_model_urls.keys():
143143
load_pretrained_network(self, default_model_urls[model_type], True, 'params')
144144
else:

pyiqa/archs/liqe_arch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __init__(self,
7575
text_feat_cache_path = os.path.expanduser("~/.cache/pyiqa/liqe_text_feat.pt")
7676

7777
if os.path.exists(text_feat_cache_path):
78-
self.text_features = torch.load(text_feat_cache_path, map_location='cpu')
78+
self.text_features = torch.load(text_feat_cache_path, map_location='cpu', weights_only=False)
7979
else:
8080
print(f'Generating text features for LIQE model, will be cached at {text_feat_cache_path}.')
8181
if self.mtl:

pyiqa/dataset_info.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
2+
DATASET_INFO = {
3+
"live": {
4+
"score_range": (1, 100),
5+
"mos_type": "dmos"
6+
},
7+
"csiq": {
8+
"score_range": (0, 1),
9+
"mos_type": "dmos"
10+
},
11+
"tid": {
12+
"score_range": (0, 9),
13+
"mos_type": "mos"
14+
},
15+
"kadid": {
16+
"score_range": (1, 5),
17+
"mos_type": "mos"
18+
},
19+
"koniq": {
20+
"score_range": (1, 100),
21+
"mos_type": "mos"
22+
},
23+
"clive": {
24+
"score_range": (1, 100),
25+
"mos_type": "mos"
26+
},
27+
"flive": {
28+
"score_range": (1, 100),
29+
"mos_type": "mos"
30+
},
31+
"spaq": {
32+
"score_range": (1, 100),
33+
"mos_type": "mos"
34+
},
35+
"ava": {
36+
"score_range": (1, 10),
37+
"mos_type": "mos"
38+
},
39+
}

0 commit comments

Comments
 (0)