-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
looking forward to the code release #2
Comments
Thank you for your interest! Unfortunately, we cannot compare ours with SPAN as the parameters in the released code and the paper do not match. The code we used is as follows: codesfrom collections import OrderedDict
import torch
from torch import nn as nn
import torch.nn.functional as F
from basicsr.utils.registry import ARCH_REGISTRY
def _make_pair(value):
if isinstance(value, int):
value = (value,) * 2
return value
def conv_layer(in_channels,
out_channels,
kernel_size,
bias=True):
"""
Re-write convolution layer for adaptive `padding`.
"""
kernel_size = _make_pair(kernel_size)
padding = (int((kernel_size[0] - 1) / 2),
int((kernel_size[1] - 1) / 2))
return nn.Conv2d(in_channels,
out_channels,
kernel_size,
padding=padding,
bias=bias)
def activation(act_type, inplace=True, neg_slope=0.05, n_prelu=1):
"""
Activation functions for ['relu', 'lrelu', 'prelu'].
Parameters
----------
act_type: str
one of ['relu', 'lrelu', 'prelu'].
inplace: bool
whether to use inplace operator.
neg_slope: float
slope of negative region for `lrelu` or `prelu`.
n_prelu: int
`num_parameters` for `prelu`.
----------
"""
act_type = act_type.lower()
if act_type == 'relu':
layer = nn.ReLU(inplace)
elif act_type == 'lrelu':
layer = nn.LeakyReLU(neg_slope, inplace)
elif act_type == 'prelu':
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
else:
raise NotImplementedError(
'activation layer [{:s}] is not found'.format(act_type))
return layer
def sequential(*args):
"""
Modules will be added to the a Sequential Container in the order they
are passed.
Parameters
----------
args: Definition of Modules in order.
-------
"""
if len(args) == 1:
if isinstance(args[0], OrderedDict):
raise NotImplementedError(
'sequential does not support OrderedDict input.')
return args[0]
modules = []
for module in args:
if isinstance(module, nn.Sequential):
for submodule in module.children():
modules.append(submodule)
elif isinstance(module, nn.Module):
modules.append(module)
return nn.Sequential(*modules)
def pixelshuffle_block(in_channels,
out_channels,
upscale_factor=2,
kernel_size=3):
"""
Upsample features according to `upscale_factor`.
"""
conv = conv_layer(in_channels,
out_channels * (upscale_factor ** 2),
kernel_size)
pixel_shuffle = nn.PixelShuffle(upscale_factor)
return sequential(conv, pixel_shuffle)
class Conv3XC(nn.Module):
def __init__(self, c_in, c_out, gain1=1, gain2=0, s=1, bias=True, relu=False):
super(Conv3XC, self).__init__()
self.weight_concat = None
self.bias_concat = None
self.update_params_flag = False
self.stride = s
self.has_relu = relu
gain = gain1
self.sk = nn.Conv2d(in_channels=c_in, out_channels=c_out, kernel_size=1, padding=0, stride=s, bias=bias)
self.conv = nn.Sequential(
nn.Conv2d(in_channels=c_in, out_channels=c_in * gain, kernel_size=1, padding=0, bias=bias),
nn.Conv2d(in_channels=c_in * gain, out_channels=c_out * gain, kernel_size=3, stride=s, padding=0, bias=bias),
nn.Conv2d(in_channels=c_out * gain, out_channels=c_out, kernel_size=1, padding=0, bias=bias),
)
self.eval_conv = nn.Conv2d(in_channels=c_in, out_channels=c_out, kernel_size=3, padding=1, stride=s, bias=bias)
self.eval_conv.weight.requires_grad = False
self.eval_conv.bias.requires_grad = False
# self.update_params()
self.updated = False
def update_params(self):
if self.updated == False:
w1 = self.conv[0].weight.data.clone().detach()
b1 = self.conv[0].bias.data.clone().detach()
w2 = self.conv[1].weight.data.clone().detach()
b2 = self.conv[1].bias.data.clone().detach()
w3 = self.conv[2].weight.data.clone().detach()
b3 = self.conv[2].bias.data.clone().detach()
w = F.conv2d(w1.flip(2, 3).permute(1, 0, 2, 3), w2, padding=2, stride=1).flip(2, 3).permute(1, 0, 2, 3)
b = (w2 * b1.reshape(1, -1, 1, 1)).sum((1, 2, 3)) + b2
self.weight_concat = F.conv2d(w.flip(2, 3).permute(1, 0, 2, 3), w3, padding=0, stride=1).flip(2, 3).permute(1, 0, 2, 3)
self.bias_concat = (w3 * b.reshape(1, -1, 1, 1)).sum((1, 2, 3)) + b3
sk_w = self.sk.weight.data.clone().detach()
sk_b = self.sk.bias.data.clone().detach()
target_kernel_size = 3
H_pixels_to_pad = (target_kernel_size - 1) // 2
W_pixels_to_pad = (target_kernel_size - 1) // 2
sk_w = F.pad(sk_w, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad])
self.weight_concat = self.weight_concat + sk_w
self.bias_concat = self.bias_concat + sk_b
self.eval_conv.weight.data = self.weight_concat
self.eval_conv.bias.data = self.bias_concat
delattr(self, 'conv') # to exclude rep convs from parameter counts
delattr(self, 'sk')
self.updated = True
def forward(self, x):
if self.training:
pad = 1
x_pad = F.pad(x, (pad, pad, pad, pad), "constant", 0)
out = self.conv(x_pad) + self.sk(x)
else:
self.update_params()
out = self.eval_conv(x)
if self.has_relu:
out = F.leaky_relu(out, negative_slope=0.05)
return out
class SPAB(nn.Module):
def __init__(self,
in_channels,
mid_channels=None,
out_channels=None,
bias=False):
super(SPAB, self).__init__()
if mid_channels is None:
mid_channels = in_channels
if out_channels is None:
out_channels = in_channels
self.in_channels = in_channels
self.c1_r = Conv3XC(in_channels, mid_channels, gain1=2, s=1)
self.c2_r = Conv3XC(mid_channels, mid_channels, gain1=2, s=1)
self.c3_r = Conv3XC(mid_channels, out_channels, gain1=2, s=1)
self.act1 = torch.nn.SiLU(inplace=True)
self.act2 = activation('lrelu', neg_slope=0.1, inplace=True)
def forward(self, x):
out1 = (self.c1_r(x))
out1_act = self.act1(out1)
out2 = (self.c2_r(out1_act))
out2_act = self.act1(out2)
out3 = (self.c3_r(out2_act))
sim_att = torch.sigmoid(out3) - 0.5
out = (out3 + x) * sim_att
return out, out1, sim_att
@ARCH_REGISTRY.register()
class SPAN(nn.Module):
"""
Swift Parameter-free Attention Network for Efficient Super-Resolution
"""
def __init__(self,
num_in_ch,
num_out_ch,
feature_channels=48,
upscale=4,
bias=True,
img_range=255.,
rgb_mean=(0.4488, 0.4371, 0.4040)
):
super(SPAN, self).__init__()
in_channels = num_in_ch
out_channels = num_out_ch
self.img_range = img_range
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
self.conv_1 = Conv3XC(in_channels, feature_channels, gain1=2, s=1)
self.block_1 = SPAB(feature_channels, bias=bias)
self.block_2 = SPAB(feature_channels, bias=bias)
self.block_3 = SPAB(feature_channels, bias=bias)
self.block_4 = SPAB(feature_channels, bias=bias)
self.block_5 = SPAB(feature_channels, bias=bias)
self.block_6 = SPAB(feature_channels, bias=bias)
self.conv_cat = conv_layer(feature_channels * 4, feature_channels, kernel_size=1, bias=True)
self.conv_2 = Conv3XC(feature_channels, feature_channels, gain1=2, s=1)
self.upsampler = pixelshuffle_block(feature_channels, out_channels, upscale_factor=upscale)
def forward(self, x):
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range
out_feature = self.conv_1(x)
out_b1, _, att1 = self.block_1(out_feature)
out_b2, _, att2 = self.block_2(out_b1)
out_b3, _, att3 = self.block_3(out_b2)
out_b4, _, att4 = self.block_4(out_b3)
out_b5, _, att5 = self.block_5(out_b4)
out_b6, out_b5_2, att6 = self.block_6(out_b5)
out_b6 = self.conv_2(out_b6)
out = self.conv_cat(torch.cat([out_feature, out_b6, out_b1, out_b5_2], 1))
output = self.upsampler(out)
return output
if __name__== '__main__':
import numpy as np
from scripts.test_direct_metrics import test_direct_metrics
test_size = 'HD'
# test_size = 'FHD'
# test_size = '4K'
height = 720 if test_size == 'HD' else 1080 if test_size == 'FHD' else 2160
width = 1280 if test_size == 'HD' else 1920 if test_size == 'FHD' else 3840
upsampling_factor = 2
batch_size = 1
model_kwargs = {
'num_in_ch': 3,
'num_out_ch': 3,
'feature_channels': 48,
'upscale': upsampling_factor
}
shape = (batch_size, 3, height // upsampling_factor, width // upsampling_factor)
model = SPAN(**model_kwargs)
test_direct_metrics(model, shape)
Please let me know if you need any further information or clarification. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
great work, looking forward to releasing the code soon, and curious about the comparison with SPAN https://github.com/hongyuanyu/SPAN
The text was updated successfully, but these errors were encountered: