Skip to content
This repository has been archived by the owner on Nov 29, 2023. It is now read-only.

Add Swapping out Conv2D for CoordConv/AntiAliased Conv2D in models #60

Merged
merged 6 commits into from
Jul 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 32 additions & 29 deletions satflow/models/attention_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@ def __init__(
loss: Union[str, torch.nn.Module] = "mse",
lr: float = 0.001,
visualize: bool = False,
conv_type: str = "standard",
):
super().__init__()
self.lr = lr
self.visualize = visualize
self.input_channels = input_channels
self.forecast_steps = forecast_steps
self.channels_per_timestep = 12
self.model = AttU_Net(input_channels=input_channels, output_channels=forecast_steps)
self.model = AttU_Net(
input_channels=input_channels, output_channels=forecast_steps, conv_type=conv_type
)
assert loss in ["mse", "bce", "binary_crossentropy", "crossentropy", "focal"]
if loss == "mse":
self.criterion = F.mse_loss
Expand Down Expand Up @@ -197,32 +200,32 @@ def visualize_step(self, x, y, y_hat, batch_idx, step):


class AttU_Net(nn.Module):
def __init__(self, input_channels=3, output_channels=1):
def __init__(self, input_channels=3, output_channels=1, conv_type: str = "standard"):
super(AttU_Net, self).__init__()

self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

self.Conv1 = conv_block(ch_in=input_channels, ch_out=64)
self.Conv2 = conv_block(ch_in=64, ch_out=128)
self.Conv3 = conv_block(ch_in=128, ch_out=256)
self.Conv4 = conv_block(ch_in=256, ch_out=512)
self.Conv5 = conv_block(ch_in=512, ch_out=1024)
self.Conv1 = conv_block(ch_in=input_channels, ch_out=64, conv_type=conv_type)
self.Conv2 = conv_block(ch_in=64, ch_out=128, conv_type=conv_type)
self.Conv3 = conv_block(ch_in=128, ch_out=256, conv_type=conv_type)
self.Conv4 = conv_block(ch_in=256, ch_out=512, conv_type=conv_type)
self.Conv5 = conv_block(ch_in=512, ch_out=1024, conv_type=conv_type)

self.Up5 = up_conv(ch_in=1024, ch_out=512)
self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256)
self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)
self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256, conv_type=conv_type)
self.Up_conv5 = conv_block(ch_in=1024, ch_out=512, conv_type=conv_type)

self.Up4 = up_conv(ch_in=512, ch_out=256)
self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128)
self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128, conv_type=conv_type)
self.Up_conv4 = conv_block(ch_in=512, ch_out=256, conv_type=conv_type)

self.Up3 = up_conv(ch_in=256, ch_out=128)
self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64)
self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64, conv_type=conv_type)
self.Up_conv3 = conv_block(ch_in=256, ch_out=128, conv_type=conv_type)

self.Up2 = up_conv(ch_in=128, ch_out=64)
self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32)
self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32, conv_type=conv_type)
self.Up_conv2 = conv_block(ch_in=128, ch_out=64, conv_type=conv_type)

self.Conv_1x1 = nn.Conv2d(64, output_channels, kernel_size=1, stride=1, padding=0)

Expand Down Expand Up @@ -269,37 +272,37 @@ def forward(self, x):


class R2AttU_Net(nn.Module):
def __init__(self, input_channels=3, output_channels=1, t=2):
def __init__(self, input_channels=3, output_channels=1, t=2, conv_type: str = "standard"):
super(R2AttU_Net, self).__init__()

self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.Upsample = nn.Upsample(scale_factor=2)

self.RRCNN1 = RRCNN_block(ch_in=input_channels, ch_out=64, t=t)
self.RRCNN1 = RRCNN_block(ch_in=input_channels, ch_out=64, t=t, conv_type=conv_type)

self.RRCNN2 = RRCNN_block(ch_in=64, ch_out=128, t=t)
self.RRCNN2 = RRCNN_block(ch_in=64, ch_out=128, t=t, conv_type=conv_type)

self.RRCNN3 = RRCNN_block(ch_in=128, ch_out=256, t=t)
self.RRCNN3 = RRCNN_block(ch_in=128, ch_out=256, t=t, conv_type=conv_type)

self.RRCNN4 = RRCNN_block(ch_in=256, ch_out=512, t=t)
self.RRCNN4 = RRCNN_block(ch_in=256, ch_out=512, t=t, conv_type=conv_type)

self.RRCNN5 = RRCNN_block(ch_in=512, ch_out=1024, t=t)
self.RRCNN5 = RRCNN_block(ch_in=512, ch_out=1024, t=t, conv_type=conv_type)

self.Up5 = up_conv(ch_in=1024, ch_out=512)
self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256)
self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512, t=t)
self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256, conv_type=conv_type)
self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512, t=t, conv_type=conv_type)

self.Up4 = up_conv(ch_in=512, ch_out=256)
self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128)
self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256, t=t)
self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128, conv_type=conv_type)
self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256, t=t, conv_type=conv_type)

self.Up3 = up_conv(ch_in=256, ch_out=128)
self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64)
self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128, t=t)
self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64, conv_type=conv_type)
self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128, t=t, conv_type=conv_type)

self.Up2 = up_conv(ch_in=128, ch_out=64)
self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32)
self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64, t=t)
self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32, conv_type=conv_type)
self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64, t=t, conv_type=conv_type)

self.Conv_1x1 = nn.Conv2d(64, output_channels, kernel_size=1, stride=1, padding=0)

Expand Down
26 changes: 20 additions & 6 deletions satflow/models/conv_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
visualize: bool = False,
loss: Union[str, torch.nn.Module] = "mse",
pretrained: bool = False,
conv_type: str = "standard",
):
super(EncoderDecoderConvLSTM, self).__init__()
self.forecast_steps = forecast_steps
Expand All @@ -42,7 +43,7 @@ def __init__(
raise ValueError(f"loss {loss} not recognized")
self.lr = lr
self.visualize = visualize
self.module = ConvLSTM(input_channels, hidden_dim, out_channels)
self.model = ConvLSTM(input_channels, hidden_dim, out_channels, conv_type=conv_type)
self.save_hyperparameters()

@classmethod
Expand All @@ -56,7 +57,7 @@ def from_config(cls, config):
)

def forward(self, x, future_seq=0, hidden_state=None):
return self.module.forward(x, future_seq, hidden_state)
return self.model.forward(x, future_seq, hidden_state)

def configure_optimizers(self):
# DeepSpeedCPUAdam provides 5x to 7x speedup over torch.optim.adam(w)
Expand Down Expand Up @@ -120,7 +121,7 @@ def visualize_step(self, x, y, y_hat, batch_idx, step="train"):


class ConvLSTM(torch.nn.Module):
def __init__(self, input_channels, hidden_dim, out_channels):
def __init__(self, input_channels, hidden_dim, out_channels, conv_type: str = "standard"):
super().__init__()
""" ARCHITECTURE

Expand All @@ -131,22 +132,35 @@ def __init__(self, input_channels, hidden_dim, out_channels):

"""
self.encoder_1_convlstm = ConvLSTMCell(
input_dim=input_channels, hidden_dim=hidden_dim, kernel_size=(3, 3), bias=True
input_dim=input_channels,
hidden_dim=hidden_dim,
kernel_size=(3, 3),
bias=True,
conv_type=conv_type,
)

self.encoder_2_convlstm = ConvLSTMCell(
input_dim=hidden_dim, hidden_dim=hidden_dim, kernel_size=(3, 3), bias=True
input_dim=hidden_dim,
hidden_dim=hidden_dim,
kernel_size=(3, 3),
bias=True,
conv_type=conv_type,
)

self.decoder_1_convlstm = ConvLSTMCell(
input_dim=hidden_dim,
hidden_dim=hidden_dim,
kernel_size=(3, 3),
bias=True, # nf + 1
conv_type=conv_type,
)

self.decoder_2_convlstm = ConvLSTMCell(
input_dim=hidden_dim, hidden_dim=hidden_dim, kernel_size=(3, 3), bias=True
input_dim=hidden_dim,
hidden_dim=hidden_dim,
kernel_size=(3, 3),
bias=True,
conv_type=conv_type,
)

self.decoder_CNN = nn.Conv3d(
Expand Down
Loading