From ae342a3039e907c4eb76058dafa8192d153990c6 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Thu, 22 Jul 2021 11:16:15 +0100 Subject: [PATCH 1/6] Add changing convolution 2D methods to generic GAN network definitions, and RUNet/Attention Unet models Makes it easy to swap between standard Conv2D layers, CoordConv layers, and soon, Anti-aliased Conv2D layers --- satflow/models/gan/discriminators.py | 71 +++++++++++++++++++--------- satflow/models/gan/generators.py | 62 ++++++++++++++++++------ satflow/models/layers/RUnetLayers.py | 38 ++++++++------- satflow/models/utils.py | 15 ++++++ 4 files changed, 133 insertions(+), 53 deletions(-) create mode 100644 satflow/models/utils.py diff --git a/satflow/models/gan/discriminators.py b/satflow/models/gan/discriminators.py index bd298554..b45b41b3 100644 --- a/satflow/models/gan/discriminators.py +++ b/satflow/models/gan/discriminators.py @@ -1,11 +1,20 @@ import functools import torch from torch import nn as nn - +from satflow.models.utils import get_conv_layer from satflow.models.gan.common import get_norm_layer, init_net -def define_D(input_nc, ndf, netD, n_layers_D=3, norm="batch", init_type="normal", init_gain=0.02): +def define_discriminator( + input_nc, + ndf, + netD, + n_layers_D=3, + norm="batch", + init_type="normal", + init_gain=0.02, + conv_type: str = "standard", +): """Create a discriminator Parameters: @@ -36,15 +45,20 @@ def define_D(input_nc, ndf, netD, n_layers_D=3, norm="batch", init_type="normal" """ net = None norm_layer = get_norm_layer(norm_type=norm) - if netD == "basic": # default PatchGAN classifier - net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer) + net = NLayerDiscriminator( + input_nc, ndf, n_layers=3, norm_layer=norm_layer, conv_type=conv_type + ) elif netD == "n_layers": # more options - net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer) + net = NLayerDiscriminator( + input_nc, ndf, n_layers_D, norm_layer=norm_layer, conv_type=conv_type + ) elif netD == "pixel": # classify if each pixel is real or fake - net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) + net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, conv_type=conv_type) elif netD == "enhanced": - net = CloudGANDiscriminator(input_channels=input_nc, num_filters=ndf, num_stages=3) + net = CloudGANDiscriminator( + input_channels=input_nc, num_filters=ndf, num_stages=3, conv_type=conv_type + ) else: raise NotImplementedError("Discriminator model name [%s] is not recognized" % netD) return init_net(net, init_type, init_gain) @@ -122,7 +136,9 @@ def __call__(self, prediction, target_is_real): class NLayerDiscriminator(nn.Module): """Defines a PatchGAN discriminator""" - def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): + def __init__( + self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, conv_type: str = "standard" + ): """Construct a PatchGAN discriminator Parameters: @@ -139,10 +155,12 @@ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): else: use_bias = norm_layer == nn.InstanceNorm2d + conv2d = get_conv_layer(conv_type) + kw = 4 padw = 1 sequence = [ - nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True), ] nf_mult = 1 @@ -151,7 +169,7 @@ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) sequence += [ - nn.Conv2d( + conv2d( ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, @@ -166,7 +184,7 @@ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): nf_mult_prev = nf_mult nf_mult = min(2 ** n_layers, 8) sequence += [ - nn.Conv2d( + conv2d( ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, @@ -179,7 +197,7 @@ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): ] sequence += [ - nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) + conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) ] # output 1 channel prediction map self.model = nn.Sequential(*sequence) @@ -191,7 +209,7 @@ def forward(self, input): class PixelDiscriminator(nn.Module): """Defines a 1x1 PatchGAN discriminator (pixelGAN)""" - def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): + def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, conv_type: str = "standard"): """Construct a 1x1 PatchGAN discriminator Parameters: @@ -207,13 +225,15 @@ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): else: use_bias = norm_layer == nn.InstanceNorm2d + conv2d = get_conv_layer(conv_type) + self.net = [ - nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), + conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), nn.LeakyReLU(0.2, True), - nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), + conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), norm_layer(ndf * 2), nn.LeakyReLU(0.2, True), - nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias), + conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias), ] self.net = nn.Sequential(*self.net) @@ -224,9 +244,9 @@ def forward(self, input): class CloudGANBlock(nn.Module): - def __init__(self, input_channels): + def __init__(self, input_channels, conv2d: torch.nn.Module): super().__init__() - self.conv = torch.nn.Conv2d(input_channels, input_channels * 2, kernel_size=(3, 3)) + self.conv = conv2d(input_channels, input_channels * 2, kernel_size=(3, 3)) self.relu = torch.nn.ReLU() self.pool = torch.nn.MaxPool2d(kernel_size=(2, 2)) @@ -240,14 +260,19 @@ def forward(self, x): class CloudGANDiscriminator(nn.Module): """Defines a discriminator based off https://www.climatechange.ai/papers/icml2021/54/slides.pdf""" - def __init__(self, input_channels: int = 12, num_filters: int = 64, num_stages: int = 3): + def __init__( + self, + input_channels: int = 12, + num_filters: int = 64, + num_stages: int = 3, + conv_type: str = "standard", + ): super().__init__() - self.conv_1 = torch.nn.Conv2d( - input_channels, num_filters, kernel_size=1, stride=1, padding=0 - ) + conv2d = get_conv_layer(conv_type) + self.conv_1 = conv2d(input_channels, num_filters, kernel_size=1, stride=1, padding=0) self.stages = [] for stage in range(num_stages): - self.stages.append(CloudGANBlock(num_filters)) + self.stages.append(CloudGANBlock(num_filters, conv2d)) num_filters = num_filters * 2 self.stages = torch.nn.Sequential(*self.stages) self.flatten = torch.nn.Flatten() diff --git a/satflow/models/gan/generators.py b/satflow/models/gan/generators.py index 3d230fea..941283f6 100644 --- a/satflow/models/gan/generators.py +++ b/satflow/models/gan/generators.py @@ -4,6 +4,7 @@ from torch import nn as nn from typing import Union from satflow.models.gan.common import get_norm_layer, init_net +from satflow.models.utils import get_conv_layer def define_G( @@ -81,6 +82,7 @@ def __init__( use_dropout=False, n_blocks=6, padding_type="reflect", + conv_type: str = "standard", ): """Construct a Resnet-based generator @@ -99,10 +101,10 @@ def __init__( use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d - + conv2d = get_conv_layer(conv_type) model = [ nn.ReflectionPad2d(3), - nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), + conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), norm_layer(ngf), nn.ReLU(True), ] @@ -111,7 +113,7 @@ def __init__( for i in range(n_downsampling): # add downsampling layers mult = 2 ** i model += [ - nn.Conv2d( + conv2d( ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias ), norm_layer(ngf * mult * 2), @@ -160,7 +162,9 @@ def forward(self, input): class ResnetBlock(nn.Module): """Define a Resnet block""" - def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): + def __init__( + self, dim, padding_type, norm_layer, use_dropout, use_bias, conv_type: str = "standard" + ): """Initialize the Resnet block A resnet block is a conv block with skip connections @@ -169,11 +173,14 @@ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf """ super(ResnetBlock, self).__init__() + conv2d = get_conv_layer(conv_type) self.conv_block = self.build_conv_block( - dim, padding_type, norm_layer, use_dropout, use_bias + dim, padding_type, norm_layer, use_dropout, use_bias, conv2d ) - def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): + def build_conv_block( + self, dim, padding_type, norm_layer, use_dropout, use_bias, conv2d: torch.nn.Module + ): """Construct a convolutional block. Parameters: @@ -197,7 +204,7 @@ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias) raise NotImplementedError("padding [%s] is not implemented" % padding_type) conv_block += [ - nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), + conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True), ] @@ -214,7 +221,7 @@ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias) else: raise NotImplementedError("padding [%s] is not implemented" % padding_type) conv_block += [ - nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), + conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), ] @@ -230,7 +237,14 @@ class UnetGenerator(nn.Module): """Create a Unet-based generator""" def __init__( - self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False + self, + input_nc, + output_nc, + num_downs, + ngf=64, + norm_layer=nn.BatchNorm2d, + use_dropout=False, + conv_type: str = "standard", ): """Construct a Unet generator Parameters: @@ -246,8 +260,15 @@ def __init__( """ super(UnetGenerator, self).__init__() # construct unet structure + conv2d = get_conv_layer(conv_type) unet_block = UnetSkipConnectionBlock( - ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True + ngf * 8, + ngf * 8, + input_nc=None, + submodule=None, + norm_layer=norm_layer, + innermost=True, + conv2d=conv2d, ) # add the innermost layer for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters unet_block = UnetSkipConnectionBlock( @@ -257,16 +278,27 @@ def __init__( submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout, + conv2d=conv2d, ) # gradually reduce the number of filters from ngf * 8 to ngf unet_block = UnetSkipConnectionBlock( - ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer + ngf * 4, + ngf * 8, + input_nc=None, + submodule=unet_block, + norm_layer=norm_layer, + conv2d=conv2d, ) unet_block = UnetSkipConnectionBlock( - ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer + ngf * 2, + ngf * 4, + input_nc=None, + submodule=unet_block, + norm_layer=norm_layer, + conv2d=conv2d, ) unet_block = UnetSkipConnectionBlock( - ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer + ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer, conv2d=conv2d ) self.model = UnetSkipConnectionBlock( output_nc, @@ -275,6 +307,7 @@ def __init__( submodule=unet_block, outermost=True, norm_layer=norm_layer, + conv2d=conv2d, ) # add the outermost layer def forward(self, input): @@ -298,6 +331,7 @@ def __init__( innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False, + conv2d=torch.nn.Module, ): """Construct a Unet submodule with skip connections. @@ -319,7 +353,7 @@ def __init__( use_bias = norm_layer == nn.InstanceNorm2d if input_nc is None: input_nc = outer_nc - downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) + downconv = conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) downrelu = nn.LeakyReLU(0.2, True) downnorm = norm_layer(inner_nc) uprelu = nn.ReLU(True) diff --git a/satflow/models/layers/RUnetLayers.py b/satflow/models/layers/RUnetLayers.py index 4dbaf856..56b75551 100644 --- a/satflow/models/layers/RUnetLayers.py +++ b/satflow/models/layers/RUnetLayers.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F +from satflow.models.utils import get_conv_layer from torch.nn import init @@ -33,13 +33,14 @@ def init_func(m): class conv_block(nn.Module): - def __init__(self, ch_in, ch_out): + def __init__(self, ch_in, ch_out, conv_type: str = "standard"): super(conv_block, self).__init__() + conv2d = get_conv_layer(conv_type) self.conv = nn.Sequential( - nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), + conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True), - nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), + conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True), ) @@ -50,11 +51,12 @@ def forward(self, x): class up_conv(nn.Module): - def __init__(self, ch_in, ch_out): + def __init__(self, ch_in, ch_out, conv_type: str = "standard"): super(up_conv, self).__init__() + conv2d = get_conv_layer(conv_type) self.up = nn.Sequential( nn.Upsample(scale_factor=2), - nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), + conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True), ) @@ -65,12 +67,13 @@ def forward(self, x): class Recurrent_block(nn.Module): - def __init__(self, ch_out, t=2): + def __init__(self, ch_out, t=2, conv_type: str = "standard"): super(Recurrent_block, self).__init__() + conv2d = get_conv_layer(conv_type) self.t = t self.ch_out = ch_out self.conv = nn.Sequential( - nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), + conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True), ) @@ -86,10 +89,11 @@ def forward(self, x): class RRCNN_block(nn.Module): - def __init__(self, ch_in, ch_out, t=2): + def __init__(self, ch_in, ch_out, t=2, conv_type: str = "standard"): super(RRCNN_block, self).__init__() + conv2d = get_conv_layer(conv_type) self.RCNN = nn.Sequential(Recurrent_block(ch_out, t=t), Recurrent_block(ch_out, t=t)) - self.Conv_1x1 = nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0) + self.Conv_1x1 = conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0) def forward(self, x): x = self.Conv_1x1(x) @@ -98,10 +102,11 @@ def forward(self, x): class single_conv(nn.Module): - def __init__(self, ch_in, ch_out): + def __init__(self, ch_in, ch_out, conv_type: str = "standard"): super(single_conv, self).__init__() + conv2d = get_conv_layer(conv_type) self.conv = nn.Sequential( - nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), + conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True), ) @@ -112,20 +117,21 @@ def forward(self, x): class Attention_block(nn.Module): - def __init__(self, F_g, F_l, F_int): + def __init__(self, F_g, F_l, F_int, conv_type: str = "standard"): super(Attention_block, self).__init__() + conv2d = get_conv_layer(conv_type) self.W_g = nn.Sequential( - nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), + conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm2d(F_int), ) self.W_x = nn.Sequential( - nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), + conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm2d(F_int), ) self.psi = nn.Sequential( - nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), + conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm2d(1), nn.Sigmoid(), ) diff --git a/satflow/models/utils.py b/satflow/models/utils.py new file mode 100644 index 00000000..138ddb68 --- /dev/null +++ b/satflow/models/utils.py @@ -0,0 +1,15 @@ +import torch +from satflow.models.layers import CoordConv + + +def get_conv_layer(conv_type: str = "standard"): + if conv_type == "standard": + conv2d = torch.nn.Conv2d + elif conv_type == "coord": + conv2d = CoordConv + elif conv_type == "antialiased": + # TODO Add anti-aliased coordconv here + conv2d = torch.nn.Conv2d + else: + raise ValueError(f"{conv_type} is not a recognized Conv method") + return conv2d From 479019a185a3b379b685fd8e6a2d2ffa404eb268 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Thu, 22 Jul 2021 11:25:51 +0100 Subject: [PATCH 2/6] Add updated conv layers for most other non-GAN models --- satflow/models/attention_unet.py | 61 ++++++++++++++++--------------- satflow/models/conv_lstm.py | 26 ++++++++++--- satflow/models/layers/ConvLSTM.py | 6 ++- satflow/models/metnet.py | 12 +++--- satflow/models/runet.py | 7 +++- 5 files changed, 68 insertions(+), 44 deletions(-) diff --git a/satflow/models/attention_unet.py b/satflow/models/attention_unet.py index 1bf191eb..327a35ba 100644 --- a/satflow/models/attention_unet.py +++ b/satflow/models/attention_unet.py @@ -16,6 +16,7 @@ def __init__( loss: Union[str, torch.nn.Module] = "mse", lr: float = 0.001, visualize: bool = False, + conv_type: str = "standard", ): super().__init__() self.lr = lr @@ -23,7 +24,9 @@ def __init__( self.input_channels = input_channels self.forecast_steps = forecast_steps self.channels_per_timestep = 12 - self.model = AttU_Net(input_channels=input_channels, output_channels=forecast_steps) + self.model = AttU_Net( + input_channels=input_channels, output_channels=forecast_steps, conv_type=conv_type + ) assert loss in ["mse", "bce", "binary_crossentropy", "crossentropy", "focal"] if loss == "mse": self.criterion = F.mse_loss @@ -197,32 +200,32 @@ def visualize_step(self, x, y, y_hat, batch_idx, step): class AttU_Net(nn.Module): - def __init__(self, input_channels=3, output_channels=1): + def __init__(self, input_channels=3, output_channels=1, conv_type: str = "standard"): super(AttU_Net, self).__init__() self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) - self.Conv1 = conv_block(ch_in=input_channels, ch_out=64) - self.Conv2 = conv_block(ch_in=64, ch_out=128) - self.Conv3 = conv_block(ch_in=128, ch_out=256) - self.Conv4 = conv_block(ch_in=256, ch_out=512) - self.Conv5 = conv_block(ch_in=512, ch_out=1024) + self.Conv1 = conv_block(ch_in=input_channels, ch_out=64, conv_type=conv_type) + self.Conv2 = conv_block(ch_in=64, ch_out=128, conv_type=conv_type) + self.Conv3 = conv_block(ch_in=128, ch_out=256, conv_type=conv_type) + self.Conv4 = conv_block(ch_in=256, ch_out=512, conv_type=conv_type) + self.Conv5 = conv_block(ch_in=512, ch_out=1024, conv_type=conv_type) self.Up5 = up_conv(ch_in=1024, ch_out=512) - self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256) - self.Up_conv5 = conv_block(ch_in=1024, ch_out=512) + self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256, conv_type=conv_type) + self.Up_conv5 = conv_block(ch_in=1024, ch_out=512, conv_type=conv_type) self.Up4 = up_conv(ch_in=512, ch_out=256) - self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128) - self.Up_conv4 = conv_block(ch_in=512, ch_out=256) + self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128, conv_type=conv_type) + self.Up_conv4 = conv_block(ch_in=512, ch_out=256, conv_type=conv_type) self.Up3 = up_conv(ch_in=256, ch_out=128) - self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64) - self.Up_conv3 = conv_block(ch_in=256, ch_out=128) + self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64, conv_type=conv_type) + self.Up_conv3 = conv_block(ch_in=256, ch_out=128, conv_type=conv_type) self.Up2 = up_conv(ch_in=128, ch_out=64) - self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32) - self.Up_conv2 = conv_block(ch_in=128, ch_out=64) + self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32, conv_type=conv_type) + self.Up_conv2 = conv_block(ch_in=128, ch_out=64, conv_type=conv_type) self.Conv_1x1 = nn.Conv2d(64, output_channels, kernel_size=1, stride=1, padding=0) @@ -269,37 +272,37 @@ def forward(self, x): class R2AttU_Net(nn.Module): - def __init__(self, input_channels=3, output_channels=1, t=2): + def __init__(self, input_channels=3, output_channels=1, t=2, conv_type: str = "standard"): super(R2AttU_Net, self).__init__() self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) self.Upsample = nn.Upsample(scale_factor=2) - self.RRCNN1 = RRCNN_block(ch_in=input_channels, ch_out=64, t=t) + self.RRCNN1 = RRCNN_block(ch_in=input_channels, ch_out=64, t=t, conv_type=conv_type) - self.RRCNN2 = RRCNN_block(ch_in=64, ch_out=128, t=t) + self.RRCNN2 = RRCNN_block(ch_in=64, ch_out=128, t=t, conv_type=conv_type) - self.RRCNN3 = RRCNN_block(ch_in=128, ch_out=256, t=t) + self.RRCNN3 = RRCNN_block(ch_in=128, ch_out=256, t=t, conv_type=conv_type) - self.RRCNN4 = RRCNN_block(ch_in=256, ch_out=512, t=t) + self.RRCNN4 = RRCNN_block(ch_in=256, ch_out=512, t=t, conv_type=conv_type) - self.RRCNN5 = RRCNN_block(ch_in=512, ch_out=1024, t=t) + self.RRCNN5 = RRCNN_block(ch_in=512, ch_out=1024, t=t, conv_type=conv_type) self.Up5 = up_conv(ch_in=1024, ch_out=512) - self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256) - self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512, t=t) + self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256, conv_type=conv_type) + self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512, t=t, conv_type=conv_type) self.Up4 = up_conv(ch_in=512, ch_out=256) - self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128) - self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256, t=t) + self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128, conv_type=conv_type) + self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256, t=t, conv_type=conv_type) self.Up3 = up_conv(ch_in=256, ch_out=128) - self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64) - self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128, t=t) + self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64, conv_type=conv_type) + self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128, t=t, conv_type=conv_type) self.Up2 = up_conv(ch_in=128, ch_out=64) - self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32) - self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64, t=t) + self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32, conv_type=conv_type) + self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64, t=t, conv_type=conv_type) self.Conv_1x1 = nn.Conv2d(64, output_channels, kernel_size=1, stride=1, padding=0) diff --git a/satflow/models/conv_lstm.py b/satflow/models/conv_lstm.py index f15d9327..ccc362e1 100644 --- a/satflow/models/conv_lstm.py +++ b/satflow/models/conv_lstm.py @@ -25,6 +25,7 @@ def __init__( visualize: bool = False, loss: Union[str, torch.nn.Module] = "mse", pretrained: bool = False, + conv_type: str = "standard", ): super(EncoderDecoderConvLSTM, self).__init__() self.forecast_steps = forecast_steps @@ -42,7 +43,7 @@ def __init__( raise ValueError(f"loss {loss} not recognized") self.lr = lr self.visualize = visualize - self.module = ConvLSTM(input_channels, hidden_dim, out_channels) + self.model = ConvLSTM(input_channels, hidden_dim, out_channels, conv_type=conv_type) self.save_hyperparameters() @classmethod @@ -56,7 +57,7 @@ def from_config(cls, config): ) def forward(self, x, future_seq=0, hidden_state=None): - return self.module.forward(x, future_seq, hidden_state) + return self.model.forward(x, future_seq, hidden_state) def configure_optimizers(self): # DeepSpeedCPUAdam provides 5x to 7x speedup over torch.optim.adam(w) @@ -120,7 +121,7 @@ def visualize_step(self, x, y, y_hat, batch_idx, step="train"): class ConvLSTM(torch.nn.Module): - def __init__(self, input_channels, hidden_dim, out_channels): + def __init__(self, input_channels, hidden_dim, out_channels, conv_type: str = "standard"): super().__init__() """ ARCHITECTURE @@ -131,11 +132,19 @@ def __init__(self, input_channels, hidden_dim, out_channels): """ self.encoder_1_convlstm = ConvLSTMCell( - input_dim=input_channels, hidden_dim=hidden_dim, kernel_size=(3, 3), bias=True + input_dim=input_channels, + hidden_dim=hidden_dim, + kernel_size=(3, 3), + bias=True, + conv_type=conv_type, ) self.encoder_2_convlstm = ConvLSTMCell( - input_dim=hidden_dim, hidden_dim=hidden_dim, kernel_size=(3, 3), bias=True + input_dim=hidden_dim, + hidden_dim=hidden_dim, + kernel_size=(3, 3), + bias=True, + conv_type=conv_type, ) self.decoder_1_convlstm = ConvLSTMCell( @@ -143,10 +152,15 @@ def __init__(self, input_channels, hidden_dim, out_channels): hidden_dim=hidden_dim, kernel_size=(3, 3), bias=True, # nf + 1 + conv_type=conv_type, ) self.decoder_2_convlstm = ConvLSTMCell( - input_dim=hidden_dim, hidden_dim=hidden_dim, kernel_size=(3, 3), bias=True + input_dim=hidden_dim, + hidden_dim=hidden_dim, + kernel_size=(3, 3), + bias=True, + conv_type=conv_type, ) self.decoder_CNN = nn.Conv3d( diff --git a/satflow/models/layers/ConvLSTM.py b/satflow/models/layers/ConvLSTM.py index e74df844..16a70ab0 100644 --- a/satflow/models/layers/ConvLSTM.py +++ b/satflow/models/layers/ConvLSTM.py @@ -1,9 +1,10 @@ import torch import torch.nn as nn +from satflow.models.utils import get_conv_layer class ConvLSTMCell(nn.Module): - def __init__(self, input_dim, hidden_dim, kernel_size, bias): + def __init__(self, input_dim, hidden_dim, kernel_size, bias, conv_type: str = "standard"): """ Initialize ConvLSTM cell. @@ -27,8 +28,9 @@ def __init__(self, input_dim, hidden_dim, kernel_size, bias): self.kernel_size = kernel_size self.padding = kernel_size[0] // 2, kernel_size[1] // 2 self.bias = bias + conv2d = get_conv_layer(conv_type) - self.conv = nn.Conv2d( + self.conv = conv2d( in_channels=self.input_dim + self.hidden_dim, out_channels=4 * self.hidden_dim, kernel_size=self.kernel_size, diff --git a/satflow/models/metnet.py b/satflow/models/metnet.py index 8285862a..786442f9 100644 --- a/satflow/models/metnet.py +++ b/satflow/models/metnet.py @@ -6,6 +6,7 @@ from satflow.models.base import register_model from satflow.models.layers import ConvGRU, TimeDistributed +from satflow.models.utils import get_conv_layer from axial_attention import AxialAttention from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR import numpy as np @@ -13,17 +14,18 @@ class DownSampler(nn.Module): - def __init__(self, in_channels): + def __init__(self, in_channels, conv_type: str = "standard"): super().__init__() + conv2d = get_conv_layer(conv_type=conv_type) self.module = nn.Sequential( - nn.Conv2d(in_channels, 160, 3, padding=1), + conv2d(in_channels, 160, 3, padding=1), nn.MaxPool2d((2, 2), stride=2), nn.BatchNorm2d(160), - nn.Conv2d(160, 256, 3, padding=1), + conv2d(160, 256, 3, padding=1), nn.BatchNorm2d(256), - nn.Conv2d(256, 256, 3, padding=1), + conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), - nn.Conv2d(256, 256, 3, padding=1), + conv2d(256, 256, 3, padding=1), nn.MaxPool2d((2, 2), stride=2), ) diff --git a/satflow/models/runet.py b/satflow/models/runet.py index dd3aacf3..15534bcd 100644 --- a/satflow/models/runet.py +++ b/satflow/models/runet.py @@ -17,10 +17,13 @@ def __init__( loss: Union[str, torch.nn.Module] = "mse", lr: float = 0.001, visualize: bool = False, + conv_type: str = "standard", ): self.input_channels = input_channels self.forecast_steps = forecast_steps - self.module = R2U_Net(img_ch=input_channels, output_ch=forecast_steps, t=recurrent_steps) + self.module = R2U_Net( + img_ch=input_channels, output_ch=forecast_steps, t=recurrent_steps, conv_type=conv_type + ) super().__init__() self.lr = lr self.input_channels = input_channels @@ -111,7 +114,7 @@ def visualize_step(self, x, y, y_hat, batch_idx, step="train"): class R2U_Net(nn.Module): - def __init__(self, img_ch=3, output_ch=1, t=2): + def __init__(self, img_ch=3, output_ch=1, t=2, conv_type: str = "standard"): super(R2U_Net, self).__init__() self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) From 5b6c9fa094c41746ab688d732ca4582d6f16cac9 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Thu, 22 Jul 2021 11:58:04 +0100 Subject: [PATCH 3/6] Add antialiased blocks to GAN generators Only change ones that have a stride of 2, like in the directions in the repo --- satflow/models/gan/discriminators.py | 54 ++++++++++++++++++++-------- satflow/models/gan/generators.py | 52 ++++++++++++++++++++++----- 2 files changed, 82 insertions(+), 24 deletions(-) diff --git a/satflow/models/gan/discriminators.py b/satflow/models/gan/discriminators.py index b45b41b3..07b5e16a 100644 --- a/satflow/models/gan/discriminators.py +++ b/satflow/models/gan/discriminators.py @@ -3,6 +3,7 @@ from torch import nn as nn from satflow.models.utils import get_conv_layer from satflow.models.gan.common import get_norm_layer, init_net +import antialiased_cnns def define_discriminator( @@ -168,18 +169,34 @@ def __init__( for n in range(1, n_layers): # gradually increase the number of filters nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) - sequence += [ - conv2d( - ndf * nf_mult_prev, - ndf * nf_mult, - kernel_size=kw, - stride=2, - padding=padw, - bias=use_bias, - ), - norm_layer(ndf * nf_mult), - nn.LeakyReLU(0.2, True), - ] + if conv_type == "antialiased": + block = [ + conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=1, + padding=padw, + bias=use_bias, + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + antialiased_cnns.BlurPool(ndf * nf_mult, stride=2), + ] + else: + block = [ + conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=2, + padding=padw, + bias=use_bias, + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + sequence += block nf_mult_prev = nf_mult nf_mult = min(2 ** n_layers, 8) @@ -244,16 +261,23 @@ def forward(self, input): class CloudGANBlock(nn.Module): - def __init__(self, input_channels, conv2d: torch.nn.Module): + def __init__(self, input_channels, conv_type: str = "standard"): super().__init__() + conv2d = get_conv_layer(conv_type) self.conv = conv2d(input_channels, input_channels * 2, kernel_size=(3, 3)) self.relu = torch.nn.ReLU() - self.pool = torch.nn.MaxPool2d(kernel_size=(2, 2)) + if conv_type == "antialiased": + self.pool = torch.nn.MaxPool2d(kernel_size=(2, 2), stride=1) + self.blurpool = antialiased_cnns.BlurPool(input_channels * 2, stride=2) + else: + self.pool = torch.nn.MaxPool2d(kernel_size=(2, 2), stride=2) + self.blurpool = torch.nn.Identity() def forward(self, x): x = self.conv(x) x = self.relu(x) x = self.pool(x) + x = self.blurpool(x) return x @@ -272,7 +296,7 @@ def __init__( self.conv_1 = conv2d(input_channels, num_filters, kernel_size=1, stride=1, padding=0) self.stages = [] for stage in range(num_stages): - self.stages.append(CloudGANBlock(num_filters, conv2d)) + self.stages.append(CloudGANBlock(num_filters, conv_type)) num_filters = num_filters * 2 self.stages = torch.nn.Sequential(*self.stages) self.flatten = torch.nn.Flatten() diff --git a/satflow/models/gan/generators.py b/satflow/models/gan/generators.py index 941283f6..16d8c076 100644 --- a/satflow/models/gan/generators.py +++ b/satflow/models/gan/generators.py @@ -5,6 +5,7 @@ from typing import Union from satflow.models.gan.common import get_norm_layer, init_net from satflow.models.utils import get_conv_layer +import antialiased_cnns def define_G( @@ -112,13 +113,35 @@ def __init__( n_downsampling = 2 for i in range(n_downsampling): # add downsampling layers mult = 2 ** i - model += [ - conv2d( - ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias - ), - norm_layer(ngf * mult * 2), - nn.ReLU(True), - ] + if conv_type == "antialiased": + block = [ + conv2d( + ngf * mult, + ngf * mult * 2, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + ), + norm_layer(ngf * mult * 2), + nn.ReLU(True), + antialiased_cnns.BlurPool(ngf * mult * 2, stride=2), + ] + else: + block = [ + conv2d( + ngf * mult, + ngf * mult * 2, + kernel_size=3, + stride=2, + padding=1, + bias=use_bias, + ), + norm_layer(ngf * mult * 2), + nn.ReLU(True), + ] + + model += block mult = 2 ** n_downsampling for i in range(n_blocks): # add ResNet blocks @@ -331,7 +354,7 @@ def __init__( innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False, - conv2d=torch.nn.Module, + conv_type: str = "standard", ): """Construct a Unet submodule with skip connections. @@ -353,7 +376,18 @@ def __init__( use_bias = norm_layer == nn.InstanceNorm2d if input_nc is None: input_nc = outer_nc - downconv = conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) + conv2d = get_conv_layer(conv_type) + if conv_type == "antialiased": + antialiased = True + downconv = conv2d( + input_nc, inner_nc, kernel_size=4, stride=1, padding=1, bias=use_bias + ) + blurpool = antialiased_cnns.BlurPool(inner_nc, stride=2) + else: + antialiased = False + downconv = conv2d( + input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias + ) downrelu = nn.LeakyReLU(0.2, True) downnorm = norm_layer(inner_nc) uprelu = nn.ReLU(True) From bbb662d1bec0766e5ee308180304da355311a1e5 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Thu, 22 Jul 2021 12:10:32 +0100 Subject: [PATCH 4/6] Add antialiased operations to MetNet and the Unet for GANs --- satflow/models/gan/generators.py | 26 +++++++++++++++++--------- satflow/models/metnet.py | 12 ++++++++++-- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/satflow/models/gan/generators.py b/satflow/models/gan/generators.py index 16d8c076..e600a002 100644 --- a/satflow/models/gan/generators.py +++ b/satflow/models/gan/generators.py @@ -283,7 +283,6 @@ def __init__( """ super(UnetGenerator, self).__init__() # construct unet structure - conv2d = get_conv_layer(conv_type) unet_block = UnetSkipConnectionBlock( ngf * 8, ngf * 8, @@ -291,7 +290,7 @@ def __init__( submodule=None, norm_layer=norm_layer, innermost=True, - conv2d=conv2d, + conv_type=conv_type, ) # add the innermost layer for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters unet_block = UnetSkipConnectionBlock( @@ -301,7 +300,7 @@ def __init__( submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout, - conv2d=conv2d, + conv_type=conv_type, ) # gradually reduce the number of filters from ngf * 8 to ngf unet_block = UnetSkipConnectionBlock( @@ -310,7 +309,7 @@ def __init__( input_nc=None, submodule=unet_block, norm_layer=norm_layer, - conv2d=conv2d, + conv_type=conv_type, ) unet_block = UnetSkipConnectionBlock( ngf * 2, @@ -318,10 +317,15 @@ def __init__( input_nc=None, submodule=unet_block, norm_layer=norm_layer, - conv2d=conv2d, + conv_type=conv_type, ) unet_block = UnetSkipConnectionBlock( - ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer, conv2d=conv2d + ngf, + ngf * 2, + input_nc=None, + submodule=unet_block, + norm_layer=norm_layer, + conv_type=conv_type, ) self.model = UnetSkipConnectionBlock( output_nc, @@ -330,7 +334,7 @@ def __init__( submodule=unet_block, outermost=True, norm_layer=norm_layer, - conv2d=conv2d, + conv_type=conv_type, ) # add the outermost layer def forward(self, input): @@ -402,14 +406,18 @@ def __init__( upconv = nn.ConvTranspose2d( inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias ) - down = [downrelu, downconv] + down = [downrelu, downconv, blurpool] if antialiased else [downrelu, downconv] up = [uprelu, upconv, upnorm] model = down + up else: upconv = nn.ConvTranspose2d( inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias ) - down = [downrelu, downconv, downnorm] + down = ( + [downrelu, downconv, downnorm, blurpool] + if antialiased + else [downrelu, downconv, downnorm] + ) up = [uprelu, upconv, upnorm] if use_dropout: diff --git a/satflow/models/metnet.py b/satflow/models/metnet.py index 786442f9..c5d5c57e 100644 --- a/satflow/models/metnet.py +++ b/satflow/models/metnet.py @@ -11,22 +11,30 @@ from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR import numpy as np import torchvision +import antialiased_cnns class DownSampler(nn.Module): def __init__(self, in_channels, conv_type: str = "standard"): super().__init__() conv2d = get_conv_layer(conv_type=conv_type) + if conv_type == "antialiased": + antialiased = True + else: + antialiased = False + self.module = nn.Sequential( conv2d(in_channels, 160, 3, padding=1), - nn.MaxPool2d((2, 2), stride=2), + nn.MaxPool2d((2, 2), stride=1 if antialiased else 2), + antialiased_cnns.BlurPool(160, stride=2) if antialiased else nn.Identity(), nn.BatchNorm2d(160), conv2d(160, 256, 3, padding=1), nn.BatchNorm2d(256), conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), conv2d(256, 256, 3, padding=1), - nn.MaxPool2d((2, 2), stride=2), + nn.MaxPool2d((2, 2), stride=1 if antialiased else 2), + antialiased_cnns.BlurPool(256, stride=2) if antialiased else nn.Identity(), ) def forward(self, x): From bb885f09d219e19b53736bfe65aa9fc16cca0ec0 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Thu, 22 Jul 2021 12:24:42 +0100 Subject: [PATCH 5/6] Add antialiased for RUnet --- satflow/models/layers/RUnetLayers.py | 5 +++- satflow/models/runet.py | 42 ++++++++++++++++++++-------- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/satflow/models/layers/RUnetLayers.py b/satflow/models/layers/RUnetLayers.py index 56b75551..7ebbf02c 100644 --- a/satflow/models/layers/RUnetLayers.py +++ b/satflow/models/layers/RUnetLayers.py @@ -92,7 +92,10 @@ class RRCNN_block(nn.Module): def __init__(self, ch_in, ch_out, t=2, conv_type: str = "standard"): super(RRCNN_block, self).__init__() conv2d = get_conv_layer(conv_type) - self.RCNN = nn.Sequential(Recurrent_block(ch_out, t=t), Recurrent_block(ch_out, t=t)) + self.RCNN = nn.Sequential( + Recurrent_block(ch_out, t=t, conv_type=conv_type), + Recurrent_block(ch_out, t=t, conv_type=conv_type), + ) self.Conv_1x1 = conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0) def forward(self, x): diff --git a/satflow/models/runet.py b/satflow/models/runet.py index 15534bcd..d2a3e7ce 100644 --- a/satflow/models/runet.py +++ b/satflow/models/runet.py @@ -1,3 +1,5 @@ +import antialiased_cnns + from satflow.models.layers.RUnetLayers import * import pytorch_lightning as pl import torchvision @@ -116,31 +118,45 @@ def visualize_step(self, x, y, y_hat, batch_idx, step="train"): class R2U_Net(nn.Module): def __init__(self, img_ch=3, output_ch=1, t=2, conv_type: str = "standard"): super(R2U_Net, self).__init__() + if conv_type == "antialiased": + self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=1) + self.antialiased = True + else: + self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) + self.antialiased = False - self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) self.Upsample = nn.Upsample(scale_factor=2) - self.RRCNN1 = RRCNN_block(ch_in=img_ch, ch_out=64, t=t) - - self.RRCNN2 = RRCNN_block(ch_in=64, ch_out=128, t=t) + self.RRCNN1 = RRCNN_block(ch_in=img_ch, ch_out=64, t=t, conv_type=conv_type) + self.Blur1 = antialiased_cnns.BlurPool(64, stride=2) if self.antialiased else nn.Identity() + self.RRCNN2 = RRCNN_block(ch_in=64, ch_out=128, t=t, conv_type=conv_type) + self.Blur2 = ( + antialiased_cnns.BlurPool(128, stride=2) if self.antialiased else nn.Identity() + ) - self.RRCNN3 = RRCNN_block(ch_in=128, ch_out=256, t=t) + self.RRCNN3 = RRCNN_block(ch_in=128, ch_out=256, t=t, conv_type=conv_type) + self.Blur3 = ( + antialiased_cnns.BlurPool(256, stride=2) if self.antialiased else nn.Identity() + ) - self.RRCNN4 = RRCNN_block(ch_in=256, ch_out=512, t=t) + self.RRCNN4 = RRCNN_block(ch_in=256, ch_out=512, t=t, conv_type=conv_type) + self.Blur4 = ( + antialiased_cnns.BlurPool(512, stride=2) if self.antialiased else nn.Identity() + ) - self.RRCNN5 = RRCNN_block(ch_in=512, ch_out=1024, t=t) + self.RRCNN5 = RRCNN_block(ch_in=512, ch_out=1024, t=t, conv_type=conv_type) self.Up5 = up_conv(ch_in=1024, ch_out=512) - self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512, t=t) + self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512, t=t, conv_type=conv_type) self.Up4 = up_conv(ch_in=512, ch_out=256) - self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256, t=t) + self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256, t=t, conv_type=conv_type) self.Up3 = up_conv(ch_in=256, ch_out=128) - self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128, t=t) + self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128, t=t, conv_type=conv_type) self.Up2 = up_conv(ch_in=128, ch_out=64) - self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64, t=t) + self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64, t=t, conv_type=conv_type) self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0) @@ -149,15 +165,19 @@ def forward(self, x): x1 = self.RRCNN1(x) x2 = self.Maxpool(x1) + x2 = self.Blur1(x2) x2 = self.RRCNN2(x2) x3 = self.Maxpool(x2) + x3 = self.Blur2(x3) x3 = self.RRCNN3(x3) x4 = self.Maxpool(x3) + x4 = self.Blur3(x4) x4 = self.RRCNN4(x4) x5 = self.Maxpool(x4) + x5 = self.Blur4(x5) x5 = self.RRCNN5(x5) # decoding + concat path From 69a8abfd7480864c79a1cfa04bb39566509dc542 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Thu, 22 Jul 2021 12:32:29 +0100 Subject: [PATCH 6/6] Fix import --- satflow/models/runet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/satflow/models/runet.py b/satflow/models/runet.py index d2a3e7ce..5d698991 100644 --- a/satflow/models/runet.py +++ b/satflow/models/runet.py @@ -1,5 +1,5 @@ import antialiased_cnns - +import torch.nn.functional as F from satflow.models.layers.RUnetLayers import * import pytorch_lightning as pl import torchvision