Skip to content

Commit 5b5afb3

Browse files
committed
feat: ✨ add piqe metric
1 parent a699b9a commit 5b5afb3

File tree

7 files changed

+245
-27
lines changed

7 files changed

+245
-27
lines changed

README.md

+3-25
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,9 @@ An IQA toolbox with pure python and pytorch. Please refer to [Awesome-Image-Qual
99
[![Awesome](https://cdn.rawgit.com/sindresorhus/awesome/d7305f38d29fed78fa85652e3a63e154dd8e8829/media/badge.svg)](https://github.com/chaofengc/Awesome-Image-Quality-Assessment)
1010
[![Citation](https://img.shields.io/badge/Citation-bibtex-green)](https://github.com/chaofengc/IQA-PyTorch/blob/main/README.md#bookmark_tabs-citation)
1111

12-
<!-- ![demo](docs/demo.gif) -->
13-
14-
<!-- - [:open\_book: Introduction](#open_book-introduction)
15-
- [:zap: Quick Start](#zap-quick-start)
16-
- [Dependencies and Installation](#dependencies-and-installation)
17-
- [Basic Usage](#basic-usage)
18-
- [:1st\_place\_medal: Benchmark Performances and Model Zoo](#1st_place_medal-benchmark-performances-and-model-zoo)
19-
- [Results Calibration](#results-calibration)
20-
- [Performance Evaluation Protocol](#performance-evaluation-protocol)
21-
- [Benchmark Performance with Provided Script](#benchmark-performance-with-provided-script)
22-
- [:hammer\_and\_wrench: Train](#hammer_and_wrench-train)
23-
- [Dataset Preparation](#dataset-preparation)
24-
- [Example Train Script](#example-trai-script) -->
25-
2612
## :open_book: Introduction
2713

28-
This is a image quality assessment toolbox with **pure python and pytorch**. We provide reimplementation of many mainstream full reference (FR) and no reference (NR) metrics (results are calibrated with official matlab scripts if exist). **With GPU acceleration, most of our implementations are much faster than Matlab.** Please refer to the following documents for details:
14+
This is a comprehensive image quality assessment (IQA) toolbox built with **pure Python and PyTorch**. We provide reimplementation of many mainstream full reference (FR) and no reference (NR) metrics (results are calibrated with official matlab scripts if exist). **With GPU acceleration, most of our implementations are much faster than Matlab.** Please refer to the following documents for details:
2915

3016
<div align="center">
3117

@@ -36,19 +22,11 @@ This is a image quality assessment toolbox with **pure python and pytorch**. We
3622
---
3723

3824
### :triangular_flag_on_post: Updates/Changelog
39-
- 🔥**Aug, 2024**. Add `lpips+` and `lpips-vgg+` proposed in our paper [TOPIQ](https://arxiv.org/abs/2308.03060).
25+
-**Aug, 2024**. Add `piqe` metric.
26+
- 💥**Aug, 2024**. Add `lpips+` and `lpips-vgg+` proposed in our paper [TOPIQ](https://arxiv.org/abs/2308.03060).
4027
- 🔥**June, 2024**. Add `arniqa` and its variances trained on different datasets, refer to official repo [here](https://github.com/miccunifi/ARNIQA). Thanks for the contribution from [Lorenzo Agnolucci](https://github.com/LorenzoAgnolucci) 🤗.
4128
- **Apr 24, 2024**. Add `inception_score` and console entry point with `pyiqa` command.
4229
- **Mar 11, 2024**. Add `unique`, refer to official repo [here](https://github.com/zwx8981/UNIQUE). Thanks for the contribution from [Weixia Zhang](https://github.com/zwx8981) 🤗.
43-
- :boom: **Jan 31, 2024**. Add `qalign` for both NR and IAA. It is our most powerful unified metric based on large vision-language models, and shows remarkable performance and robustness. Refer [Q-Align](https://github.com/Q-Future/Q-Align) for more details. Use it with the following codes:
44-
```
45-
qalign = create_metric('qalign').cuda()
46-
quality_score = qalign(input, task_='quality')
47-
aesthetic_score = qalign(input, task_='aesthetic')
48-
```
49-
- **Jan 19, 2024**. Add `wadiqam_fr` and `wadiqam_nr`. All implemented methods are usable now 🍻.
50-
- **Dec 23, 2023**. Add `liqe` and `liqe_mix`. Thanks for the contribution from [Weixia Zhang](https://github.com/zwx8981) 🤗.
51-
- **Oct 09, 2023**. Add datasets: [PIQ2023](https://github.com/DXOMARK-Research/PIQ2023), [GFIQA](http://database.mmsp-kn.de/gfiqa-20k-database.html). Add metric `topiq_nr-face`. We release example results on FFHQ [here](tests/ffhq_score_topiq_nr-face.csv) for reference.
5230
- [**More**](docs/history_changelog.md)
5331

5432
---

ResultsCalibra/calibration_summary.csv

+2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ paq2piq,44.134,73.6015,74.3297,76.8748,70.9153
4343
paq2piq(ours),44.1341,73.6015,74.3297,76.8748,70.9153
4444
pi,11.9235,3.072,2.618,2.8074,6.7713
4545
pi(ours),11.9286,3.073,2.6357,2.7979,6.9546
46+
piqe,100.0,21.62,35.86,41.15,76.95
47+
piqe(ours),100.0,21.6242,35.8646,41.147,76.9485
4648
psnr,21.11,20.99,27.01,23.3,21.62
4749
psnr(ours),21.1136,20.9872,27.0139,23.3002,21.6186
4850
ssim,0.6993,0.9978,0.9989,0.9669,0.6519

ResultsCalibra/results_official.csv

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ lpips,0.7237,0.2572,0.05079,0.05205,0.4253
99
mad,195.2796,80.8379,30.3918,84.3542,202.2371
1010
ms_ssim,0.6733,0.9996,0.9998,0.9566,0.8462
1111
niqe,15.7536293917814,3.65492152353770,3.23547743716998,3.18403333858339,8.63519663862637
12+
piqe,100.00,21.62,35.86,41.15,76.95
1213
nlpd,0.561610096893874,0.019534798560102,0.015915631543598,0.302802106557736,0.432604962261603
1314
psnr,21.11,20.99,27.01,23.3,21.62
1415
ssim,0.6993,0.9978,0.9989,0.9669,0.6519
@@ -19,7 +20,6 @@ pi,11.9235,3.0720,2.6180,2.8074,6.7713
1920
ilniqe,113.4801,23.9968,19.9750,22.4493,56.6721
2021
musiq,12.494,75.332,73.429,75.188,36.938
2122
musiq-ava,3.398,5.648,4.635,5.186,4.128
22-
musiq-koniq,12.494,75.332,73.429,75.188,36.938
2323
musiq-paq2piq,46.035,72.660,73.625,74.361,69.006
2424
musiq-spaq,17.685,70.492,78.740,79.015,49.105
2525
paq2piq,44.1340,73.6015,74.3297,76.8748,70.9153

docs/history_changelog.md

+9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
# History of Changelog
22

3+
- :boom: **Jan 31, 2024**. Add `qalign` for both NR and IAA. It is our most powerful unified metric based on large vision-language models, and shows remarkable performance and robustness. Refer [Q-Align](https://github.com/Q-Future/Q-Align) for more details. Use it with the following codes:
4+
```
5+
qalign = create_metric('qalign').cuda()
6+
quality_score = qalign(input, task_='quality')
7+
aesthetic_score = qalign(input, task_='aesthetic')
8+
```
9+
- **Jan 19, 2024**. Add `wadiqam_fr` and `wadiqam_nr`. All implemented methods are usable now 🍻.
10+
- **Dec 23, 2023**. Add `liqe` and `liqe_mix`. Thanks for the contribution from [Weixia Zhang](https://github.com/zwx8981) 🤗.
11+
- **Oct 09, 2023**. Add datasets: [PIQ2023](https://github.com/DXOMARK-Research/PIQ2023), [GFIQA](http://database.mmsp-kn.de/gfiqa-20k-database.html). Add metric `topiq_nr-face`. We release example results on FFHQ [here](tests/ffhq_score_topiq_nr-face.csv) for reference.
312
- **Aug 15, 2023**. Add `st-lpips` and `laion_aes`. Refer to official repo at [ShiftTolerant-LPIPS](https://github.com/abhijay9/ShiftTolerant-LPIPS) and [improved-aesthetic-predictor](https://github.com/christophschuhmann/improved-aesthetic-predictor)
413
- **Aug 05, 2023**. Add our work [TOPIQ](https://arxiv.org/abs/2308.03060) with remarkable performance on almost all benchmarks via efficient Resnet50 backbone. Use it with `topiq_fr, topiq_nr, topiq_iaa` for Full-Reference, No-Reference and Aesthetic assessment respectively.
514
- **March 30, 2023**. Add [URanker](https://github.com/RQ-Wu/UnderwaterRanker) for IQA of under water images.

pyiqa/archs/piqe_arch.py

+218
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
r"""PIQE metric implementation.
2+
3+
Paper:
4+
N. Venkatanath, D. Praneeth, Bh. M. Chandrasekhar, S. S. Channappayya, and S. S. Medasani. "Blind Image Quality Evaluation Using Perception Based Features", In Proceedings of the 21st National Conference on Communications (NCC). Piscataway, NJ: IEEE, 2015.
5+
6+
References:
7+
- Matlab: https://www.mathworks.com/help/images/ref/piqe.html
8+
- Python: https://github.com/michael-rutherford/pypiqe
9+
10+
This PyTorch implementation by: Chaofeng Chen (https://github.com/chaofengc)
11+
"""
12+
13+
import torch
14+
15+
from pyiqa.utils.color_util import to_y_channel
16+
from pyiqa.utils import scandir_images, imread2tensor
17+
from pyiqa.matlab_utils import symm_pad
18+
from pyiqa.archs.func_util import normalize_img_with_guass
19+
from pyiqa.utils.registry import ARCH_REGISTRY
20+
21+
22+
def piqe(
23+
img: torch.Tensor,
24+
block_size: int = 16,
25+
activity_threshold: float = 0.1,
26+
block_impaired_threshold: float = 0.1,
27+
window_size: int = 6,
28+
) -> torch.Tensor:
29+
"""
30+
Calculates the Perceptual Image Quality Estimator (PIQE) score for an input image.
31+
Args:
32+
- img (torch.Tensor): The input image tensor.
33+
- block_size (int, optional): The size of the blocks used for processing. Defaults to 16.
34+
- activity_threshold (float, optional): The threshold for considering a block as active. Defaults to 0.1.
35+
- block_impaired_threshold (float, optional): The threshold for considering a block as impaired. Defaults to 0.1.
36+
- window_size (int, optional): The size of the window used for block analysis. Defaults to 6.
37+
Returns:
38+
- torch.Tensor: The PIQE score for the input image.
39+
"""
40+
41+
# RGB to Gray Conversion
42+
if img.shape[1] == 3:
43+
img = to_y_channel(img, out_data_range=1, color_space='yiq')
44+
45+
# Convert input image to double and scaled to the range 0-255
46+
img = torch.round(255 * (img / torch.max(img.flatten(1), dim=-1)[0].reshape(img.shape[0], 1, 1, 1)))
47+
48+
# Symmetric pad if image size is not divisible by block_size.
49+
bsz, _, height, width = img.shape
50+
col_pad, row_pad = width % block_size, height % block_size
51+
img = symm_pad(img, (0, col_pad, 0, row_pad))
52+
53+
# Normalize image to zero mean and ~unit std
54+
# used circularly-symmetric Gaussian weighting function sampled out
55+
# to 3 standard deviations.
56+
img_normalized = normalize_img_with_guass(img, padding='replicate')
57+
58+
# Preallocation for masks
59+
noticeable_artifacts_mask = torch.zeros_like(img_normalized, dtype=bool)
60+
noise_mask = torch.zeros_like(img_normalized, dtype=bool)
61+
activity_mask = torch.zeros_like(img_normalized, dtype=bool)
62+
score = torch.zeros(bsz)
63+
64+
nsegments = block_size - window_size + 1
65+
# Start of block by block processing
66+
for b in range(0, bsz):
67+
NHSA = 0
68+
dist_block_scores = 0
69+
for i in range(0, height, block_size):
70+
for j in range(0, width, block_size):
71+
72+
# Weights Initialization
73+
WNDC = WNC = 0
74+
75+
# Compute block variance
76+
block = img_normalized[b, 0, i:i + block_size, j:j + block_size]
77+
block_var = torch.var(block, unbiased=True)
78+
79+
# Considering spatially prominent blocks
80+
if block_var > activity_threshold:
81+
activity_mask[b, 0, i:i + block_size, j:j + block_size] = True
82+
WHSA = 1
83+
NHSA += 1
84+
85+
# Analyze Block for noticeable artifacts
86+
block_impaired = notice_dist_criterion(block, nsegments, block_size - 1, window_size, block_impaired_threshold, block_size)
87+
88+
if block_impaired:
89+
WNDC = 1
90+
noticeable_artifacts_mask[b, 0, i:i + block_size, j:j + block_size] = True
91+
92+
# Analyze Block for Gaussian noise distortions
93+
block_sigma, block_beta = noise_criterion(block, block_size - 1, block_var)
94+
95+
if block_sigma > 2 * block_beta:
96+
WNC = 1
97+
noise_mask[b, 0, i:i + block_size, j:j + block_size] = True
98+
99+
# Pooling/ distortion assignment
100+
dist_block_scores += WHSA * WNDC * (1 - block_var) + WHSA * WNC * block_var
101+
102+
# Quality score computation
103+
# C is a positive constant, it is included to prevent numerical instability
104+
C = 1
105+
score[b] = ((dist_block_scores + C) / (C + NHSA)) * 100
106+
107+
noticeable_artifacts_mask = noticeable_artifacts_mask[..., :height, :width]
108+
noise_mask = noise_mask[..., :height, :width]
109+
activity_mask = activity_mask[..., :height, :width]
110+
111+
return score, noticeable_artifacts_mask, noise_mask, activity_mask
112+
113+
114+
def noise_criterion(block, block_size, block_var):
115+
"""Function to analyze block for Gaussian noise distortions.
116+
"""
117+
# Compute block standard deviation
118+
block_sigma = torch.sqrt(block_var)
119+
# Compute ratio of center and surround standard deviation
120+
cen_sur_dev = cal_center_sur_dev(block, block_size)
121+
# Relation between center-surround deviation and the block standard deviation
122+
block_beta = torch.abs(block_sigma - cen_sur_dev) / torch.max(block_sigma, cen_sur_dev)
123+
return block_sigma, block_beta
124+
125+
126+
def cal_center_sur_dev(block, block_size):
127+
"""Function to compute center surround Deviation of a block.
128+
"""
129+
# block center
130+
center1 = (block_size + 1) // 2
131+
center2 = center1 + 1
132+
center = torch.cat((block[..., center1 - 1], block[..., center2 - 1]), dim=0)
133+
134+
# block surround
135+
block = torch.cat((block[..., :center1 - 1], block[..., center1:]), dim=-1)
136+
block = torch.cat((block[..., :center2 - 1], block[..., center2:]), dim=-1)
137+
138+
# Compute standard deviation of block center and block surround
139+
center_std = torch.std(center, unbiased=True)
140+
surround_std = torch.std(block, unbiased=True)
141+
# Ratio of center and surround standard deviation
142+
cen_sur_dev = center_std / surround_std
143+
# Check for nan's
144+
if torch.isnan(cen_sur_dev):
145+
cen_sur_dev = 0
146+
return cen_sur_dev
147+
148+
149+
def notice_dist_criterion(block, nsegments, block_size, window_size, block_impaired_threshold, N):
150+
# Top edge of block
151+
top_edge = block[0, :]
152+
seg_top_edge = segment_edge(top_edge, nsegments, block_size, window_size)
153+
154+
# Right side edge of block
155+
right_side_edge = block[:, N - 1]
156+
seg_right_side_edge = segment_edge(right_side_edge, nsegments, block_size, window_size)
157+
158+
# Down side edge of block
159+
down_side_edge = block[N - 1, :]
160+
seg_down_side_edge = segment_edge(down_side_edge, nsegments, block_size, window_size)
161+
162+
# Left side edge of block
163+
left_side_edge = block[:, 0]
164+
seg_left_side_edge = segment_edge(left_side_edge, nsegments, block_size, window_size)
165+
166+
# Compute standard deviation of segments in left, right, top and down side edges of a block
167+
seg_top_edge_std_dev = torch.std(seg_top_edge, dim=1, unbiased=True)
168+
seg_right_side_edge_std_dev = torch.std(seg_right_side_edge, dim=1, unbiased=True)
169+
seg_down_side_edge_std_dev = torch.std(seg_down_side_edge, dim=1, unbiased=True)
170+
seg_left_side_edge_std_dev = torch.std(seg_left_side_edge, dim=1, unbiased=True)
171+
172+
# Check for segment in block exhibits impairedness, if the standard deviation of the segment is less than block_impaired_threshold.
173+
block_impaired = 0
174+
for seg_index in range(seg_top_edge.shape[0]):
175+
if (
176+
(seg_top_edge_std_dev[seg_index] < block_impaired_threshold)
177+
or (seg_right_side_edge_std_dev[seg_index] < block_impaired_threshold)
178+
or (seg_down_side_edge_std_dev[seg_index] < block_impaired_threshold)
179+
or (seg_left_side_edge_std_dev[seg_index] < block_impaired_threshold)
180+
):
181+
block_impaired = 1
182+
break
183+
184+
return block_impaired
185+
186+
187+
def segment_edge(block_edge, nsegments, block_size, window_size):
188+
# Segment is defined as a collection of 6 contiguous pixels in a block edge
189+
segments = torch.zeros(nsegments, window_size)
190+
for i in range(nsegments):
191+
segments[i, :] = block_edge[i: window_size]
192+
if window_size <= (block_size + 1):
193+
window_size += 1
194+
return segments
195+
196+
197+
@ARCH_REGISTRY.register()
198+
class PIQE(torch.nn.Module):
199+
"""
200+
PIQE module.
201+
202+
Args:
203+
x (torch.Tensor): Input tensor of shape (B, C, H, W).
204+
205+
Returns:
206+
torch.Tensor: PIQE score.
207+
"""
208+
def get_masks(self,):
209+
assert self.results is not None, "Please calculate the piqe score first."
210+
return {
211+
'noticeable_artifacts_mask': self.results[1],
212+
'noise_mask': self.results[2],
213+
'activity_mask': self.results[3],
214+
}
215+
216+
def forward(self, x: torch.Tensor) -> torch.Tensor:
217+
self.results = piqe(x)
218+
return self.results[0]

pyiqa/default_model_configs.py

+8
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,14 @@
193193
'score_range': '0, ~',
194194
},
195195
# =============================================================
196+
'piqe': {
197+
'metric_opts': {
198+
'type': 'PIQE',
199+
},
200+
'metric_mode': 'NR',
201+
'lower_better': True,
202+
'score_range': '0, 100',
203+
},
196204
'niqe': {
197205
'metric_opts': {
198206
'type': 'NIQE',

pyiqa/matlab_utils/padding.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -94,4 +94,7 @@ def __init__(self, kernel, stride=1, dilation=1, mode='same'):
9494
self.mode = mode
9595

9696
def forward(self, x):
97-
return exact_padding_2d(x, self.kernel, self.stride, self.dilation, self.mode)
97+
if self.mode is None:
98+
return x
99+
else:
100+
return exact_padding_2d(x, self.kernel, self.stride, self.dilation, self.mode)

0 commit comments

Comments
 (0)