Skip to content

Commit 0f85a1e

Browse files
committed
upload test demo and test_cross_dataset demo
1 parent d5c772d commit 0f85a1e

File tree

10 files changed

+107
-10
lines changed

10 files changed

+107
-10
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.

data/TID2008info.mat

-130 KB
Binary file not shown.

data/TID2013info.mat

-152 KB
Binary file not shown.

images/img98.jpg

109 KB
Loading

images/img98_colorblock_5.jpg

145 KB
Loading

main.py

-10
Original file line numberDiff line numberDiff line change
@@ -759,16 +759,6 @@ def final_testing_results(engine):
759759
args.im_dir = '/media/ldq/Others/Data/kadid10k/image/'
760760
args.ref_dir = '/media/ldq/Others/Data/kadid10k/image/'
761761

762-
if args.database == 'TID2013blur':
763-
args.data_info = './data/TID2013info.mat'
764-
args.im_dir = '/media/ldq/Research/Data/tid2013/distorted_images/'
765-
args.ref_dir = '/media/ldq/Research/Data/tid2013/reference_images/'
766-
767-
if args.database == 'TID2008blur':
768-
args.data_info = './data/TID2008info.mat'
769-
args.im_dir = '/media/ldq/Research/Data/tid2008/distorted_images/'
770-
args.ref_dir = '/media/ldq/Research/Data/tid2008/reference_images/'
771-
772762
if args.database == 'LIVE':
773763
args.data_info = './data/LIVEfullinfo.mat'
774764
args.im_dir = '/media/ldq/Research/Data/databaserelease2/'

test.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""
2+
Test
3+
For help
4+
```bash
5+
python test.py --help
6+
```
7+
Date: 2019/9/20
8+
"""
9+
10+
from argparse import ArgumentParser
11+
import torch
12+
from torch import nn
13+
import torch.nn.functional as F
14+
from PIL import Image
15+
from main import RandomCropPatches, NonOverlappingCropPatches, FRnet
16+
import numpy as np
17+
import h5py, os
18+
19+
20+
if __name__ == "__main__":
21+
parser = ArgumentParser(description='PyTorch WaDIQaM-FR test')
22+
parser.add_argument("--dist_path", type=str, default='images/img98_colorblock_5.jpg',
23+
help="distorted image path.")
24+
parser.add_argument("--ref_path", type=str, default='images/img98.jpg',
25+
help="reference image path.")
26+
parser.add_argument("--model_file", type=str, default='checkpoints/WaDIQaM-FR-KADID-10K-EXP1000-5-lr=0.0001-bs=4',
27+
help="model file (default: checkpoints/WaDIQaM-FR-KADID-10K-EXP1000-5-lr=0.0001-bs=4)")
28+
29+
args = parser.parse_args()
30+
31+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32+
33+
model = FRnet(weighted_average=True).to(device)
34+
35+
model.load_state_dict(torch.load(args.model_file))
36+
37+
model.eval()
38+
with torch.no_grad():
39+
im = Image.open(args.dist_path).convert('RGB')
40+
ref = Image.open(args.ref_path).convert('RGB')
41+
# data = RandomCropPatches(im, ref)
42+
data = NonOverlappingCropPatches(im, ref)
43+
44+
dist_patches = data[0].unsqueeze(0).to(device)
45+
ref_patches = data[1].unsqueeze(0).to(device)
46+
score = model((dist_patches, ref_patches))
47+
print(score.item())

test_cross_dataset.py

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""
2+
Test Cross Dataset
3+
For help
4+
```bash
5+
python test_cross_dataset.py --help
6+
```
7+
Date: 2018/9/20
8+
"""
9+
10+
from argparse import ArgumentParser
11+
import torch
12+
from torch import nn
13+
import torch.nn.functional as F
14+
from PIL import Image
15+
from main import RandomCropPatches, NonOverlappingCropPatches, FRnet
16+
import numpy as np
17+
import h5py, os
18+
19+
20+
if __name__ == "__main__":
21+
parser = ArgumentParser(description='PyTorch WaDIQaM-FR test on the whole cross dataset')
22+
parser.add_argument("--dist_dir", type=str, default=None,
23+
help="distorted images dir.")
24+
parser.add_argument("--ref_dir", type=str, default=None,
25+
help="reference images dir.")
26+
parser.add_argument("--names_info", type=str, default=None,
27+
help=".mat file that includes image names in the dataset.")
28+
parser.add_argument("--model_file", type=str, default='checkpoints/WaDIQaM-FR-KADID-10K-EXP1000-5-lr=0.0001-bs=4',
29+
help="model file (default: checkpoints/WaDIQaM-FR-KADID-10K-EXP1000-5-lr=0.0001-bs=4)")
30+
parser.add_argument("--save_path", type=str, default='scores',
31+
help="save path (default: scores)")
32+
33+
args = parser.parse_args()
34+
35+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36+
37+
model = FRnet(weighted_average=True).to(device)
38+
39+
model.load_state_dict(torch.load(args.model_file))
40+
41+
Info = h5py.File(args.names_info)
42+
im_names = [Info[Info['im_names'][0, :][i]].value.tobytes()\
43+
[::2].decode() for i in range(len(Info['im_names'][0, :]))]
44+
ref_names = [Info[Info['ref_names'][0, :][i]].value.tobytes()\
45+
[::2].decode() for i in (Info['ref_ids'][0, :]-1).astype(int)]
46+
47+
model.eval()
48+
scores = []
49+
with torch.no_grad():
50+
for i in range(len(im_names)):
51+
im = Image.open(os.path.join(args.dist_dir, im_names[i])).convert('RGB')
52+
ref = Image.open(os.path.join(args.ref_dir, ref_names[i])).convert('RGB')
53+
# data = RandomCropPatches(im, ref)
54+
data = NonOverlappingCropPatches(im, ref)
55+
56+
dist_patches = data[0].unsqueeze(0).to(device)
57+
ref_patches = data[1].unsqueeze(0).to(device)
58+
score = model((dist_patches, ref_patches))
59+
scores.append(score.item())
60+
np.save(args.save_path, scores)

0 commit comments

Comments
 (0)