Skip to content

Commit b57dede

Browse files
committed
feat: 🚩 add lpips+ and lpips-vgg+
1 parent 2c1de94 commit b57dede

6 files changed

+55
-17
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ This is a image quality assessment toolbox with **pure python and pytorch**. We
3636
---
3737

3838
### :triangular_flag_on_post: Updates/Changelog
39+
- 🔥**Aug, 2024**. Add `lpips+` and `lpips-vgg+` proposed in our paper [TOPIQ](https://arxiv.org/abs/2308.03060).
3940
- 🔥**June, 2024**. Add `arniqa` and its variances trained on different datasets, refer to official repo [here](https://github.com/miccunifi/ARNIQA). Thanks for the contribution from [Lorenzo Agnolucci](https://github.com/LorenzoAgnolucci) 🤗.
4041
- **Apr 24, 2024**. Add `inception_score` and console entry point with `pyiqa` command.
4142
- **Mar 11, 2024**. Add `unique`, refer to official repo [here](https://github.com/zwx8981/UNIQUE). Thanks for the contribution from [Weixia Zhang](https://github.com/zwx8981) 🤗.

benchmark_results.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,9 @@ def main():
126126
for data in dataloader:
127127
gt_labels += flatten_list(data['mos_label'].cpu().tolist())
128128
if metric_mode == 'FR':
129-
iqa_score = iqa_model(data['img'], data['ref_img']).cpu().tolist()
129+
iqa_score = iqa_model(data['img'], data['ref_img']).squeeze().cpu().tolist()
130130
else:
131-
iqa_score = iqa_model(data['img']).cpu().tolist()
131+
iqa_score = iqa_model(data['img']).squeeze().cpu().tolist()
132132
result_scores += flatten_list(iqa_score)
133133
pbar.update(1)
134134
pbar.close()
@@ -151,14 +151,12 @@ def main():
151151
if save_result_path is not None:
152152
csv_writer.writerow(results_row)
153153

154-
155-
156154
if save_result_path is not None:
157155
csv_file.close()
158156

159157
if update_benchmark_file is not None:
158+
benchmark = benchmark.sort_values(by=benchmark.columns[0], key=lambda x: x.str.split('/').str[0].astype(float))
160159
benchmark.to_csv(update_benchmark_file)
161160

162-
163161
if __name__ == '__main__':
164162
main()

pyiqa/archs/lpips_arch.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@
44
55
Modified by: Jiadi Mo (https://github.com/JiadiMo)
66
7+
Reference:
8+
Zhang, Richard, et al. "The unreasonable effectiveness of deep features as
9+
a perceptual metric." Proceedings of the IEEE conference on computer vision
10+
and pattern recognition. 2018.
11+
12+
TOPIQ: A Top-down Approach from Semantics to Distortions for Image Quality Assessment.
13+
Chaofeng Chen, Jiadi Mo, Jingwen Hou, Haoning Wu, Liang Liao, Wenxiu Sun, Qiong Yan, Weisi Lin.
14+
Transactions on Image Processing, 2024.
715
"""
816

917
import torch
@@ -63,10 +71,6 @@ class LPIPS(nn.Module):
6371
pnet_tune (Boolean): Whether to tune the base/trunk network.
6472
use_dropout (Boolean): Whether to use dropout when training linear layers.
6573
66-
Reference:
67-
Zhang, Richard, et al. "The unreasonable effectiveness of deep features as
68-
a perceptual metric." Proceedings of the IEEE conference on computer vision
69-
and pattern recognition. 2018.
7074
7175
"""
7276

@@ -81,6 +85,7 @@ def __init__(self,
8185
use_dropout=True,
8286
pretrained_model_path=None,
8387
eval_mode=True,
88+
semantic_weight_layer=-1,
8489
**kwargs):
8590

8691
super(LPIPS, self).__init__()
@@ -93,6 +98,8 @@ def __init__(self,
9398
self.version = version
9499
self.scaling_layer = ScalingLayer()
95100

101+
self.semantic_weight_layer = semantic_weight_layer
102+
96103
if (self.pnet_type in ['vgg', 'vgg16']):
97104
net_type = vgg16
98105
self.chns = [64, 128, 256, 512, 512]
@@ -156,8 +163,16 @@ def forward(self, in1, in0, retPerLayer=False, normalize=True):
156163
diffs[kk] = (feats0[kk] - feats1[kk])**2
157164

158165
if (self.lpips):
159-
if (self.spatial):
166+
if self.spatial:
160167
res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)]
168+
elif self.semantic_weight_layer >= 0:
169+
res = []
170+
semantic_feat = outs0[self.semantic_weight_layer]
171+
for kk in range(self.L):
172+
diff_score = self.lins[kk](diffs[kk])
173+
semantic_weight = torch.nn.functional.interpolate(semantic_feat, size=diff_score.shape[2:], mode='bilinear', align_corners=False)
174+
avg_score = torch.sum(diff_score * semantic_weight, dim=[1, 2, 3], keepdim=True) / torch.sum(semantic_weight, dim=[1, 2, 3], keepdim=True)
175+
res.append(avg_score)
161176
else:
162177
res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
163178
else:

pyiqa/archs/topiq_arch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
33
TOPIQ: A Top-down Approach from Semantics to Distortions for Image Quality Assessment.
44
Chaofeng Chen, Jiadi Mo, Jingwen Hou, Haoning Wu, Liang Liao, Wenxiu Sun, Qiong Yan, Weisi Lin.
5-
Arxiv 2023.
5+
Transactions on Image Processing, 2024.
66
77
Paper link: https://arxiv.org/abs/2308.03060
88

pyiqa/default_model_configs.py

+20
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,26 @@
3131
'metric_mode': 'FR',
3232
'lower_better': True,
3333
},
34+
'lpips+': {
35+
'metric_opts': {
36+
'type': 'LPIPS',
37+
'net': 'alex',
38+
'version': '0.1',
39+
'semantic_weight_layer': 2,
40+
},
41+
'metric_mode': 'FR',
42+
'lower_better': True,
43+
},
44+
'lpips-vgg+': {
45+
'metric_opts': {
46+
'type': 'LPIPS',
47+
'net': 'vgg',
48+
'version': '0.1',
49+
'semantic_weight_layer': 2,
50+
},
51+
'metric_mode': 'FR',
52+
'lower_better': True,
53+
},
3454
'stlpips': {
3555
'metric_opts': {
3656
'type': 'STLPIPS',

tests/FR_benchmark_results.csv

+10-6
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
Metric name,csiq(PLCC/SRCC/KRCC),live(PLCC/SRCC/KRCC),tid2008(PLCC/SRCC/KRCC),tid2013(PLCC/SRCC/KRCC)
2-
psnr,0.7857/0.8087/0.5989,0.7633/0.8013/0.5964,0.489/0.5245/0.3696,0.6601/0.6869/0.4958
2+
cw_ssim,0.6078/0.7588/0.5562,0.5714/0.7681/0.5673,0.5965/0.6473/0.4625,0.5815/0.6533/0.4715
33
ssim,0.765/0.8367/0.6323,0.7369/0.8509/0.6547,0.6003/0.6242/0.4521,0.6558/0.6269/0.455
44
ms_ssim,0.7717/0.9125/0.7372,0.679/0.9027/0.7227,0.7894/0.8531/0.6555,0.7814/0.7852/0.6033
5-
cw_ssim,0.6078/0.7588/0.5562,0.5714/0.7681/0.5673,0.5965/0.6473/0.4625,0.5815/0.6533/0.4715
5+
psnr,0.7857/0.8087/0.5989,0.7633/0.8013/0.5964,0.489/0.5245/0.3696,0.6601/0.6869/0.4958
66
fsim,0.8207/0.9309/0.7683,0.7747/0.9204/0.7515,0.8341/0.884/0.6991,0.8322/0.8509/0.6665
7-
vif,0.9219/0.9194/0.7532,0.9409/0.9526/0.8067,0.7769/0.7491/0.5861,0.7336/0.677/0.5148
8-
lpips,0.9005/0.9233/0.7499,0.7672/0.869/0.6768,0.711/0.7151/0.5221,0.7529/0.7445/0.5477
9-
dists,0.9324/0.9296/0.7644,0.8392/0.9051/0.7283,0.7032/0.6648/0.4861,0.7538/0.7077/0.5212
10-
pieapp,0.838/0.8968/0.7109,0.8577/0.9182/0.7491,0.6443/0.7971/0.6089,0.7195/0.8438/0.6571
7+
stlpips,0.823/0.8952/0.7094,0.813/0.8826/0.6931,0.624/0.6404/0.454,0.7147/0.7365/0.5387
118
ahiq,0.8234/0.8273/0.6168,0.8039/0.8967/0.7066,0.6772/0.6807/0.4842,0.7379/0.7075/0.5127
9+
pieapp,0.838/0.8968/0.7109,0.8577/0.9182/0.7491,0.6443/0.7971/0.6089,0.7195/0.8438/0.6571
10+
lpips,0.9005/0.9233/0.7499,0.7672/0.869/0.6768,0.711/0.7151/0.5221,0.7529/0.7445/0.5477
11+
lpips+,0.9041/0.9285/0.7575,0.8455/0.9248/0.7546,0.7318/0.7379/0.5424,0.7656/0.7622/0.5639
12+
lpips-vgg,0.9043/0.883/0.6968,0.9336/0.9318/0.7646,0.6974/0.6536/0.4822,0.7324/0.6696/0.497
1213
wadiqam_fr,0.9087/0.922/0.7461,0.9163/0.9308/0.7584,0.8221/0.8222/0.6245,0.8424/0.8264/0.628
14+
lpips-vgg+,0.9169/0.894/0.7128,0.9499/0.9503/0.7983,0.7406/0.6869/0.5113,0.7606/0.6913/0.5152
15+
vif,0.9219/0.9194/0.7532,0.9409/0.9526/0.8067,0.7769/0.7491/0.5861,0.7336/0.677/0.5148
16+
dists,0.9324/0.9296/0.7644,0.8392/0.9051/0.7283,0.7032/0.6648/0.4861,0.7538/0.7077/0.5212
1317
topiq_fr,0.9589/0.9674/0.8379,0.9542/0.9759/0.8617,0.9044/0.9226/0.7554,0.9158/0.9165/0.7441

0 commit comments

Comments
 (0)