Skip to content

Commit 065c2b7

Browse files
committed
feat: ✨ add niqe_matlab and brisque_matlab
1 parent 5c2e4ba commit 065c2b7

10 files changed

+141
-15
lines changed

Makefile

+3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ test:
2323
pytest tests/test_metric_general.py::test_forward -v
2424
pytest tests/test_metric_general.py::test_cpu_gpu_consistency -v
2525

26+
test_cal:
27+
pytest tests/ -m calibration -v
28+
2629
test_cs:
2730
pytest tests/test_metric_general.py::test_cpu_gpu_consistency -v
2831

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ This is a comprehensive image quality assessment (IQA) toolbox built with **pure
2222
---
2323

2424
### :triangular_flag_on_post: Updates/Changelog
25-
-**Aug, 2024**. Add `piqe` metric.
25+
-**Aug, 2024**. Add `piqe` metric, and `niqe_matlab, brisque_matlab` with default matlab parameters (results have been calibrated with MATLAB R2021b).
2626
- 💥**Aug, 2024**. Add `lpips+` and `lpips-vgg+` proposed in our paper [TOPIQ](https://arxiv.org/abs/2308.03060).
2727
- 🔥**June, 2024**. Add `arniqa` and its variances trained on different datasets, refer to official repo [here](https://github.com/miccunifi/ARNIQA). Thanks for the contribution from [Lorenzo Agnolucci](https://github.com/LorenzoAgnolucci) 🤗.
2828
- **Apr 24, 2024**. Add `inception_score` and console entry point with `pyiqa` command.

ResultsCalibra/calibration_summary.csv

+4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
Method,I03.bmp,I04.bmp,I06.bmp,I08.bmp,I19.bmp
22
brisque,94.6421,-0.1076,0.9929,5.3583,72.2617
33
brisque(ours),94.6452,-0.1082,1.0759,5.1423,66.8381
4+
brisque_matlab,70.68,-2.36,15.8,23.28,60.97
5+
brisque_matlab(ours),70.6809,-2.369,15.8031,23.0785,60.8247
46
ckdn,0.2833,0.5767,0.6367,0.658,0.5999
57
ckdn(ours),0.284,0.565,0.6264,0.6414,0.5935
68
cw_ssim,0.2763,0.9996,1.0,0.9068,0.8658
@@ -35,6 +37,8 @@ musiq-spaq,17.685,70.492,78.74,79.015,49.105
3537
musiq-spaq(ours),17.6804,70.6531,79.0364,79.3189,50.4526
3638
niqe,15.7536,3.6549,3.2355,3.184,8.6352
3739
niqe(ours),15.6538,3.6549,3.2342,3.1921,9.0722
40+
niqe_matlab,7.2,2.99,3.17,3.71,7.69
41+
niqe_matlab(ours),7.1707,2.9908,3.1551,3.6781,7.4782
3842
nlpd,0.5616,0.0195,0.0159,0.3028,0.4326
3943
nlpd(ours),0.5616,0.0139,0.011,0.3033,0.4335
4044
nrqm,1.3894,8.9394,8.9735,6.829,6.312

ResultsCalibra/results_official.csv

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
Method,I03.bmp,I04.bmp,I06.bmp,I08.bmp,I19.bmp
22
brisque,94.642100000000000,-0.107618000000000,0.992889000000000,5.358270000000000,72.261700000000000
3+
brisque_matlab,70.68,-2.36,15.80,23.28,60.97
34
ckdn,0.2833,0.5767,0.6367,0.658,0.5999
45
cw_ssim,0.2763,0.9996,1.0,0.9068,0.8658
56
dists,0.4742,0.1424,0.06825,0.02867,0.3123
@@ -9,6 +10,7 @@ lpips,0.7237,0.2572,0.05079,0.05205,0.4253
910
mad,195.2796,80.8379,30.3918,84.3542,202.2371
1011
ms_ssim,0.6733,0.9996,0.9998,0.9566,0.8462
1112
niqe,15.7536293917814,3.65492152353770,3.23547743716998,3.18403333858339,8.63519663862637
13+
niqe_matlab,7.20,2.99,3.17,3.71,7.69
1214
piqe,100.00,21.62,35.86,41.15,76.95
1315
nlpd,0.561610096893874,0.019534798560102,0.015915631543598,0.302802106557736,0.432604962261603
1416
psnr,21.11,20.99,27.01,23.3,21.62

docs/ModelCard.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ print(pyiqa.list_models())
4747
| CNNIQA | `cnniqa` |
4848
| NRQM(Ma)<sup>[2](#fn2)</sup> | `nrqm` | No backward |
4949
| PI(Perceptual Index) | `pi` | No backward |
50-
| BRISQUE | `brisque` | No backward |
50+
| BRISQUE | `brisque`, `brisque_matlab` | No backward |
5151
| ILNIQE | `ilniqe` | No backward |
52-
| NIQE | `niqe` | No backward |
52+
| NIQE | `niqe`, `niqe_matlab` | No backward |
5353
| PIQE | `piqe` | No backward |
5454
<!-- </tr>
5555
</table> -->

pyiqa/archs/brisque_arch.py

+34-11
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,19 @@
1010
1111
"""
1212

13+
import scipy
14+
import numpy as np
1315
import torch
1416
from pyiqa.utils.color_util import to_y_channel
1517
from pyiqa.matlab_utils import imresize
18+
from pyiqa.matlab_utils.nss_feature import compute_nss_features
1619
from .func_util import estimate_ggd_param, estimate_aggd_param, normalize_img_with_guass
1720
from pyiqa.utils.download_util import load_file_from_url
1821
from pyiqa.utils.registry import ARCH_REGISTRY
1922

2023
default_model_urls = {
21-
'url': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/brisque_svm_weights.pth'
24+
'url': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/brisque_svm_weights.pth',
25+
'brisque_matlab': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/brisque_matlab.mat',
2226
}
2327

2428

@@ -50,29 +54,45 @@ def brisque(x: torch.Tensor,
5054
x = to_y_channel(x, 255.)
5155
else:
5256
x = x * 255
57+
58+
use_matlab_version = 'matlab' in pretrained_model_path
5359

5460
features = []
5561
num_of_scales = 2
5662
for _ in range(num_of_scales):
57-
features.append(natural_scene_statistics(x, kernel_size, kernel_sigma))
63+
if use_matlab_version:
64+
xnorm = normalize_img_with_guass(x, kernel_size, kernel_sigma, padding='replicate')
65+
features.append(compute_nss_features(xnorm))
66+
else:
67+
features.append(natural_scene_statistics(x, kernel_size, kernel_sigma))
5868
x = imresize(x, scale=0.5, antialiasing=True)
59-
6069
features = torch.cat(features, dim=-1)
6170
scaled_features = scale_features(features)
6271

63-
if pretrained_model_path:
72+
if pretrained_model_path and not use_matlab_version:
73+
# gamma and rho are SVM model parameters taken from official implementation of BRISQUE on MATLAB
74+
# Source: https://live.ece.utexas.edu/research/Quality/index_algorithms.htm
6475
sv_coef, sv = torch.load(pretrained_model_path, weights_only=False)
6576
sv_coef = sv_coef.to(x)
6677
sv = sv.to(x)
78+
gamma = 0.05
79+
rho = -153.591
80+
elif use_matlab_version:
81+
params = scipy.io.loadmat(pretrained_model_path)
82+
sv = params['sv']
83+
sv_coef = np.ravel(params['sv_coef'])
84+
sv = torch.from_numpy(sv).to(features)
85+
sv_coef = torch.from_numpy(sv_coef).to(features)
86+
scale = 0.3210
87+
scaled_features = features / scale
88+
sv = sv / scale
89+
gamma = 1
90+
rho = - 43.4582
6791

68-
# gamma and rho are SVM model parameters taken from official implementation of BRISQUE on MATLAB
69-
# Source: https://live.ece.utexas.edu/research/Quality/index_algorithms.htm
70-
gamma = 0.05
71-
rho = -153.591
7292
sv.t_()
7393
kernel_features = rbf_kernel(features=scaled_features, sv=sv, gamma=gamma)
74-
score = kernel_features @ sv_coef
75-
return score - rho
94+
score = kernel_features @ sv_coef - rho
95+
return score
7696

7797

7898
def natural_scene_statistics(luma: torch.Tensor, kernel_size: int = 7, sigma: float = 7. / 6) -> torch.Tensor:
@@ -137,6 +157,7 @@ def __init__(self,
137157
kernel_size: int = 7,
138158
kernel_sigma: float = 7 / 6,
139159
test_y_channel: bool = True,
160+
version: str = 'original',
140161
pretrained_model_path: str = None) -> None:
141162
super().__init__()
142163
self.kernel_size = kernel_size
@@ -150,8 +171,10 @@ def __init__(self,
150171
self.test_y_channel = test_y_channel
151172
if pretrained_model_path is not None:
152173
self.pretrained_model_path = pretrained_model_path
153-
else:
174+
elif version == 'original':
154175
self.pretrained_model_path = load_file_from_url(default_model_urls['url'])
176+
elif version == 'matlab':
177+
self.pretrained_model_path = load_file_from_url(default_model_urls['brisque_matlab'])
155178

156179
def forward(self, x: torch.Tensor) -> torch.Tensor:
157180
r"""Computation of BRISQUE score as a loss function.

pyiqa/archs/niqe_arch.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
default_model_urls = {
2929
'url': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/niqe_modelparameters.mat',
3030
'niqe': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/niqe_modelparameters.mat',
31+
'niqe_matlab': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/niqe_matlab_params.mat',
3132
'ilniqe': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/ILNIQE_templateModel.mat',
3233
}
3334

@@ -427,6 +428,7 @@ def __init__(self,
427428
test_y_channel: bool = True,
428429
color_space: str = 'yiq',
429430
crop_border: int = 0,
431+
version: str = 'original',
430432
pretrained_model_path: str = None) -> None:
431433

432434
super(NIQE, self).__init__()
@@ -436,8 +438,10 @@ def __init__(self,
436438
self.crop_border = crop_border
437439
if pretrained_model_path is not None:
438440
self.pretrained_model_path = pretrained_model_path
439-
else:
441+
elif version == 'original':
440442
self.pretrained_model_path = load_file_from_url(default_model_urls['url'])
443+
elif version == 'matlab':
444+
self.pretrained_model_path = load_file_from_url(default_model_urls['niqe_matlab'])
441445

442446
def forward(self, x: torch.Tensor) -> torch.Tensor:
443447
r"""Computation of NIQE metric.

pyiqa/default_model_configs.py

+20
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,16 @@
210210
'lower_better': True,
211211
'score_range': '~0, ~100',
212212
},
213+
'niqe_matlab': {
214+
'metric_opts': {
215+
'type': 'NIQE',
216+
'test_y_channel': True,
217+
'version': 'matlab',
218+
},
219+
'metric_mode': 'NR',
220+
'lower_better': True,
221+
'score_range': '~0, ~100',
222+
},
213223
'ilniqe': {
214224
'metric_opts': {
215225
'type': 'ILNIQE',
@@ -227,6 +237,16 @@
227237
'lower_better': True,
228238
'score_range': '~0, ~150',
229239
},
240+
'brisque_matlab': {
241+
'metric_opts': {
242+
'type': 'BRISQUE',
243+
'test_y_channel': True,
244+
'version': 'matlab',
245+
},
246+
'metric_mode': 'NR',
247+
'lower_better': True,
248+
'score_range': '~0, ~150',
249+
},
230250
'nrqm': {
231251
'metric_opts': {
232252
'type': 'NRQM',

pyiqa/matlab_utils/nss_feature.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import torch
2+
from typing import Tuple
3+
4+
5+
def estimate_aggd_param(
6+
block: torch.Tensor, return_sigma=False
7+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
8+
"""Estimate AGGD (Asymmetric Generalized Gaussian Distribution) parameters.
9+
Args:
10+
block (Tensor): Image block with shape (b, 1, h, w).
11+
Returns:
12+
Tensor: alpha, beta_l and beta_r for the AGGD distribution
13+
(Estimating the parames in Equation 7 in the paper).
14+
"""
15+
gam = torch.arange(0.2, 10 + 0.001, 0.001).to(block)
16+
r_gam = (
17+
2 * torch.lgamma(2.0 / gam)
18+
- (torch.lgamma(1.0 / gam) + torch.lgamma(3.0 / gam))
19+
).exp()
20+
r_gam = r_gam.repeat(block.shape[0], 1)
21+
22+
mask_left = block < 0
23+
mask_right = block > 0
24+
count_left = mask_left.sum(dim=(-1, -2), dtype=torch.float32)
25+
count_right = mask_right.sum(dim=(-1, -2), dtype=torch.float32)
26+
27+
left_std = torch.sqrt((block * mask_left).pow(2).sum(dim=(-1, -2)) / (count_left))
28+
right_std = torch.sqrt(
29+
(block * mask_right).pow(2).sum(dim=(-1, -2)) / (count_right)
30+
)
31+
32+
gammahat = left_std / right_std
33+
rhat = block.abs().mean(dim=(-1, -2)).pow(2) / block.pow(2).mean(dim=(-1, -2))
34+
rhatnorm = (rhat * (gammahat.pow(3) + 1) * (gammahat + 1)) / (
35+
gammahat.pow(2) + 1
36+
).pow(2)
37+
array_position = (r_gam - rhatnorm).abs().argmin(dim=-1)
38+
39+
alpha = gam[array_position]
40+
beta_l = (
41+
left_std.squeeze(-1)
42+
* (torch.lgamma(1 / alpha) - torch.lgamma(3 / alpha)).exp().sqrt()
43+
)
44+
beta_r = (
45+
right_std.squeeze(-1)
46+
* (torch.lgamma(1 / alpha) - torch.lgamma(3 / alpha)).exp().sqrt()
47+
)
48+
49+
if return_sigma:
50+
return alpha, left_std.squeeze(-1), right_std.squeeze(-1)
51+
else:
52+
return alpha, beta_l, beta_r
53+
54+
55+
def compute_nss_features(luma_nrmlzd: torch.Tensor) -> torch.Tensor:
56+
57+
alpha, betal, betar = estimate_aggd_param(luma_nrmlzd, return_sigma=False)
58+
features = [alpha, (betal + betar) / 2]
59+
60+
shifts = [(0, 1), (1, 0), (1, 1), (-1, 1)]
61+
62+
for shift in shifts:
63+
shifted_luma_nrmlzd = torch.roll(luma_nrmlzd, shifts=shift, dims=(-2, -1))
64+
alpha, betal, betar = estimate_aggd_param(luma_nrmlzd * shifted_luma_nrmlzd, return_sigma=False)
65+
distmean = (betar - betal) * torch.exp(torch.lgamma(2/alpha) - torch.lgamma(1/alpha))
66+
features.extend((alpha, distmean, betal, betar))
67+
68+
return torch.stack(features, dim=-1)
69+

tests/test_metric_general.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
TOL_DICT = {
1515
'brisque': (1e-2, 8e-2),
1616
'niqe': (1e-2, 6e-2),
17+
'niqe_matlab': (1e-2, 6e-2),
1718
'pi': (1e-2, 3e-2),
1819
'ilniqe': (1e-2, 4e-2),
1920
'ckdn': (1e-2, 3e-2),

0 commit comments

Comments
 (0)