From c05114978488d86bf3ba42d5395b423097c8694a Mon Sep 17 00:00:00 2001 From: Xintao Date: Tue, 5 Oct 2021 00:24:55 +0800 Subject: [PATCH 01/13] add ecbsr arch --- basicsr/archs/ecbsr_arch.py | 257 ++++++++++++++++++++++++++++++++++++ 1 file changed, 257 insertions(+) create mode 100644 basicsr/archs/ecbsr_arch.py diff --git a/basicsr/archs/ecbsr_arch.py b/basicsr/archs/ecbsr_arch.py new file mode 100644 index 000000000..3bfe52b71 --- /dev/null +++ b/basicsr/archs/ecbsr_arch.py @@ -0,0 +1,257 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SeqConv3x3(nn.Module): + + def __init__(self, seq_type, inp_planes, out_planes, depth_multiplier): + super(SeqConv3x3, self).__init__() + + self.type = seq_type + self.inp_planes = inp_planes + self.out_planes = out_planes + + if self.type == 'conv1x1-conv3x3': + self.mid_planes = int(out_planes * depth_multiplier) + conv0 = torch.nn.Conv2d(self.inp_planes, self.mid_planes, kernel_size=1, padding=0) + self.k0 = conv0.weight + self.b0 = conv0.bias + + conv1 = torch.nn.Conv2d(self.mid_planes, self.out_planes, kernel_size=3) + self.k1 = conv1.weight + self.b1 = conv1.bias + + elif self.type == 'conv1x1-sobelx': + conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0) + self.k0 = conv0.weight + self.b0 = conv0.bias + + # init scale & bias + scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3 + self.scale = nn.Parameter(scale) + # bias = 0.0 + # bias = [bias for c in range(self.out_planes)] + # bias = torch.FloatTensor(bias) + bias = torch.randn(self.out_planes) * 1e-3 + bias = torch.reshape(bias, (self.out_planes, )) + self.bias = nn.Parameter(bias) + # init mask + self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32) + for i in range(self.out_planes): + self.mask[i, 0, 0, 0] = 1.0 + self.mask[i, 0, 1, 0] = 2.0 + self.mask[i, 0, 2, 0] = 1.0 + self.mask[i, 0, 0, 2] = -1.0 + self.mask[i, 0, 1, 2] = -2.0 + self.mask[i, 0, 2, 2] = -1.0 + self.mask = nn.Parameter(data=self.mask, requires_grad=False) + + elif self.type == 'conv1x1-sobely': + conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0) + self.k0 = conv0.weight + self.b0 = conv0.bias + + # init scale & bias + scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3 + self.scale = nn.Parameter(torch.FloatTensor(scale)) + # bias = 0.0 + # bias = [bias for c in range(self.out_planes)] + # bias = torch.FloatTensor(bias) + bias = torch.randn(self.out_planes) * 1e-3 + bias = torch.reshape(bias, (self.out_planes, )) + self.bias = nn.Parameter(torch.FloatTensor(bias)) + # init mask + self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32) + for i in range(self.out_planes): + self.mask[i, 0, 0, 0] = 1.0 + self.mask[i, 0, 0, 1] = 2.0 + self.mask[i, 0, 0, 2] = 1.0 + self.mask[i, 0, 2, 0] = -1.0 + self.mask[i, 0, 2, 1] = -2.0 + self.mask[i, 0, 2, 2] = -1.0 + self.mask = nn.Parameter(data=self.mask, requires_grad=False) + + elif self.type == 'conv1x1-laplacian': + conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0) + self.k0 = conv0.weight + self.b0 = conv0.bias + + # init scale & bias + scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3 + self.scale = nn.Parameter(torch.FloatTensor(scale)) + # bias = 0.0 + # bias = [bias for c in range(self.out_planes)] + # bias = torch.FloatTensor(bias) + bias = torch.randn(self.out_planes) * 1e-3 + bias = torch.reshape(bias, (self.out_planes, )) + self.bias = nn.Parameter(torch.FloatTensor(bias)) + # init mask + self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32) + for i in range(self.out_planes): + self.mask[i, 0, 0, 1] = 1.0 + self.mask[i, 0, 1, 0] = 1.0 + self.mask[i, 0, 1, 2] = 1.0 + self.mask[i, 0, 2, 1] = 1.0 + self.mask[i, 0, 1, 1] = -4.0 + self.mask = nn.Parameter(data=self.mask, requires_grad=False) + else: + raise ValueError('the type of seqconv is not supported!') + + def forward(self, x): + if self.type == 'conv1x1-conv3x3': + # conv-1x1 + y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1) + # explicitly padding with bias + y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0) + b0_pad = self.b0.view(1, -1, 1, 1) + y0[:, :, 0:1, :] = b0_pad + y0[:, :, -1:, :] = b0_pad + y0[:, :, :, 0:1] = b0_pad + y0[:, :, :, -1:] = b0_pad + # conv-3x3 + y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1) + else: + y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1) + # explicitly padding with bias + y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0) + b0_pad = self.b0.view(1, -1, 1, 1) + y0[:, :, 0:1, :] = b0_pad + y0[:, :, -1:, :] = b0_pad + y0[:, :, :, 0:1] = b0_pad + y0[:, :, :, -1:] = b0_pad + # conv-3x3 + y1 = F.conv2d(input=y0, weight=self.scale * self.mask, bias=self.bias, stride=1, groups=self.out_planes) + return y1 + + def rep_params(self): + device = self.k0.get_device() + if device < 0: + device = None + + if self.type == 'conv1x1-conv3x3': + # re-param conv kernel + RK = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3)) + # re-param conv bias + RB = torch.ones(1, self.mid_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1) + RB = F.conv2d(input=RB, weight=self.k1).view(-1, ) + self.b1 + else: + tmp = self.scale * self.mask + k1 = torch.zeros((self.out_planes, self.out_planes, 3, 3), device=device) + for i in range(self.out_planes): + k1[i, i, :, :] = tmp[i, 0, :, :] + b1 = self.bias + # re-param conv kernel + RK = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3)) + # re-param conv bias + RB = torch.ones(1, self.out_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1) + RB = F.conv2d(input=RB, weight=k1).view(-1, ) + b1 + return RK, RB + + +class ECB(nn.Module): + + def __init__(self, inp_planes, out_planes, depth_multiplier, act_type='prelu', with_idt=False): + super(ECB, self).__init__() + + self.depth_multiplier = depth_multiplier + self.inp_planes = inp_planes + self.out_planes = out_planes + self.act_type = act_type + + if with_idt and (self.inp_planes == self.out_planes): + self.with_idt = True + else: + self.with_idt = False + + self.conv3x3 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=3, padding=1) + self.conv1x1_3x3 = SeqConv3x3('conv1x1-conv3x3', self.inp_planes, self.out_planes, self.depth_multiplier) + self.conv1x1_sbx = SeqConv3x3('conv1x1-sobelx', self.inp_planes, self.out_planes, -1) + self.conv1x1_sby = SeqConv3x3('conv1x1-sobely', self.inp_planes, self.out_planes, -1) + self.conv1x1_lpl = SeqConv3x3('conv1x1-laplacian', self.inp_planes, self.out_planes, -1) + + if self.act_type == 'prelu': + self.act = nn.PReLU(num_parameters=self.out_planes) + elif self.act_type == 'relu': + self.act = nn.ReLU(inplace=True) + elif self.act_type == 'rrelu': + self.act = nn.RReLU(lower=-0.05, upper=0.05) + elif self.act_type == 'softplus': + self.act = nn.Softplus() + elif self.act_type == 'linear': + pass + else: + raise ValueError('The type of activation if not support!') + + def forward(self, x): + if self.training: + y = self.conv3x3(x) + self.conv1x1_3x3(x) + self.conv1x1_sbx(x) + self.conv1x1_sby(x) + self.conv1x1_lpl(x) + if self.with_idt: + y += x + else: + RK, RB = self.rep_params() + y = F.conv2d(input=x, weight=RK, bias=RB, stride=1, padding=1) + if self.act_type != 'linear': + y = self.act(y) + return y + + def rep_params(self): + K0, B0 = self.conv3x3.weight, self.conv3x3.bias + K1, B1 = self.conv1x1_3x3.rep_params() + K2, B2 = self.conv1x1_sbx.rep_params() + K3, B3 = self.conv1x1_sby.rep_params() + K4, B4 = self.conv1x1_lpl.rep_params() + RK, RB = (K0 + K1 + K2 + K3 + K4), (B0 + B1 + B2 + B3 + B4) + + if self.with_idt: + device = RK.get_device() + if device < 0: + device = None + K_idt = torch.zeros(self.out_planes, self.out_planes, 3, 3, device=device) + for i in range(self.out_planes): + K_idt[i, i, 1, 1] = 1.0 + B_idt = 0.0 + RK, RB = RK + K_idt, RB + B_idt + return RK, RB + + +class ECBSR(nn.Module): + + def __init__(self, module_nums, channel_nums, with_idt, act_type, scale, colors): + super(ECBSR, self).__init__() + self.module_nums = module_nums + self.channel_nums = channel_nums + self.scale = scale + self.colors = colors + self.with_idt = with_idt + self.act_type = act_type + self.backbone = None + self.upsampler = None + + backbone = [] + backbone += [ + ECB(self.colors, self.channel_nums, depth_multiplier=2.0, act_type=self.act_type, with_idt=self.with_idt) + ] + for i in range(self.module_nums): + backbone += [ + ECB(self.channel_nums, + self.channel_nums, + depth_multiplier=2.0, + act_type=self.act_type, + with_idt=self.with_idt) + ] + backbone += [ + ECB(self.channel_nums, + self.colors * self.scale * self.scale, + depth_multiplier=2.0, + act_type='linear', + with_idt=self.with_idt) + ] + + self.backbone = nn.Sequential(*backbone) + self.upsampler = nn.PixelShuffle(self.scale) + + def forward(self, x): + y = self.backbone(x) + x + y = self.upsampler(y) + return y From b4ce575a9e82ce8ee3957fc82b769869788782d3 Mon Sep 17 00:00:00 2001 From: Xintao Date: Tue, 5 Oct 2021 01:03:18 +0800 Subject: [PATCH 02/13] first run train_ECBSR_x4_m4c16_prelu --- basicsr/archs/ecbsr_arch.py | 3 + basicsr/data/paired_image_dataset.py | 9 +- basicsr/models/sr_model.py | 2 +- .../ECBSR/train_ECBSR_x4_m4c16_prelu.yml | 145 ++++++++++++++++++ 4 files changed, 157 insertions(+), 2 deletions(-) create mode 100644 options/train/ECBSR/train_ECBSR_x4_m4c16_prelu.yml diff --git a/basicsr/archs/ecbsr_arch.py b/basicsr/archs/ecbsr_arch.py index 3bfe52b71..2dd270779 100644 --- a/basicsr/archs/ecbsr_arch.py +++ b/basicsr/archs/ecbsr_arch.py @@ -2,6 +2,8 @@ import torch.nn as nn import torch.nn.functional as F +from basicsr.utils.registry import ARCH_REGISTRY + class SeqConv3x3(nn.Module): @@ -215,6 +217,7 @@ def rep_params(self): return RK, RB +@ARCH_REGISTRY.register() class ECBSR(nn.Module): def __init__(self, module_nums, channel_nums, with_idt, act_type, scale, colors): diff --git a/basicsr/data/paired_image_dataset.py b/basicsr/data/paired_image_dataset.py index c6a6c07b1..6e95a35be 100644 --- a/basicsr/data/paired_image_dataset.py +++ b/basicsr/data/paired_image_dataset.py @@ -4,6 +4,7 @@ from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file from basicsr.data.transforms import augment, paired_random_crop from basicsr.utils import FileClient, imfrombytes, img2tensor +from basicsr.utils.matlab_functions import rgb2ycbcr from basicsr.utils.registry import DATASET_REGISTRY @@ -87,7 +88,13 @@ def __getitem__(self, index): # flip, rotation img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot']) - # TODO: color space transform + if self.opt['color'] == 'y': + img_gt = rgb2ycbcr(img_gt, y_only=True)[..., None] + img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None] + + # TODO: fix me during release + img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :] + # BGR to RGB, HWC to CHW, numpy to tensor img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) # normalize diff --git a/basicsr/models/sr_model.py b/basicsr/models/sr_model.py index 62ffe22bb..3467a0496 100644 --- a/basicsr/models/sr_model.py +++ b/basicsr/models/sr_model.py @@ -208,7 +208,7 @@ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): logger.info(log_str) if tb_logger: for metric, value in self.metric_results.items(): - tb_logger.add_scalar(f'metrics/{metric}', value, current_iter) + tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter) def get_current_visuals(self): out_dict = OrderedDict() diff --git a/options/train/ECBSR/train_ECBSR_x4_m4c16_prelu.yml b/options/train/ECBSR/train_ECBSR_x4_m4c16_prelu.yml new file mode 100644 index 000000000..251a8d83a --- /dev/null +++ b/options/train/ECBSR/train_ECBSR_x4_m4c16_prelu.yml @@ -0,0 +1,145 @@ +# general settings +name: 100_train_ECBSR_x4_m4c16_prelu +model_type: SRModel +scale: 4 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 0 + +# dataset and data loader settings +datasets: + train: + name: DIV2K + type: PairedImageDataset + dataroot_gt: datasets/DF2K/DIV2K_train_HR_sub + dataroot_lq: datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub + meta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt + # (for lmdb) + # dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb + # dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb + filename_tmpl: '{}' + io_backend: + type: disk + # (for lmdb) + # type: lmdb + + gt_size: 256 + use_flip: true + use_rot: true + color: y + + # data loader + use_shuffle: true + num_worker_per_gpu: 12 + batch_size_per_gpu: 32 + dataset_enlarge_ratio: 100 + prefetch_mode: ~ + + val: + name: Set5 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Set5/HR + dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4 + filename_tmpl: '{}x4' + color: y + io_backend: + type: disk + + val_2: + name: Set14 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Set14/HR + dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X4 + filename_tmpl: '{}x4' + color: y + io_backend: + type: disk + + val_3: + name: B100 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/B100/HR + dataroot_lq: datasets/benchmark/B100/LR_bicubic/X4 + filename_tmpl: '{}x4' + color: y + io_backend: + type: disk + + val_4: + name: Urban100 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Urban100/HR + dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X4 + filename_tmpl: '{}x4' + color: y + io_backend: + type: disk + +# network structures +network_g: + type: ECBSR + module_nums: 4 + channel_nums: 16 + with_idt: False + act_type: prelu + scale: 4 + colors: 1 + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0 + optim_g: + type: Adam + lr: !!float 5e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: MultiStepLR + milestones: [51200000] + gamma: 1 + + total_iter: 51200000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + +# validation settings +val: + val_freq: !!float 51200 + save_img: false + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 4 + test_y_channel: true + better: higher # the higher, the better. Default: higher + ssim: + type: calculate_ssim + crop_border: 4 + test_y_channel: true + better: higher # the higher, the better. Default: higher + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 51200 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 From 83838f2f7d6587cb3bb924eceb6e889a2aa93041 Mon Sep 17 00:00:00 2001 From: Xintao Date: Tue, 5 Oct 2021 11:57:30 +0800 Subject: [PATCH 03/13] 255 range --- basicsr/archs/ecbsr_arch.py | 2 + basicsr/models/sr_model.py | 12 +- .../ECBSR/train_ECBSR_x2_m4c16_prelu.yml | 146 ++++++++++++++++++ .../ECBSR/train_ECBSR_x4_m4c16_prelu.yml | 19 +-- 4 files changed, 166 insertions(+), 13 deletions(-) create mode 100644 options/train/ECBSR/train_ECBSR_x2_m4c16_prelu.yml diff --git a/basicsr/archs/ecbsr_arch.py b/basicsr/archs/ecbsr_arch.py index 2dd270779..e7be5b9c9 100644 --- a/basicsr/archs/ecbsr_arch.py +++ b/basicsr/archs/ecbsr_arch.py @@ -255,6 +255,8 @@ def __init__(self, module_nums, channel_nums, with_idt, act_type, scale, colors) self.upsampler = nn.PixelShuffle(self.scale) def forward(self, x): + x = x * 255. y = self.backbone(x) + x y = self.upsampler(y) + y = y / 255. return y diff --git a/basicsr/models/sr_model.py b/basicsr/models/sr_model.py index 3467a0496..fdbb88678 100644 --- a/basicsr/models/sr_model.py +++ b/basicsr/models/sr_model.py @@ -136,6 +136,7 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img): def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): dataset_name = dataloader.dataset.opt['name'] with_metrics = self.opt['val'].get('metrics') is not None + use_pbar = self.opt['val'].get('pbar', False) if with_metrics and not hasattr(self, 'metric_results'): # only execute in the first run self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} @@ -146,7 +147,8 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): self.metric_results = {metric: 0 for metric in self.metric_results} metric_data = dict() - pbar = tqdm(total=len(dataloader), unit='image') + if use_pbar: + pbar = tqdm(total=len(dataloader), unit='image') for idx, val_data in enumerate(dataloader): img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] @@ -183,9 +185,11 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): # calculate metrics for name, opt_ in self.opt['val']['metrics'].items(): self.metric_results[name] += calculate_metric(metric_data, opt_) - pbar.update(1) - pbar.set_description(f'Test {img_name}') - pbar.close() + if use_pbar: + pbar.update(1) + pbar.set_description(f'Test {img_name}') + if use_pbar: + pbar.close() if with_metrics: for metric in self.metric_results.keys(): diff --git a/options/train/ECBSR/train_ECBSR_x2_m4c16_prelu.yml b/options/train/ECBSR/train_ECBSR_x2_m4c16_prelu.yml new file mode 100644 index 000000000..0f5f47e11 --- /dev/null +++ b/options/train/ECBSR/train_ECBSR_x2_m4c16_prelu.yml @@ -0,0 +1,146 @@ +# general settings +name: 101_train_ECBSR_x2_m4c16_prelu255 +model_type: SRModel +scale: 2 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 0 + +# dataset and data loader settings +datasets: + train: + name: DIV2K + type: PairedImageDataset + dataroot_gt: datasets/DF2K/DIV2K_train_HR_sub + dataroot_lq: datasets/DF2K/DIV2K_train_LR_bicubic_X2_sub + meta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt + # (for lmdb) + # dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb + # dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb + filename_tmpl: '{}' + io_backend: + type: disk + # (for lmdb) + # type: lmdb + + gt_size: 128 + use_flip: true + use_rot: true + color: y + + # data loader + use_shuffle: true + num_worker_per_gpu: 12 + batch_size_per_gpu: 32 + dataset_enlarge_ratio: 10 + prefetch_mode: ~ + + val: + name: Set5 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Set5/HR + dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X2 + filename_tmpl: '{}x2' + color: y + io_backend: + type: disk + + val_2: + name: Set14 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Set14/HR + dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X2 + filename_tmpl: '{}x2' + color: y + io_backend: + type: disk + + val_3: + name: B100 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/B100/HR + dataroot_lq: datasets/benchmark/B100/LR_bicubic/X2 + filename_tmpl: '{}x2' + color: y + io_backend: + type: disk + + val_4: + name: Urban100 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Urban100/HR + dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X2 + filename_tmpl: '{}x2' + color: y + io_backend: + type: disk + +# network structures +network_g: + type: ECBSR + module_nums: 4 + channel_nums: 16 + with_idt: False + act_type: prelu + scale: 2 + colors: 1 + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0 + optim_g: + type: Adam + lr: !!float 5e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: MultiStepLR + milestones: [1600000] + gamma: 1 + + total_iter: 1600000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + +# validation settings +val: + val_freq: !!float 1600 + save_img: false + pbar: False + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 2 + test_y_channel: true + better: higher # the higher, the better. Default: higher + ssim: + type: calculate_ssim + crop_border: 2 + test_y_channel: true + better: higher # the higher, the better. Default: higher + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 1600 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/train/ECBSR/train_ECBSR_x4_m4c16_prelu.yml b/options/train/ECBSR/train_ECBSR_x4_m4c16_prelu.yml index 251a8d83a..53db23704 100644 --- a/options/train/ECBSR/train_ECBSR_x4_m4c16_prelu.yml +++ b/options/train/ECBSR/train_ECBSR_x4_m4c16_prelu.yml @@ -1,5 +1,5 @@ # general settings -name: 100_train_ECBSR_x4_m4c16_prelu +name: 100_train_ECBSR_x4_m4c16_prelu_lmdb255 model_type: SRModel scale: 4 num_gpu: 1 # set num_gpu: 0 for cpu mode @@ -10,15 +10,15 @@ datasets: train: name: DIV2K type: PairedImageDataset - dataroot_gt: datasets/DF2K/DIV2K_train_HR_sub - dataroot_lq: datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub + dataroot_gt: datasets/DF2K/DIV2K_train_HR_sub.lmdb + dataroot_lq: datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub.lmdb meta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt # (for lmdb) # dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb # dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb filename_tmpl: '{}' io_backend: - type: disk + type: lmdb # (for lmdb) # type: lmdb @@ -31,7 +31,7 @@ datasets: use_shuffle: true num_worker_per_gpu: 12 batch_size_per_gpu: 32 - dataset_enlarge_ratio: 100 + dataset_enlarge_ratio: 10 prefetch_mode: ~ val: @@ -101,10 +101,10 @@ train: scheduler: type: MultiStepLR - milestones: [51200000] + milestones: [1600000] gamma: 1 - total_iter: 51200000 + total_iter: 1600000 warmup_iter: -1 # no warm up # losses @@ -115,8 +115,9 @@ train: # validation settings val: - val_freq: !!float 51200 + val_freq: !!float 1600 save_img: false + pbar: False metrics: psnr: # metric name, can be arbitrary @@ -133,7 +134,7 @@ val: # logging settings logger: print_freq: 100 - save_checkpoint_freq: !!float 51200 + save_checkpoint_freq: !!float 1600 use_tb_logger: true wandb: project: ~ From 07222d07d72dbba4e9556395109eff75a4fce0d3 Mon Sep 17 00:00:00 2001 From: Xintao Date: Tue, 5 Oct 2021 17:29:11 +0800 Subject: [PATCH 04/13] clean arch --- basicsr/archs/ecbsr_arch.py | 135 ++++++++---------- .../ECBSR/train_ECBSR_x2_m4c16_prelu.yml | 7 +- .../ECBSR/train_ECBSR_x4_m4c16_prelu.yml | 7 +- 3 files changed, 67 insertions(+), 82 deletions(-) diff --git a/basicsr/archs/ecbsr_arch.py b/basicsr/archs/ecbsr_arch.py index e7be5b9c9..a05c0a027 100644 --- a/basicsr/archs/ecbsr_arch.py +++ b/basicsr/archs/ecbsr_arch.py @@ -7,14 +7,13 @@ class SeqConv3x3(nn.Module): - def __init__(self, seq_type, inp_planes, out_planes, depth_multiplier): + def __init__(self, seq_type, inp_planes, out_planes, depth_multiplier=1): super(SeqConv3x3, self).__init__() - - self.type = seq_type + self.seq_type = seq_type self.inp_planes = inp_planes self.out_planes = out_planes - if self.type == 'conv1x1-conv3x3': + if self.seq_type == 'conv1x1-conv3x3': self.mid_planes = int(out_planes * depth_multiplier) conv0 = torch.nn.Conv2d(self.inp_planes, self.mid_planes, kernel_size=1, padding=0) self.k0 = conv0.weight @@ -24,17 +23,14 @@ def __init__(self, seq_type, inp_planes, out_planes, depth_multiplier): self.k1 = conv1.weight self.b1 = conv1.bias - elif self.type == 'conv1x1-sobelx': + elif self.seq_type == 'conv1x1-sobelx': conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0) self.k0 = conv0.weight self.b0 = conv0.bias - # init scale & bias + # init scale and bias scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3 self.scale = nn.Parameter(scale) - # bias = 0.0 - # bias = [bias for c in range(self.out_planes)] - # bias = torch.FloatTensor(bias) bias = torch.randn(self.out_planes) * 1e-3 bias = torch.reshape(bias, (self.out_planes, )) self.bias = nn.Parameter(bias) @@ -49,17 +45,14 @@ def __init__(self, seq_type, inp_planes, out_planes, depth_multiplier): self.mask[i, 0, 2, 2] = -1.0 self.mask = nn.Parameter(data=self.mask, requires_grad=False) - elif self.type == 'conv1x1-sobely': + elif self.seq_type == 'conv1x1-sobely': conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0) self.k0 = conv0.weight self.b0 = conv0.bias - # init scale & bias + # init scale and bias scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3 self.scale = nn.Parameter(torch.FloatTensor(scale)) - # bias = 0.0 - # bias = [bias for c in range(self.out_planes)] - # bias = torch.FloatTensor(bias) bias = torch.randn(self.out_planes) * 1e-3 bias = torch.reshape(bias, (self.out_planes, )) self.bias = nn.Parameter(torch.FloatTensor(bias)) @@ -74,17 +67,14 @@ def __init__(self, seq_type, inp_planes, out_planes, depth_multiplier): self.mask[i, 0, 2, 2] = -1.0 self.mask = nn.Parameter(data=self.mask, requires_grad=False) - elif self.type == 'conv1x1-laplacian': + elif self.seq_type == 'conv1x1-laplacian': conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0) self.k0 = conv0.weight self.b0 = conv0.bias - # init scale & bias + # init scale and bias scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3 self.scale = nn.Parameter(torch.FloatTensor(scale)) - # bias = 0.0 - # bias = [bias for c in range(self.out_planes)] - # bias = torch.FloatTensor(bias) bias = torch.randn(self.out_planes) * 1e-3 bias = torch.reshape(bias, (self.out_planes, )) self.bias = nn.Parameter(torch.FloatTensor(bias)) @@ -98,10 +88,10 @@ def __init__(self, seq_type, inp_planes, out_planes, depth_multiplier): self.mask[i, 0, 1, 1] = -4.0 self.mask = nn.Parameter(data=self.mask, requires_grad=False) else: - raise ValueError('the type of seqconv is not supported!') + raise ValueError('The type of seqconv is not supported!') def forward(self, x): - if self.type == 'conv1x1-conv3x3': + if self.seq_type == 'conv1x1-conv3x3': # conv-1x1 y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1) # explicitly padding with bias @@ -131,12 +121,12 @@ def rep_params(self): if device < 0: device = None - if self.type == 'conv1x1-conv3x3': + if self.seq_type == 'conv1x1-conv3x3': # re-param conv kernel - RK = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3)) + rep_weight = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3)) # re-param conv bias - RB = torch.ones(1, self.mid_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1) - RB = F.conv2d(input=RB, weight=self.k1).view(-1, ) + self.b1 + rep_bias = torch.ones(1, self.mid_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1) + rep_bias = F.conv2d(input=rep_bias, weight=self.k1).view(-1, ) + self.b1 else: tmp = self.scale * self.mask k1 = torch.zeros((self.out_planes, self.out_planes, 3, 3), device=device) @@ -144,11 +134,11 @@ def rep_params(self): k1[i, i, :, :] = tmp[i, 0, :, :] b1 = self.bias # re-param conv kernel - RK = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3)) + rep_weight = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3)) # re-param conv bias - RB = torch.ones(1, self.out_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1) - RB = F.conv2d(input=RB, weight=k1).view(-1, ) + b1 - return RK, RB + rep_bias = torch.ones(1, self.out_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1) + rep_bias = F.conv2d(input=rep_bias, weight=k1).view(-1, ) + b1 + return rep_weight, rep_bias class ECB(nn.Module): @@ -168,9 +158,9 @@ def __init__(self, inp_planes, out_planes, depth_multiplier, act_type='prelu', w self.conv3x3 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=3, padding=1) self.conv1x1_3x3 = SeqConv3x3('conv1x1-conv3x3', self.inp_planes, self.out_planes, self.depth_multiplier) - self.conv1x1_sbx = SeqConv3x3('conv1x1-sobelx', self.inp_planes, self.out_planes, -1) - self.conv1x1_sby = SeqConv3x3('conv1x1-sobely', self.inp_planes, self.out_planes, -1) - self.conv1x1_lpl = SeqConv3x3('conv1x1-laplacian', self.inp_planes, self.out_planes, -1) + self.conv1x1_sbx = SeqConv3x3('conv1x1-sobelx', self.inp_planes, self.out_planes) + self.conv1x1_sby = SeqConv3x3('conv1x1-sobely', self.inp_planes, self.out_planes) + self.conv1x1_lpl = SeqConv3x3('conv1x1-laplacian', self.inp_planes, self.out_planes) if self.act_type == 'prelu': self.act = nn.PReLU(num_parameters=self.out_planes) @@ -191,72 +181,65 @@ def forward(self, x): if self.with_idt: y += x else: - RK, RB = self.rep_params() - y = F.conv2d(input=x, weight=RK, bias=RB, stride=1, padding=1) + rep_weight, rep_bias = self.rep_params() + y = F.conv2d(input=x, weight=rep_weight, bias=rep_bias, stride=1, padding=1) if self.act_type != 'linear': y = self.act(y) return y def rep_params(self): - K0, B0 = self.conv3x3.weight, self.conv3x3.bias - K1, B1 = self.conv1x1_3x3.rep_params() - K2, B2 = self.conv1x1_sbx.rep_params() - K3, B3 = self.conv1x1_sby.rep_params() - K4, B4 = self.conv1x1_lpl.rep_params() - RK, RB = (K0 + K1 + K2 + K3 + K4), (B0 + B1 + B2 + B3 + B4) + weight0, bias0 = self.conv3x3.weight, self.conv3x3.bias + weight1, bias1 = self.conv1x1_3x3.rep_params() + weight2, bias2 = self.conv1x1_sbx.rep_params() + weight3, bias3 = self.conv1x1_sby.rep_params() + weight4, bias4 = self.conv1x1_lpl.rep_params() + rep_weight, rep_bias = (weight0 + weight1 + weight2 + weight3 + weight4), ( + bias0 + bias1 + bias2 + bias3 + bias4) if self.with_idt: - device = RK.get_device() + device = rep_weight.get_device() if device < 0: device = None - K_idt = torch.zeros(self.out_planes, self.out_planes, 3, 3, device=device) + weight_idt = torch.zeros(self.out_planes, self.out_planes, 3, 3, device=device) for i in range(self.out_planes): - K_idt[i, i, 1, 1] = 1.0 - B_idt = 0.0 - RK, RB = RK + K_idt, RB + B_idt - return RK, RB + weight_idt[i, i, 1, 1] = 1.0 + bias_idt = 0.0 + rep_weight, rep_bias = rep_weight + weight_idt, rep_bias + bias_idt + return rep_weight, rep_bias @ARCH_REGISTRY.register() class ECBSR(nn.Module): - - def __init__(self, module_nums, channel_nums, with_idt, act_type, scale, colors): + """ECBSR architecture. + + Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices + Ref git repo: https://github.com/xindongzhang/ECBSR + + Args: + num_in_ch (int): Channel number of inputs. + num_out_ch (int): Channel number of outputs. + num_block (int): Block number in the trunk network. + num_channel (int): Channel number. + with_idt (bool): Whether use identity in convolution layers. + act_type (str): Activation type. + scale (int): Upsampling factor. + """ + + def __init__(self, num_in_ch, num_out_ch, num_block, num_channel, with_idt, act_type, scale): super(ECBSR, self).__init__() - self.module_nums = module_nums - self.channel_nums = channel_nums - self.scale = scale - self.colors = colors - self.with_idt = with_idt - self.act_type = act_type - self.backbone = None - self.upsampler = None backbone = [] + backbone += [ECB(num_in_ch, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)] + for _ in range(num_block): + backbone += [ECB(num_channel, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)] backbone += [ - ECB(self.colors, self.channel_nums, depth_multiplier=2.0, act_type=self.act_type, with_idt=self.with_idt) - ] - for i in range(self.module_nums): - backbone += [ - ECB(self.channel_nums, - self.channel_nums, - depth_multiplier=2.0, - act_type=self.act_type, - with_idt=self.with_idt) - ] - backbone += [ - ECB(self.channel_nums, - self.colors * self.scale * self.scale, - depth_multiplier=2.0, - act_type='linear', - with_idt=self.with_idt) + ECB(num_channel, num_out_ch * scale * scale, depth_multiplier=2.0, act_type='linear', with_idt=with_idt) ] self.backbone = nn.Sequential(*backbone) - self.upsampler = nn.PixelShuffle(self.scale) + self.upsampler = nn.PixelShuffle(scale) def forward(self, x): - x = x * 255. - y = self.backbone(x) + x + y = self.backbone(x) + x # will repeat the input in the channel dimension (repeat scale * scale times) y = self.upsampler(y) - y = y / 255. return y diff --git a/options/train/ECBSR/train_ECBSR_x2_m4c16_prelu.yml b/options/train/ECBSR/train_ECBSR_x2_m4c16_prelu.yml index 0f5f47e11..9ff68fcda 100644 --- a/options/train/ECBSR/train_ECBSR_x2_m4c16_prelu.yml +++ b/options/train/ECBSR/train_ECBSR_x2_m4c16_prelu.yml @@ -77,12 +77,13 @@ datasets: # network structures network_g: type: ECBSR - module_nums: 4 - channel_nums: 16 + num_in_ch: 1 + num_out_ch: 1 + num_block: 4 + num_channel: 16 with_idt: False act_type: prelu scale: 2 - colors: 1 # path path: diff --git a/options/train/ECBSR/train_ECBSR_x4_m4c16_prelu.yml b/options/train/ECBSR/train_ECBSR_x4_m4c16_prelu.yml index 53db23704..0664bee3d 100644 --- a/options/train/ECBSR/train_ECBSR_x4_m4c16_prelu.yml +++ b/options/train/ECBSR/train_ECBSR_x4_m4c16_prelu.yml @@ -77,12 +77,13 @@ datasets: # network structures network_g: type: ECBSR - module_nums: 4 - channel_nums: 16 + num_in_ch: 1 + num_out_ch: 1 + num_block: 4 + num_channel: 16 with_idt: False act_type: prelu scale: 4 - colors: 1 # path path: From 1e03a94b03ffbbc61c7fa6e68c5966a91a33b4b8 Mon Sep 17 00:00:00 2001 From: Xintao Date: Tue, 5 Oct 2021 17:33:00 +0800 Subject: [PATCH 05/13] improve datasets --- basicsr/data/paired_image_dataset.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/basicsr/data/paired_image_dataset.py b/basicsr/data/paired_image_dataset.py index 6e95a35be..ee15d5a83 100644 --- a/basicsr/data/paired_image_dataset.py +++ b/basicsr/data/paired_image_dataset.py @@ -92,8 +92,10 @@ def __getitem__(self, index): img_gt = rgb2ycbcr(img_gt, y_only=True)[..., None] img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None] - # TODO: fix me during release - img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :] + # crop the unmatched GT images during validation or testing, especially for SR benchmark datasets + # TODO: It is better to update the datasets, rather than force to crop + if self.opt['phase'] != 'train': + img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :] # BGR to RGB, HWC to CHW, numpy to tensor img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) From bd99f6290a2e1a1246cc38ef79ceb1044bf9d85e Mon Sep 17 00:00:00 2001 From: Xintao Date: Tue, 5 Oct 2021 17:37:57 +0800 Subject: [PATCH 06/13] update ecbsr option files --- options/train/ECBSR/train_ECBSR_x2_m4c16_prelu.yml | 13 +++++-------- options/train/ECBSR/train_ECBSR_x4_m4c16_prelu.yml | 13 +++++-------- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/options/train/ECBSR/train_ECBSR_x2_m4c16_prelu.yml b/options/train/ECBSR/train_ECBSR_x2_m4c16_prelu.yml index 9ff68fcda..b5827f4ab 100644 --- a/options/train/ECBSR/train_ECBSR_x2_m4c16_prelu.yml +++ b/options/train/ECBSR/train_ECBSR_x2_m4c16_prelu.yml @@ -1,5 +1,5 @@ # general settings -name: 101_train_ECBSR_x2_m4c16_prelu255 +name: 101_train_ECBSR_x2_m4c16_prelu model_type: SRModel scale: 2 num_gpu: 1 # set num_gpu: 0 for cpu mode @@ -10,17 +10,13 @@ datasets: train: name: DIV2K type: PairedImageDataset + # It is strongly recommended to use lmdb for faster IO speed, especially for small networks dataroot_gt: datasets/DF2K/DIV2K_train_HR_sub dataroot_lq: datasets/DF2K/DIV2K_train_LR_bicubic_X2_sub meta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt - # (for lmdb) - # dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb - # dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb filename_tmpl: '{}' io_backend: type: disk - # (for lmdb) - # type: lmdb gt_size: 128 use_flip: true @@ -34,6 +30,7 @@ datasets: dataset_enlarge_ratio: 10 prefetch_mode: ~ + # we use multiple validation datasets. The SR benchmark datasets can be download from: https://cv.snu.ac.kr/research/EDSR/benchmark.tar val: name: Set5 type: PairedImageDataset @@ -116,12 +113,12 @@ train: # validation settings val: - val_freq: !!float 1600 + val_freq: !!float 1600 # the same as the original setting. # TODO: Can be larger save_img: false pbar: False metrics: - psnr: # metric name, can be arbitrary + psnr: type: calculate_psnr crop_border: 2 test_y_channel: true diff --git a/options/train/ECBSR/train_ECBSR_x4_m4c16_prelu.yml b/options/train/ECBSR/train_ECBSR_x4_m4c16_prelu.yml index 0664bee3d..61175f947 100644 --- a/options/train/ECBSR/train_ECBSR_x4_m4c16_prelu.yml +++ b/options/train/ECBSR/train_ECBSR_x4_m4c16_prelu.yml @@ -1,5 +1,5 @@ # general settings -name: 100_train_ECBSR_x4_m4c16_prelu_lmdb255 +name: 100_train_ECBSR_x4_m4c16_prelu model_type: SRModel scale: 4 num_gpu: 1 # set num_gpu: 0 for cpu mode @@ -10,17 +10,13 @@ datasets: train: name: DIV2K type: PairedImageDataset + # It is strongly recommended to use lmdb for faster IO speed, especially for small networks dataroot_gt: datasets/DF2K/DIV2K_train_HR_sub.lmdb dataroot_lq: datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub.lmdb meta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt - # (for lmdb) - # dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb - # dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb filename_tmpl: '{}' io_backend: type: lmdb - # (for lmdb) - # type: lmdb gt_size: 256 use_flip: true @@ -34,6 +30,7 @@ datasets: dataset_enlarge_ratio: 10 prefetch_mode: ~ + # we use multiple validation datasets. The SR benchmark datasets can be download from: https://cv.snu.ac.kr/research/EDSR/benchmark.tar val: name: Set5 type: PairedImageDataset @@ -116,12 +113,12 @@ train: # validation settings val: - val_freq: !!float 1600 + val_freq: !!float 1600 # the same as the original setting. # TODO: Can be larger save_img: false pbar: False metrics: - psnr: # metric name, can be arbitrary + psnr: type: calculate_psnr crop_border: 4 test_y_channel: true From 3caee557f62484a2abc13b99bab319b4e8f3219f Mon Sep 17 00:00:00 2001 From: Xintao Date: Tue, 5 Oct 2021 17:44:59 +0800 Subject: [PATCH 07/13] update readme --- README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index f544cca84..3c8637340 100644 --- a/README.md +++ b/README.md @@ -38,13 +38,12 @@ BasicSR (**Basic** **S**uper **R**estoration) 是一个基于 PyTorch 的开源 - :white_check_mark: July 31, 2021. Add **bi-directional video super-resolution** codes: [**BasicVSR** and IconVSR](https://arxiv.org/abs/2012.02181). - :white_check_mark: July 20, 2021. Add **dual-blind face restoration** codes: [HiFaceGAN](https://github.com/Lotayou/Face-Renovation) codes by [Lotayou](https://lotayou.github.io/). - :white_check_mark: Nov 29, 2020. Add **ESRGAN** and **DFDNet** [colab demo](colab) -- :white_check_mark: Sep 8, 2020. Add **blind face restoration** inference codes: [DFDNet](https://github.com/csxmli2016/DFDNet). - :white_check_mark: Aug 27, 2020. Add **StyleGAN2 training and testing** codes: [StyleGAN2](https://github.com/rosinality/stylegan2-pytorch).
More
    -
  • Sep 8, 2020. Add blind face restoration inference codes: DFDNet.
    ECCV20: Blind Face Restoration via Deep Multi-scale Component Dictionaries
    Xiaoming Li, Chaofeng Chen, Shangchen Zhou, Xianhui Lin, Wangmeng Zuo and Lei Zhang
  • +
  • Sep 8, 2020. Add blind face restoration inference codes: DFDNet.
    ECCV20: Blind Face Restoration via Deep Multi-scale Component Dictionaries
    Xiaoming Li, Chaofeng Chen, Shangchen Zhou, Xianhui Lin, Wangmeng Zuo and Lei Zhang
  • Aug 27, 2020. Add StyleGAN2 training and testing codes.
    CVPR20: Analyzing and Improving the Image Quality of StyleGAN
    Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen and Timo Aila
  • Aug 19, 2020. A brand-new BasicSR v1.0.0 online.
From f9816d66e5b6438b4ee693b4e5e319ff2871be16 Mon Sep 17 00:00:00 2001 From: Xintao Date: Tue, 5 Oct 2021 17:47:45 +0800 Subject: [PATCH 08/13] update readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 3c8637340..4977769ab 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ BasicSR (**Basic** **S**uper **R**estoration) 是一个基于 PyTorch 的开源
More
    -
  • Sep 8, 2020. Add blind face restoration inference codes: DFDNet.
    ECCV20: Blind Face Restoration via Deep Multi-scale Component Dictionaries
    Xiaoming Li, Chaofeng Chen, Shangchen Zhou, Xianhui Lin, Wangmeng Zuo and Lei Zhang
  • +
  • Sep 8, 2020. Add blind face restoration inference codes: DFDNet.
    ECCV20: Blind Face Restoration via Deep Multi-scale Component Dictionaries
    Xiaoming Li, Chaofeng Chen, Shangchen Zhou, Xianhui Lin, Wangmeng Zuo and Lei Zhang
  • Aug 27, 2020. Add StyleGAN2 training and testing codes.
    CVPR20: Analyzing and Improving the Image Quality of StyleGAN
    Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen and Timo Aila
  • Aug 19, 2020. A brand-new BasicSR v1.0.0 online.
From e7692a2fedffb8c171bac9f4eca2ad0ba06ed653 Mon Sep 17 00:00:00 2001 From: Xintao Date: Tue, 5 Oct 2021 17:55:45 +0800 Subject: [PATCH 09/13] reorganize history updates --- README.md | 13 +------------ docs/history_updates.md | 13 +++++++++++++ 2 files changed, 14 insertions(+), 12 deletions(-) create mode 100644 docs/history_updates.md diff --git a/README.md b/README.md index 4977769ab..9e41c0727 100644 --- a/README.md +++ b/README.md @@ -36,18 +36,7 @@ BasicSR (**Basic** **S**uper **R**estoration) 是一个基于 PyTorch 的开源 - :white_check_mark: Sep 2, 2021. Add **SwinIR training and testing** codes: [SwinIR](https://github.com/JingyunLiang/SwinIR) by [Jingyun Liang](https://github.com/JingyunLiang):+1:. More details are in [HOWTOs.md](docs/HOWTOs.md#how-to-train-swinir-sr) - :white_check_mark: Aug 5, 2021. Add NIQE, which produces the same results as MATLAB (both are 5.7296 for tests/data/baboon.png). - :white_check_mark: July 31, 2021. Add **bi-directional video super-resolution** codes: [**BasicVSR** and IconVSR](https://arxiv.org/abs/2012.02181). -- :white_check_mark: July 20, 2021. Add **dual-blind face restoration** codes: [HiFaceGAN](https://github.com/Lotayou/Face-Renovation) codes by [Lotayou](https://lotayou.github.io/). -- :white_check_mark: Nov 29, 2020. Add **ESRGAN** and **DFDNet** [colab demo](colab) -- :white_check_mark: Aug 27, 2020. Add **StyleGAN2 training and testing** codes: [StyleGAN2](https://github.com/rosinality/stylegan2-pytorch). - -
- More -
    -
  • Sep 8, 2020. Add blind face restoration inference codes: DFDNet.
    ECCV20: Blind Face Restoration via Deep Multi-scale Component Dictionaries
    Xiaoming Li, Chaofeng Chen, Shangchen Zhou, Xianhui Lin, Wangmeng Zuo and Lei Zhang
  • -
  • Aug 27, 2020. Add StyleGAN2 training and testing codes.
    CVPR20: Analyzing and Improving the Image Quality of StyleGAN
    Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen and Timo Aila
  • -
  • Aug 19, 2020. A brand-new BasicSR v1.0.0 online.
  • -
-
+- **[More](docs/history_updates.md)** :sparkles: **Projects that use BasicSR** - [**Real-ESRGAN**](https://github.com/xinntao/Real-ESRGAN): A practical algorithm for general image restoration diff --git a/docs/history_updates.md b/docs/history_updates.md new file mode 100644 index 000000000..9f6cfa7ae --- /dev/null +++ b/docs/history_updates.md @@ -0,0 +1,13 @@ +# History of New Features/Updates + +:triangular_flag_on_post: **New Features/Updates** + +- :white_check_mark: July 20, 2021. Add **dual-blind face restoration** codes: [HiFaceGAN](https://github.com/Lotayou/Face-Renovation) codes by [Lotayou](https://lotayou.github.io/). +- :white_check_mark: Nov 29, 2020. Add **ESRGAN** and **DFDNet** [colab demo](../colab) +- :white_check_mark: Sep 8, 2020. Add **blind face restoration** inference codes: [DFDNet](https://github.com/csxmli2016/DFDNet). + > ECCV20: Blind Face Restoration via Deep Multi-scale Component Dictionaries
+ > Xiaoming Li, Chaofeng Chen, Shangchen Zhou, Xianhui Lin, Wangmeng Zuo and Lei Zhang +- :white_check_mark: Aug 27, 2020. Add **StyleGAN2 training and testing** codes: [StyleGAN2](https://github.com/rosinality/stylegan2-pytorch). + > CVPR20: Analyzing and Improving the Image Quality of StyleGAN
+ > Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen and Timo Aila +- :white_check_mark: Aug 19, 2020. A **brand-new** BasicSR v1.0.0 online. From 99f6e2bd9fcbdbb7e7ef885618c57a687738ee32 Mon Sep 17 00:00:00 2001 From: Xintao Date: Tue, 5 Oct 2021 18:00:02 +0800 Subject: [PATCH 10/13] update readme: ecbsr --- README.md | 3 +++ README_CN.md | 17 ++++------------- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 9e41c0727..56e9b9894 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,9 @@ BasicSR (**Basic** **S**uper **R**estoration) 是一个基于 PyTorch 的开源 :triangular_flag_on_post: **New Features/Updates** +- :white_check_mark: Oct 5, 2021. Add **ECBSR training and testing** codes: [ECBSR](https://github.com/xindongzhang/ECBSR). + > ACMMM21: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
+ > Xindong Zhang, Hui Zeng, Lei Zhang - :white_check_mark: Sep 2, 2021. Add **SwinIR training and testing** codes: [SwinIR](https://github.com/JingyunLiang/SwinIR) by [Jingyun Liang](https://github.com/JingyunLiang):+1:. More details are in [HOWTOs.md](docs/HOWTOs.md#how-to-train-swinir-sr) - :white_check_mark: Aug 5, 2021. Add NIQE, which produces the same results as MATLAB (both are 5.7296 for tests/data/baboon.png). - :white_check_mark: July 31, 2021. Add **bi-directional video super-resolution** codes: [**BasicVSR** and IconVSR](https://arxiv.org/abs/2012.02181). diff --git a/README_CN.md b/README_CN.md index 37ed588bb..67acdb496 100644 --- a/README_CN.md +++ b/README_CN.md @@ -31,22 +31,13 @@ BasicSR (**Basic** **S**uper **R**estoration) 是一个基于 PyTorch 的开源 :triangular_flag_on_post: **新的特性/更新** +- :white_check_mark: Oct 5, 2021. 添加 **ECBSR 训练和测试** 代码: [ECBSR](https://github.com/xindongzhang/ECBSR). + > ACMMM21: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
+ > Xindong Zhang, Hui Zeng, Lei Zhang - :white_check_mark: Sep 2, 2021. 添加 **SwinIR 训练和测试** 代码: [SwinIR](https://github.com/JingyunLiang/SwinIR) by [Jingyun Liang](https://github.com/JingyunLiang):+1:. 更多内容参见 [HOWTOs.md](docs/HOWTOs.md#how-to-train-swinir-sr) - :white_check_mark: Aug 5, 2021. 添加了NIQE, 它输出和MATLAB一样的结果 (both are 5.7296 for tests/data/baboon.png). - :white_check_mark: July 31, 2021. Add **bi-directional video super-resolution** codes: [**BasicVSR** and IconVSR](https://arxiv.org/abs/2012.02181). -- :white_check_mark: July 20, 2021. Add **dual-blind face restoration** codes: [**HiFaceGAN**](https://github.com/Lotayou/Face-Renovation) codes by [Lotayou](https://lotayou.github.io/). -- :white_check_mark: Nov 29, 2020. 添加 **ESRGAN** and **DFDNet** [colab demo](colab). -- :white_check_mark: Sep 8, 2020. 添加 **盲人脸复原**测试代码: [DFDNet](https://github.com/csxmli2016/DFDNet). -- :white_check_mark: Aug 27, 2020. 添加 **StyleGAN2 训练和测试** 代码: [StyleGAN2](https://github.com/rosinality/stylegan2-pytorch). - -
- 更多 -
    -
  • Sep 8, 2020. 添加 盲人脸复原 测试代码: DFDNet.
    ECCV20: Blind Face Restoration via Deep Multi-scale Component Dictionaries
    Xiaoming Li, Chaofeng Chen, Shangchen Zhou, Xianhui Lin, Wangmeng Zuo and Lei Zhang
  • -
  • Aug 27, 2020. 添加 StyleGAN2 训练和测试代码.
    CVPR20: Analyzing and Improving the Image Quality of StyleGAN
    Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen and Timo Aila
  • -
  • Aug 19, 2020. 全新的 BasicSR v1.0.0 上线.
  • -
-
+- **[更多](docs/history_updates.md)** :sparkles: **使用 BasicSR 的项目** - [**Real-ESRGAN**](https://github.com/xinntao/Real-ESRGAN): 通用图像复原的实用算法 From fcbee2caa1f805be184068d84c221a34d940b6c7 Mon Sep 17 00:00:00 2001 From: Xintao Date: Tue, 5 Oct 2021 18:09:54 +0800 Subject: [PATCH 11/13] update readme --- README.md | 6 +++--- README_CN.md | 6 +++--- docs/history_updates.md | 18 ++++++++++++++---- setup.cfg | 3 ++- 4 files changed, 22 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 56e9b9894..3cf219f26 100644 --- a/README.md +++ b/README.md @@ -34,11 +34,11 @@ BasicSR (**Basic** **S**uper **R**estoration) 是一个基于 PyTorch 的开源 :triangular_flag_on_post: **New Features/Updates** - :white_check_mark: Oct 5, 2021. Add **ECBSR training and testing** codes: [ECBSR](https://github.com/xindongzhang/ECBSR). - > ACMMM21: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
- > Xindong Zhang, Hui Zeng, Lei Zhang -- :white_check_mark: Sep 2, 2021. Add **SwinIR training and testing** codes: [SwinIR](https://github.com/JingyunLiang/SwinIR) by [Jingyun Liang](https://github.com/JingyunLiang):+1:. More details are in [HOWTOs.md](docs/HOWTOs.md#how-to-train-swinir-sr) + > ACMMM21: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices +- :white_check_mark: Sep 2, 2021. Add **SwinIR training and testing** codes: [SwinIR](https://github.com/JingyunLiang/SwinIR) by [Jingyun Liang](https://github.com/JingyunLiang). More details are in [HOWTOs.md](docs/HOWTOs.md#how-to-train-swinir-sr) - :white_check_mark: Aug 5, 2021. Add NIQE, which produces the same results as MATLAB (both are 5.7296 for tests/data/baboon.png). - :white_check_mark: July 31, 2021. Add **bi-directional video super-resolution** codes: [**BasicVSR** and IconVSR](https://arxiv.org/abs/2012.02181). + > CVPR21: BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond - **[More](docs/history_updates.md)** :sparkles: **Projects that use BasicSR** diff --git a/README_CN.md b/README_CN.md index 67acdb496..a979da704 100644 --- a/README_CN.md +++ b/README_CN.md @@ -32,11 +32,11 @@ BasicSR (**Basic** **S**uper **R**estoration) 是一个基于 PyTorch 的开源 :triangular_flag_on_post: **新的特性/更新** - :white_check_mark: Oct 5, 2021. 添加 **ECBSR 训练和测试** 代码: [ECBSR](https://github.com/xindongzhang/ECBSR). - > ACMMM21: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
- > Xindong Zhang, Hui Zeng, Lei Zhang -- :white_check_mark: Sep 2, 2021. 添加 **SwinIR 训练和测试** 代码: [SwinIR](https://github.com/JingyunLiang/SwinIR) by [Jingyun Liang](https://github.com/JingyunLiang):+1:. 更多内容参见 [HOWTOs.md](docs/HOWTOs.md#how-to-train-swinir-sr) + > ACMMM21: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices +- :white_check_mark: Sep 2, 2021. 添加 **SwinIR 训练和测试** 代码: [SwinIR](https://github.com/JingyunLiang/SwinIR) by [Jingyun Liang](https://github.com/JingyunLiang). 更多内容参见 [HOWTOs.md](docs/HOWTOs.md#how-to-train-swinir-sr) - :white_check_mark: Aug 5, 2021. 添加了NIQE, 它输出和MATLAB一样的结果 (both are 5.7296 for tests/data/baboon.png). - :white_check_mark: July 31, 2021. Add **bi-directional video super-resolution** codes: [**BasicVSR** and IconVSR](https://arxiv.org/abs/2012.02181). + > CVPR21: BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond - **[更多](docs/history_updates.md)** :sparkles: **使用 BasicSR 的项目** diff --git a/docs/history_updates.md b/docs/history_updates.md index 9f6cfa7ae..c9c8779fe 100644 --- a/docs/history_updates.md +++ b/docs/history_updates.md @@ -2,12 +2,22 @@ :triangular_flag_on_post: **New Features/Updates** +- :white_check_mark: Oct 5, 2021. Add **ECBSR training and testing** codes: [ECBSR](https://github.com/xindongzhang/ECBSR). + > ACMMM21: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
+ > Xindong Zhang, Hui Zeng, Lei Zhang +- :white_check_mark: Sep 2, 2021. Add **SwinIR training and testing** codes: [SwinIR](https://github.com/JingyunLiang/SwinIR) by [Jingyun Liang](https://github.com/JingyunLiang). More details are in [HOWTOs.md](docs/HOWTOs.md#how-to-train-swinir-sr) + > ICCVW21: SwinIR: Image Restoration Using Swin Transformer
+ > Jingyun Liang, Jiezhang Cao, Sun, Guolei Sun, Kai Zhang, Luc Van Gool and Radu Timofte +- :white_check_mark: Aug 5, 2021. Add NIQE, which produces the same results as MATLAB (both are 5.7296 for tests/data/baboon.png). +- :white_check_mark: July 31, 2021. Add **bi-directional video super-resolution** codes: [**BasicVSR** and IconVSR](https://arxiv.org/abs/2012.02181). + > CVPR21: BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond
+ > Kelvin C.K., Xintao Wang, Ke Yu, Chao Dong, Chen Change Loy - :white_check_mark: July 20, 2021. Add **dual-blind face restoration** codes: [HiFaceGAN](https://github.com/Lotayou/Face-Renovation) codes by [Lotayou](https://lotayou.github.io/). - :white_check_mark: Nov 29, 2020. Add **ESRGAN** and **DFDNet** [colab demo](../colab) - :white_check_mark: Sep 8, 2020. Add **blind face restoration** inference codes: [DFDNet](https://github.com/csxmli2016/DFDNet). - > ECCV20: Blind Face Restoration via Deep Multi-scale Component Dictionaries
- > Xiaoming Li, Chaofeng Chen, Shangchen Zhou, Xianhui Lin, Wangmeng Zuo and Lei Zhang + > ECCV20: Blind Face Restoration via Deep Multi-scale Component Dictionaries
+ > Xiaoming Li, Chaofeng Chen, Shangchen Zhou, Xianhui Lin, Wangmeng Zuo and Lei Zhang - :white_check_mark: Aug 27, 2020. Add **StyleGAN2 training and testing** codes: [StyleGAN2](https://github.com/rosinality/stylegan2-pytorch). - > CVPR20: Analyzing and Improving the Image Quality of StyleGAN
- > Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen and Timo Aila + > CVPR20: Analyzing and Improving the Image Quality of StyleGAN
+ > Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen and Timo Aila - :white_check_mark: Aug 19, 2020. A **brand-new** BasicSR v1.0.0 online. diff --git a/setup.cfg b/setup.cfg index 7eae2ea29..ac9d44ddd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -22,6 +22,7 @@ no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY [codespell] -skip = .git,./docs/build +skip = .git,./docs/build,*.cfg count = quiet-level = 3 +ignore-words-list = gool From e8d466c5469622a7067cdee2bb0be7deaa0d1414 Mon Sep 17 00:00:00 2001 From: Xintao Date: Tue, 5 Oct 2021 18:13:01 +0800 Subject: [PATCH 12/13] update readme --- README.md | 2 +- README_CN.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 3cf219f26..d49f2f056 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ :loudspeaker: **技术交流QQ群**:**320960100**   入群答案:**互帮互助共同进步** -:compass: [入群二维码](#e-mail-contact)    [入群指南 (腾讯文档)](https://docs.qq.com/doc/DYXBSUmxOT0xBZ05u) +:compass: [入群二维码](#e-mail-contact) (QQ、微信)    [入群指南 (腾讯文档)](https://docs.qq.com/doc/DYXBSUmxOT0xBZ05u) --- diff --git a/README_CN.md b/README_CN.md index a979da704..7b3424f74 100644 --- a/README_CN.md +++ b/README_CN.md @@ -13,7 +13,7 @@ :loudspeaker: **技术交流QQ群**:**320960100**   入群答案:**互帮互助共同进步** -:compass: [入群二维码](#e-mail-%E8%81%94%E7%B3%BB)    [入群指南 (腾讯文档)](https://docs.qq.com/doc/DYXBSUmxOT0xBZ05u) +:compass: [入群二维码](#e-mail-%E8%81%94%E7%B3%BB) (QQ、微信)    [入群指南 (腾讯文档)](https://docs.qq.com/doc/DYXBSUmxOT0xBZ05u) --- From 3ece51b93f9ca70d4a1a8ec7ef192001fb20e9f7 Mon Sep 17 00:00:00 2001 From: Xintao Date: Tue, 5 Oct 2021 18:15:24 +0800 Subject: [PATCH 13/13] update license of ecbsr --- LICENSE/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/LICENSE/README.md b/LICENSE/README.md index 93179cc13..3bd86f341 100644 --- a/LICENSE/README.md +++ b/LICENSE/README.md @@ -13,6 +13,8 @@ This BasicSR project is released under the Apache 2.0 license. - We use the implementation of `DropPath` and `trunc_normal_` from [pytorch-image-models](https://github.com/rwightman/pytorch-image-models/). The LICENSE is included as [LICENSE_pytorch-image-models](LICENSE/LICENSE_pytorch-image-models). - [SwinIR](https://github.com/JingyunLiang/SwinIR) - The arch implementation of SwinIR is from [SwinIR](https://github.com/JingyunLiang/SwinIR). The LICENSE is included as [LICENSE_SwinIR](LICENSE/LICENSE_SwinIR). +- [ECBSR](https://github.com/xindongzhang/ECBSR) + - The arch implementation of ECBSR is from [ECBSR](https://github.com/xindongzhang/ECBSR). The LICENSE of ECBSR is [Apache License 2.0](https://github.com/xindongzhang/ECBSR/blob/main/LICENSE) ## References