Skip to content

Commit 3835ad9

Browse files
committed
feat: 🎨 add color difference metric
1 parent e66dee7 commit 3835ad9

File tree

5 files changed

+237
-1
lines changed

5 files changed

+237
-1
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ This is a comprehensive image quality assessment (IQA) toolbox built with **pure
2323
---
2424

2525
### :triangular_flag_on_post: Updates/Changelog
26+
- 🎨**Oct, 2024**. Add perceptual color difference metric `msswd` proposed in [MS-SWD (ECCV2024)](https://github.com/real-hjq/MS-SWD). Thanks to their work! 🤗
2627
-**Sep, 2024**. Add [efficiency benchmark](tests/Efficiency_benchmark.csv). With $1080\times800$ image as inputs, all metrics complete **in under 1 second on the GPU** (NVIDIA V100), and most of them, except for `qalign` and `qalign_8bit`, require **less than 6GB of GPU memory**.
2728
-**Aug, 2024**. Add `qalign_4bit` and `qalign_8bit` with much less memory requirement and similar performance.
2829
-**Aug, 2024**. Add `piqe` metric, and `niqe_matlab, brisque_matlab` with default matlab parameters (results have been calibrated with MATLAB R2021b).

docs/ModelCard.md

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ print(pyiqa.list_models())
6161

6262
| Task | Method | Description |
6363
| -------------- | ------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
64+
| Color IQA | `msswd` | Perceptual color difference metric MS-SWD, ECCV2024, [Arxiv](http://arxiv.org/abs/2407.10181), [Github](https://github.com/real-hjq/MS-SWD)
6465
| Face IQA | `topiq_nr-face` | TOPIQ model trained with face IQA dataset (GFIQA) |
6566
| Underwater IQA | `uranker` | A ranking-based underwater image quality assessment (UIQA) method, AAAI2023, [Arxiv](https://arxiv.org/abs/2208.06857), [Github](https://github.com/RQ-Wu/UnderwaterRanker) |
6667

inference_iqa.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@ def main():
2424
parser.add_argument('-m', '--metric_name', type=str, default='PSNR', help='IQA metric name, case sensitive.')
2525
parser.add_argument('--save_file', type=str, default=None, help='path to save results.')
2626

27+
# Add a --verbose flag
28+
parser.add_argument(
29+
'-v', '--verbose',
30+
action='store_true', # This makes it a flag (True when used, False otherwise)
31+
help='Enable verbose output'
32+
)
33+
2734
args = parser.parse_args()
2835

2936
metric_name = args.metric_name.lower()
@@ -72,7 +79,7 @@ def main():
7279
assert os.path.isdir(args.target), 'input path must be a folder for FID.'
7380
avg_score = iqa_model(args.target, args.ref)
7481

75-
if torch.cuda.is_available():
82+
if args.verbose and torch.cuda.is_available():
7683
print(torch.cuda.memory_summary())
7784

7885
msg = f'Average {metric_name} score of {args.target} with {test_img_num} images is: {avg_score}'

pyiqa/archs/msswd_arch.py

+219
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
"""Perceptual color difference metric, MS-SWD.
2+
3+
@inproceedings{he2024ms-swd,
4+
title={Multiscale Sliced {Wasserstein} Distances as Perceptual Color Difference Measures},
5+
author={He, Jiaqi and Wang, Zhihua and Wang, Leon and Liu, Tsein-I and Fang, Yuming and Sun, Qilin and Ma, Kede},
6+
booktitle={European Conference on Computer Vision},
7+
pages={1--18},
8+
year={2024},
9+
url={http://arxiv.org/abs/2407.10181}
10+
}
11+
12+
Reference:
13+
- Official github: https://github.com/real-hjq/MS-SWD
14+
"""
15+
16+
import torch
17+
from torch import nn
18+
import torchvision.transforms.functional as TF
19+
20+
from huggingface_hub import hf_hub_url
21+
22+
from pyiqa.archs.arch_util import load_pretrained_network
23+
from pyiqa.utils.registry import ARCH_REGISTRY
24+
25+
26+
def color_space_transform(input_color, fromSpace2toSpace):
27+
"""
28+
Transforms inputs between different color spaces
29+
:param input_color: tensor of colors to transform (with NxCxHxW layout)
30+
:param fromSpace2toSpace: string describing transform
31+
:return: transformed tensor (with NxCxHxW layout)
32+
"""
33+
dim = input_color.size()
34+
device = input_color.device
35+
36+
# Assume D65 standard illuminant
37+
reference_illuminant = torch.tensor([[[0.950428545]], [[1.000000000]], [[1.088900371]]]).to(device)
38+
inv_reference_illuminant = torch.tensor([[[1.052156925]], [[1.000000000]], [[0.918357670]]]).to(device)
39+
40+
if fromSpace2toSpace == "srgb2linrgb":
41+
limit = 0.04045
42+
transformed_color = torch.where(
43+
input_color > limit,
44+
torch.pow((torch.clamp(input_color, min=limit) + 0.055) / 1.055, 2.4),
45+
input_color / 12.92
46+
) # clamp to stabilize training
47+
48+
elif fromSpace2toSpace == "linrgb2srgb":
49+
limit = 0.0031308
50+
transformed_color = torch.where(
51+
input_color > limit,
52+
1.055 * torch.pow(torch.clamp(input_color, min=limit), (1.0 / 2.4)) - 0.055,
53+
12.92 * input_color
54+
)
55+
56+
elif fromSpace2toSpace in ["linrgb2xyz", "xyz2linrgb"]:
57+
# Source: https://www.image-engineering.de/library/technotes/958-how-to-convert-between-srgb-and-ciexyz
58+
# Assumes D65 standard illuminant
59+
if fromSpace2toSpace == "linrgb2xyz":
60+
a11 = 10135552 / 24577794
61+
a12 = 8788810 / 24577794
62+
a13 = 4435075 / 24577794
63+
a21 = 2613072 / 12288897
64+
a22 = 8788810 / 12288897
65+
a23 = 887015 / 12288897
66+
a31 = 1425312 / 73733382
67+
a32 = 8788810 / 73733382
68+
a33 = 70074185 / 73733382
69+
else:
70+
# Constants found by taking the inverse of the matrix
71+
# defined by the constants for linrgb2xyz
72+
a11 = 3.241003275
73+
a12 = -1.537398934
74+
a13 = -0.498615861
75+
a21 = -0.969224334
76+
a22 = 1.875930071
77+
a23 = 0.041554224
78+
a31 = 0.055639423
79+
a32 = -0.204011202
80+
a33 = 1.057148933
81+
A = torch.Tensor([[a11, a12, a13],
82+
[a21, a22, a23],
83+
[a31, a32, a33]])
84+
85+
input_color = input_color.view(dim[0], dim[1], dim[2] * dim[3]) # NC(HW)
86+
87+
transformed_color = torch.matmul(A.to(device), input_color)
88+
transformed_color = transformed_color.view(dim[0], dim[1], dim[2], dim[3])
89+
90+
elif fromSpace2toSpace == "xyz2ycxcz":
91+
input_color = torch.mul(input_color, inv_reference_illuminant)
92+
y = 116 * input_color[:, 1:2, :, :] - 16
93+
cx = 500 * (input_color[:, 0:1, :, :] - input_color[:, 1:2, :, :])
94+
cz = 200 * (input_color[:, 1:2, :, :] - input_color[:, 2:3, :, :])
95+
transformed_color = torch.cat((y, cx, cz), 1)
96+
97+
elif fromSpace2toSpace == "ycxcz2xyz":
98+
y = (input_color[:, 0:1, :, :] + 16) / 116
99+
cx = input_color[:, 1:2, :, :] / 500
100+
cz = input_color[:, 2:3, :, :] / 200
101+
102+
x = y + cx
103+
z = y - cz
104+
transformed_color = torch.cat((x, y, z), 1)
105+
106+
transformed_color = torch.mul(transformed_color, reference_illuminant)
107+
108+
elif fromSpace2toSpace == "xyz2lab":
109+
input_color = torch.mul(input_color, inv_reference_illuminant)
110+
delta = 6 / 29
111+
delta_square = delta * delta
112+
delta_cube = delta * delta_square
113+
factor = 1 / (3 * delta_square)
114+
115+
clamped_term = torch.pow(torch.clamp(input_color, min=delta_cube), 1.0 / 3.0).to(dtype=input_color.dtype)
116+
div = (factor * input_color + (4 / 29)).to(dtype=input_color.dtype)
117+
input_color = torch.where(input_color > delta_cube, clamped_term, div) # clamp to stabilize training
118+
119+
L = 116 * input_color[:, 1:2, :, :] - 16
120+
a = 500 * (input_color[:, 0:1, :, :] - input_color[:, 1:2, :, :])
121+
b = 200 * (input_color[:, 1:2, :, :] - input_color[:, 2:3, :, :])
122+
123+
transformed_color = torch.cat((L, a, b), 1)
124+
125+
elif fromSpace2toSpace == "lab2xyz":
126+
y = (input_color[:, 0:1, :, :] + 16) / 116
127+
a = input_color[:, 1:2, :, :] / 500
128+
b = input_color[:, 2:3, :, :] / 200
129+
130+
x = y + a
131+
z = y - b
132+
133+
xyz = torch.cat((x, y, z), 1)
134+
delta = 6 / 29
135+
delta_square = delta * delta
136+
factor = 3 * delta_square
137+
xyz = torch.where(xyz > delta, torch.pow(xyz, 3), factor * (xyz - 4 / 29))
138+
139+
transformed_color = torch.mul(xyz, reference_illuminant)
140+
141+
elif fromSpace2toSpace == "srgb2xyz":
142+
transformed_color = color_space_transform(input_color, 'srgb2linrgb')
143+
transformed_color = color_space_transform(transformed_color, 'linrgb2xyz')
144+
elif fromSpace2toSpace == "srgb2ycxcz":
145+
transformed_color = color_space_transform(input_color, 'srgb2linrgb')
146+
transformed_color = color_space_transform(transformed_color, 'linrgb2xyz')
147+
transformed_color = color_space_transform(transformed_color, 'xyz2ycxcz')
148+
elif fromSpace2toSpace == "linrgb2ycxcz":
149+
transformed_color = color_space_transform(input_color, 'linrgb2xyz')
150+
transformed_color = color_space_transform(transformed_color, 'xyz2ycxcz')
151+
elif fromSpace2toSpace == "srgb2lab":
152+
transformed_color = color_space_transform(input_color, 'srgb2linrgb')
153+
transformed_color = color_space_transform(transformed_color, 'linrgb2xyz')
154+
transformed_color = color_space_transform(transformed_color, 'xyz2lab')
155+
elif fromSpace2toSpace == "linrgb2lab":
156+
transformed_color = color_space_transform(input_color, 'linrgb2xyz')
157+
transformed_color = color_space_transform(transformed_color, 'xyz2lab')
158+
elif fromSpace2toSpace == "ycxcz2linrgb":
159+
transformed_color = color_space_transform(input_color, 'ycxcz2xyz')
160+
transformed_color = color_space_transform(transformed_color, 'xyz2linrgb')
161+
elif fromSpace2toSpace == "lab2srgb":
162+
transformed_color = color_space_transform(input_color, 'lab2xyz')
163+
transformed_color = color_space_transform(transformed_color, 'xyz2linrgb')
164+
transformed_color = color_space_transform(transformed_color, 'linrgb2srgb')
165+
elif fromSpace2toSpace == "ycxcz2lab":
166+
transformed_color = color_space_transform(input_color, 'ycxcz2xyz')
167+
transformed_color = color_space_transform(transformed_color, 'xyz2lab')
168+
else:
169+
raise ValueError('Error: The color transform %s is not defined!' % fromSpace2toSpace)
170+
171+
return transformed_color
172+
173+
174+
@ARCH_REGISTRY.register()
175+
class MS_SWD_learned(nn.Module):
176+
def __init__(self,
177+
resize_input: bool = True,
178+
pretrained: bool = True,
179+
pretrained_model_path: str = None,
180+
**kwargs
181+
):
182+
super(MS_SWD_learned, self).__init__()
183+
184+
self.conv11x11 = nn.Conv2d(3, 128, kernel_size=11, stride=1, padding=5, padding_mode='reflect', dilation=1, bias=False)
185+
self.conv_m1 = nn.Conv2d(128, 64, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
186+
self.relu = nn.LeakyReLU()
187+
188+
self.resize_input = resize_input
189+
190+
if pretrained_model_path is not None:
191+
load_pretrained_network(self, pretrained_model_path, weight_keys='params')
192+
elif pretrained:
193+
url = hf_hub_url(repo_id="chaofengc/IQA-Toolbox-Weights", filename="msswd_weights.pth")
194+
load_pretrained_network(self, url, weight_keys='net_dict')
195+
196+
def preprocess_img(self, x):
197+
if self.resize_input and min(x.shape[2:]) > 256:
198+
x = TF.resize(x, 256)
199+
return x
200+
201+
def forward_once(self, x):
202+
x = color_space_transform(x, 'srgb2lab')
203+
x = self.conv11x11(x)
204+
x = self.relu(x)
205+
x = self.conv_m1(x)
206+
x = x.reshape(x.shape[0], x.shape[1], -1)
207+
return x
208+
209+
def forward(self, x, y):
210+
x = self.preprocess_img(x)
211+
y = self.preprocess_img(y)
212+
output_x = self.forward_once(x)
213+
output_y = self.forward_once(y)
214+
# Sort and compute L1 distance
215+
output_x, _ = torch.sort(output_x, dim=2)
216+
output_y, _ = torch.sort(output_y, dim=2)
217+
swd = torch.abs(output_x - output_y)
218+
swd = torch.mean(swd, dim=[1, 2])
219+
return swd

pyiqa/default_model_configs.py

+8
Original file line numberDiff line numberDiff line change
@@ -707,4 +707,12 @@
707707
'metric_mode': 'NR',
708708
'score_range': '0, 1',
709709
},
710+
'msswd': {
711+
'metric_opts': {
712+
'type': 'MS_SWD_learned',
713+
},
714+
'metric_mode': 'FR',
715+
'score_range': '0, ~10',
716+
'lower_better': True,
717+
}
710718
})

0 commit comments

Comments
 (0)