Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

looking forward to the code release #2

Open
FlotingDream opened this issue Apr 19, 2024 · 1 comment
Open

looking forward to the code release #2

FlotingDream opened this issue Apr 19, 2024 · 1 comment

Comments

@FlotingDream
Copy link

great work, looking forward to releasing the code soon, and curious about the comparison with SPAN https://github.com/hongyuanyu/SPAN

@dslisleedh
Copy link
Owner

Thank you for your interest! Unfortunately, we cannot compare ours with SPAN as the parameters in the released code and the paper do not match. The code we used is as follows:

codes from collections import OrderedDict import torch from torch import nn as nn import torch.nn.functional as F from basicsr.utils.registry import ARCH_REGISTRY def _make_pair(value): if isinstance(value, int): value = (value,) * 2 return value def conv_layer(in_channels, out_channels, kernel_size, bias=True): """ Re-write convolution layer for adaptive `padding`. """ kernel_size = _make_pair(kernel_size) padding = (int((kernel_size[0] - 1) / 2), int((kernel_size[1] - 1) / 2)) return nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) def activation(act_type, inplace=True, neg_slope=0.05, n_prelu=1): """ Activation functions for ['relu', 'lrelu', 'prelu']. Parameters ---------- act_type: str one of ['relu', 'lrelu', 'prelu']. inplace: bool whether to use inplace operator. neg_slope: float slope of negative region for `lrelu` or `prelu`. n_prelu: int `num_parameters` for `prelu`. ---------- """ act_type = act_type.lower() if act_type == 'relu': layer = nn.ReLU(inplace) elif act_type == 'lrelu': layer = nn.LeakyReLU(neg_slope, inplace) elif act_type == 'prelu': layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) else: raise NotImplementedError( 'activation layer [{:s}] is not found'.format(act_type)) return layer def sequential(*args): """ Modules will be added to the a Sequential Container in the order they are passed. Parameters ---------- args: Definition of Modules in order. ------- """ if len(args) == 1: if isinstance(args[0], OrderedDict): raise NotImplementedError( 'sequential does not support OrderedDict input.') return args[0] modules = [] for module in args: if isinstance(module, nn.Sequential): for submodule in module.children(): modules.append(submodule) elif isinstance(module, nn.Module): modules.append(module) return nn.Sequential(*modules) def pixelshuffle_block(in_channels, out_channels, upscale_factor=2, kernel_size=3): """ Upsample features according to `upscale_factor`. """ conv = conv_layer(in_channels, out_channels * (upscale_factor ** 2), kernel_size) pixel_shuffle = nn.PixelShuffle(upscale_factor) return sequential(conv, pixel_shuffle) class Conv3XC(nn.Module): def __init__(self, c_in, c_out, gain1=1, gain2=0, s=1, bias=True, relu=False): super(Conv3XC, self).__init__() self.weight_concat = None self.bias_concat = None self.update_params_flag = False self.stride = s self.has_relu = relu gain = gain1 self.sk = nn.Conv2d(in_channels=c_in, out_channels=c_out, kernel_size=1, padding=0, stride=s, bias=bias) self.conv = nn.Sequential( nn.Conv2d(in_channels=c_in, out_channels=c_in * gain, kernel_size=1, padding=0, bias=bias), nn.Conv2d(in_channels=c_in * gain, out_channels=c_out * gain, kernel_size=3, stride=s, padding=0, bias=bias), nn.Conv2d(in_channels=c_out * gain, out_channels=c_out, kernel_size=1, padding=0, bias=bias), ) self.eval_conv = nn.Conv2d(in_channels=c_in, out_channels=c_out, kernel_size=3, padding=1, stride=s, bias=bias) self.eval_conv.weight.requires_grad = False self.eval_conv.bias.requires_grad = False # self.update_params() self.updated = False def update_params(self): if self.updated == False: w1 = self.conv[0].weight.data.clone().detach() b1 = self.conv[0].bias.data.clone().detach() w2 = self.conv[1].weight.data.clone().detach() b2 = self.conv[1].bias.data.clone().detach() w3 = self.conv[2].weight.data.clone().detach() b3 = self.conv[2].bias.data.clone().detach() w = F.conv2d(w1.flip(2, 3).permute(1, 0, 2, 3), w2, padding=2, stride=1).flip(2, 3).permute(1, 0, 2, 3) b = (w2 * b1.reshape(1, -1, 1, 1)).sum((1, 2, 3)) + b2 self.weight_concat = F.conv2d(w.flip(2, 3).permute(1, 0, 2, 3), w3, padding=0, stride=1).flip(2, 3).permute(1, 0, 2, 3) self.bias_concat = (w3 * b.reshape(1, -1, 1, 1)).sum((1, 2, 3)) + b3 sk_w = self.sk.weight.data.clone().detach() sk_b = self.sk.bias.data.clone().detach() target_kernel_size = 3 H_pixels_to_pad = (target_kernel_size - 1) // 2 W_pixels_to_pad = (target_kernel_size - 1) // 2 sk_w = F.pad(sk_w, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad]) self.weight_concat = self.weight_concat + sk_w self.bias_concat = self.bias_concat + sk_b self.eval_conv.weight.data = self.weight_concat self.eval_conv.bias.data = self.bias_concat delattr(self, 'conv') # to exclude rep convs from parameter counts delattr(self, 'sk') self.updated = True def forward(self, x): if self.training: pad = 1 x_pad = F.pad(x, (pad, pad, pad, pad), "constant", 0) out = self.conv(x_pad) + self.sk(x) else: self.update_params() out = self.eval_conv(x) if self.has_relu: out = F.leaky_relu(out, negative_slope=0.05) return out class SPAB(nn.Module): def __init__(self, in_channels, mid_channels=None, out_channels=None, bias=False): super(SPAB, self).__init__() if mid_channels is None: mid_channels = in_channels if out_channels is None: out_channels = in_channels self.in_channels = in_channels self.c1_r = Conv3XC(in_channels, mid_channels, gain1=2, s=1) self.c2_r = Conv3XC(mid_channels, mid_channels, gain1=2, s=1) self.c3_r = Conv3XC(mid_channels, out_channels, gain1=2, s=1) self.act1 = torch.nn.SiLU(inplace=True) self.act2 = activation('lrelu', neg_slope=0.1, inplace=True) def forward(self, x): out1 = (self.c1_r(x)) out1_act = self.act1(out1) out2 = (self.c2_r(out1_act)) out2_act = self.act1(out2) out3 = (self.c3_r(out2_act)) sim_att = torch.sigmoid(out3) - 0.5 out = (out3 + x) * sim_att return out, out1, sim_att @ARCH_REGISTRY.register() class SPAN(nn.Module): """ Swift Parameter-free Attention Network for Efficient Super-Resolution """ def __init__(self, num_in_ch, num_out_ch, feature_channels=48, upscale=4, bias=True, img_range=255., rgb_mean=(0.4488, 0.4371, 0.4040) ): super(SPAN, self).__init__() in_channels = num_in_ch out_channels = num_out_ch self.img_range = img_range self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) self.conv_1 = Conv3XC(in_channels, feature_channels, gain1=2, s=1) self.block_1 = SPAB(feature_channels, bias=bias) self.block_2 = SPAB(feature_channels, bias=bias) self.block_3 = SPAB(feature_channels, bias=bias) self.block_4 = SPAB(feature_channels, bias=bias) self.block_5 = SPAB(feature_channels, bias=bias) self.block_6 = SPAB(feature_channels, bias=bias) self.conv_cat = conv_layer(feature_channels * 4, feature_channels, kernel_size=1, bias=True) self.conv_2 = Conv3XC(feature_channels, feature_channels, gain1=2, s=1) self.upsampler = pixelshuffle_block(feature_channels, out_channels, upscale_factor=upscale) def forward(self, x): self.mean = self.mean.type_as(x) x = (x - self.mean) * self.img_range out_feature = self.conv_1(x) out_b1, _, att1 = self.block_1(out_feature) out_b2, _, att2 = self.block_2(out_b1) out_b3, _, att3 = self.block_3(out_b2) out_b4, _, att4 = self.block_4(out_b3) out_b5, _, att5 = self.block_5(out_b4) out_b6, out_b5_2, att6 = self.block_6(out_b5) out_b6 = self.conv_2(out_b6) out = self.conv_cat(torch.cat([out_feature, out_b6, out_b1, out_b5_2], 1)) output = self.upsampler(out) return output if __name__== '__main__': import numpy as np from scripts.test_direct_metrics import test_direct_metrics test_size = 'HD' # test_size = 'FHD' # test_size = '4K' height = 720 if test_size == 'HD' else 1080 if test_size == 'FHD' else 2160 width = 1280 if test_size == 'HD' else 1920 if test_size == 'FHD' else 3840 upsampling_factor = 2 batch_size = 1 model_kwargs = { 'num_in_ch': 3, 'num_out_ch': 3, 'feature_channels': 48, 'upscale': upsampling_factor } shape = (batch_size, 3, height // upsampling_factor, width // upsampling_factor) model = SPAN(**model_kwargs) test_direct_metrics(model, shape)

Please let me know if you need any further information or clarification.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants