diff --git a/satflow/models/attention_unet.py b/satflow/models/attention_unet.py index 1bf191eb..327a35ba 100644 --- a/satflow/models/attention_unet.py +++ b/satflow/models/attention_unet.py @@ -16,6 +16,7 @@ def __init__( loss: Union[str, torch.nn.Module] = "mse", lr: float = 0.001, visualize: bool = False, + conv_type: str = "standard", ): super().__init__() self.lr = lr @@ -23,7 +24,9 @@ def __init__( self.input_channels = input_channels self.forecast_steps = forecast_steps self.channels_per_timestep = 12 - self.model = AttU_Net(input_channels=input_channels, output_channels=forecast_steps) + self.model = AttU_Net( + input_channels=input_channels, output_channels=forecast_steps, conv_type=conv_type + ) assert loss in ["mse", "bce", "binary_crossentropy", "crossentropy", "focal"] if loss == "mse": self.criterion = F.mse_loss @@ -197,32 +200,32 @@ def visualize_step(self, x, y, y_hat, batch_idx, step): class AttU_Net(nn.Module): - def __init__(self, input_channels=3, output_channels=1): + def __init__(self, input_channels=3, output_channels=1, conv_type: str = "standard"): super(AttU_Net, self).__init__() self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) - self.Conv1 = conv_block(ch_in=input_channels, ch_out=64) - self.Conv2 = conv_block(ch_in=64, ch_out=128) - self.Conv3 = conv_block(ch_in=128, ch_out=256) - self.Conv4 = conv_block(ch_in=256, ch_out=512) - self.Conv5 = conv_block(ch_in=512, ch_out=1024) + self.Conv1 = conv_block(ch_in=input_channels, ch_out=64, conv_type=conv_type) + self.Conv2 = conv_block(ch_in=64, ch_out=128, conv_type=conv_type) + self.Conv3 = conv_block(ch_in=128, ch_out=256, conv_type=conv_type) + self.Conv4 = conv_block(ch_in=256, ch_out=512, conv_type=conv_type) + self.Conv5 = conv_block(ch_in=512, ch_out=1024, conv_type=conv_type) self.Up5 = up_conv(ch_in=1024, ch_out=512) - self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256) - self.Up_conv5 = conv_block(ch_in=1024, ch_out=512) + self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256, conv_type=conv_type) + self.Up_conv5 = conv_block(ch_in=1024, ch_out=512, conv_type=conv_type) self.Up4 = up_conv(ch_in=512, ch_out=256) - self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128) - self.Up_conv4 = conv_block(ch_in=512, ch_out=256) + self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128, conv_type=conv_type) + self.Up_conv4 = conv_block(ch_in=512, ch_out=256, conv_type=conv_type) self.Up3 = up_conv(ch_in=256, ch_out=128) - self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64) - self.Up_conv3 = conv_block(ch_in=256, ch_out=128) + self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64, conv_type=conv_type) + self.Up_conv3 = conv_block(ch_in=256, ch_out=128, conv_type=conv_type) self.Up2 = up_conv(ch_in=128, ch_out=64) - self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32) - self.Up_conv2 = conv_block(ch_in=128, ch_out=64) + self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32, conv_type=conv_type) + self.Up_conv2 = conv_block(ch_in=128, ch_out=64, conv_type=conv_type) self.Conv_1x1 = nn.Conv2d(64, output_channels, kernel_size=1, stride=1, padding=0) @@ -269,37 +272,37 @@ def forward(self, x): class R2AttU_Net(nn.Module): - def __init__(self, input_channels=3, output_channels=1, t=2): + def __init__(self, input_channels=3, output_channels=1, t=2, conv_type: str = "standard"): super(R2AttU_Net, self).__init__() self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) self.Upsample = nn.Upsample(scale_factor=2) - self.RRCNN1 = RRCNN_block(ch_in=input_channels, ch_out=64, t=t) + self.RRCNN1 = RRCNN_block(ch_in=input_channels, ch_out=64, t=t, conv_type=conv_type) - self.RRCNN2 = RRCNN_block(ch_in=64, ch_out=128, t=t) + self.RRCNN2 = RRCNN_block(ch_in=64, ch_out=128, t=t, conv_type=conv_type) - self.RRCNN3 = RRCNN_block(ch_in=128, ch_out=256, t=t) + self.RRCNN3 = RRCNN_block(ch_in=128, ch_out=256, t=t, conv_type=conv_type) - self.RRCNN4 = RRCNN_block(ch_in=256, ch_out=512, t=t) + self.RRCNN4 = RRCNN_block(ch_in=256, ch_out=512, t=t, conv_type=conv_type) - self.RRCNN5 = RRCNN_block(ch_in=512, ch_out=1024, t=t) + self.RRCNN5 = RRCNN_block(ch_in=512, ch_out=1024, t=t, conv_type=conv_type) self.Up5 = up_conv(ch_in=1024, ch_out=512) - self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256) - self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512, t=t) + self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256, conv_type=conv_type) + self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512, t=t, conv_type=conv_type) self.Up4 = up_conv(ch_in=512, ch_out=256) - self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128) - self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256, t=t) + self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128, conv_type=conv_type) + self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256, t=t, conv_type=conv_type) self.Up3 = up_conv(ch_in=256, ch_out=128) - self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64) - self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128, t=t) + self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64, conv_type=conv_type) + self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128, t=t, conv_type=conv_type) self.Up2 = up_conv(ch_in=128, ch_out=64) - self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32) - self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64, t=t) + self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32, conv_type=conv_type) + self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64, t=t, conv_type=conv_type) self.Conv_1x1 = nn.Conv2d(64, output_channels, kernel_size=1, stride=1, padding=0) diff --git a/satflow/models/conv_lstm.py b/satflow/models/conv_lstm.py index f15d9327..ccc362e1 100644 --- a/satflow/models/conv_lstm.py +++ b/satflow/models/conv_lstm.py @@ -25,6 +25,7 @@ def __init__( visualize: bool = False, loss: Union[str, torch.nn.Module] = "mse", pretrained: bool = False, + conv_type: str = "standard", ): super(EncoderDecoderConvLSTM, self).__init__() self.forecast_steps = forecast_steps @@ -42,7 +43,7 @@ def __init__( raise ValueError(f"loss {loss} not recognized") self.lr = lr self.visualize = visualize - self.module = ConvLSTM(input_channels, hidden_dim, out_channels) + self.model = ConvLSTM(input_channels, hidden_dim, out_channels, conv_type=conv_type) self.save_hyperparameters() @classmethod @@ -56,7 +57,7 @@ def from_config(cls, config): ) def forward(self, x, future_seq=0, hidden_state=None): - return self.module.forward(x, future_seq, hidden_state) + return self.model.forward(x, future_seq, hidden_state) def configure_optimizers(self): # DeepSpeedCPUAdam provides 5x to 7x speedup over torch.optim.adam(w) @@ -120,7 +121,7 @@ def visualize_step(self, x, y, y_hat, batch_idx, step="train"): class ConvLSTM(torch.nn.Module): - def __init__(self, input_channels, hidden_dim, out_channels): + def __init__(self, input_channels, hidden_dim, out_channels, conv_type: str = "standard"): super().__init__() """ ARCHITECTURE @@ -131,11 +132,19 @@ def __init__(self, input_channels, hidden_dim, out_channels): """ self.encoder_1_convlstm = ConvLSTMCell( - input_dim=input_channels, hidden_dim=hidden_dim, kernel_size=(3, 3), bias=True + input_dim=input_channels, + hidden_dim=hidden_dim, + kernel_size=(3, 3), + bias=True, + conv_type=conv_type, ) self.encoder_2_convlstm = ConvLSTMCell( - input_dim=hidden_dim, hidden_dim=hidden_dim, kernel_size=(3, 3), bias=True + input_dim=hidden_dim, + hidden_dim=hidden_dim, + kernel_size=(3, 3), + bias=True, + conv_type=conv_type, ) self.decoder_1_convlstm = ConvLSTMCell( @@ -143,10 +152,15 @@ def __init__(self, input_channels, hidden_dim, out_channels): hidden_dim=hidden_dim, kernel_size=(3, 3), bias=True, # nf + 1 + conv_type=conv_type, ) self.decoder_2_convlstm = ConvLSTMCell( - input_dim=hidden_dim, hidden_dim=hidden_dim, kernel_size=(3, 3), bias=True + input_dim=hidden_dim, + hidden_dim=hidden_dim, + kernel_size=(3, 3), + bias=True, + conv_type=conv_type, ) self.decoder_CNN = nn.Conv3d( diff --git a/satflow/models/gan/discriminators.py b/satflow/models/gan/discriminators.py index bd298554..07b5e16a 100644 --- a/satflow/models/gan/discriminators.py +++ b/satflow/models/gan/discriminators.py @@ -1,11 +1,21 @@ import functools import torch from torch import nn as nn - +from satflow.models.utils import get_conv_layer from satflow.models.gan.common import get_norm_layer, init_net - - -def define_D(input_nc, ndf, netD, n_layers_D=3, norm="batch", init_type="normal", init_gain=0.02): +import antialiased_cnns + + +def define_discriminator( + input_nc, + ndf, + netD, + n_layers_D=3, + norm="batch", + init_type="normal", + init_gain=0.02, + conv_type: str = "standard", +): """Create a discriminator Parameters: @@ -36,15 +46,20 @@ def define_D(input_nc, ndf, netD, n_layers_D=3, norm="batch", init_type="normal" """ net = None norm_layer = get_norm_layer(norm_type=norm) - if netD == "basic": # default PatchGAN classifier - net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer) + net = NLayerDiscriminator( + input_nc, ndf, n_layers=3, norm_layer=norm_layer, conv_type=conv_type + ) elif netD == "n_layers": # more options - net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer) + net = NLayerDiscriminator( + input_nc, ndf, n_layers_D, norm_layer=norm_layer, conv_type=conv_type + ) elif netD == "pixel": # classify if each pixel is real or fake - net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) + net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, conv_type=conv_type) elif netD == "enhanced": - net = CloudGANDiscriminator(input_channels=input_nc, num_filters=ndf, num_stages=3) + net = CloudGANDiscriminator( + input_channels=input_nc, num_filters=ndf, num_stages=3, conv_type=conv_type + ) else: raise NotImplementedError("Discriminator model name [%s] is not recognized" % netD) return init_net(net, init_type, init_gain) @@ -122,7 +137,9 @@ def __call__(self, prediction, target_is_real): class NLayerDiscriminator(nn.Module): """Defines a PatchGAN discriminator""" - def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): + def __init__( + self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, conv_type: str = "standard" + ): """Construct a PatchGAN discriminator Parameters: @@ -139,10 +156,12 @@ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): else: use_bias = norm_layer == nn.InstanceNorm2d + conv2d = get_conv_layer(conv_type) + kw = 4 padw = 1 sequence = [ - nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True), ] nf_mult = 1 @@ -150,23 +169,39 @@ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): for n in range(1, n_layers): # gradually increase the number of filters nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) - sequence += [ - nn.Conv2d( - ndf * nf_mult_prev, - ndf * nf_mult, - kernel_size=kw, - stride=2, - padding=padw, - bias=use_bias, - ), - norm_layer(ndf * nf_mult), - nn.LeakyReLU(0.2, True), - ] + if conv_type == "antialiased": + block = [ + conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=1, + padding=padw, + bias=use_bias, + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + antialiased_cnns.BlurPool(ndf * nf_mult, stride=2), + ] + else: + block = [ + conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=2, + padding=padw, + bias=use_bias, + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + sequence += block nf_mult_prev = nf_mult nf_mult = min(2 ** n_layers, 8) sequence += [ - nn.Conv2d( + conv2d( ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, @@ -179,7 +214,7 @@ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): ] sequence += [ - nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) + conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) ] # output 1 channel prediction map self.model = nn.Sequential(*sequence) @@ -191,7 +226,7 @@ def forward(self, input): class PixelDiscriminator(nn.Module): """Defines a 1x1 PatchGAN discriminator (pixelGAN)""" - def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): + def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, conv_type: str = "standard"): """Construct a 1x1 PatchGAN discriminator Parameters: @@ -207,13 +242,15 @@ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): else: use_bias = norm_layer == nn.InstanceNorm2d + conv2d = get_conv_layer(conv_type) + self.net = [ - nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), + conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), nn.LeakyReLU(0.2, True), - nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), + conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), norm_layer(ndf * 2), nn.LeakyReLU(0.2, True), - nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias), + conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias), ] self.net = nn.Sequential(*self.net) @@ -224,30 +261,42 @@ def forward(self, input): class CloudGANBlock(nn.Module): - def __init__(self, input_channels): + def __init__(self, input_channels, conv_type: str = "standard"): super().__init__() - self.conv = torch.nn.Conv2d(input_channels, input_channels * 2, kernel_size=(3, 3)) + conv2d = get_conv_layer(conv_type) + self.conv = conv2d(input_channels, input_channels * 2, kernel_size=(3, 3)) self.relu = torch.nn.ReLU() - self.pool = torch.nn.MaxPool2d(kernel_size=(2, 2)) + if conv_type == "antialiased": + self.pool = torch.nn.MaxPool2d(kernel_size=(2, 2), stride=1) + self.blurpool = antialiased_cnns.BlurPool(input_channels * 2, stride=2) + else: + self.pool = torch.nn.MaxPool2d(kernel_size=(2, 2), stride=2) + self.blurpool = torch.nn.Identity() def forward(self, x): x = self.conv(x) x = self.relu(x) x = self.pool(x) + x = self.blurpool(x) return x class CloudGANDiscriminator(nn.Module): """Defines a discriminator based off https://www.climatechange.ai/papers/icml2021/54/slides.pdf""" - def __init__(self, input_channels: int = 12, num_filters: int = 64, num_stages: int = 3): + def __init__( + self, + input_channels: int = 12, + num_filters: int = 64, + num_stages: int = 3, + conv_type: str = "standard", + ): super().__init__() - self.conv_1 = torch.nn.Conv2d( - input_channels, num_filters, kernel_size=1, stride=1, padding=0 - ) + conv2d = get_conv_layer(conv_type) + self.conv_1 = conv2d(input_channels, num_filters, kernel_size=1, stride=1, padding=0) self.stages = [] for stage in range(num_stages): - self.stages.append(CloudGANBlock(num_filters)) + self.stages.append(CloudGANBlock(num_filters, conv_type)) num_filters = num_filters * 2 self.stages = torch.nn.Sequential(*self.stages) self.flatten = torch.nn.Flatten() diff --git a/satflow/models/gan/generators.py b/satflow/models/gan/generators.py index 3d230fea..e600a002 100644 --- a/satflow/models/gan/generators.py +++ b/satflow/models/gan/generators.py @@ -4,6 +4,8 @@ from torch import nn as nn from typing import Union from satflow.models.gan.common import get_norm_layer, init_net +from satflow.models.utils import get_conv_layer +import antialiased_cnns def define_G( @@ -81,6 +83,7 @@ def __init__( use_dropout=False, n_blocks=6, padding_type="reflect", + conv_type: str = "standard", ): """Construct a Resnet-based generator @@ -99,10 +102,10 @@ def __init__( use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d - + conv2d = get_conv_layer(conv_type) model = [ nn.ReflectionPad2d(3), - nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), + conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), norm_layer(ngf), nn.ReLU(True), ] @@ -110,13 +113,35 @@ def __init__( n_downsampling = 2 for i in range(n_downsampling): # add downsampling layers mult = 2 ** i - model += [ - nn.Conv2d( - ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias - ), - norm_layer(ngf * mult * 2), - nn.ReLU(True), - ] + if conv_type == "antialiased": + block = [ + conv2d( + ngf * mult, + ngf * mult * 2, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + ), + norm_layer(ngf * mult * 2), + nn.ReLU(True), + antialiased_cnns.BlurPool(ngf * mult * 2, stride=2), + ] + else: + block = [ + conv2d( + ngf * mult, + ngf * mult * 2, + kernel_size=3, + stride=2, + padding=1, + bias=use_bias, + ), + norm_layer(ngf * mult * 2), + nn.ReLU(True), + ] + + model += block mult = 2 ** n_downsampling for i in range(n_blocks): # add ResNet blocks @@ -160,7 +185,9 @@ def forward(self, input): class ResnetBlock(nn.Module): """Define a Resnet block""" - def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): + def __init__( + self, dim, padding_type, norm_layer, use_dropout, use_bias, conv_type: str = "standard" + ): """Initialize the Resnet block A resnet block is a conv block with skip connections @@ -169,11 +196,14 @@ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf """ super(ResnetBlock, self).__init__() + conv2d = get_conv_layer(conv_type) self.conv_block = self.build_conv_block( - dim, padding_type, norm_layer, use_dropout, use_bias + dim, padding_type, norm_layer, use_dropout, use_bias, conv2d ) - def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): + def build_conv_block( + self, dim, padding_type, norm_layer, use_dropout, use_bias, conv2d: torch.nn.Module + ): """Construct a convolutional block. Parameters: @@ -197,7 +227,7 @@ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias) raise NotImplementedError("padding [%s] is not implemented" % padding_type) conv_block += [ - nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), + conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True), ] @@ -214,7 +244,7 @@ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias) else: raise NotImplementedError("padding [%s] is not implemented" % padding_type) conv_block += [ - nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), + conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), ] @@ -230,7 +260,14 @@ class UnetGenerator(nn.Module): """Create a Unet-based generator""" def __init__( - self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False + self, + input_nc, + output_nc, + num_downs, + ngf=64, + norm_layer=nn.BatchNorm2d, + use_dropout=False, + conv_type: str = "standard", ): """Construct a Unet generator Parameters: @@ -247,7 +284,13 @@ def __init__( super(UnetGenerator, self).__init__() # construct unet structure unet_block = UnetSkipConnectionBlock( - ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True + ngf * 8, + ngf * 8, + input_nc=None, + submodule=None, + norm_layer=norm_layer, + innermost=True, + conv_type=conv_type, ) # add the innermost layer for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters unet_block = UnetSkipConnectionBlock( @@ -257,16 +300,32 @@ def __init__( submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout, + conv_type=conv_type, ) # gradually reduce the number of filters from ngf * 8 to ngf unet_block = UnetSkipConnectionBlock( - ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer + ngf * 4, + ngf * 8, + input_nc=None, + submodule=unet_block, + norm_layer=norm_layer, + conv_type=conv_type, ) unet_block = UnetSkipConnectionBlock( - ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer + ngf * 2, + ngf * 4, + input_nc=None, + submodule=unet_block, + norm_layer=norm_layer, + conv_type=conv_type, ) unet_block = UnetSkipConnectionBlock( - ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer + ngf, + ngf * 2, + input_nc=None, + submodule=unet_block, + norm_layer=norm_layer, + conv_type=conv_type, ) self.model = UnetSkipConnectionBlock( output_nc, @@ -275,6 +334,7 @@ def __init__( submodule=unet_block, outermost=True, norm_layer=norm_layer, + conv_type=conv_type, ) # add the outermost layer def forward(self, input): @@ -298,6 +358,7 @@ def __init__( innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False, + conv_type: str = "standard", ): """Construct a Unet submodule with skip connections. @@ -319,7 +380,18 @@ def __init__( use_bias = norm_layer == nn.InstanceNorm2d if input_nc is None: input_nc = outer_nc - downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) + conv2d = get_conv_layer(conv_type) + if conv_type == "antialiased": + antialiased = True + downconv = conv2d( + input_nc, inner_nc, kernel_size=4, stride=1, padding=1, bias=use_bias + ) + blurpool = antialiased_cnns.BlurPool(inner_nc, stride=2) + else: + antialiased = False + downconv = conv2d( + input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias + ) downrelu = nn.LeakyReLU(0.2, True) downnorm = norm_layer(inner_nc) uprelu = nn.ReLU(True) @@ -334,14 +406,18 @@ def __init__( upconv = nn.ConvTranspose2d( inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias ) - down = [downrelu, downconv] + down = [downrelu, downconv, blurpool] if antialiased else [downrelu, downconv] up = [uprelu, upconv, upnorm] model = down + up else: upconv = nn.ConvTranspose2d( inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias ) - down = [downrelu, downconv, downnorm] + down = ( + [downrelu, downconv, downnorm, blurpool] + if antialiased + else [downrelu, downconv, downnorm] + ) up = [uprelu, upconv, upnorm] if use_dropout: diff --git a/satflow/models/layers/ConvLSTM.py b/satflow/models/layers/ConvLSTM.py index e74df844..16a70ab0 100644 --- a/satflow/models/layers/ConvLSTM.py +++ b/satflow/models/layers/ConvLSTM.py @@ -1,9 +1,10 @@ import torch import torch.nn as nn +from satflow.models.utils import get_conv_layer class ConvLSTMCell(nn.Module): - def __init__(self, input_dim, hidden_dim, kernel_size, bias): + def __init__(self, input_dim, hidden_dim, kernel_size, bias, conv_type: str = "standard"): """ Initialize ConvLSTM cell. @@ -27,8 +28,9 @@ def __init__(self, input_dim, hidden_dim, kernel_size, bias): self.kernel_size = kernel_size self.padding = kernel_size[0] // 2, kernel_size[1] // 2 self.bias = bias + conv2d = get_conv_layer(conv_type) - self.conv = nn.Conv2d( + self.conv = conv2d( in_channels=self.input_dim + self.hidden_dim, out_channels=4 * self.hidden_dim, kernel_size=self.kernel_size, diff --git a/satflow/models/layers/RUnetLayers.py b/satflow/models/layers/RUnetLayers.py index 4dbaf856..7ebbf02c 100644 --- a/satflow/models/layers/RUnetLayers.py +++ b/satflow/models/layers/RUnetLayers.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F +from satflow.models.utils import get_conv_layer from torch.nn import init @@ -33,13 +33,14 @@ def init_func(m): class conv_block(nn.Module): - def __init__(self, ch_in, ch_out): + def __init__(self, ch_in, ch_out, conv_type: str = "standard"): super(conv_block, self).__init__() + conv2d = get_conv_layer(conv_type) self.conv = nn.Sequential( - nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), + conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True), - nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), + conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True), ) @@ -50,11 +51,12 @@ def forward(self, x): class up_conv(nn.Module): - def __init__(self, ch_in, ch_out): + def __init__(self, ch_in, ch_out, conv_type: str = "standard"): super(up_conv, self).__init__() + conv2d = get_conv_layer(conv_type) self.up = nn.Sequential( nn.Upsample(scale_factor=2), - nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), + conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True), ) @@ -65,12 +67,13 @@ def forward(self, x): class Recurrent_block(nn.Module): - def __init__(self, ch_out, t=2): + def __init__(self, ch_out, t=2, conv_type: str = "standard"): super(Recurrent_block, self).__init__() + conv2d = get_conv_layer(conv_type) self.t = t self.ch_out = ch_out self.conv = nn.Sequential( - nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), + conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True), ) @@ -86,10 +89,14 @@ def forward(self, x): class RRCNN_block(nn.Module): - def __init__(self, ch_in, ch_out, t=2): + def __init__(self, ch_in, ch_out, t=2, conv_type: str = "standard"): super(RRCNN_block, self).__init__() - self.RCNN = nn.Sequential(Recurrent_block(ch_out, t=t), Recurrent_block(ch_out, t=t)) - self.Conv_1x1 = nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0) + conv2d = get_conv_layer(conv_type) + self.RCNN = nn.Sequential( + Recurrent_block(ch_out, t=t, conv_type=conv_type), + Recurrent_block(ch_out, t=t, conv_type=conv_type), + ) + self.Conv_1x1 = conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0) def forward(self, x): x = self.Conv_1x1(x) @@ -98,10 +105,11 @@ def forward(self, x): class single_conv(nn.Module): - def __init__(self, ch_in, ch_out): + def __init__(self, ch_in, ch_out, conv_type: str = "standard"): super(single_conv, self).__init__() + conv2d = get_conv_layer(conv_type) self.conv = nn.Sequential( - nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), + conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True), ) @@ -112,20 +120,21 @@ def forward(self, x): class Attention_block(nn.Module): - def __init__(self, F_g, F_l, F_int): + def __init__(self, F_g, F_l, F_int, conv_type: str = "standard"): super(Attention_block, self).__init__() + conv2d = get_conv_layer(conv_type) self.W_g = nn.Sequential( - nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), + conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm2d(F_int), ) self.W_x = nn.Sequential( - nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), + conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm2d(F_int), ) self.psi = nn.Sequential( - nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), + conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm2d(1), nn.Sigmoid(), ) diff --git a/satflow/models/metnet.py b/satflow/models/metnet.py index 8285862a..c5d5c57e 100644 --- a/satflow/models/metnet.py +++ b/satflow/models/metnet.py @@ -6,25 +6,35 @@ from satflow.models.base import register_model from satflow.models.layers import ConvGRU, TimeDistributed +from satflow.models.utils import get_conv_layer from axial_attention import AxialAttention from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR import numpy as np import torchvision +import antialiased_cnns class DownSampler(nn.Module): - def __init__(self, in_channels): + def __init__(self, in_channels, conv_type: str = "standard"): super().__init__() + conv2d = get_conv_layer(conv_type=conv_type) + if conv_type == "antialiased": + antialiased = True + else: + antialiased = False + self.module = nn.Sequential( - nn.Conv2d(in_channels, 160, 3, padding=1), - nn.MaxPool2d((2, 2), stride=2), + conv2d(in_channels, 160, 3, padding=1), + nn.MaxPool2d((2, 2), stride=1 if antialiased else 2), + antialiased_cnns.BlurPool(160, stride=2) if antialiased else nn.Identity(), nn.BatchNorm2d(160), - nn.Conv2d(160, 256, 3, padding=1), + conv2d(160, 256, 3, padding=1), nn.BatchNorm2d(256), - nn.Conv2d(256, 256, 3, padding=1), + conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), - nn.Conv2d(256, 256, 3, padding=1), - nn.MaxPool2d((2, 2), stride=2), + conv2d(256, 256, 3, padding=1), + nn.MaxPool2d((2, 2), stride=1 if antialiased else 2), + antialiased_cnns.BlurPool(256, stride=2) if antialiased else nn.Identity(), ) def forward(self, x): diff --git a/satflow/models/runet.py b/satflow/models/runet.py index dd3aacf3..5d698991 100644 --- a/satflow/models/runet.py +++ b/satflow/models/runet.py @@ -1,3 +1,5 @@ +import antialiased_cnns +import torch.nn.functional as F from satflow.models.layers.RUnetLayers import * import pytorch_lightning as pl import torchvision @@ -17,10 +19,13 @@ def __init__( loss: Union[str, torch.nn.Module] = "mse", lr: float = 0.001, visualize: bool = False, + conv_type: str = "standard", ): self.input_channels = input_channels self.forecast_steps = forecast_steps - self.module = R2U_Net(img_ch=input_channels, output_ch=forecast_steps, t=recurrent_steps) + self.module = R2U_Net( + img_ch=input_channels, output_ch=forecast_steps, t=recurrent_steps, conv_type=conv_type + ) super().__init__() self.lr = lr self.input_channels = input_channels @@ -111,33 +116,47 @@ def visualize_step(self, x, y, y_hat, batch_idx, step="train"): class R2U_Net(nn.Module): - def __init__(self, img_ch=3, output_ch=1, t=2): + def __init__(self, img_ch=3, output_ch=1, t=2, conv_type: str = "standard"): super(R2U_Net, self).__init__() + if conv_type == "antialiased": + self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=1) + self.antialiased = True + else: + self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) + self.antialiased = False - self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) self.Upsample = nn.Upsample(scale_factor=2) - self.RRCNN1 = RRCNN_block(ch_in=img_ch, ch_out=64, t=t) - - self.RRCNN2 = RRCNN_block(ch_in=64, ch_out=128, t=t) + self.RRCNN1 = RRCNN_block(ch_in=img_ch, ch_out=64, t=t, conv_type=conv_type) + self.Blur1 = antialiased_cnns.BlurPool(64, stride=2) if self.antialiased else nn.Identity() + self.RRCNN2 = RRCNN_block(ch_in=64, ch_out=128, t=t, conv_type=conv_type) + self.Blur2 = ( + antialiased_cnns.BlurPool(128, stride=2) if self.antialiased else nn.Identity() + ) - self.RRCNN3 = RRCNN_block(ch_in=128, ch_out=256, t=t) + self.RRCNN3 = RRCNN_block(ch_in=128, ch_out=256, t=t, conv_type=conv_type) + self.Blur3 = ( + antialiased_cnns.BlurPool(256, stride=2) if self.antialiased else nn.Identity() + ) - self.RRCNN4 = RRCNN_block(ch_in=256, ch_out=512, t=t) + self.RRCNN4 = RRCNN_block(ch_in=256, ch_out=512, t=t, conv_type=conv_type) + self.Blur4 = ( + antialiased_cnns.BlurPool(512, stride=2) if self.antialiased else nn.Identity() + ) - self.RRCNN5 = RRCNN_block(ch_in=512, ch_out=1024, t=t) + self.RRCNN5 = RRCNN_block(ch_in=512, ch_out=1024, t=t, conv_type=conv_type) self.Up5 = up_conv(ch_in=1024, ch_out=512) - self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512, t=t) + self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512, t=t, conv_type=conv_type) self.Up4 = up_conv(ch_in=512, ch_out=256) - self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256, t=t) + self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256, t=t, conv_type=conv_type) self.Up3 = up_conv(ch_in=256, ch_out=128) - self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128, t=t) + self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128, t=t, conv_type=conv_type) self.Up2 = up_conv(ch_in=128, ch_out=64) - self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64, t=t) + self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64, t=t, conv_type=conv_type) self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0) @@ -146,15 +165,19 @@ def forward(self, x): x1 = self.RRCNN1(x) x2 = self.Maxpool(x1) + x2 = self.Blur1(x2) x2 = self.RRCNN2(x2) x3 = self.Maxpool(x2) + x3 = self.Blur2(x3) x3 = self.RRCNN3(x3) x4 = self.Maxpool(x3) + x4 = self.Blur3(x4) x4 = self.RRCNN4(x4) x5 = self.Maxpool(x4) + x5 = self.Blur4(x5) x5 = self.RRCNN5(x5) # decoding + concat path diff --git a/satflow/models/utils.py b/satflow/models/utils.py new file mode 100644 index 00000000..138ddb68 --- /dev/null +++ b/satflow/models/utils.py @@ -0,0 +1,15 @@ +import torch +from satflow.models.layers import CoordConv + + +def get_conv_layer(conv_type: str = "standard"): + if conv_type == "standard": + conv2d = torch.nn.Conv2d + elif conv_type == "coord": + conv2d = CoordConv + elif conv_type == "antialiased": + # TODO Add anti-aliased coordconv here + conv2d = torch.nn.Conv2d + else: + raise ValueError(f"{conv_type} is not a recognized Conv method") + return conv2d