|
| 1 | +r"""PIQE metric implementation. |
| 2 | +
|
| 3 | +Paper: |
| 4 | + N. Venkatanath, D. Praneeth, Bh. M. Chandrasekhar, S. S. Channappayya, and S. S. Medasani. "Blind Image Quality Evaluation Using Perception Based Features", In Proceedings of the 21st National Conference on Communications (NCC). Piscataway, NJ: IEEE, 2015. |
| 5 | +
|
| 6 | +References: |
| 7 | + - Matlab: https://www.mathworks.com/help/images/ref/piqe.html |
| 8 | + - Python: https://github.com/michael-rutherford/pypiqe |
| 9 | +
|
| 10 | +This PyTorch implementation by: Chaofeng Chen (https://github.com/chaofengc) |
| 11 | +""" |
| 12 | + |
| 13 | +import torch |
| 14 | + |
| 15 | +from pyiqa.utils.color_util import to_y_channel |
| 16 | +from pyiqa.utils import scandir_images, imread2tensor |
| 17 | +from pyiqa.matlab_utils import symm_pad |
| 18 | +from pyiqa.archs.func_util import normalize_img_with_guass |
| 19 | +from pyiqa.utils.registry import ARCH_REGISTRY |
| 20 | + |
| 21 | + |
| 22 | +def piqe( |
| 23 | + img: torch.Tensor, |
| 24 | + block_size: int = 16, |
| 25 | + activity_threshold: float = 0.1, |
| 26 | + block_impaired_threshold: float = 0.1, |
| 27 | + window_size: int = 6, |
| 28 | +) -> torch.Tensor: |
| 29 | + """ |
| 30 | + Calculates the Perceptual Image Quality Estimator (PIQE) score for an input image. |
| 31 | + Args: |
| 32 | + - img (torch.Tensor): The input image tensor. |
| 33 | + - block_size (int, optional): The size of the blocks used for processing. Defaults to 16. |
| 34 | + - activity_threshold (float, optional): The threshold for considering a block as active. Defaults to 0.1. |
| 35 | + - block_impaired_threshold (float, optional): The threshold for considering a block as impaired. Defaults to 0.1. |
| 36 | + - window_size (int, optional): The size of the window used for block analysis. Defaults to 6. |
| 37 | + Returns: |
| 38 | + - torch.Tensor: The PIQE score for the input image. |
| 39 | + """ |
| 40 | + |
| 41 | + # RGB to Gray Conversion |
| 42 | + if img.shape[1] == 3: |
| 43 | + img = to_y_channel(img, out_data_range=1, color_space='yiq') |
| 44 | + |
| 45 | + # Convert input image to double and scaled to the range 0-255 |
| 46 | + img = torch.round(255 * (img / torch.max(img.flatten(1), dim=-1)[0].reshape(img.shape[0], 1, 1, 1))) |
| 47 | + |
| 48 | + # Symmetric pad if image size is not divisible by block_size. |
| 49 | + bsz, _, height, width = img.shape |
| 50 | + col_pad, row_pad = width % block_size, height % block_size |
| 51 | + img = symm_pad(img, (0, col_pad, 0, row_pad)) |
| 52 | + |
| 53 | + # Normalize image to zero mean and ~unit std |
| 54 | + # used circularly-symmetric Gaussian weighting function sampled out |
| 55 | + # to 3 standard deviations. |
| 56 | + img_normalized = normalize_img_with_guass(img, padding='replicate') |
| 57 | + |
| 58 | + # Preallocation for masks |
| 59 | + noticeable_artifacts_mask = torch.zeros_like(img_normalized, dtype=bool) |
| 60 | + noise_mask = torch.zeros_like(img_normalized, dtype=bool) |
| 61 | + activity_mask = torch.zeros_like(img_normalized, dtype=bool) |
| 62 | + score = torch.zeros(bsz) |
| 63 | + |
| 64 | + nsegments = block_size - window_size + 1 |
| 65 | + # Start of block by block processing |
| 66 | + for b in range(0, bsz): |
| 67 | + NHSA = 0 |
| 68 | + dist_block_scores = 0 |
| 69 | + for i in range(0, height, block_size): |
| 70 | + for j in range(0, width, block_size): |
| 71 | + |
| 72 | + # Weights Initialization |
| 73 | + WNDC = WNC = 0 |
| 74 | + |
| 75 | + # Compute block variance |
| 76 | + block = img_normalized[b, 0, i:i + block_size, j:j + block_size] |
| 77 | + block_var = torch.var(block, unbiased=True) |
| 78 | + |
| 79 | + # Considering spatially prominent blocks |
| 80 | + if block_var > activity_threshold: |
| 81 | + activity_mask[b, 0, i:i + block_size, j:j + block_size] = True |
| 82 | + WHSA = 1 |
| 83 | + NHSA += 1 |
| 84 | + |
| 85 | + # Analyze Block for noticeable artifacts |
| 86 | + block_impaired = notice_dist_criterion(block, nsegments, block_size - 1, window_size, block_impaired_threshold, block_size) |
| 87 | + |
| 88 | + if block_impaired: |
| 89 | + WNDC = 1 |
| 90 | + noticeable_artifacts_mask[b, 0, i:i + block_size, j:j + block_size] = True |
| 91 | + |
| 92 | + # Analyze Block for Gaussian noise distortions |
| 93 | + block_sigma, block_beta = noise_criterion(block, block_size - 1, block_var) |
| 94 | + |
| 95 | + if block_sigma > 2 * block_beta: |
| 96 | + WNC = 1 |
| 97 | + noise_mask[b, 0, i:i + block_size, j:j + block_size] = True |
| 98 | + |
| 99 | + # Pooling/ distortion assignment |
| 100 | + dist_block_scores += WHSA * WNDC * (1 - block_var) + WHSA * WNC * block_var |
| 101 | + |
| 102 | + # Quality score computation |
| 103 | + # C is a positive constant, it is included to prevent numerical instability |
| 104 | + C = 1 |
| 105 | + score[b] = ((dist_block_scores + C) / (C + NHSA)) * 100 |
| 106 | + |
| 107 | + noticeable_artifacts_mask = noticeable_artifacts_mask[..., :height, :width] |
| 108 | + noise_mask = noise_mask[..., :height, :width] |
| 109 | + activity_mask = activity_mask[..., :height, :width] |
| 110 | + |
| 111 | + return score, noticeable_artifacts_mask, noise_mask, activity_mask |
| 112 | + |
| 113 | + |
| 114 | +def noise_criterion(block, block_size, block_var): |
| 115 | + """Function to analyze block for Gaussian noise distortions. |
| 116 | + """ |
| 117 | + # Compute block standard deviation |
| 118 | + block_sigma = torch.sqrt(block_var) |
| 119 | + # Compute ratio of center and surround standard deviation |
| 120 | + cen_sur_dev = cal_center_sur_dev(block, block_size) |
| 121 | + # Relation between center-surround deviation and the block standard deviation |
| 122 | + block_beta = torch.abs(block_sigma - cen_sur_dev) / torch.max(block_sigma, cen_sur_dev) |
| 123 | + return block_sigma, block_beta |
| 124 | + |
| 125 | + |
| 126 | +def cal_center_sur_dev(block, block_size): |
| 127 | + """Function to compute center surround Deviation of a block. |
| 128 | + """ |
| 129 | + # block center |
| 130 | + center1 = (block_size + 1) // 2 |
| 131 | + center2 = center1 + 1 |
| 132 | + center = torch.cat((block[..., center1 - 1], block[..., center2 - 1]), dim=0) |
| 133 | + |
| 134 | + # block surround |
| 135 | + block = torch.cat((block[..., :center1 - 1], block[..., center1:]), dim=-1) |
| 136 | + block = torch.cat((block[..., :center2 - 1], block[..., center2:]), dim=-1) |
| 137 | + |
| 138 | + # Compute standard deviation of block center and block surround |
| 139 | + center_std = torch.std(center, unbiased=True) |
| 140 | + surround_std = torch.std(block, unbiased=True) |
| 141 | + # Ratio of center and surround standard deviation |
| 142 | + cen_sur_dev = center_std / surround_std |
| 143 | + # Check for nan's |
| 144 | + if torch.isnan(cen_sur_dev): |
| 145 | + cen_sur_dev = 0 |
| 146 | + return cen_sur_dev |
| 147 | + |
| 148 | + |
| 149 | +def notice_dist_criterion(block, nsegments, block_size, window_size, block_impaired_threshold, N): |
| 150 | + # Top edge of block |
| 151 | + top_edge = block[0, :] |
| 152 | + seg_top_edge = segment_edge(top_edge, nsegments, block_size, window_size) |
| 153 | + |
| 154 | + # Right side edge of block |
| 155 | + right_side_edge = block[:, N - 1] |
| 156 | + seg_right_side_edge = segment_edge(right_side_edge, nsegments, block_size, window_size) |
| 157 | + |
| 158 | + # Down side edge of block |
| 159 | + down_side_edge = block[N - 1, :] |
| 160 | + seg_down_side_edge = segment_edge(down_side_edge, nsegments, block_size, window_size) |
| 161 | + |
| 162 | + # Left side edge of block |
| 163 | + left_side_edge = block[:, 0] |
| 164 | + seg_left_side_edge = segment_edge(left_side_edge, nsegments, block_size, window_size) |
| 165 | + |
| 166 | + # Compute standard deviation of segments in left, right, top and down side edges of a block |
| 167 | + seg_top_edge_std_dev = torch.std(seg_top_edge, dim=1, unbiased=True) |
| 168 | + seg_right_side_edge_std_dev = torch.std(seg_right_side_edge, dim=1, unbiased=True) |
| 169 | + seg_down_side_edge_std_dev = torch.std(seg_down_side_edge, dim=1, unbiased=True) |
| 170 | + seg_left_side_edge_std_dev = torch.std(seg_left_side_edge, dim=1, unbiased=True) |
| 171 | + |
| 172 | + # Check for segment in block exhibits impairedness, if the standard deviation of the segment is less than block_impaired_threshold. |
| 173 | + block_impaired = 0 |
| 174 | + for seg_index in range(seg_top_edge.shape[0]): |
| 175 | + if ( |
| 176 | + (seg_top_edge_std_dev[seg_index] < block_impaired_threshold) |
| 177 | + or (seg_right_side_edge_std_dev[seg_index] < block_impaired_threshold) |
| 178 | + or (seg_down_side_edge_std_dev[seg_index] < block_impaired_threshold) |
| 179 | + or (seg_left_side_edge_std_dev[seg_index] < block_impaired_threshold) |
| 180 | + ): |
| 181 | + block_impaired = 1 |
| 182 | + break |
| 183 | + |
| 184 | + return block_impaired |
| 185 | + |
| 186 | + |
| 187 | +def segment_edge(block_edge, nsegments, block_size, window_size): |
| 188 | + # Segment is defined as a collection of 6 contiguous pixels in a block edge |
| 189 | + segments = torch.zeros(nsegments, window_size) |
| 190 | + for i in range(nsegments): |
| 191 | + segments[i, :] = block_edge[i: window_size] |
| 192 | + if window_size <= (block_size + 1): |
| 193 | + window_size += 1 |
| 194 | + return segments |
| 195 | + |
| 196 | + |
| 197 | +@ARCH_REGISTRY.register() |
| 198 | +class PIQE(torch.nn.Module): |
| 199 | + """ |
| 200 | + PIQE module. |
| 201 | +
|
| 202 | + Args: |
| 203 | + x (torch.Tensor): Input tensor of shape (B, C, H, W). |
| 204 | +
|
| 205 | + Returns: |
| 206 | + torch.Tensor: PIQE score. |
| 207 | + """ |
| 208 | + def get_masks(self,): |
| 209 | + assert self.results is not None, "Please calculate the piqe score first." |
| 210 | + return { |
| 211 | + 'noticeable_artifacts_mask': self.results[1], |
| 212 | + 'noise_mask': self.results[2], |
| 213 | + 'activity_mask': self.results[3], |
| 214 | + } |
| 215 | + |
| 216 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 217 | + self.results = piqe(x) |
| 218 | + return self.results[0] |
0 commit comments