|
| 1 | +"""Perceptual color difference metric, MS-SWD. |
| 2 | +
|
| 3 | +@inproceedings{he2024ms-swd, |
| 4 | + title={Multiscale Sliced {Wasserstein} Distances as Perceptual Color Difference Measures}, |
| 5 | + author={He, Jiaqi and Wang, Zhihua and Wang, Leon and Liu, Tsein-I and Fang, Yuming and Sun, Qilin and Ma, Kede}, |
| 6 | + booktitle={European Conference on Computer Vision}, |
| 7 | + pages={1--18}, |
| 8 | + year={2024}, |
| 9 | + url={http://arxiv.org/abs/2407.10181} |
| 10 | +} |
| 11 | +
|
| 12 | +Reference: |
| 13 | + - Official github: https://github.com/real-hjq/MS-SWD |
| 14 | +""" |
| 15 | + |
| 16 | +import torch |
| 17 | +from torch import nn |
| 18 | +import torchvision.transforms.functional as TF |
| 19 | + |
| 20 | +from huggingface_hub import hf_hub_url |
| 21 | + |
| 22 | +from pyiqa.archs.arch_util import load_pretrained_network |
| 23 | +from pyiqa.utils.registry import ARCH_REGISTRY |
| 24 | + |
| 25 | + |
| 26 | +def color_space_transform(input_color, fromSpace2toSpace): |
| 27 | + """ |
| 28 | + Transforms inputs between different color spaces |
| 29 | + :param input_color: tensor of colors to transform (with NxCxHxW layout) |
| 30 | + :param fromSpace2toSpace: string describing transform |
| 31 | + :return: transformed tensor (with NxCxHxW layout) |
| 32 | + """ |
| 33 | + dim = input_color.size() |
| 34 | + device = input_color.device |
| 35 | + |
| 36 | + # Assume D65 standard illuminant |
| 37 | + reference_illuminant = torch.tensor([[[0.950428545]], [[1.000000000]], [[1.088900371]]]).to(device) |
| 38 | + inv_reference_illuminant = torch.tensor([[[1.052156925]], [[1.000000000]], [[0.918357670]]]).to(device) |
| 39 | + |
| 40 | + if fromSpace2toSpace == "srgb2linrgb": |
| 41 | + limit = 0.04045 |
| 42 | + transformed_color = torch.where( |
| 43 | + input_color > limit, |
| 44 | + torch.pow((torch.clamp(input_color, min=limit) + 0.055) / 1.055, 2.4), |
| 45 | + input_color / 12.92 |
| 46 | + ) # clamp to stabilize training |
| 47 | + |
| 48 | + elif fromSpace2toSpace == "linrgb2srgb": |
| 49 | + limit = 0.0031308 |
| 50 | + transformed_color = torch.where( |
| 51 | + input_color > limit, |
| 52 | + 1.055 * torch.pow(torch.clamp(input_color, min=limit), (1.0 / 2.4)) - 0.055, |
| 53 | + 12.92 * input_color |
| 54 | + ) |
| 55 | + |
| 56 | + elif fromSpace2toSpace in ["linrgb2xyz", "xyz2linrgb"]: |
| 57 | + # Source: https://www.image-engineering.de/library/technotes/958-how-to-convert-between-srgb-and-ciexyz |
| 58 | + # Assumes D65 standard illuminant |
| 59 | + if fromSpace2toSpace == "linrgb2xyz": |
| 60 | + a11 = 10135552 / 24577794 |
| 61 | + a12 = 8788810 / 24577794 |
| 62 | + a13 = 4435075 / 24577794 |
| 63 | + a21 = 2613072 / 12288897 |
| 64 | + a22 = 8788810 / 12288897 |
| 65 | + a23 = 887015 / 12288897 |
| 66 | + a31 = 1425312 / 73733382 |
| 67 | + a32 = 8788810 / 73733382 |
| 68 | + a33 = 70074185 / 73733382 |
| 69 | + else: |
| 70 | + # Constants found by taking the inverse of the matrix |
| 71 | + # defined by the constants for linrgb2xyz |
| 72 | + a11 = 3.241003275 |
| 73 | + a12 = -1.537398934 |
| 74 | + a13 = -0.498615861 |
| 75 | + a21 = -0.969224334 |
| 76 | + a22 = 1.875930071 |
| 77 | + a23 = 0.041554224 |
| 78 | + a31 = 0.055639423 |
| 79 | + a32 = -0.204011202 |
| 80 | + a33 = 1.057148933 |
| 81 | + A = torch.Tensor([[a11, a12, a13], |
| 82 | + [a21, a22, a23], |
| 83 | + [a31, a32, a33]]) |
| 84 | + |
| 85 | + input_color = input_color.view(dim[0], dim[1], dim[2] * dim[3]) # NC(HW) |
| 86 | + |
| 87 | + transformed_color = torch.matmul(A.to(device), input_color) |
| 88 | + transformed_color = transformed_color.view(dim[0], dim[1], dim[2], dim[3]) |
| 89 | + |
| 90 | + elif fromSpace2toSpace == "xyz2ycxcz": |
| 91 | + input_color = torch.mul(input_color, inv_reference_illuminant) |
| 92 | + y = 116 * input_color[:, 1:2, :, :] - 16 |
| 93 | + cx = 500 * (input_color[:, 0:1, :, :] - input_color[:, 1:2, :, :]) |
| 94 | + cz = 200 * (input_color[:, 1:2, :, :] - input_color[:, 2:3, :, :]) |
| 95 | + transformed_color = torch.cat((y, cx, cz), 1) |
| 96 | + |
| 97 | + elif fromSpace2toSpace == "ycxcz2xyz": |
| 98 | + y = (input_color[:, 0:1, :, :] + 16) / 116 |
| 99 | + cx = input_color[:, 1:2, :, :] / 500 |
| 100 | + cz = input_color[:, 2:3, :, :] / 200 |
| 101 | + |
| 102 | + x = y + cx |
| 103 | + z = y - cz |
| 104 | + transformed_color = torch.cat((x, y, z), 1) |
| 105 | + |
| 106 | + transformed_color = torch.mul(transformed_color, reference_illuminant) |
| 107 | + |
| 108 | + elif fromSpace2toSpace == "xyz2lab": |
| 109 | + input_color = torch.mul(input_color, inv_reference_illuminant) |
| 110 | + delta = 6 / 29 |
| 111 | + delta_square = delta * delta |
| 112 | + delta_cube = delta * delta_square |
| 113 | + factor = 1 / (3 * delta_square) |
| 114 | + |
| 115 | + clamped_term = torch.pow(torch.clamp(input_color, min=delta_cube), 1.0 / 3.0).to(dtype=input_color.dtype) |
| 116 | + div = (factor * input_color + (4 / 29)).to(dtype=input_color.dtype) |
| 117 | + input_color = torch.where(input_color > delta_cube, clamped_term, div) # clamp to stabilize training |
| 118 | + |
| 119 | + L = 116 * input_color[:, 1:2, :, :] - 16 |
| 120 | + a = 500 * (input_color[:, 0:1, :, :] - input_color[:, 1:2, :, :]) |
| 121 | + b = 200 * (input_color[:, 1:2, :, :] - input_color[:, 2:3, :, :]) |
| 122 | + |
| 123 | + transformed_color = torch.cat((L, a, b), 1) |
| 124 | + |
| 125 | + elif fromSpace2toSpace == "lab2xyz": |
| 126 | + y = (input_color[:, 0:1, :, :] + 16) / 116 |
| 127 | + a = input_color[:, 1:2, :, :] / 500 |
| 128 | + b = input_color[:, 2:3, :, :] / 200 |
| 129 | + |
| 130 | + x = y + a |
| 131 | + z = y - b |
| 132 | + |
| 133 | + xyz = torch.cat((x, y, z), 1) |
| 134 | + delta = 6 / 29 |
| 135 | + delta_square = delta * delta |
| 136 | + factor = 3 * delta_square |
| 137 | + xyz = torch.where(xyz > delta, torch.pow(xyz, 3), factor * (xyz - 4 / 29)) |
| 138 | + |
| 139 | + transformed_color = torch.mul(xyz, reference_illuminant) |
| 140 | + |
| 141 | + elif fromSpace2toSpace == "srgb2xyz": |
| 142 | + transformed_color = color_space_transform(input_color, 'srgb2linrgb') |
| 143 | + transformed_color = color_space_transform(transformed_color, 'linrgb2xyz') |
| 144 | + elif fromSpace2toSpace == "srgb2ycxcz": |
| 145 | + transformed_color = color_space_transform(input_color, 'srgb2linrgb') |
| 146 | + transformed_color = color_space_transform(transformed_color, 'linrgb2xyz') |
| 147 | + transformed_color = color_space_transform(transformed_color, 'xyz2ycxcz') |
| 148 | + elif fromSpace2toSpace == "linrgb2ycxcz": |
| 149 | + transformed_color = color_space_transform(input_color, 'linrgb2xyz') |
| 150 | + transformed_color = color_space_transform(transformed_color, 'xyz2ycxcz') |
| 151 | + elif fromSpace2toSpace == "srgb2lab": |
| 152 | + transformed_color = color_space_transform(input_color, 'srgb2linrgb') |
| 153 | + transformed_color = color_space_transform(transformed_color, 'linrgb2xyz') |
| 154 | + transformed_color = color_space_transform(transformed_color, 'xyz2lab') |
| 155 | + elif fromSpace2toSpace == "linrgb2lab": |
| 156 | + transformed_color = color_space_transform(input_color, 'linrgb2xyz') |
| 157 | + transformed_color = color_space_transform(transformed_color, 'xyz2lab') |
| 158 | + elif fromSpace2toSpace == "ycxcz2linrgb": |
| 159 | + transformed_color = color_space_transform(input_color, 'ycxcz2xyz') |
| 160 | + transformed_color = color_space_transform(transformed_color, 'xyz2linrgb') |
| 161 | + elif fromSpace2toSpace == "lab2srgb": |
| 162 | + transformed_color = color_space_transform(input_color, 'lab2xyz') |
| 163 | + transformed_color = color_space_transform(transformed_color, 'xyz2linrgb') |
| 164 | + transformed_color = color_space_transform(transformed_color, 'linrgb2srgb') |
| 165 | + elif fromSpace2toSpace == "ycxcz2lab": |
| 166 | + transformed_color = color_space_transform(input_color, 'ycxcz2xyz') |
| 167 | + transformed_color = color_space_transform(transformed_color, 'xyz2lab') |
| 168 | + else: |
| 169 | + raise ValueError('Error: The color transform %s is not defined!' % fromSpace2toSpace) |
| 170 | + |
| 171 | + return transformed_color |
| 172 | + |
| 173 | + |
| 174 | +@ARCH_REGISTRY.register() |
| 175 | +class MS_SWD_learned(nn.Module): |
| 176 | + def __init__(self, |
| 177 | + resize_input: bool = True, |
| 178 | + pretrained: bool = True, |
| 179 | + pretrained_model_path: str = None, |
| 180 | + **kwargs |
| 181 | + ): |
| 182 | + super(MS_SWD_learned, self).__init__() |
| 183 | + |
| 184 | + self.conv11x11 = nn.Conv2d(3, 128, kernel_size=11, stride=1, padding=5, padding_mode='reflect', dilation=1, bias=False) |
| 185 | + self.conv_m1 = nn.Conv2d(128, 64, kernel_size=1, stride=1, padding=0, dilation=1, bias=False) |
| 186 | + self.relu = nn.LeakyReLU() |
| 187 | + |
| 188 | + self.resize_input = resize_input |
| 189 | + |
| 190 | + if pretrained_model_path is not None: |
| 191 | + load_pretrained_network(self, pretrained_model_path, weight_keys='params') |
| 192 | + elif pretrained: |
| 193 | + url = hf_hub_url(repo_id="chaofengc/IQA-Toolbox-Weights", filename="msswd_weights.pth") |
| 194 | + load_pretrained_network(self, url, weight_keys='net_dict') |
| 195 | + |
| 196 | + def preprocess_img(self, x): |
| 197 | + if self.resize_input and min(x.shape[2:]) > 256: |
| 198 | + x = TF.resize(x, 256) |
| 199 | + return x |
| 200 | + |
| 201 | + def forward_once(self, x): |
| 202 | + x = color_space_transform(x, 'srgb2lab') |
| 203 | + x = self.conv11x11(x) |
| 204 | + x = self.relu(x) |
| 205 | + x = self.conv_m1(x) |
| 206 | + x = x.reshape(x.shape[0], x.shape[1], -1) |
| 207 | + return x |
| 208 | + |
| 209 | + def forward(self, x, y): |
| 210 | + x = self.preprocess_img(x) |
| 211 | + y = self.preprocess_img(y) |
| 212 | + output_x = self.forward_once(x) |
| 213 | + output_y = self.forward_once(y) |
| 214 | + # Sort and compute L1 distance |
| 215 | + output_x, _ = torch.sort(output_x, dim=2) |
| 216 | + output_y, _ = torch.sort(output_y, dim=2) |
| 217 | + swd = torch.abs(output_x - output_y) |
| 218 | + swd = torch.mean(swd, dim=[1, 2]) |
| 219 | + return swd |
0 commit comments