Skip to content

Commit e66dee7

Browse files
committed
refactor: 🚩 merge split info to meta csv files
1 parent 0575058 commit e66dee7

8 files changed

+107
-56
lines changed

options/default_dataset_opt.yml

+8-9
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ koniq10k:
5454
type: GeneralNRDataset
5555
dataroot_target: './datasets/koniq10k/512x384'
5656
meta_info_file: './datasets/meta_info/meta_info_KonIQ10kDataset.csv'
57-
split_file: './datasets/meta_info/koniq10k_official.pkl'
57+
split_index: 'official_split'
5858
phase: 'test'
5959
mos_range: [0, 100]
6060
lower_better: false
@@ -64,7 +64,7 @@ koniq10k-1024:
6464
type: GeneralNRDataset
6565
dataroot_target: './datasets/koniq10k/1024x768'
6666
meta_info_file: './datasets/meta_info/meta_info_KonIQ10kDataset.csv'
67-
split_file: './datasets/meta_info/koniq10k_official.pkl'
67+
split_index: 'official_split'
6868
phase: 'test'
6969
mos_range: [0, 100]
7070
lower_better: false
@@ -74,7 +74,7 @@ koniq10k++:
7474
type: GeneralNRDataset
7575
dataroot_target: './datasets/koniq10k/512x384'
7676
meta_info_file: './datasets/meta_info/meta_info_KonIQ10k++Dataset.csv'
77-
split_file: './datasets/meta_info/koniq10k_official.pkl'
77+
split_index: 'official_split'
7878
phase: 'test'
7979
mos_range: [1, 5]
8080
lower_better: false
@@ -102,8 +102,7 @@ ava:
102102
type: AVADataset
103103
dataroot_target: './datasets/AVA_dataset/ava_images/'
104104
meta_info_file: './datasets/meta_info/meta_info_AVADataset.csv'
105-
split_file: './datasets/meta_info/ava_official_ilgnet.pkl'
106-
split_index: 1 # use official split
105+
split_index: 'official_split'
107106
mos_range: [1, 10]
108107
lower_better: false
109108

@@ -113,7 +112,7 @@ pipal:
113112
dataroot_target: './datasets/PIPAL/Dist_Imgs'
114113
dataroot_ref: './datasets/PIPAL/Train_Ref'
115114
meta_info_file: './datasets/meta_info/meta_info_PIPALDataset.csv'
116-
split_file: './datasets/meta_info/pipal_official.pkl'
115+
split_index: 'official_split'
117116
mos_range: [0, 1]
118117
lower_better: false
119118

@@ -122,7 +121,7 @@ flive:
122121
type: GeneralNRDataset
123122
dataroot_target: './datasets/FLIVE_Database/database'
124123
meta_info_file: './datasets/meta_info/meta_info_FLIVEDataset.csv'
125-
split_file: './datasets/meta_info/flive_official.pkl'
124+
split_index: 'official_split'
126125
phase: test
127126
mos_range: [0, 100]
128127
lower_better: false
@@ -132,7 +131,7 @@ pieapp:
132131
type: PieAPPDataset
133132
dataroot_target: './datasets/PieAPP_dataset_CVPR_2018/'
134133
meta_info_file: './datasets/meta_info/meta_info_PieAPPDataset.csv'
135-
split_file: './datasets/meta_info/pieapp_official.pkl'
134+
split_index: 'official_split'
136135

137136
bapps:
138137
name: BAPPS
@@ -152,4 +151,4 @@ gfiqa:
152151
type: GeneralNRDataset
153152
dataroot_target: ./datasets/GFIQA/image
154153
meta_info_file: ./datasets/meta_info/meta_info_GFIQADataset.csv
155-
split_file: ./datasets/meta_info/gfiqa_seed123.pkl
154+
split_index: 1

pyiqa/data/ava_dataset.py

+34-16
Original file line numberDiff line numberDiff line change
@@ -29,30 +29,48 @@ class AVADataset(BaseIQADataset):
2929
"""
3030

3131
def init_path_mos(self, opt):
32+
super().init_path_mos(opt)
3233
target_img_folder = opt['dataroot_target']
3334
self.dataroot = target_img_folder
34-
self.paths_mos = pd.read_csv(opt['meta_info_file']).values.tolist()
3535

3636
def get_split(self, opt):
37-
# read train/val/test splits
38-
split_file_path = opt.get('split_file', None)
39-
if split_file_path:
40-
split_index = opt.get('split_index', 1)
41-
with open(opt['split_file'], 'rb') as f:
42-
split_dict = pickle.load(f)
43-
37+
split_index = opt.get('split_index', None)
38+
39+
# compatible with previous version using split file
40+
# when using split file, previous version will use official_split or split_index=1
41+
if opt.get('split_file', None) is not None:
42+
split_index = 'official_split'
43+
44+
if split_index is not None:
4445
# use val_num for validation
4546
val_num = opt.get('val_num', 2000)
46-
train_split = split_dict[split_index]['train']
47-
val_split = split_dict[split_index]['val']
48-
train_split = train_split + val_split[:-val_num]
49-
val_split = val_split[-val_num:]
50-
split_dict[split_index]['train'] = train_split
51-
split_dict[split_index]['val'] = val_split
5247

53-
splits = split_dict[split_index][self.phase]
54-
self.paths_mos = [self.paths_mos[i] for i in splits]
48+
train_split_paths_mos = []
49+
val_split_paths_mos = []
50+
test_split_paths_mos = []
51+
for i in range(len(self.paths_mos)):
52+
if self.meta_info[split_index][i] == 0: # 0 for train
53+
train_split_paths_mos.append(self.paths_mos[i])
54+
elif self.meta_info[split_index][i] == 1: # 1 for val
55+
val_split_paths_mos.append(self.paths_mos[i])
56+
elif self.meta_info[split_index][i] == 2: # 2 for test
57+
test_split_paths_mos.append(self.paths_mos[i])
5558

59+
if len(val_split_paths_mos) < val_num:
60+
val_num = val_num - len(val_split_paths_mos)
61+
val_split_paths_mos = val_split_paths_mos + train_split_paths_mos[-val_num:]
62+
train_split_paths_mos = train_split_paths_mos[:-val_num]
63+
else:
64+
train_split_paths_mos = train_split_paths_mos + val_split_paths_mos[:-val_num]
65+
val_split_paths_mos = val_split_paths_mos[-val_num:]
66+
67+
if self.phase == 'train':
68+
self.paths_mos = train_split_paths_mos
69+
elif self.phase == 'val':
70+
self.paths_mos = val_split_paths_mos
71+
elif self.phase == 'test':
72+
self.paths_mos = test_split_paths_mos
73+
5674
self.mean_mos = np.array([item[1] for item in self.paths_mos]).mean()
5775

5876
def __getitem__(self, index):

pyiqa/data/bapps_dataset.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,10 @@ def init_path_mos(self, opt):
4242
self.paths_mos = pd.read_csv(opt['meta_info_file']).values.tolist()
4343

4444
def get_split(self, opt):
45-
val_types = opt.get('val_types', None)
46-
# read train/val/test splits
47-
split_file_path = opt.get('split_file', None)
48-
if split_file_path:
49-
split_index = opt.get('split_index', 1)
50-
with open(opt['split_file'], 'rb') as f:
51-
split_dict = pickle.load(f)
52-
splits = split_dict[split_index][self.phase]
53-
self.paths_mos = [self.paths_mos[i] for i in splits]
45+
super().get_split(opt)
5446

47+
val_types = opt.get('val_types', None)
48+
5549
if self.dataset_mode == '2afc':
5650
self.paths_mos = [x for x in self.paths_mos if x[0] != 'jnd']
5751
elif self.dataset_mode == 'jnd':

pyiqa/data/base_iqa_dataset.py

+35-6
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1+
import pandas as pd
12
import pickle
23

34
from torch.utils import data as data
45
import torchvision.transforms as tf
56

6-
from pyiqa.data.data_util import read_meta_info_file
77
from pyiqa.data.transforms import transform_mapping, PairedToTensor
88
from pyiqa.utils import get_root_logger
99

@@ -24,6 +24,7 @@ def __init__(self, opt):
2424
self.phase = opt['phase']
2525
else:
2626
self.phase = opt['override_phase']
27+
assert self.phase in ['train', 'val', 'test'], f'phase should be in [train, val, test], got {self.phase}'
2728

2829
# initialize datasets
2930
self.init_path_mos(opt)
@@ -38,10 +39,10 @@ def __init__(self, opt):
3839
self.get_transforms(opt)
3940

4041
def init_path_mos(self, opt):
41-
target_img_folder = opt['dataroot_target']
42-
self.paths_mos = read_meta_info_file(target_img_folder, opt['meta_info_file'])
43-
44-
def get_split(self, opt):
42+
self.meta_info = pd.read_csv(opt['meta_info_file'])
43+
self.paths_mos = self.meta_info.values.tolist()
44+
45+
def get_split_with_file(self, opt):
4546
# read train/val/test splits
4647
split_file_path = opt.get('split_file', None)
4748
if split_file_path:
@@ -50,7 +51,35 @@ def get_split(self, opt):
5051
split_dict = pickle.load(f)
5152
splits = split_dict[split_index][self.phase]
5253
self.paths_mos = [self.paths_mos[i] for i in splits]
53-
54+
55+
def get_split(self, opt):
56+
"""Read train/val/test splits
57+
"""
58+
# compatible with previous version using split file
59+
if opt.get('split_file', None) is not None:
60+
self.get_split_with_file(opt)
61+
return
62+
63+
# get all split column names
64+
all_split_lists = [x for x in self.meta_info.columns.tolist() if 'split' in x]
65+
66+
split_index = opt.get('split_index', None)
67+
68+
if split_index is not None:
69+
if isinstance(split_index, str):
70+
split_name = split_index
71+
elif isinstance(split_index, int):
72+
split_ratio = opt.get('split_ratio', '802')
73+
split_name = f'ratio{split_ratio}_seed123_split_{split_index:02d}'
74+
75+
assert split_name in all_split_lists, f'The given split [{split_name}] is not available in {all_split_lists}'
76+
77+
split_paths_mos = []
78+
for i in range(len(self.paths_mos)):
79+
if self.meta_info[split_name][i] == self.phase:
80+
split_paths_mos.append(self.paths_mos[i])
81+
self.paths_mos = split_paths_mos
82+
5483
def mos_normalize(self, opt):
5584
mos_range = opt.get('mos_range', None)
5685
mos_lower_better = opt.get('lower_better', None)

pyiqa/data/general_fr_dataset.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from PIL import Image
2+
from os import path as osp
23

34
import torch
45
from torch.utils import data as data
56
import torchvision.transforms as tf
67

7-
from pyiqa.data.data_util import read_meta_info_file
88
from pyiqa.data.transforms import transform_mapping, PairedToTensor
99
from pyiqa.utils.registry import DATASET_REGISTRY
1010

@@ -16,9 +16,19 @@ class GeneralFRDataset(BaseIQADataset):
1616
"""
1717

1818
def init_path_mos(self, opt):
19+
super().init_path_mos(opt)
20+
1921
target_img_folder = opt['dataroot_target']
2022
ref_img_folder = opt.get('dataroot_ref', None)
21-
self.paths_mos = read_meta_info_file(target_img_folder, opt['meta_info_file'], mode='fr', ref_dir=ref_img_folder)
23+
if ref_img_folder is None:
24+
ref_img_folder = target_img_folder
25+
26+
self.paths_mos = []
27+
for row in self.meta_info.values:
28+
ref_path = osp.join(ref_img_folder, row[0])
29+
img_path = osp.join(target_img_folder, row[1])
30+
mos_label = float(row[2])
31+
self.paths_mos.append([img_path, ref_path, mos_label])
2232

2333
def get_transforms(self, opt):
2434
# do paired transform first and then do common transform

pyiqa/data/general_nr_dataset.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from PIL import Image
2+
from os import path as osp
3+
24
import torch
35
from torch.utils import data as data
46

5-
from pyiqa.data.data_util import read_meta_info_file
67
from pyiqa.utils.registry import DATASET_REGISTRY
78
from .base_iqa_dataset import BaseIQADataset
89

@@ -11,8 +12,15 @@ class GeneralNRDataset(BaseIQADataset):
1112
"""General No Reference dataset with meta info file.
1213
"""
1314
def init_path_mos(self, opt):
15+
super().init_path_mos(opt)
16+
1417
target_img_folder = opt['dataroot_target']
15-
self.paths_mos = read_meta_info_file(target_img_folder, opt['meta_info_file'])
18+
19+
self.paths_mos = []
20+
for row in self.meta_info.values:
21+
img_path = osp.join(target_img_folder, row[0])
22+
mos_label = float(row[1])
23+
self.paths_mos.append([img_path, mos_label])
1624

1725
def __getitem__(self, index):
1826

pyiqa/data/livechallenge_dataset.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ class LIVEChallengeDataset(GeneralNRDataset):
2020
"""
2121

2222
def init_path_mos(self, opt):
23-
target_img_folder = os.path.join(opt['dataroot_target'], 'Images')
24-
self.paths_mos = read_meta_info_file(target_img_folder, opt['meta_info_file'])
23+
super().init_path_mos(opt)
2524
# remove first 7 training images as previous works
2625
self.paths_mos = self.paths_mos[7:]

pyiqa/data/pieapp_dataset.py

+4-10
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,13 @@ def init_path_mos(self, opt):
3232
self.paths_mos = metadata.values.tolist()
3333

3434
def get_split(self, opt):
35-
# read train/val/test splits
36-
split_file_path = opt.get('split_file', None)
37-
if split_file_path:
38-
split_index = opt.get('split_index', 1)
39-
with open(opt['split_file'], 'rb') as f:
40-
split_dict = pickle.load(f)
41-
splits = split_dict[split_index][self.phase]
42-
self.paths_mos = [self.paths_mos[i] for i in splits]
43-
35+
super().get_split(opt)
4436
# remove duplicates
4537
if self.phase == 'test':
4638
temp = []
47-
[temp.append(item) for item in self.paths_mos if not item in temp]
39+
for item in self.paths_mos:
40+
if not item in temp:
41+
temp.append(item)
4842
self.paths_mos = temp
4943

5044
def __getitem__(self, index):

0 commit comments

Comments
 (0)