1
+ import pandas as pd
1
2
import pickle
2
3
3
4
from torch .utils import data as data
4
5
import torchvision .transforms as tf
5
6
6
- from pyiqa .data .data_util import read_meta_info_file
7
7
from pyiqa .data .transforms import transform_mapping , PairedToTensor
8
8
from pyiqa .utils import get_root_logger
9
9
@@ -24,6 +24,7 @@ def __init__(self, opt):
24
24
self .phase = opt ['phase' ]
25
25
else :
26
26
self .phase = opt ['override_phase' ]
27
+ assert self .phase in ['train' , 'val' , 'test' ], f'phase should be in [train, val, test], got { self .phase } '
27
28
28
29
# initialize datasets
29
30
self .init_path_mos (opt )
@@ -38,10 +39,10 @@ def __init__(self, opt):
38
39
self .get_transforms (opt )
39
40
40
41
def init_path_mos (self , opt ):
41
- target_img_folder = opt ['dataroot_target' ]
42
- self .paths_mos = read_meta_info_file ( target_img_folder , opt [ 'meta_info_file' ] )
43
-
44
- def get_split (self , opt ):
42
+ self . meta_info = pd . read_csv ( opt ['meta_info_file' ])
43
+ self .paths_mos = self . meta_info . values . tolist ( )
44
+
45
+ def get_split_with_file (self , opt ):
45
46
# read train/val/test splits
46
47
split_file_path = opt .get ('split_file' , None )
47
48
if split_file_path :
@@ -50,7 +51,35 @@ def get_split(self, opt):
50
51
split_dict = pickle .load (f )
51
52
splits = split_dict [split_index ][self .phase ]
52
53
self .paths_mos = [self .paths_mos [i ] for i in splits ]
53
-
54
+
55
+ def get_split (self , opt ):
56
+ """Read train/val/test splits
57
+ """
58
+ # compatible with previous version using split file
59
+ if opt .get ('split_file' , None ) is not None :
60
+ self .get_split_with_file (opt )
61
+ return
62
+
63
+ # get all split column names
64
+ all_split_lists = [x for x in self .meta_info .columns .tolist () if 'split' in x ]
65
+
66
+ split_index = opt .get ('split_index' , None )
67
+
68
+ if split_index is not None :
69
+ if isinstance (split_index , str ):
70
+ split_name = split_index
71
+ elif isinstance (split_index , int ):
72
+ split_ratio = opt .get ('split_ratio' , '802' )
73
+ split_name = f'ratio{ split_ratio } _seed123_split_{ split_index :02d} '
74
+
75
+ assert split_name in all_split_lists , f'The given split [{ split_name } ] is not available in { all_split_lists } '
76
+
77
+ split_paths_mos = []
78
+ for i in range (len (self .paths_mos )):
79
+ if self .meta_info [split_name ][i ] == self .phase :
80
+ split_paths_mos .append (self .paths_mos [i ])
81
+ self .paths_mos = split_paths_mos
82
+
54
83
def mos_normalize (self , opt ):
55
84
mos_range = opt .get ('mos_range' , None )
56
85
mos_lower_better = opt .get ('lower_better' , None )
0 commit comments