diff --git a/import_deepmind_weights.py b/import_deepmind_weights.py new file mode 100644 index 0000000..a225418 --- /dev/null +++ b/import_deepmind_weights.py @@ -0,0 +1,11 @@ +import tensorflow as tf +import tensorflow_hub +import torch +from nowcasting_gan import DGMR +import os +import fsspec + +module = tensorflow_hub.load("/home/jacob/256x256/") +print(module) +print(module.signatures) +sig_model = module.signatures["default"] diff --git a/nowcasting_gan/__init__.py b/nowcasting_gan/__init__.py index 2a1626c..b9cd0b7 100644 --- a/nowcasting_gan/__init__.py +++ b/nowcasting_gan/__init__.py @@ -1,4 +1,4 @@ -from .nowcasting_gan import NowcastingGAN -from .generators import NowcastingSampler, NowcastingGenerator -from .discriminators import NowcastingSpatialDiscriminator, NowcastingTemporalDiscriminator +from .dgmr import DGMR +from .generators import Sampler, Generator +from .discriminators import SpatialDiscriminator, TemporalDiscriminator, Discriminator from .common import LatentConditioningStack, ContextConditioningStack diff --git a/nowcasting_gan/common.py b/nowcasting_gan/common.py index cd3d641..2188b0a 100644 --- a/nowcasting_gan/common.py +++ b/nowcasting_gan/common.py @@ -1,16 +1,24 @@ from typing import Tuple + +import einops import torch -from torch.distributions import uniform +import torch.nn.functional as F +from torch.distributions import normal from torch.nn.utils import spectral_norm from torch.nn.modules.pixelshuffle import PixelUnshuffle -from torch.nn.functional import interpolate from nowcasting_gan.layers.utils import get_conv_layer -from nowcasting_gan.layers import SelfAttention2d +from nowcasting_gan.layers import AttentionLayer class GBlock(torch.nn.Module): + """Residual generator block without upsampling""" + def __init__( - self, input_channels: int = 12, output_channels: int = 12, conv_type: str = "standard" + self, + input_channels: int = 12, + output_channels: int = 12, + conv_type: str = "standard", + spectral_normalized_eps=0.0001, ): """ G Block from Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf @@ -20,45 +28,120 @@ def __init__( conv_type: Type of convolution desired, see satflow/models/utils.py for options """ super().__init__() + self.output_channels = output_channels self.bn1 = torch.nn.BatchNorm2d(input_channels) self.bn2 = torch.nn.BatchNorm2d(input_channels) self.relu = torch.nn.ReLU() # Upsample in the 1x1 conv2d = get_conv_layer(conv_type) - self.conv_1x1 = conv2d( - in_channels=input_channels, - out_channels=output_channels, - kernel_size=1, + self.conv_1x1 = spectral_norm( + conv2d( + in_channels=input_channels, + out_channels=output_channels, + kernel_size=1, + ), + eps=spectral_normalized_eps, ) - self.upsample = torch.nn.Upsample(scale_factor=2, mode="bilinear") # Upsample 2D conv - self.first_conv_3x3 = torch.nn.ConvTranspose2d( - in_channels=input_channels, - out_channels=input_channels, - kernel_size=3, - stride=2, - padding=1, + self.first_conv_3x3 = spectral_norm( + conv2d( + in_channels=input_channels, + out_channels=input_channels, + kernel_size=3, + padding=1, + ), + eps=spectral_normalized_eps, ) - self.last_conv_3x3 = conv2d( - in_channels=input_channels, out_channels=output_channels, kernel_size=3, padding=1 + self.last_conv_3x3 = spectral_norm( + conv2d( + in_channels=input_channels, out_channels=output_channels, kernel_size=3, padding=1 + ), + eps=spectral_normalized_eps, ) def forward(self, x: torch.Tensor) -> torch.Tensor: - # Branch 1 - x1 = self.upsample(x) - x1 = self.conv_1x1(x1) + # Optionally spectrally normalized 1x1 convolution + if x.shape[1] != self.output_channels: + sc = self.conv_1x1(x) + else: + sc = x - # Branch 2 x2 = self.bn1(x) x2 = self.relu(x2) - x2 = self.first_conv_3x3( - x2, output_size=(2 * x.size()[-2], 2 * x.size()[-1]) - ) # Make sure size is doubled + x2 = self.first_conv_3x3(x2) # Make sure size is doubled x2 = self.bn2(x2) x2 = self.relu(x2) x2 = self.last_conv_3x3(x2) - # Sum combine - x = x1 + x2 + # Sum combine, residual connection + x = x2 + sc + return x + + +class UpsampleGBlock(torch.nn.Module): + """Residual generator block with upsampling""" + + def __init__( + self, + input_channels: int = 12, + output_channels: int = 12, + conv_type: str = "standard", + spectral_normalized_eps=0.0001, + ): + """ + G Block from Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf + Args: + input_channels: Number of input channels + output_channels: Number of output channels + conv_type: Type of convolution desired, see satflow/models/utils.py for options + """ + super().__init__() + self.output_channels = output_channels + self.bn1 = torch.nn.BatchNorm2d(input_channels) + self.bn2 = torch.nn.BatchNorm2d(input_channels) + self.relu = torch.nn.ReLU() + # Upsample in the 1x1 + conv2d = get_conv_layer(conv_type) + self.conv_1x1 = spectral_norm( + conv2d( + in_channels=input_channels, + out_channels=output_channels, + kernel_size=1, + ), + eps=spectral_normalized_eps, + ) + self.upsample = torch.nn.Upsample(scale_factor=2, mode="nearest") + # Upsample 2D conv + self.first_conv_3x3 = spectral_norm( + conv2d( + in_channels=input_channels, + out_channels=input_channels, + kernel_size=3, + padding=1, + ), + eps=spectral_normalized_eps, + ) + self.last_conv_3x3 = spectral_norm( + conv2d( + in_channels=input_channels, out_channels=output_channels, kernel_size=3, padding=1 + ), + eps=spectral_normalized_eps, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Spectrally normalized 1x1 convolution + sc = self.upsample(x) + sc = self.conv_1x1(sc) + + x2 = self.bn1(x) + x2 = self.relu(x2) + # Upsample + x2 = self.upsample(x2) + x2 = self.first_conv_3x3(x2) # Make sure size is doubled + x2 = self.bn2(x2) + x2 = self.relu(x2) + x2 = self.last_conv_3x3(x2) + # Sum combine, residual connection + x = x2 + sc return x @@ -81,59 +164,74 @@ def __init__( keep_same_output: Whether the output should have the same spatial dimensions as input, if False, downscales by 2 """ super().__init__() + self.input_channels = input_channels + self.output_channels = output_channels self.first_relu = first_relu self.keep_same_output = keep_same_output self.conv_type = conv_type conv2d = get_conv_layer(conv_type) - self.conv_1x1 = conv2d( - in_channels=input_channels, - out_channels=output_channels, - kernel_size=1, + if conv_type == "3d": + # 3D Average pooling + self.pooling = torch.nn.AvgPool3d(kernel_size=2, stride=2) + else: + self.pooling = torch.nn.AvgPool2d(kernel_size=2, stride=2) + self.conv_1x1 = spectral_norm( + conv2d( + in_channels=input_channels, + out_channels=output_channels, + kernel_size=1, + ) ) - self.first_conv_3x3 = conv2d( - in_channels=input_channels, - out_channels=input_channels, - kernel_size=3, - padding=1, + self.first_conv_3x3 = spectral_norm( + conv2d( + in_channels=input_channels, + out_channels=output_channels, + kernel_size=3, + padding=1, + ) ) - self.last_conv_3x3 = conv2d( - in_channels=input_channels, - out_channels=output_channels, - kernel_size=3, - padding=1, - stride=1, + self.last_conv_3x3 = spectral_norm( + conv2d( + in_channels=output_channels, + out_channels=output_channels, + kernel_size=3, + padding=1, + stride=1, + ) ) - if conv_type == "3d": - # Need spectrally normalized convolutions - self.conv_1x1 = spectral_norm(self.conv_1x1) - self.first_conv_3x3 = spectral_norm(self.first_conv_3x3) - self.last_conv_3x3 = spectral_norm(self.last_conv_3x3) # Downsample at end of 3x3 self.relu = torch.nn.ReLU() # Concatenate to double final channels and keep reduced spatial extent def forward(self, x: torch.Tensor) -> torch.Tensor: - x1 = self.conv_1x1(x) - if not self.keep_same_output: - x1 = interpolate( - x1, mode="trilinear" if self.conv_type == "3d" else "bilinear", scale_factor=0.5 - ) # Downscale by half + if self.input_channels != self.output_channels: + x1 = self.conv_1x1(x) + if not self.keep_same_output: + x1 = self.pooling(x1) + else: + x1 = x + if self.first_relu: x = self.relu(x) x = self.first_conv_3x3(x) x = self.relu(x) x = self.last_conv_3x3(x) + if not self.keep_same_output: - x = interpolate( - x, mode="trilinear" if self.conv_type == "3d" else "bilinear", scale_factor=0.5 - ) # Downscale by half + x = self.pooling(x) x = x1 + x # Sum the outputs should be half spatial and double channels return x class LBlock(torch.nn.Module): + """Residual block for the Latent Stack.""" + def __init__( - self, input_channels: int = 12, output_channels: int = 12, conv_type: str = "standard" + self, + input_channels: int = 12, + output_channels: int = 12, + kernel_size: int = 3, + conv_type: str = "standard", ): """ L-Block for increasing the number of channels in the input @@ -145,6 +243,8 @@ def __init__( """ super().__init__() # Output size should be channel_out - channel_in + self.input_channels = input_channels + self.output_channels = output_channels conv2d = get_conv_layer(conv_type) self.conv_1x1 = conv2d( in_channels=input_channels, @@ -153,25 +253,34 @@ def __init__( ) self.first_conv_3x3 = conv2d( - input_channels, out_channels=output_channels, kernel_size=3, padding=1, stride=1 + input_channels, + out_channels=output_channels, + kernel_size=kernel_size, + padding=1, + stride=1, ) self.relu = torch.nn.ReLU() self.last_conv_3x3 = conv2d( in_channels=output_channels, out_channels=output_channels, - kernel_size=3, + kernel_size=kernel_size, padding=1, stride=1, ) def forward(self, x) -> torch.Tensor: - x1 = self.conv_1x1(x) - x2 = self.first_conv_3x3(x) + if self.input_channels < self.output_channels: + sc = self.conv_1x1(x) + sc = torch.cat([x, sc], dim=1) + else: + sc = x + + x2 = self.relu(x) + x2 = self.first_conv_3x3(x2) x2 = self.relu(x2) x2 = self.last_conv_3x3(x2) - x = x2 + (torch.cat((x, x1), dim=1)) - return x + return x2 + sc class ContextConditioningStack(torch.nn.Module): @@ -196,7 +305,6 @@ def __init__( # Process each observation processed separately with 4 downsample blocks # Concatenate across channel dimension, and for each output, 3x3 spectrally normalized convolution to reduce # number of channels by 2, followed by ReLU - # TODO Not sure if a different block for each timestep, or same block used separately self.d1 = DBlock( input_channels=4 * input_channels, output_channels=((output_channels // 4) * input_channels) // num_context_steps, @@ -274,18 +382,23 @@ def forward( scale_2.append(s2) scale_3.append(s3) scale_4.append(s4) - scale_1 = torch.cat(scale_1, dim=1) # B, T, C, H, W and want along C dimension - scale_2 = torch.cat(scale_2, dim=1) # B, T, C, H, W and want along C dimension - scale_3 = torch.cat(scale_3, dim=1) # B, T, C, H, W and want along C dimension - scale_4 = torch.cat(scale_4, dim=1) # B, T, C, H, W and want along C dimension - # TODO Figure out where extra channels come from, paper says concat outputs and divide channels by 2 gives 48,96,192,384 total, but this gives 8*4 = 32, 16*4 = 64 - scale_1 = self.relu(self.conv1(scale_1)) - scale_2 = self.relu(self.conv2(scale_2)) - scale_3 = self.relu(self.conv3(scale_3)) - scale_4 = self.relu(self.conv4(scale_4)) - + scale_1 = torch.stack(scale_1, dim=1) # B, T, C, H, W and want along C dimension + scale_2 = torch.stack(scale_2, dim=1) # B, T, C, H, W and want along C dimension + scale_3 = torch.stack(scale_3, dim=1) # B, T, C, H, W and want along C dimension + scale_4 = torch.stack(scale_4, dim=1) # B, T, C, H, W and want along C dimension + # Mixing layer + scale_1 = self._mixing_layer(scale_1, self.conv1) + scale_2 = self._mixing_layer(scale_2, self.conv2) + scale_3 = self._mixing_layer(scale_3, self.conv3) + scale_4 = self._mixing_layer(scale_4, self.conv4) return scale_1, scale_2, scale_3, scale_4 + def _mixing_layer(self, inputs, conv_block): + # Convert from [batch_size, time, h, w, c] -> [batch_size, h, w, c * time] + # then perform convolution on the output while preserving number of c. + stacked_inputs = einops.rearrange(inputs, "b t c h w -> b (c t) h w") + return F.relu(conv_block(stacked_inputs)) + class LatentConditioningStack(torch.nn.Module): def __init__( @@ -305,10 +418,12 @@ def __init__( super().__init__() self.shape = shape self.use_attention = use_attention - self.distribution = uniform.Uniform(low=torch.Tensor([0.0]), high=torch.Tensor([1.0])) + self.distribution = normal.Normal(loc=torch.Tensor([0.0]), scale=torch.Tensor([1.0])) - self.conv_3x3 = torch.nn.Conv2d( - in_channels=shape[0], out_channels=shape[0], kernel_size=3, padding=1 + self.conv_3x3 = spectral_norm( + torch.nn.Conv2d( + in_channels=shape[0], out_channels=shape[0], kernel_size=(3, 3), padding=1 + ) ) self.l_block1 = LBlock(input_channels=shape[0], output_channels=output_channels // 32) self.l_block2 = LBlock( @@ -318,8 +433,8 @@ def __init__( input_channels=output_channels // 16, output_channels=output_channels // 4 ) if self.use_attention: - self.att_block = SelfAttention2d( - input_dims=output_channels // 4, output_dims=output_channels // 4 + self.att_block = AttentionLayer( + input_channels=output_channels // 4, output_channels=output_channels // 4 ) self.l_block4 = LBlock(input_channels=output_channels // 4, output_channels=output_channels) @@ -332,14 +447,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: """ + + # Independent draws from Norma ldistribution z = self.distribution.sample(self.shape) # Batch is at end for some reason, reshape z = torch.permute(z, (3, 0, 1, 2)).type_as(x) + + # 3x3 Convolution z = self.conv_3x3(z) + + # 3 L Blocks to increase number of channels z = self.l_block1(z) z = self.l_block2(z) z = self.l_block3(z) - if self.use_attention: - z = self.att_block(z) + # Spatial attention module + z = self.att_block(z) + + # L block to increase number of channel to 768 z = self.l_block4(z) return z diff --git a/nowcasting_gan/nowcasting_gan.py b/nowcasting_gan/dgmr.py similarity index 50% rename from nowcasting_gan/nowcasting_gan.py rename to nowcasting_gan/dgmr.py index f872958..f777009 100644 --- a/nowcasting_gan/nowcasting_gan.py +++ b/nowcasting_gan/dgmr.py @@ -1,24 +1,28 @@ import torch -from nowcasting_gan.losses import NowcastingLoss, GridCellLoss +from nowcasting_gan.losses import ( + NowcastingLoss, + GridCellLoss, + loss_hinge_disc, + loss_hinge_gen, + grid_cell_regularizer, +) import pytorch_lightning as pl import torchvision -from typing import List from nowcasting_gan.common import LatentConditioningStack, ContextConditioningStack -from nowcasting_gan.generators import NowcastingSampler, NowcastingGenerator -from nowcasting_gan.discriminators import ( - NowcastingSpatialDiscriminator, - NowcastingTemporalDiscriminator, -) +from nowcasting_gan.generators import Sampler, Generator +from nowcasting_gan.discriminators import Discriminator -class NowcastingGAN(pl.LightningModule): +class DGMR(pl.LightningModule): + """Deep Generative Model of Radar""" + def __init__( self, forecast_steps: int = 18, input_channels: int = 1, output_shape: int = 256, - gen_lr: float = 0.00005, - disc_lr: float = 0.0002, + gen_lr: float = 5e-5, + disc_lr: float = 2e-4, visualize: bool = False, pretrained: bool = False, conv_type: str = "standard", @@ -48,7 +52,7 @@ def __init__( input dimension into ConvGRU, also affects the number of channels for other linked inputs/outputs pretrained: """ - super(NowcastingGAN, self).__init__() + super(DGMR, self).__init__() self.gen_lr = gen_lr self.disc_lr = disc_lr self.beta1 = beta1 @@ -70,20 +74,13 @@ def __init__( shape=(8 * self.input_channels, output_shape // 32, output_shape // 32), output_channels=self.latent_channels, ) - self.sampler = NowcastingSampler( + self.sampler = Sampler( forecast_steps=forecast_steps, latent_channels=self.latent_channels, context_channels=self.context_channels, ) - self.generator = NowcastingGenerator( - self.conditioning_stack, self.latent_stack, self.sampler - ) - self.temporal_discriminator = NowcastingTemporalDiscriminator( - input_channels=input_channels, crop_size=output_shape // 2, conv_type=conv_type - ) - self.spatial_discriminator = NowcastingSpatialDiscriminator( - input_channels=input_channels, num_timesteps=8, conv_type=conv_type - ) + self.generator = Generator(self.conditioning_stack, self.latent_stack, self.sampler) + self.discriminator = Discriminator(input_channels) self.save_hyperparameters() self.global_iteration = 0 @@ -99,70 +96,45 @@ def forward(self, x): def training_step(self, batch, batch_idx): images, future_images = batch self.global_iteration += 1 - g_opt, d_opt_s, d_opt_t = self.optimizers() + g_opt, d_opt = self.optimizers() ########################## # Optimize Discriminator # ########################## # Two discriminator steps per generator step for _ in range(2): - # TODO Make sure this is meant to be the mean predictions, or to run it 6 times and then take mean? - # Get the best prediction of the six - # mean_prediction = [] - # for _ in range(self.num_samples): - # mean_prediction.append(self(images)) - # mean_prediction = self.average_tensors(mean_prediction) - mean_prediction = self(images) - # Get Spatial Loss - # Should go with lowest loss of the 6 predictions - # x should be the chosen 8 or so - spatial_real = self.spatial_discriminator(future_images) - spatial_fake = self.spatial_discriminator(mean_prediction) - spatial_loss = self.discriminator_loss(spatial_real, True) + self.discriminator_loss( - spatial_fake, False - ) - # Get Temporal Loss - temporal_real = self.temporal_discriminator(torch.cat((images, future_images), 1)) - temporal_fake = self.temporal_discriminator(torch.cat((images, mean_prediction), 1)) - temporal_loss = self.discriminator_loss(temporal_real, True) + self.discriminator_loss( - temporal_fake, False - ) + predictions = self(images) + # Cat along time dimension [B, C, T, H, W] + generated_sequence = torch.cat([images, predictions], dim=2) + real_sequence = torch.cat([images, future_images], dim=2) + # Cat long batch for the real+generated + concatenated_inputs = torch.cat([real_sequence, generated_sequence], dim=0) + + concatenated_outputs = self.discriminator(concatenated_inputs) - # discriminator loss is the average of these - d_loss = spatial_loss + temporal_loss - d_opt_t.zero_grad() - d_opt_s.zero_grad() - self.manual_backward(d_loss) - d_opt_t.step() - d_opt_s.step() + score_real, score_generated = torch.split(concatenated_outputs, 2, dim=0) + discriminator_loss = loss_hinge_disc(score_generated, score_real) + d_opt.zero_grad() + self.manual_backward(discriminator_loss) + d_opt.step() ###################### # Optimize Generator # ###################### - # TODO Do the 6 samples for this? - mean_prediction = self(images) - # Get Spatial Loss - spatial_fake = self.spatial_discriminator(torch.cat((images, mean_prediction), 1)) - spatial_loss = self.discriminator_loss(spatial_fake, True) - - # Get Temporal Loss - temporal_fake = self.temporal_discriminator(torch.cat((images, mean_prediction), 1)) - temporal_loss = self.discriminator_loss(temporal_fake, True) - - # Grid Cell Loss - grid_loss = self.grid_regularizer(mean_prediction, future_images) - - g_loss = spatial_loss + temporal_loss - (self.grid_lambda * grid_loss) + predictions = [self(images) for _ in range(6)] + grid_cell_reg = grid_cell_regularizer(torch.stack(predictions, dim=0), future_images) + # Concat along time dimension + generated_sequence = [torch.cat([images, x], dim=2) for x in predictions] + generator_disc_loss = loss_hinge_gen(torch.cat(generated_sequence, dim=0)) + generator_loss = generator_disc_loss + self.grid_lambda * grid_cell_reg g_opt.zero_grad() - self.manual_backward(g_loss) + self.manual_backward(generator_loss) g_opt.step() self.log_dict( { - "train/d_loss": d_loss, - "train/temporal_loss": temporal_loss, - "train/spatial_loss": spatial_loss, - "train/g_loss": g_loss, - "train/grid_loss": grid_loss, + "train/d_loss": discriminator_loss, + "train/g_loss": generator_loss, + "train/grid_loss": grid_cell_reg, }, prog_bar=True, ) @@ -175,63 +147,14 @@ def training_step(self, batch, batch_idx): images, future_images, generated_images, self.global_iteration, step="train" ) - def average_tensors(self, x: List[torch.Tensor]): - summed_tensor = torch.stack(x, dim=0) - summed_tensor = torch.mean(summed_tensor, dim=0) - return summed_tensor - - def validation_step(self, batch, batch_idx): - images, future_images = batch - - # First get the 6 samples to mean? - # TODO Make sure this is what the paper actually means, or is it run it 6 times then average output? - mean_prediction = self(images) - # Get Spatial Loss - # x should be the chosen 8 or so - spatial_real = self.spatial_discriminator(future_images) - spatial_fake = self.spatial_discriminator(mean_prediction) - spatial_loss = self.discriminator_loss(spatial_real, True) + self.discriminator_loss( - spatial_fake, False - ) - # Get Temporal Loss - temporal_real = self.temporal_discriminator(torch.cat((images, future_images), 1)) - temporal_fake = self.temporal_discriminator(torch.cat((images, mean_prediction), 1)) - temporal_loss = self.discriminator_loss(temporal_real, True) + self.discriminator_loss( - temporal_fake, False - ) - - # Grid Cell Loss - grid_loss = self.grid_regularizer(mean_prediction, future_images) - - # Generator Loss - g_s = self.discriminator_loss(spatial_fake, True) - g_t = self.discriminator_loss(temporal_fake, True) - g_loss = g_s + g_t - (self.grid_lambda * grid_loss) - - self.log_dict( - { - "val/d_loss": temporal_loss + spatial_loss, - "val/temporal_loss": temporal_loss, - "val/spatial_loss": spatial_loss, - "val/g_loss": g_loss, - "val/grid_loss": grid_loss, - }, - prog_bar=True, - ) - def configure_optimizers(self): b1 = self.beta1 b2 = self.beta2 opt_g = torch.optim.Adam(self.generator.parameters(), lr=self.gen_lr, betas=(b1, b2)) - opt_d_s = torch.optim.Adam( - self.spatial_discriminator.parameters(), lr=self.disc_lr, betas=(b1, b2) - ) - opt_d_t = torch.optim.Adam( - self.temporal_discriminator.parameters(), lr=self.disc_lr, betas=(b1, b2) - ) + opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=self.disc_lr, betas=(b1, b2)) - return [opt_g, opt_d_s, opt_d_t], [] + return [opt_g, opt_d], [] def visualize_step( self, x: torch.Tensor, y: torch.Tensor, y_hat: torch.Tensor, batch_idx: int, step: str diff --git a/nowcasting_gan/discriminators.py b/nowcasting_gan/discriminators.py index 916c9aa..00d01d0 100644 --- a/nowcasting_gan/discriminators.py +++ b/nowcasting_gan/discriminators.py @@ -1,15 +1,36 @@ import torch from torch.nn.modules.pixelshuffle import PixelUnshuffle from torch.nn.utils import spectral_norm -from torchvision.transforms import RandomCrop +import torch.nn.functional as F from nowcasting_gan.common import DBlock -class NowcastingTemporalDiscriminator(torch.nn.Module): +class Discriminator(torch.nn.Module): + def __init__( + self, + input_channels: int = 12, + num_spatial_frames: int = 8, + conv_type: str = "standard", + ): + super().__init__() + self.spatial_discriminator = SpatialDiscriminator( + input_channels=input_channels, num_timesteps=num_spatial_frames, conv_type=conv_type + ) + self.temporal_discriminator = TemporalDiscriminator( + input_channels=input_channels, conv_type=conv_type + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + spatial_loss = self.spatial_discriminator(x) + temporal_loss = self.temporal_discriminator(x) + + return torch.cat([spatial_loss, temporal_loss], dim=1) + + +class TemporalDiscriminator(torch.nn.Module): def __init__( self, input_channels: int = 12, - crop_size: int = 128, num_layers: int = 3, conv_type: str = "standard", ): @@ -23,7 +44,7 @@ def __init__( conv_type: Type of 2d convolutions to use, see satflow/models/utils.py for options """ super().__init__() - self.transform = RandomCrop(crop_size) + self.downsample = torch.nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)) self.space2depth = PixelUnshuffle(downscale_factor=2) internal_chn = 48 self.d1 = DBlock( @@ -57,12 +78,16 @@ def __init__( self.fc = spectral_norm(torch.nn.Linear(2 * internal_chn * input_channels, 1)) self.relu = torch.nn.ReLU() + self.bn = torch.nn.BatchNorm1d(2 * internal_chn * input_channels) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.transform(x) + x = self.downsample(x) + x = self.space2depth(x) # Have to move time and channels x = torch.permute(x, dims=(0, 2, 1, 3, 4)) + # 2 residual 3D blocks to halve resolution if image, double number of channels and reduce + # number of time steps x = self.d1(x) x = self.d2(x) # Convert back to T x C x H x W @@ -71,23 +96,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: representations = [] for idx in range(x.size(1)): # Intermediate DBlocks + # Three residual D Blocks to halve the resolution of the image and double + # the number of channels. rep = x[:, idx, :, :, :] for d in self.intermediate_dblocks: rep = d(rep) + # One more D Block without downsampling or increase number of channels rep = self.d_last(rep) - # Sum-pool along width and height all 8 representations, pretty sure only the last output - rep = torch.sum(rep.view(rep.size(0), rep.size(1), -1), dim=2) + + rep = torch.sum(F.relu(rep), dim=[2, 3]) + rep = self.bn(rep) + rep = self.fc(rep) + # rep = self.fc(rep) representations.append(rep) # The representations are summed together before the ReLU - x = torch.stack(representations, dim=0).sum(dim=0) # Should be right shape? TODO Check - # ReLU the output - x = self.fc(x) - # x = self.relu(x) + x = torch.stack(representations, dim=1) + # Should be [Batch, N, 1] + x = torch.sum(x, keepdim=True, dim=1) return x -class NowcastingSpatialDiscriminator(torch.nn.Module): +class SpatialDiscriminator(torch.nn.Module): def __init__( self, input_channels: int = 12, @@ -137,6 +167,7 @@ def __init__( # Spectrally normalized linear layer for binary classification self.fc = spectral_norm(torch.nn.Linear(2 * internal_chn * input_channels, 1)) self.relu = torch.nn.ReLU() + self.bn = torch.nn.BatchNorm1d(2 * internal_chn * input_channels) def forward(self, x: torch.Tensor) -> torch.Tensor: # x should be the chosen 8 or so @@ -150,14 +181,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: for d in self.intermediate_dblocks: rep = d(rep) rep = self.d6(rep) # 2x2 + rep = torch.sum(F.relu(rep), dim=[2, 3]) + rep = self.bn(rep) + rep = self.fc(rep) + """ + Pseudocode from DeepMind + # Sum-pool the representations and feed to spectrally normalized lin. layer. + y = tf.reduce_sum(tf.nn.relu(y), axis=[1, 2]) + y = layers.BatchNorm(calc_sigma=False)(y) + output_layer = layers.Linear(output_size=1) + output = output_layer(y) - # Sum-pool along width and height all 8 representations, pretty sure only the last output - rep = torch.sum(rep.view(rep.size(0), rep.size(1), -1), dim=2) + # Take the sum across the t samples. Note: we apply the ReLU to + # (1 - score_real) and (1 + score_generated) in the loss. + output = tf.reshape(output, [b, n, 1]) + output = tf.reduce_sum(output, keepdims=True, axis=1) + return output + """ representations.append(rep) # The representations are summed together before the ReLU - x = torch.stack(representations, dim=0).sum(dim=0) # Should be right shape? TODO Check - # ReLU the output - x = self.fc(x) - # x = self.relu(x) + x = torch.stack(representations, dim=1) + # Should be [Batch, N, 1] + x = torch.sum(x, keepdim=True, dim=1) return x diff --git a/nowcasting_gan/generators.py b/nowcasting_gan/generators.py index 8260a43..3b3ab24 100644 --- a/nowcasting_gan/generators.py +++ b/nowcasting_gan/generators.py @@ -1,12 +1,18 @@ +import einops import torch +import torch.nn.functional as F from torch.nn.modules.pixelshuffle import PixelShuffle from torch.nn.utils import spectral_norm from typing import List -from nowcasting_gan.common import GBlock +from nowcasting_gan.common import GBlock, UpsampleGBlock from nowcasting_gan.layers import ConvGRU +import logging +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARN) -class NowcastingSampler(torch.nn.Module): + +class Sampler(torch.nn.Module): def __init__( self, forecast_steps: int = 18, @@ -25,65 +31,85 @@ def __init__( """ super().__init__() self.forecast_steps = forecast_steps + self.convGRU1 = ConvGRU( - input_channels=latent_channels, - hidden_channels=context_channels, - kernel_size=(3, 3), - padding=(1, 1), + input_channels=latent_channels + context_channels, + output_channels=context_channels, + kernel_size=3, + ) + self.gru_conv_1x1 = spectral_norm( + torch.nn.Conv2d( + in_channels=context_channels, out_channels=latent_channels, kernel_size=(1, 1) + ) + ) + self.g1 = GBlock(input_channels=latent_channels, output_channels=latent_channels) + self.up_g1 = UpsampleGBlock( + input_channels=latent_channels, output_channels=latent_channels // 2 ) - self.g1 = GBlock(input_channels=latent_channels, output_channels=latent_channels // 2) + self.convGRU2 = ConvGRU( - input_channels=latent_channels // 2, - hidden_channels=context_channels // 2, - kernel_size=(3, 3), - padding=(1, 1), + input_channels=latent_channels // 2 + context_channels // 2, + output_channels=context_channels // 2, + kernel_size=3, + ) + self.gru_conv_1x1_2 = spectral_norm( + torch.nn.Conv2d( + in_channels=context_channels // 2, + out_channels=latent_channels // 2, + kernel_size=(1, 1), + ) ) - self.g2 = GBlock(input_channels=latent_channels // 2, output_channels=latent_channels // 4) + self.g2 = GBlock(input_channels=latent_channels // 2, output_channels=latent_channels // 2) + self.up_g2 = UpsampleGBlock( + input_channels=latent_channels // 2, output_channels=latent_channels // 4 + ) + self.convGRU3 = ConvGRU( - input_channels=latent_channels // 4, - hidden_channels=context_channels // 4, - kernel_size=(3, 3), - padding=(1, 1), + input_channels=latent_channels // 4 + context_channels // 4, + output_channels=context_channels // 4, + kernel_size=3, + ) + self.gru_conv_1x1_3 = spectral_norm( + torch.nn.Conv2d( + in_channels=context_channels // 4, + out_channels=latent_channels // 4, + kernel_size=(1, 1), + ) + ) + self.g3 = GBlock(input_channels=latent_channels // 4, output_channels=latent_channels // 4) + self.up_g3 = UpsampleGBlock( + input_channels=latent_channels // 4, output_channels=latent_channels // 8 ) - self.g3 = GBlock(input_channels=latent_channels // 4, output_channels=latent_channels // 8) + self.convGRU4 = ConvGRU( - input_channels=latent_channels // 8, - hidden_channels=context_channels // 8, - kernel_size=(3, 3), - padding=(1, 1), + input_channels=latent_channels // 8 + context_channels // 8, + output_channels=context_channels // 8, + kernel_size=3, + ) + self.gru_conv_1x1_4 = spectral_norm( + torch.nn.Conv2d( + in_channels=context_channels // 8, + out_channels=latent_channels // 8, + kernel_size=(1, 1), + ) ) - self.g4 = GBlock(input_channels=latent_channels // 8, output_channels=latent_channels // 16) + self.g4 = GBlock(input_channels=latent_channels // 8, output_channels=latent_channels // 8) + self.up_g4 = UpsampleGBlock( + input_channels=latent_channels // 8, output_channels=latent_channels // 16 + ) + self.bn = torch.nn.BatchNorm2d(latent_channels // 16) self.relu = torch.nn.ReLU() self.conv_1x1 = spectral_norm( torch.nn.Conv2d( - in_channels=latent_channels // 16, out_channels=4 * output_channels, kernel_size=1 + in_channels=latent_channels // 16, + out_channels=4 * output_channels, + kernel_size=(1, 1), ) ) self.depth2space = PixelShuffle(upscale_factor=2) - # Now make copies of the entire stack, one for each future timestep - stacks = torch.nn.ModuleDict() - for i in range(forecast_steps): - stacks[f"forecast_{i}"] = torch.nn.ModuleList( - [ - self.convGRU1, - self.g1, - self.convGRU2, - self.g2, - self.convGRU3, - self.g3, - self.convGRU4, - self.g4, - self.bn, - self.relu, - self.conv_1x1, - self.depth2space, - ] - ) - self.stacks = stacks - def forward( self, conditioning_states: List[torch.Tensor], latent_dim: torch.Tensor ) -> torch.Tensor: @@ -99,68 +125,48 @@ def forward( """ # Iterate through each forecast step # Initialize with conditioning state for first one, output for second one - forecasts = [] - init_states = list( - conditioning_states - ) # [torch.unsqueeze(c, dim=1) for c in conditioning_states] - # Need to expand latent dim to the batch size - latent_dim = torch.cat(init_states[0].size()[0] * [latent_dim]) - latent_dim = torch.unsqueeze(latent_dim, dim=1) - for i in range(self.forecast_steps): - # Start at lowest one and go up, conditioning states - # ConvGRU1 - x = self.stacks[f"forecast_{i}"][0](latent_dim, hidden_state=init_states[3]) - # Update for next timestep - init_states[3] = torch.squeeze(x, dim=0) - # Reduce to 4D input - x = torch.squeeze(x, dim=0) - # GBlock1 - x = self.stacks[f"forecast_{i}"][1](x) - # Expand to 5D input - x = torch.unsqueeze(x, dim=1) - # ConvGRU2 - x = self.stacks[f"forecast_{i}"][2](x, hidden_state=init_states[2]) - # Update for next timestep - init_states[2] = torch.squeeze(x, dim=0) - # Reduce to 4D input - x = torch.squeeze(x, dim=0) - # GBlock2 - x = self.stacks[f"forecast_{i}"][3](x) - # Expand to 5D input - x = torch.unsqueeze(x, dim=1) - # ConvGRU3 - x = self.stacks[f"forecast_{i}"][4](x, hidden_state=init_states[1]) - # Update for next timestep - init_states[1] = torch.squeeze(x, dim=0) - # Reduce to 4D input - x = torch.squeeze(x, dim=0) - # GBlock3 - x = self.stacks[f"forecast_{i}"][5](x) - # Expand to 5D input - x = torch.unsqueeze(x, dim=1) - # ConvGRU4 - x = self.stacks[f"forecast_{i}"][6](x, hidden_state=init_states[0]) - # Update for next timestep - init_states[0] = torch.squeeze(x, dim=0) - # Reduce to 4D input - x = torch.squeeze(x, dim=0) - # GBlock4 - x = self.stacks[f"forecast_{i}"][7](x) - # BN - x = self.stacks[f"forecast_{i}"][8](x) - # ReLU - x = self.stacks[f"forecast_{i}"][9](x) - # Conv 1x1 - x = self.stacks[f"forecast_{i}"][10](x) - # Depth2Space - x = self.stacks[f"forecast_{i}"][11](x) - forecasts.append(x) + init_states = conditioning_states + # Expand latent dim to match batch size + latent_dim = einops.repeat( + latent_dim, "b c h w -> (repeat b) c h w", repeat=init_states[0].shape[0] + ) + hidden_states = [latent_dim] * self.forecast_steps + + # Layer 4 (bottom most) + hidden_states = self.convGRU1(hidden_states, init_states[3]) + hidden_states = [self.gru_conv_1x1(h) for h in hidden_states] + hidden_states = [self.g1(h) for h in hidden_states] + hidden_states = [self.up_g1(h) for h in hidden_states] + + # Layer 3. + hidden_states = self.convGRU2(hidden_states, init_states[2]) + hidden_states = [self.gru_conv_1x1_2(h) for h in hidden_states] + hidden_states = [self.g2(h) for h in hidden_states] + hidden_states = [self.up_g2(h) for h in hidden_states] + + # Layer 2. + hidden_states = self.convGRU3(hidden_states, init_states[1]) + hidden_states = [self.gru_conv_1x1_3(h) for h in hidden_states] + hidden_states = [self.g3(h) for h in hidden_states] + hidden_states = [self.up_g3(h) for h in hidden_states] + + # Layer 1 (top-most). + hidden_states = self.convGRU4(hidden_states, init_states[0]) + hidden_states = [self.gru_conv_1x1_4(h) for h in hidden_states] + hidden_states = [self.g4(h) for h in hidden_states] + hidden_states = [self.up_g4(h) for h in hidden_states] + + # Output layer. + hidden_states = [F.relu(self.bn(h)) for h in hidden_states] + hidden_states = [self.conv_1x1(h) for h in hidden_states] + hidden_states = [self.depth2space(h) for h in hidden_states] + # Convert forecasts to a torch Tensor - forecasts = torch.stack(forecasts, dim=1) + forecasts = torch.stack(hidden_states, dim=1) return forecasts -class NowcastingGenerator(torch.nn.Module): +class Generator(torch.nn.Module): def __init__( self, conditioning_stack: torch.nn.Module, diff --git a/nowcasting_gan/layers/Attention.py b/nowcasting_gan/layers/Attention.py index 1da2ecf..0613320 100644 --- a/nowcasting_gan/layers/Attention.py +++ b/nowcasting_gan/layers/Attention.py @@ -1,56 +1,81 @@ import torch import torch.nn as nn from torch.nn import functional as F +import einops -class SelfAttention2d(nn.Module): - r"""Self Attention Module as proposed in the paper `"Self-Attention Generative Adversarial - Networks by Han Zhang et. al." `_ - .. math:: attention = softmax((query(x))^T * key(x)) - .. math:: output = \gamma * value(x) * attention + x - where - - :math:`query` : 2D Convolution Operation - - :math:`key` : 2D Convolution Operation - - :math:`value` : 2D Convolution Operation - - :math:`x` : Input - Args: - input_dims (int): The input channel dimension in the input ``x``. - output_dims (int, optional): The output channel dimension. If ``None`` the output - channel value is computed as ``input_dims // 8``. So if the ``input_dims`` is **less - than 8** then the layer will give an error. - return_attn (bool, optional): Set it to ``True`` if you want the attention values to be - returned. - """ - - def __init__(self, input_dims, output_dims=None, return_attn=False): - output_dims = input_dims // 8 if output_dims is None else output_dims - if output_dims == 0: - raise Exception( - "The output dims corresponding to the input dims is 0. Increase the input\ - dims to 8 or more. Else specify output_dims" - ) - super(SelfAttention2d, self).__init__() - self.query = nn.Conv2d(input_dims, output_dims, 1) - self.key = nn.Conv2d(input_dims, output_dims, 1) - self.value = nn.Conv2d(input_dims, input_dims, 1) +def attention_einsum(q, k, v): + """Apply the attention operator to tensors of shape [h, w, c].""" + + # Reshape 3D tensors to 2D tensor with first dimension L = h x w. + k = einops.rearrange(k, "h w c -> (h w) c") # [h, w, c] -> [L, c] + v = einops.rearrange(v, "h w c -> (h w) c") # [h, w, c] -> [L, c] + + # Einstein summation corresponding to the query * key operation. + beta = F.softmax(torch.einsum("hwc, Lc->hwL", q, k), dim=-1) + + # Einstein summation corresponding to the attention * value operation. + out = torch.einsum("hwL, Lc->hwc", beta, v) + return out + + +class AttentionLayer(torch.nn.Module): + """Attention Module""" + + def __init__(self, input_channels: int, output_channels: int, ratio_kq=8, ratio_v=8): + super(AttentionLayer, self).__init__() + + self.ratio_kq = ratio_kq + self.ratio_v = ratio_v + self.output_channels = output_channels + self.input_channels = input_channels + + # Compute query, key and value using 1x1 convolutions. + self.query = torch.nn.Conv2d( + in_channels=input_channels, + out_channels=self.output_channels // self.ratio_kq, + kernel_size=(1, 1), + padding="valid", + bias=False, + ) + self.key = torch.nn.Conv2d( + in_channels=input_channels, + out_channels=self.output_channels // self.ratio_kq, + kernel_size=(1, 1), + padding="valid", + bias=False, + ) + self.value = torch.nn.Conv2d( + in_channels=input_channels, + out_channels=self.output_channels // self.ratio_v, + kernel_size=(1, 1), + padding="valid", + bias=False, + ) + + self.last_conv = torch.nn.Conv2d( + in_channels=self.output_channels // 8, + out_channels=self.output_channels, + kernel_size=(1, 1), + padding="valid", + bias=False, + ) + + # Learnable gain parameter self.gamma = nn.Parameter(torch.zeros(1)) - self.return_attn = return_attn - - def forward(self, x): - r"""Computes the output of the Self Attention Layer - Args: - x (torch.Tensor): A 4D Tensor with the channel dimension same as ``input_dims``. - Returns: - A tuple of the ``output`` and the ``attention`` if ``return_attn`` is set to ``True`` - else just the ``output`` tensor. - """ - dims = (x.size(0), -1, x.size(2) * x.size(3)) - out_query = self.query(x).view(dims) - out_key = self.key(x).view(dims).permute(0, 2, 1) - attn = F.softmax(torch.bmm(out_key, out_query), dim=-1) - out_value = self.value(x).view(dims) - out_value = torch.bmm(out_value, attn).view(x.size()) - out = self.gamma * out_value + x - if self.return_attn: - return out, attn - return out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Compute query, key and value using 1x1 convolutions. + query = self.query(x) + key = self.key(x) + value = self.value(x) + # Apply the attention operation. + # TODO See can speed this up, ApplyAlongAxis isn't defined in the pseudocode + out = [] + for b in range(x.shape[0]): + # Apply to each in batch + out.append(attention_einsum(query[b], key[b], value[b])) + out = torch.stack(out, dim=0) + out = self.gamma * self.last_conv(out) + # Residual connection. + return out + x diff --git a/nowcasting_gan/layers/ConvGRU.py b/nowcasting_gan/layers/ConvGRU.py index fde3597..cfe81bb 100644 --- a/nowcasting_gan/layers/ConvGRU.py +++ b/nowcasting_gan/layers/ConvGRU.py @@ -1,344 +1,95 @@ import torch -import torch.nn as nn -import torch.nn.init as init -import torch.nn.functional as functional -from typing import Union, List, Tuple +import torch.nn.functional as F +from torch.nn.utils import spectral_norm -# ------------------------------------------------------------------------------ -# One-dimensional Convolution Gated Recurrent Unit -# ------------------------------------------------------------------------------ +class ConvGRUCell(torch.nn.Module): + """A ConvGRU implementation.""" -class ConvGRU1DCell(nn.Module): + def __init__(self, input_channels: int, output_channels: int, kernel_size=3, sn_eps=0.0001): + """Constructor. - # -------------------------------------------------------------------------- - # Initialization - # -------------------------------------------------------------------------- - - def __init__( - self, - input_channels: int, - hidden_channels: int, - kernel_size: int, - stride: int = 1, - padding: int = 0, - recurrent_kernel_size: int = 3, - ): - """ - One-Dimensional Convolutional Gated Recurrent Unit (ConvGRU1D) cell. - - The input-to-hidden convolution kernel can be defined arbitrarily using - the kernel_size, stride and padding parameters. The hidden-to-hidden - convolution kernel is forced to be unit-stride, with a padding assuming - an odd kernel size, in order to keep the number of features the same. - - The hidden state is initialized by default to a zero tensor of the - appropriate shape. - - Arguments: - input_channels {int} -- [Number of channels of the input tensor] - hidden_channels {int} -- [Number of channels of the hidden state] - kernel_size {int} -- [Size of the input-to-hidden convolving kernel] - - Keyword Arguments: - stride {int} -- [Stride of the input-to-hidden convolution] - (default: {1}) - padding {int} -- [Zero-padding added to both sides of the input] - (default: {0}) - recurrent_kernel_size {int} -- [Size of the hidden-to-hidden - convolving kernel] (default: {3}) - """ - super(ConvGRU1DCell, self).__init__() - # ---------------------------------------------------------------------- - self.kernel_size = kernel_size - self.stride = stride - self.h_channels = hidden_channels - self.padding_ih = padding - self.padding_hh = recurrent_kernel_size // 2 - # ---------------------------------------------------------------------- - self.weight_ih = nn.Parameter( - torch.ones(hidden_channels * 3, input_channels, kernel_size), - requires_grad=True, - ) - self.weight_hh = nn.Parameter( - torch.ones(hidden_channels * 3, input_channels, recurrent_kernel_size), - requires_grad=True, - ) - self.bias_ih = nn.Parameter(torch.zeros(hidden_channels * 3), requires_grad=True) - self.bias_hh = nn.Parameter(torch.zeros(hidden_channels * 3), requires_grad=True) - # ---------------------------------------------------------------------- - self.reset_parameters() - - def reset_parameters(self): - init.orthogonal_(self.weight_hh) - init.xavier_uniform_(self.weight_ih) - init.zeros_(self.bias_hh) - init.zeros_(self.bias_ih) - - # -------------------------------------------------------------------------- - # Processing - # -------------------------------------------------------------------------- - - def forward(self, input, hx=None): - output_size = ( - int((input.size(-1) - self.kernel_size + 2 * self.padding_ih) / self.stride) + 1 - ) - # Handle the case of no hidden state provided - if hx is None: - hx = torch.zeros(input.size(0), self.h_channels, output_size, device=input.device) - # Run the optimized convgru-cell - return _opt_convgrucell_1d( - input, - hx, - self.h_channels, - self.weight_ih, - self.weight_hh, - self.bias_ih, - self.bias_hh, - self.stride, - self.padding_ih, - self.padding_hh, - ) - - -# ------------------------------------------------------------------------------ -# Two-dimensional Convolution Gated Recurrent Unit -# ------------------------------------------------------------------------------ - - -class ConvGRU2DCell(nn.Module): - - # -------------------------------------------------------------------------- - # Initialization - # -------------------------------------------------------------------------- - - def __init__( - self, - input_channels: int, - hidden_channels: int, - kernel_size: Union[int, Tuple[int, int]], - stride: Union[int, Tuple[int, int]] = (1, 1), - padding: Union[int, Tuple[int, int]] = (0, 0), - recurrent_kernel_size: Union[int, Tuple[int, int]] = (3, 3), - ): - """ - Two-Dimensional Convolutional Gated Recurrent Unit (ConvGRU2D) cell. - - The input-to-hidden convolution kernel can be defined arbitrarily using - the kernel_size, stride and padding parameters. The hidden-to-hidden - convolution kernel is forced to be unit-stride, with a padding assuming - an odd kernel size in both dimensions, in order to keep the number of - features the same. - - The hidden state is initialized by default to a zero tensor of the - appropriate shape. - - Arguments: - input_channels {int} -- [Number of channels of the input tensor] - hidden_channels {int} -- [Number of channels of the hidden state] - kernel_size {int or tuple} -- [Size of the input-to-hidden - convolving kernel] - - Keyword Arguments: - stride {int or tuple} -- [Stride of the input-to-hidden convolution] - (default: {(1, 1)}) - padding {int or tuple} -- [Zero-padding added to both sides of the - input] (default: {0}) - recurrent_kernel_size {int or tuple} -- [Size of the hidden-to- - -hidden convolving kernel] - (default: {(3, 3)}) + Args: + kernel_size: kernel size of the convolutions. Default: 3. + sn_eps: constant for spectral normalization. Default: 1e-4. """ - super(ConvGRU2DCell, self).__init__() - # ---------------------------------------------------------------------- - # Handle int to tuple conversion - if isinstance(recurrent_kernel_size, int): - recurrent_kernel_size = (recurrent_kernel_size,) * 2 - if isinstance(kernel_size, int): - kernel_size = (kernel_size,) * 2 - if isinstance(stride, int): - stride = (stride,) * 2 - if isinstance(padding, int): - padding = (padding,) * 2 - # ---------------------------------------------------------------------- - # Save input parameters for later - self.kernel_size = kernel_size - self.stride = stride - self.h_channels = hidden_channels - self.padding_ih = padding - self.padding_hh = ( - recurrent_kernel_size[0] // 2, - recurrent_kernel_size[1] // 2, - ) - # ---------------------------------------------------------------------- - # Initialize the convolution kernels - self.weight_ih = nn.Parameter( - torch.ones( - hidden_channels * 3, - input_channels, - kernel_size[0], - kernel_size[1], + super().__init__() + self._kernel_size = kernel_size + self._sn_eps = sn_eps + self.read_gate_conv = spectral_norm( + torch.nn.Conv2d( + in_channels=input_channels, + out_channels=output_channels, + kernel_size=(kernel_size, kernel_size), + padding=1, ), - requires_grad=True, + eps=sn_eps, ) - self.weight_hh = nn.Parameter( - torch.ones( - hidden_channels * 3, - input_channels, - recurrent_kernel_size[0], - recurrent_kernel_size[1], + self.update_gate_conv = spectral_norm( + torch.nn.Conv2d( + in_channels=input_channels, + out_channels=output_channels, + kernel_size=(kernel_size, kernel_size), + padding=1, ), - requires_grad=True, - ) - self.bias_ih = nn.Parameter(torch.zeros(hidden_channels * 3), requires_grad=True) - self.bias_hh = nn.Parameter(torch.zeros(hidden_channels * 3), requires_grad=True) - # ---------------------------------------------------------------------- - self.reset_parameters() - - def reset_parameters(self): - init.orthogonal_(self.weight_hh) - init.xavier_uniform_(self.weight_ih) - init.zeros_(self.bias_hh) - init.zeros_(self.bias_ih) - - # -------------------------------------------------------------------------- - # Processing - # -------------------------------------------------------------------------- - - def forward(self, input, hx=None): - output_size = ( - int((input.size(-2) - self.kernel_size[0] + 2 * self.padding_ih[0]) / self.stride[0]) - + 1, - int((input.size(-1) - self.kernel_size[1] + 2 * self.padding_ih[1]) / self.stride[1]) - + 1, + eps=sn_eps, ) - # Handle the case of no hidden state provided - if hx is None: - hx = torch.zeros(input.size(0), self.h_channels, *output_size, device=input.device) - # Run the optimized convgru-cell - return _opt_convgrucell_2d( - input, - hx, - self.h_channels, - self.weight_ih, - self.weight_hh, - self.bias_ih, - self.bias_hh, - self.stride, - self.padding_ih, - self.padding_hh, + self.output_conv = spectral_norm( + torch.nn.Conv2d( + in_channels=input_channels, + out_channels=output_channels, + kernel_size=(kernel_size, kernel_size), + padding=1, + ), + eps=sn_eps, ) - -class ConvGRU(nn.Module): - def __init__( - self, - input_channels: int, - hidden_channels: int, - kernel_size: Union[int, Tuple[int, int]], - stride: Union[int, Tuple[int, int]] = (1, 1), - padding: Union[int, Tuple[int, int]] = (0, 0), - recurrent_kernel_size: Union[int, Tuple[int, int]] = (3, 3), - ): - """ - Recurrent wrapper for use with any rnn cell that takes as input a tensor - and a hidden state and returns an updated hidden state. This wrapper - returns the full sequence of hidden states. It assumes the first - dimension corresponds to the timesteps, and that the other dimensions - are directly compatible with the given rnn cell. - Implements a very basic truncated backpropagation through time - corresponding to the case k1=k2 (see 'An Efficient Gradient-Based - Algorithm for On-Line Training of Recurrent Network Trajectories', - Ronald J. Williams and Jing Pen, Neural Computation, vol. 2, - pp. 490-501, 1990). - Args: - rnn_cell (nn.Module): [The torch module that takes one timestep of - the input tensor and the hidden state and returns a new hidden - state] - truncation_steps (int, optional): [The maximum number of timesteps - to include in the backpropagation graph. This can help speed up - runtime on CPU and avoid vanishing gradient problems, however - it is mostly useful for very long sequences]. Defaults to None. + def forward(self, x, prev_state): """ - super(ConvGRU, self).__init__() + ConvGRU forward, returning the current+new state - self.rnn_cell = ConvGRU2DCell( - input_channels, hidden_channels, kernel_size, stride, padding, recurrent_kernel_size - ) + Args: + x: Input tensor + prev_state: Previous state - def forward(self, input, hidden_state=None): - output = [] - for step in range(input.size(1)): - # Compute current time-step - hidden_state = self.rnn_cell(input[:, step, :, :, :], hidden_state) - output.append(hidden_state) - # Stack the list of output hidden states into a tensor - output = torch.stack(output, 0) - return output + Returns: + New tensor plus the new state + """ + # Concatenate the inputs and previous state along the channel axis. + xh = torch.cat([x, prev_state], dim=1) + # Read gate of the GRU. + read_gate = F.sigmoid(self.read_gate_conv(xh)) -# -------------------------------------------------------------------------- -# Torchscript optimized cell functions -# -------------------------------------------------------------------------- + # Update gate of the GRU. + update_gate = F.sigmoid(self.update_gate_conv(xh)) + # Gate the inputs. + gated_input = torch.cat([x, read_gate * prev_state], dim=1) -@torch.jit.script -def _opt_cell_end(hidden, ih_1, hh_1, ih_2, hh_2, ih_3, hh_3): - z = torch.sigmoid(ih_1 + hh_1) - r = torch.sigmoid(ih_2 + hh_2) - n = torch.tanh(ih_3 + r * hh_3) - out = (1 - z) * n + z * hidden - return out + # Gate the cell and state / outputs. + c = F.relu(self.output_conv(gated_input)) + out = update_gate * prev_state + (1.0 - update_gate) * c + new_state = out + return out, new_state -@torch.jit.script -def _opt_convgrucell_1d( - inputs, - hidden, - channels: int, - w_ih, - w_hh, - b_ih, - b_hh, - stride: int, - pad1: int, - pad2: int, -): - ih_output = functional.conv1d(inputs, w_ih, bias=b_ih, stride=stride, padding=pad1) - hh_output = functional.conv1d(hidden, w_hh, bias=b_hh, stride=1, padding=pad2) - output = _opt_cell_end( - hidden, - torch.narrow(ih_output, 1, 0, channels), - torch.narrow(hh_output, 1, 0, channels), - torch.narrow(ih_output, 1, channels, channels), - torch.narrow(hh_output, 1, channels, channels), - torch.narrow(ih_output, 1, 2 * channels, channels), - torch.narrow(hh_output, 1, 2 * channels, channels), - ) - return output +class ConvGRU(torch.nn.Module): + """ConvGRU Cell wrapper to replace tf.static_rnn in TF implementation""" -@torch.jit.script -def _opt_convgrucell_2d( - inputs, - hidden, - channels: int, - w_ih, - w_hh, - b_ih, - b_hh, - stride: List[int], - pad1: List[int], - pad2: List[int], -): - ih_output = functional.conv2d(inputs, w_ih, bias=b_ih, stride=stride, padding=pad1) - hh_output = functional.conv2d(hidden, w_hh, bias=b_hh, stride=1, padding=pad2) - output = _opt_cell_end( - hidden, - torch.narrow(ih_output, 1, 0, channels), - torch.narrow(hh_output, 1, 0, channels), - torch.narrow(ih_output, 1, channels, channels), - torch.narrow(hh_output, 1, channels, channels), - torch.narrow(ih_output, 1, 2 * channels, channels), - torch.narrow(hh_output, 1, 2 * channels, channels), - ) - return output + def __init__( + self, input_channels: int, output_channels: int, kernel_size: int = 3, sn_eps=0.0001 + ): + super().__init__() + self.cell = ConvGRUCell(input_channels, output_channels, kernel_size, sn_eps) + + def forward(self, x: torch.Tensor, hidden_state=None) -> torch.Tensor: + outputs = [] + for step in range(len(x)): + # Compute current timestep + output, hidden_state = self.cell(x[step], hidden_state) + outputs.append(output) + # Stack outputs to return as tensor + outputs = torch.stack(outputs, dim=0) + return outputs diff --git a/nowcasting_gan/layers/__init__.py b/nowcasting_gan/layers/__init__.py index 610aabb..ef61020 100644 --- a/nowcasting_gan/layers/__init__.py +++ b/nowcasting_gan/layers/__init__.py @@ -1,3 +1,3 @@ -from .Attention import SelfAttention2d +from .Attention import AttentionLayer from .ConvGRU import ConvGRU from .CoordConv import CoordConv diff --git a/nowcasting_gan/losses.py b/nowcasting_gan/losses.py index dc17ff7..6281f1d 100644 --- a/nowcasting_gan/losses.py +++ b/nowcasting_gan/losses.py @@ -255,6 +255,38 @@ def forward(self, logit, target): return loss +def loss_hinge_disc(score_generated, score_real): + """Discriminator hinge loss.""" + l1 = F.relu(1.0 - score_real) + loss = torch.mean(l1) + l2 = F.relu(1.0 + score_generated) + loss += torch.mean(l2) + return loss + + +def loss_hinge_gen(score_generated): + """Generator hinge loss.""" + loss = -torch.mean(score_generated) + return loss + + +def grid_cell_regularizer(generated_samples, batch_targets): + """Grid cell regularizer. + + Args: + generated_samples: Tensor of size [n_samples, batch_size, 18, 256, 256, 1]. + batch_targets: Tensor of size [batch_size, 18, 256, 256, 1]. + + Returns: + loss: A tensor of shape [batch_size]. + """ + gen_mean = torch.mean(generated_samples, dim=0) + # TODO Possibly change clip here? + weights = torch.clip(batch_targets, 0.0, 24.0) + loss = torch.mean(torch.abs(gen_mean - batch_targets) * weights) + return loss + + def get_loss(loss: str = "mse", **kwargs) -> torch.nn.Module: if isinstance(loss, torch.nn.Module): return loss diff --git a/pseudocode.zip b/pseudocode.zip new file mode 100644 index 0000000..301e6a5 Binary files /dev/null and b/pseudocode.zip differ diff --git a/pseudocode/LICENSE b/pseudocode/LICENSE new file mode 100644 index 0000000..d651c26 --- /dev/null +++ b/pseudocode/LICENSE @@ -0,0 +1,7 @@ +Nowcasting DGMR Pseudocode (c) by DeepMind + +Nowcasting DGMR Pseudocode is licensed under a +Creative Commons Attribution 4.0 International License. + +You should have received a copy of the license along with this +work. If not, see . diff --git a/pseudocode/README.md b/pseudocode/README.md new file mode 100644 index 0000000..ded5f73 --- /dev/null +++ b/pseudocode/README.md @@ -0,0 +1,15 @@ +# Pseudocode for Precipitation Nowcasting using Deep Generative Models of Radar. + +This archive contains a pseudocode for the DGMR nowcasting model. It is written in +Python syntax, but is not runnable as it serves only for reference purposes. + +`discriminator.txt` - Describes the architecture of the various discriminators. + +`generator.txt` - Describes the architecture of the generator. + +`latent_stack.txt` - Describes the architecture of the fully-convolutional +convolutional latent stack. + +`layers.txt` - Describes the various common modules. + +`train.txt` - Pseudocode for the training loop. diff --git a/pseudocode/discriminator.txt b/pseudocode/discriminator.txt new file mode 100644 index 0000000..524efcc --- /dev/null +++ b/pseudocode/discriminator.txt @@ -0,0 +1,243 @@ +"""Discriminator implementation.""" + +from . import layers +import tensorflow.compat.v1 as tf + + +class Discriminator(object): + """Discriminator.""" + + def __init__(self): + """Constructor.""" + # Number of random time steps for the spatial discriminator. + self._num_spatial_frames = 8 + # Input size ratio with respect to crop size for the temporal discriminator. + self._temporal_crop_ratio = 2 + # As the input is the whole sequence of the event (including conditioning + # frames), the spatial discriminator needs to pick only the t > T+0. + self._num_conditioning_frames = 4 + self._spatial_discriminator = SpatialDiscriminator() + self._temporal_discriminator = TemporalDiscriminator() + + def __call__(self, frames): + """Build the discriminator. + + Args: + frames: a tensor with a complete observation [b, 22, 256, 256, 1]. + + Returns: + A tensor with discriminator loss scalars [b, 2]. + """ + b, t, h, w, c = tf.shape(frames).as_list() + + # Prepare the frames for spatial discriminator: pick 8 random time steps out + # of 18 lead time steps, and downsample from 256x256 to 128x128. + target_frames_sel = tf.range(self._num_conditioning_frames, t) + permutation = tf.stack([ + tf.random_shuffle(target_frames_sel)[:self._num_spatial_frames] + for _ in range(b) + ], 0) + frames_for_sd = tf.gather(frames, permutation, batch_dims=1) + frames_for_sd = tf.layers.average_pooling3d( + frames_for_sd, [1, 2, 2], [1, 2, 2], data_format='channels_last') + + # Compute the average spatial discriminator score for each of 8 picked time + # steps. + sd_out = self._spatial_discriminator(frames_for_sd) + + # Prepare the frames for temporal discriminator: choose the offset of a + # random crop of size 128x128 out of 256x256 and pick full sequence samples. + cr = self._temporal_crop_ratio + h_offset = tf.random_uniform([], 0, (cr - 1) * (h // cr), tf.int32) + w_offset = tf.random_uniform([], 0, (cr - 1) * (w // cr), tf.int32) + zero_offset = tf.zeros_like(w_offset) + begin_tensor = tf.stack( + [zero_offset, zero_offset, h_offset, w_offset, zero_offset], -1) + size_tensor = tf.constant([b, t, h // cr, w // cr, c]) + frames_for_td = tf.slice(frames, begin_tensor, size_tensor) + frames_for_td.set_shape([b, t, h // cr, w // cr, c]) + + # Compute the average temporal discriminator score over length 5 sequences. + td_out = self._temporal_discriminator(frames_for_td) + + return tf.concat([sd_out, td_out], 1) + + +class DBlock(object): + """Convolutional residual block.""" + + def __init__(self, output_channels, kernel_size=3, downsample=True, + pre_activation=True, conv=layers.SNConv2D, + pooling=layers.downsample_avg_pool, activation=tf.nn.relu): + """Constructor for the D blocks of the DVD-GAN. + + Args: + output_channels: Integer number of channels in the second convolution, and + number of channels in the residual 1x1 convolution module. + kernel_size: Integer kernel size of the convolutions. + downsample: Boolean: shall we use the average pooling layer? + pre_activation: Boolean: shall we apply pre-activation to inputs? + conv: TF module, either layers.Conv2D or a wrapper with spectral + normalisation layers.SNConv2D. + pooling: Average pooling layer. Default: layers.downsample_avg_pool. + activation: Activation at optional preactivation and first conv layers. + """ + self._output_channels = output_channels + self._kernel_size = kernel_size + self._downsample = downsample + self._pre_activation = pre_activation + self._conv = conv + self._pooling = pooling + self._activation = activation + + def __call__(self, inputs): + """Build the DBlock. + + Args: + inputs: a tensor with a complete observation [b, 256, 256, 1] + + Returns: + A tensor with discriminator loss scalars [b]. + """ + h0 = inputs + + # Pre-activation. + if self._pre_activation: + h0 = self._activation(h0) + + # First convolution. + input_channels = h0.shape.as_list()[-1] + h1 = self._conv(num_channels=input_channels, + kernel_size=self._kernel_size)(h0) + h1 = self._activation(h1) + + # Second convolution. + h2 = self._conv(num_channels=self._output_channels, + kernel_size=self._kernel_size)(h1) + + # Downsampling. + if self._downsample: + h2 = self.pooling(h2) + + # The residual connection, make sure it has the same dimensionality + # with additional 1x1 convolution and downsampling if needed. + if input_channels != self._output_channels or self._downsample: + sc = self._conv(num_channels=self._output_channels, + kernel_size=1)(inputs) + if self.downsample: + sc = self.pooling(sc) + else: + sc = inputs + + # Residual connection. + return h2 + sc + + +class SpatialDiscriminator(object): + """Spatial Discriminator.""" + + def __init__(self): + pass + + def __call__(self, frames): + """Build the spatial discriminator. + + Args: + frames: a tensor with a complete observation [b, n, 128, 128, 1]. + + Returns: + A tensor with discriminator loss scalars [b]. + """ + b, n, h, w, c = tf.shape(frames).as_list() + + # Process each of the n inputs independently. + frames = tf.reshape(frames, [b * n, h, w, c]) + + # Space-to-depth stacking from 128x128x1 to 64x64x4. + frames = tf.nn.space_to_depth(frames, block_size=2) + + # Five residual D Blocks to halve the resolution of the image and double + # the number of channels. + y = DBlock(output_channels=48, pre_activation=False)(frames) + y = DBlock(output_channels=96)(y) + y = DBlock(output_channels=192)(y) + y = DBlock(output_channels=384)(y) + y = DBlock(output_channels=768)(y) + + # One more D Block without downsampling or increase in number of channels. + y = DBlock(output_channels=768, downsample=False)(y) + + # Sum-pool the representations and feed to spectrally normalized lin. layer. + y = tf.reduce_sum(tf.nn.relu(y), axis=[1, 2]) + y = layers.BatchNorm(calc_sigma=False)(y) + output_layer = layers.Linear(output_size=1) + output = output_layer(y) + + # Take the sum across the t samples. Note: we apply the ReLU to + # (1 - score_real) and (1 + score_generated) in the loss. + output = tf.reshape(output, [b, n, 1]) + output = tf.reduce_sum(output, keepdims=True, axis=1) + return output + + +class TemporalDiscriminator(object): + """Spatial Discriminator.""" + + def __init__(self): + pass + + def __call__(self, frames): + """Build the temporal discriminator. + + Args: + frames: a tensor with a complete observation [b, ts, 128, 128, 1] + + Returns: + A tensor with discriminator loss scalars [b]. + """ + b, ts, hs, ws, cs = tf.shape(frames).as_list() + + # Process each of the ti inputs independently. + frames = tf.reshape(frames, [b * ts, hs, ws, cs]) + + # Space-to-depth stacking from 128x128x1 to 64x64x4. + frames = tf.nn.space_to_depth(frames, block_size=2) + + # Stack back to sequences of length ti. + frames = tf.reshape(frames, [b, ts, hs, ws, cs]) + + # Two residual 3D Blocks to halve the resolution of the image, double + # the number of channels, and reduce the number of time steps. + y = DBlock(output_channels=48, conv=layers.SNConv3D, + pooling=layers.downsample_avg_pool3d, + pre_activation=False)(frames) + y = DBlock(output_channels=96, conv=layers.SNConv3D, + pooling=layers.downsample_avg_pool3d)(y) + + # Get t < ts, h, w, and c, as we have downsampled in 3D. + _, t, h, w, c = tf.shape(frames).as_list() + + # Process each of the t images independently. + # b t h w c -> (b x t) h w c + y = tf.reshape(y, [-1] + [h, w, c]) + + # Three residual D Blocks to halve the resolution of the image and double + # the number of channels. + y = DBlock(output_channels=192)(y) + y = DBlock(output_channels=384)(y) + y = DBlock(output_channels=768)(y) + + # One more D Block without downsampling or increase in number of channels. + y = DBlock(output_channels=768, downsample=False)(y) + + # Sum-pool the representations and feed to spectrally normalized lin. layer. + y = tf.reduce_sum(tf.nn.relu(y), axis=[1, 2]) + y = layers.BatchNorm(calc_sigma=False)(y) + output_layer = layers.Linear(output_size=1) + output = output_layer(y) + + # Take the sum across the t samples. Note: we apply the ReLU to + # (1 - score_real) and (1 + score_generated) in the loss. + output = tf.reshape(output, [b, t, 1]) + scores = tf.reduce_sum(output, keepdims=True, axis=1) + return scores diff --git a/pseudocode/generator.txt b/pseudocode/generator.txt new file mode 100644 index 0000000..1e65a0a --- /dev/null +++ b/pseudocode/generator.txt @@ -0,0 +1,269 @@ +"""Generator implementation.""" + +import functools +from . import discriminator +from . import latent_stack +from . import layers +import tensorflow.compat.v1 as tf + + +class Generator(object): + """Generator for the proposed model.""" + + def __init__(self, lead_time=90, time_delta=5): + """Constructor. + + Args: + lead_time: last lead time for the generator to predict. Default: 90 min. + time_delta: time step between predictions. Default: 5 min. + """ + self._cond_stack = ConditioningStack() + self._sampler = Sampler(lead_time, time_delta) + + def __call__(self, inputs): + """Connect to a graph. + + Args: + inputs: a batch of inputs on the shape [batch_size, time, h, w, 1]. + Returns: + predictions: a batch of predictions in the form + [batch_size, num_lead_times, h, w, 1]. + """ + _, _, height, width, _ = inputs.shape.as_list() + initial_states = self._cond_stack(inputs) + predictions = self._sampler(initial_states, [height, width]) + return predictions + + def get_variables(self): + """Get all variables of the module.""" + pass + + +class ConditioningStack(object): + """Conditioning Stack for the Generator.""" + + def __init__(self): + self._block1 = discriminator.DBlock(output_channels=48, downsample=True) + self._conv_mix1 = layers.SNConv2D(output_channels=48, kernel_size=3) + self._block2 = discriminator.DBlock(output_channels=96, downsample=True) + self._conv_mix2 = layers.SNConv2D(output_channels=96, kernel_size=3) + self._block3 = discriminator.DBlock(output_channels=192, downsample=True) + self._conv_mix3 = layers.SNConv2D(output_channels=192, kernel_size=3) + self._block4 = discriminator.DBlock(output_channels=384, downsample=True) + self._conv_mix4 = layers.SNConv2D(output_channels=384, kernel_size=3) + + def __call__(self, inputs): + # Space to depth conversion of 256x256x1 radar to 128x128x4 hiddens. + h0 = batch_apply( + functools.partial(tf.nn.space_to_depth, block_size=2), inputs) + + # Downsampling residual D Blocks. + h1 = time_apply(self._block1, h0) + h2 = time_apply(self._block2, h1) + h3 = time_apply(self._block3, h2) + h4 = time_apply(self._block4, h3) + + # Spectrally normalized convolutions, followed by rectified linear units. + init_state_1 = self._mixing_layer(h1, self._conv_mix1) + init_state_2 = self._mixing_layer(h2, self._conv_mix2) + init_state_3 = self._mixing_layer(h3, self._conv_mix3) + init_state_4 = self._mixing_layer(h4, self._conv_mix4) + + # Return a stack of conditioning representations of size 64x64x48, 32x32x96, + # 16x16x192 and 8x8x384. + return init_state_1, init_state_2, init_state_3, init_state_4 + + def _mixing_layer(self, inputs, conv_block): + # Convert from [batch_size, time, h, w, c] -> [batch_size, h, w, c * time] + # then perform convolution on the output while preserving number of c. + stacked_inputs = tf.concat(tf.unstack(inputs, axis=1), axis=-1) + return tf.nn.relu(conv_block(stacked_inputs)) + + +class Sampler(object): + """Sampler for the Generator.""" + + def __init__(self, lead_time=90, time_delta=5): + self._num_predictions = lead_time // time_delta + self._latent_stack = latent_stack.LatentCondStack() + + self._conv_gru4 = ConvGRU() + self._conv4 = layers.SNConv2D(kernel_size=1, output_channels=768) + self._gblock4 = GBlock(output_channels=768) + self._g_up_block4 = UpsampleGBlock(output_channels=384) + + self._conv_gru3 = ConvGRU() + self._conv3 = layers.SNConv2D(kernel_size=1, output_channels=384) + self._gblock3 = GBlock(output_channels=384) + self._g_up_block3 = UpsampleGBlock(output_channels=192) + + self._conv_gru2 = ConvGRU() + self._conv2 = layers.SNConv2D(kernel_size=1, output_channels=192) + self._gblock2 = GBlock(output_channels=192) + self._g_up_block2 = GBlock(output_channels=96) + + self._conv_gru1 = ConvGRU() + self._conv1 = layers.SNConv2D(kernel_size=1, output_channels=96) + self._gblock1 = GBlock(output_channels=96) + self._g_up_block1 = UpsampleGBlock(output_channels=48) + + self._bn = layers.BatchNorm() + self._output_conv = layers.SNConv2D(kernel_size=1, output_channels=4) + + def __call__(self, initial_states, resolution): + init_state_1, init_state_2, init_state_3, init_state_4 = initial_states + batch_size = init_state_1.shape.as_list()[0] + + # Latent conditioning stack. + z = self._latent_stack(batch_size, resolution) + hs = [z] * self._num_predictions + + # Layer 4 (bottom-most). + hs, _ = tf.nn.static_rnn(self._conv_gru4, hs, init_state_4) + hs = [self._conv4(h) for h in hs] + hs = [self._gblock4(h) for h in hs] + hs = [self._g_up_block4(h) for h in hs] + + # Layer 3. + hs, _ = tf.nn.static_rnn(self._conv_gru3, hs, init_state_3) + hs = [self._conv3(h) for h in hs] + hs = [self._gblock3(h) for h in hs] + hs = [self._g_up_block3(h) for h in hs] + + # Layer 2. + hs, _ = tf.nn.static_rnn(self._conv_gru2, hs, init_state_2) + hs = [self._conv2(h) for h in hs] + hs = [self._gblock2(h) for h in hs] + hs = [self._g_up_block2(h) for h in hs] + + # Layer 1 (top-most). + hs, _ = tf.nn.static_rnn(self._conv_gru1, hs, init_state_1) + hs = [self._conv1(h) for h in hs] + hs = [self._gblock1(h) for h in hs] + hs = [self._g_up_block1(h) for h in hs] + + # Output layer. + hs = [tf.nn.relu(self._bn(h)) for h in hs] + hs = [self._output_conv(h) for h in hs] + hs = [tf.nn.depth_to_space(h, 2) for h in hs] + + return tf.stack(hs, axis=1) + + +class GBlock(object): + """Residual generator block without upsampling.""" + + def __init__(self, output_channels, sn_eps=0.0001): + self._conv1_3x3 = layers.SNConv2D( + output_channels, kernel_size=3, sn_eps=sn_eps) + self._bn1 = layers.BatchNorm() + self._conv2_3x3 = layers.SNConv2D( + output_channels, kernel_size=3, sn_eps=sn_eps) + self._bn2 = layers.BatchNorm() + self._output_channels = output_channels + self._sn_eps = sn_eps + + def __call__(self, inputs): + input_channels = inputs.shape[-1] + + # Optional spectrally normalized 1x1 convolution. + if input_channels != self._output_channels: + conv_1x1 = layers.SNConv2D( + self._output_channels, kernel_size=1, sn_eps=self._sn_eps) + sc = conv_1x1(inputs) + else: + sc = inputs + + # Two-layer residual connection, with batch normalization, nonlinearity and + # 3x3 spectrally normalized convolution in each layer. + h = tf.nn.relu(self._bn1(inputs)) + h = self._conv1_3x3(h) + h = tf.nn.relu(self._bn2(h)) + h = self._conv2_3x3(h) + + # Residual connection. + return h + sc + + +class UpsampleGBlock(object): + """Upsampling residual generator block.""" + + def __init__(self, output_channels, sn_eps=0.0001): + self._conv_1x1 = layers.SNConv2D( + output_channels, kernel_size=1, sn_eps=sn_eps) + self._conv1_3x3 = layers.SNConv2D( + output_channels, kernel_size=3, sn_eps=sn_eps) + self._bn1 = layers.BatchNorm() + self._conv2_3x3 = layers.SNConv2D( + output_channels, kernel_size=3, sn_eps=sn_eps) + self._bn2 = layers.BatchNorm() + self._output_channels = output_channels + + def __call__(self, inputs): + # x2 upsampling and spectrally normalized 1x1 convolution. + sc = layers.upsample_nearest_neighbor(inputs, upsample_size=2) + sc = self._conv_1x1(sc) + + # Two-layer residual connection, with batch normalization, nonlinearity and + # 3x3 spectrally normalized convolution in each layer, and x2 upsampling in + # the first layer. + h = tf.nn.relu(self._bn1(inputs)) + h = layers.upsample_nearest_neighbor(h, upsample_size=2) + h = self._conv1_3x3(h) + h = tf.nn.relu(self._bn2(h)) + h = self._conv2_3x3(h) + + # Residual connection. + return h + sc + + +class ConvGRU(object): + """A ConvGRU implementation.""" + + def __init__(self, kernel_size=3, sn_eps=0.0001): + """Constructor. + + Args: + kernel_size: kernel size of the convolutions. Default: 3. + sn_eps: constant for spectral normalization. Default: 1e-4. + """ + self._kernel_size = kernel_size + self._sn_eps = sn_eps + + def __call__(self, inputs, prev_state): + + # Concatenate the inputs and previous state along the channel axis. + num_channels = prev_state.shape[-1] + xh = tf.concat([inputs, prev_state], axis=-1) + + # Read gate of the GRU. + read_gate_conv = layers.SNConv2D( + num_channels, self._kernel_size, sn_eps=self._sn_eps) + read_gate = tf.math.sigmoid(read_gate_conv(xh)) + + # Update gate of the GRU. + update_gate_conv = layers.SNConv2D( + num_channels, self._kernel_size, sn_eps=self._sn_eps) + update_gate = tf.math.sigmoid(update_gate_conv(xh)) + + # Gate the inputs. + gated_input = tf.concat([inputs, read_gate * prev_state], axis=-1) + + # Gate the cell and state / outputs. + output_conv = layers.SNConv2D( + num_channels, self._kernel_size, sn_eps=self._sn_eps) + c = tf.nn.relu(output_conv(gated_input)) + out = update_gate * prev_state + (1. - update_gate) * c + new_state = out + + return out, new_state + + +def time_apply(func, inputs): + """Apply function func on each element of inputs along the time axis.""" + return layers.ApplyAlongAxis(func, axis=1)(inputs) + + +def batch_apply(func, inputs): + """Apply function func on each element of inputs along the batch axis.""" + return layers.ApplyAlongAxis(func, axis=0)(inputs) diff --git a/pseudocode/latent_stack.txt b/pseudocode/latent_stack.txt new file mode 100644 index 0000000..e948200 --- /dev/null +++ b/pseudocode/latent_stack.txt @@ -0,0 +1,141 @@ +"""Latent Conditioning Stack.""" + +from . import layers +import tensorflow.compat.v1 as tf + + +class LatentCondStack(object): + """Latent Conditioning Stack for the Sampler.""" + + def __init__(self): + self._conv1 = layers.SNConv2D(output_channels=8, kernel_size=3) + self._lblock1 = LBlock(output_channels=24) + self._lblock2 = LBlock(output_channels=48) + self._lblock3 = LBlock(output_channels=192) + self._mini_attn_block = Attention(num_channels=192) + self._lblock4 = LBlock(output_channels=768) + + def __call__(self, batch_size, resolution=(256, 256)): + + # Independent draws from a Normal distribution. + h, w = resolution[0] // 32, resolution[1] // 32 + z = tf.random.normal([batch_size, h, w, 8]) + + # 3x3 convolution. + z = self._conv1(z) + + # Three L Blocks to increase the number of channels to 24, 48, 192. + z = self._lblock1(z) + z = self._lblock2(z) + z = self._lblock3(z) + + # Spatial attention module. + z = self._mini_atten_block(z) + + # L Block to increase the number of channels to 768. + z = self._lblock4(z) + + return z + + +class LBlock(object): + """Residual block for the Latent Stack.""" + + def __init__(self, output_channels, kernel_size=3, conv=layers.Conv2D, + activation=tf.nn.relu): + """Constructor for the D blocks of the DVD-GAN. + + Args: + output_channels: Integer number of channels in convolution operations in + the main branch, and number of channels in the output of the block. + kernel_size: Integer kernel size of the convolutions. Default: 3. + conv: TF module. Default: layers.Conv2D. + activation: Activation before the conv. layers. Default: tf.nn.relu. + """ + self._output_channels = output_channels + self._kernel_size = kernel_size + self._conv = conv + self._activation = activation + + def __call__(self, inputs): + """Build the LBlock. + + Args: + inputs: a tensor with a complete observation [N 256 256 1] + + Returns: + A tensor with discriminator loss scalars [B]. + """ + + # Stack of two conv. layers and nonlinearities that increase the number of + # channels. + h0 = self._activation(inputs) + h1 = self._conv(num_channels=self.output_channels, + kernel_size=self._kernel_size)(h0) + h1 = self._activation(h1) + h2 = self._conv(num_channels=self._output_channels, + kernel_size=self._kernel_size)(h1) + + # Prepare the residual connection branch. + input_channels = h0.shape.as_list()[-1] + if input_channels < self._output_channels: + sc = self._conv(num_channels=self._output_channels - input_channels, + kernel_size=1)(inputs) + sc = tf.concat([inputs, sc], axis=-1) + else: + sc = inputs + + # Residual connection. + return h2 + sc + + +def attention_einsum(q, k, v): + """Apply the attention operator to tensors of shape [h, w, c].""" + + # Reshape 3D tensors to 2D tensor with first dimension L = h x w. + k = tf.reshape(k, [-1, k.shape[-1]]) # [h, w, c] -> [L, c] + v = tf.reshape(v, [-1, v.shape[-1]]) # [h, w, c] -> [L, c] + + # Einstein summation corresponding to the query * key operation. + beta = tf.nn.softmax(tf.einsum('hwc, Lc->hwL', q, k), axis=-1) + + # Einstein summation corresponding to the attention * value operation. + out = tf.einsum('hwL, Lc->hwc', beta, v) + return out + + +class Attention(object): + """Attention module.""" + + def __init__(self, num_channels, ratio_kq=8, ratio_v=8, conv=layers.Conv2D): + """Constructor.""" + self._num_channels = num_channels + self._ratio_kq = ratio_kq + self._ratio_v = ratio_v + self._conv = conv + + # Learnable gain parameter + self._gamma = tf.get_variable( + 'miniattn_gamma', shape=[], + initializer=tf.initializers.zeros(tf.float32)) + + def __call__(self, tensor): + # Compute query, key and value using 1x1 convolutions. + query = self._conv( + output_channels=self._num_channels // self._ratio_kq, + kernel_size=1, padding='VALID', use_bias=False)(tensor) + key = self._conv( + output_channels=self._num_channels // self._ratio_kq, + kernel_size=1, padding='VALID', use_bias=False)(tensor) + value = self._conv( + output_channels=self._num_channels // self._ratio_v, + kernel_size=1, padding='VALID', use_bias=False)(tensor) + + # Apply the attention operation. + out = layers.ApplyAlongAxis(attention_einsum, axis=0)(query, key, value) + out = self._gamma * self._conv( + output_channels=self._num_channels, + kernel_size=1, padding='VALID', use_bias=False)(out) + + # Residual connection. + return out + tensor diff --git a/pseudocode/layers.txt b/pseudocode/layers.txt new file mode 100644 index 0000000..e948200 --- /dev/null +++ b/pseudocode/layers.txt @@ -0,0 +1,141 @@ +"""Latent Conditioning Stack.""" + +from . import layers +import tensorflow.compat.v1 as tf + + +class LatentCondStack(object): + """Latent Conditioning Stack for the Sampler.""" + + def __init__(self): + self._conv1 = layers.SNConv2D(output_channels=8, kernel_size=3) + self._lblock1 = LBlock(output_channels=24) + self._lblock2 = LBlock(output_channels=48) + self._lblock3 = LBlock(output_channels=192) + self._mini_attn_block = Attention(num_channels=192) + self._lblock4 = LBlock(output_channels=768) + + def __call__(self, batch_size, resolution=(256, 256)): + + # Independent draws from a Normal distribution. + h, w = resolution[0] // 32, resolution[1] // 32 + z = tf.random.normal([batch_size, h, w, 8]) + + # 3x3 convolution. + z = self._conv1(z) + + # Three L Blocks to increase the number of channels to 24, 48, 192. + z = self._lblock1(z) + z = self._lblock2(z) + z = self._lblock3(z) + + # Spatial attention module. + z = self._mini_atten_block(z) + + # L Block to increase the number of channels to 768. + z = self._lblock4(z) + + return z + + +class LBlock(object): + """Residual block for the Latent Stack.""" + + def __init__(self, output_channels, kernel_size=3, conv=layers.Conv2D, + activation=tf.nn.relu): + """Constructor for the D blocks of the DVD-GAN. + + Args: + output_channels: Integer number of channels in convolution operations in + the main branch, and number of channels in the output of the block. + kernel_size: Integer kernel size of the convolutions. Default: 3. + conv: TF module. Default: layers.Conv2D. + activation: Activation before the conv. layers. Default: tf.nn.relu. + """ + self._output_channels = output_channels + self._kernel_size = kernel_size + self._conv = conv + self._activation = activation + + def __call__(self, inputs): + """Build the LBlock. + + Args: + inputs: a tensor with a complete observation [N 256 256 1] + + Returns: + A tensor with discriminator loss scalars [B]. + """ + + # Stack of two conv. layers and nonlinearities that increase the number of + # channels. + h0 = self._activation(inputs) + h1 = self._conv(num_channels=self.output_channels, + kernel_size=self._kernel_size)(h0) + h1 = self._activation(h1) + h2 = self._conv(num_channels=self._output_channels, + kernel_size=self._kernel_size)(h1) + + # Prepare the residual connection branch. + input_channels = h0.shape.as_list()[-1] + if input_channels < self._output_channels: + sc = self._conv(num_channels=self._output_channels - input_channels, + kernel_size=1)(inputs) + sc = tf.concat([inputs, sc], axis=-1) + else: + sc = inputs + + # Residual connection. + return h2 + sc + + +def attention_einsum(q, k, v): + """Apply the attention operator to tensors of shape [h, w, c].""" + + # Reshape 3D tensors to 2D tensor with first dimension L = h x w. + k = tf.reshape(k, [-1, k.shape[-1]]) # [h, w, c] -> [L, c] + v = tf.reshape(v, [-1, v.shape[-1]]) # [h, w, c] -> [L, c] + + # Einstein summation corresponding to the query * key operation. + beta = tf.nn.softmax(tf.einsum('hwc, Lc->hwL', q, k), axis=-1) + + # Einstein summation corresponding to the attention * value operation. + out = tf.einsum('hwL, Lc->hwc', beta, v) + return out + + +class Attention(object): + """Attention module.""" + + def __init__(self, num_channels, ratio_kq=8, ratio_v=8, conv=layers.Conv2D): + """Constructor.""" + self._num_channels = num_channels + self._ratio_kq = ratio_kq + self._ratio_v = ratio_v + self._conv = conv + + # Learnable gain parameter + self._gamma = tf.get_variable( + 'miniattn_gamma', shape=[], + initializer=tf.initializers.zeros(tf.float32)) + + def __call__(self, tensor): + # Compute query, key and value using 1x1 convolutions. + query = self._conv( + output_channels=self._num_channels // self._ratio_kq, + kernel_size=1, padding='VALID', use_bias=False)(tensor) + key = self._conv( + output_channels=self._num_channels // self._ratio_kq, + kernel_size=1, padding='VALID', use_bias=False)(tensor) + value = self._conv( + output_channels=self._num_channels // self._ratio_v, + kernel_size=1, padding='VALID', use_bias=False)(tensor) + + # Apply the attention operation. + out = layers.ApplyAlongAxis(attention_einsum, axis=0)(query, key, value) + out = self._gamma * self._conv( + output_channels=self._num_channels, + kernel_size=1, padding='VALID', use_bias=False)(out) + + # Residual connection. + return out + tensor diff --git a/pseudocode/train.txt b/pseudocode/train.txt new file mode 100644 index 0000000..4e2a57b --- /dev/null +++ b/pseudocode/train.txt @@ -0,0 +1,111 @@ +"""Pseudocode for the training loop, assuming the UK data. + +This code presents, as clearly as possible, the algorithmic logic behind +the generative method. It does not include some control dependencies and +initialization ops that are specific to the hardware architecture on which it is +run as well as specific dataset storage choices. +""" + +from . import discriminator +from . import generator +import tensorflow.compat.v1 as tf + + +def get_data_batch(batch_size): + """Returns data batch. + + This function should return a pair of (input sequence, target unroll sequence) + of image frames for a given batch size, with the following dimensions: + batch_inputs are of size [batch_size, 4, 256, 256, 1], + batch_targets are of size [batch_size, 18, 256, 256, 1]. + + Args: + batch_size: The batch size, int. + + Returns: + batch_inputs: + batch_targets: Data for training. + """ + del batch_size + # TO BE IMPLEMENTED + return None, None + + +def loss_hinge_disc(score_generated, score_real): + """Discriminator hinge loss.""" + l1 = tf.nn.relu(1. - score_real) + loss = tf.reduce_mean(l1) + l2 = tf.nn.relu(1. + score_generated) + loss += tf.reduce_mean(l2) + return loss + + +def loss_hinge_gen(score_generated): + """Generator hinge loss.""" + loss = -tf.reduce_mean(score_generated) + return loss + + +def grid_cell_regularizer(generated_samples, batch_targets): + """Grid cell regularizer. + + Args: + generated_samples: Tensor of size [n_samples, batch_size, 18, 256, 256, 1]. + batch_targets: Tensor of size [batch_size, 18, 256, 256, 1]. + + Returns: + loss: A tensor of shape [batch_size]. + """ + gen_mean = tf.reduce_mean(generated_samples, axis=0) + weights = tf.clip_by_value(batch_targets, 0.0, 24.0) + loss = tf.reduce_mean(tf.math.abs(gen_mean - batch_targets) * weights) + return loss + + +def train(): + """Pseudocode of training loop for the generative method.""" + batch_size = 16 + batch_inputs, batch_targets = get_data_batch(batch_size) + generator_obj = generator.Generator(lead_time=90, time_delta=5) + # the discriminator combines the spatial and temporal discriminators. + discriminator_obj = discriminator.Discriminator() + + # calculate samples and targets for discriminator steps + batch_predictions = generator_obj(batch_inputs) + gen_sequence = tf.concat([batch_inputs, batch_predictions], axis=1) + real_sequence = tf.concat([batch_inputs, batch_targets], axis=1) + # Concatenate the real and generated samples along the batch dimension + concat_inputs = tf.concat([real_sequence, gen_sequence], axis=0) + concat_outputs = discriminator_obj(concat_inputs) + # And split back to scores for real and generated samples + score_real, score_generated = tf.split(concat_outputs, 2, axis=0) + disc_loss = loss_hinge_disc(score_generated, score_real) + disc_optimizer = tf.train.AdamOptimizer( + learning_rate=2E-4, beta1=0.0, beta2=0.999) + disc_step = disc_optimizer.minimize( + disc_loss, var_list=discriminator_obj.get_variables()) + + # make generator loss + num_samples_per_input = 6 + gen_samples = [ + generator_obj(batch_inputs) for _ in range(num_samples_per_input)] + grid_cell_reg = grid_cell_regularizer(tf.stack(gen_samples, axis=0), + batch_targets) + gen_sequences = [tf.concat([batch_inputs, x], axis=1) for x in gen_samples] + gen_disc_loss = loss_hinge_gen(tf.concat(gen_sequences, axis=0)) + gen_loss = gen_disc_loss + 20.0 * grid_cell_reg + gen_optimizer = tf.train.AdamOptimizer( + learning_rate=5E-5, beta1=0.0, beta2=0.999) + gen_step = gen_optimizer.minimize( + gen_loss, var_list=generator_obj.get_variables()) + + num_training_steps = 500000 + with tf.Session() as sess: + for _ in range(num_training_steps): + for _ in range(2): + sess.run(disc_step) + sess.run(gen_step) + + +if __name__ == "__main__": + train() diff --git a/tests/test_model.py b/tests/test_model.py index 701c3ba..a20abe4 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,24 +1,220 @@ import torch -from nowcasting_gan import NowcastingGAN +import torch.nn.functional as F +from nowcasting_gan import ( + DGMR, + Generator, + Discriminator, + TemporalDiscriminator, + SpatialDiscriminator, + Sampler, + LatentConditioningStack, + ContextConditioningStack, +) +from nowcasting_gan.layers import ConvGRU +import einops + + +def test_conv_gru(): + model = ConvGRU( + input_channels=768 + 384, + output_channels=384, + kernel_size=3, + ) + init_states = [torch.rand((2, 384, 32, 32)) for _ in range(4)] + # Expand latent dim to match batch size + x = torch.rand((2, 768, 32, 32)) + hidden_states = [x] * 18 + model.eval() + with torch.no_grad(): + out = model(hidden_states, init_states[3]) + assert out.size() == (18, 2, 384, 32, 32) + assert not torch.isnan(out).any(), "Output included NaNs" + + +def test_latent_conditioning_stack(): + model = LatentConditioningStack() + x = torch.rand((2, 4, 1, 128, 128)) + out = model(x) + assert out.size() == (1, 768, 8, 8) + assert not torch.isnan(out).any(), "Output included NaNs" + + +def test_context_conditioning_stack(): + model = ContextConditioningStack() + x = torch.rand((2, 4, 1, 128, 128)) + model.eval() + with torch.no_grad(): + out = model(x) + assert len(out) == 4 + assert out[0].size() == (2, 96, 32, 32) + assert out[1].size() == (2, 192, 16, 16) + assert out[2].size() == (2, 384, 8, 8) + assert out[3].size() == (2, 768, 4, 4) + assert not all(torch.isnan(out[i]).any() for i in range(len(out))), "Output included NaNs" + + +def test_temporal_discriminator(): + model = TemporalDiscriminator(input_channels=1) + x = torch.rand((2, 8, 1, 256, 256)) + model.eval() + with torch.no_grad(): + out = model(x) + assert out.shape == (2, 1, 1) + assert not torch.isnan(out).any() + + +def test_spatial_discriminator(): + model = SpatialDiscriminator(input_channels=1) + x = torch.rand((2, 18, 1, 128, 128)) + model.eval() + with torch.no_grad(): + out = model(x) + assert out.shape == (2, 1, 1) + assert not torch.isnan(out).any() + + +def test_discriminator(): + model = Discriminator(input_channels=1) + x = torch.rand((2, 18, 1, 256, 256)) + model.eval() + with torch.no_grad(): + out = model(x) + assert out.shape == (2, 2, 1) + assert not torch.isnan(out).any() + + +def test_sampler(): + input_channels = 1 + conv_type = "standard" + context_channels = 384 + latent_channels = 768 + forecast_steps = 18 + output_shape = 256 + conditioning_stack = ContextConditioningStack( + input_channels=input_channels, + conv_type=conv_type, + output_channels=context_channels, + ) + latent_stack = LatentConditioningStack( + shape=(8 * input_channels, output_shape // 32, output_shape // 32), + output_channels=latent_channels, + ) + sampler = Sampler( + forecast_steps=forecast_steps, + latent_channels=latent_channels, + context_channels=context_channels, + ) + latent_stack.eval() + conditioning_stack.eval() + sampler.eval() + x = torch.rand((2, 4, 1, 256, 256)) + with torch.no_grad(): + latent_dim = latent_stack(x) + assert not torch.isnan(latent_dim).any() + init_states = conditioning_stack(x) + assert not all(torch.isnan(init_states[i]).any() for i in range(len(init_states))) + # Expand latent dim to match batch size + latent_dim = einops.repeat( + latent_dim, "b c h w -> (repeat b) c h w", repeat=init_states[0].shape[0] + ) + assert not torch.isnan(latent_dim).any() + hidden_states = [latent_dim] * forecast_steps + assert not all(torch.isnan(hidden_states[i]).any() for i in range(len(hidden_states))) + hidden_states = sampler.convGRU1(hidden_states, init_states[3]) + assert not all(torch.isnan(hidden_states[i]).any() for i in range(len(hidden_states))) + hidden_states = [sampler.gru_conv_1x1(h) for h in hidden_states] + assert not all(torch.isnan(hidden_states[i]).any() for i in range(len(hidden_states))) + hidden_states = [sampler.g1(h) for h in hidden_states] + assert not all(torch.isnan(hidden_states[i]).any() for i in range(len(hidden_states))) + hidden_states = [sampler.up_g1(h) for h in hidden_states] + assert not all(torch.isnan(hidden_states[i]).any() for i in range(len(hidden_states))) + # Layer 3. + hidden_states = sampler.convGRU2(hidden_states, init_states[2]) + assert not all(torch.isnan(hidden_states[i]).any() for i in range(len(hidden_states))) + hidden_states = [sampler.gru_conv_1x1_2(h) for h in hidden_states] + assert not all(torch.isnan(hidden_states[i]).any() for i in range(len(hidden_states))) + hidden_states = [sampler.g2(h) for h in hidden_states] + assert not all(torch.isnan(hidden_states[i]).any() for i in range(len(hidden_states))) + hidden_states = [sampler.up_g2(h) for h in hidden_states] + assert not all(torch.isnan(hidden_states[i]).any() for i in range(len(hidden_states))) + + # Layer 2. + hidden_states = sampler.convGRU3(hidden_states, init_states[1]) + assert not all(torch.isnan(hidden_states[i]).any() for i in range(len(hidden_states))) + hidden_states = [sampler.gru_conv_1x1_3(h) for h in hidden_states] + assert not all(torch.isnan(hidden_states[i]).any() for i in range(len(hidden_states))) + hidden_states = [sampler.g3(h) for h in hidden_states] + assert not all(torch.isnan(hidden_states[i]).any() for i in range(len(hidden_states))) + hidden_states = [sampler.up_g3(h) for h in hidden_states] + assert not all(torch.isnan(hidden_states[i]).any() for i in range(len(hidden_states))) + + # Layer 1 (top-most). + hidden_states = sampler.convGRU4(hidden_states, init_states[0]) + assert not all(torch.isnan(hidden_states[i]).any() for i in range(len(hidden_states))) + hidden_states = [sampler.gru_conv_1x1_4(h) for h in hidden_states] + assert not all(torch.isnan(hidden_states[i]).any() for i in range(len(hidden_states))) + hidden_states = [sampler.g4(h) for h in hidden_states] + assert not all(torch.isnan(hidden_states[i]).any() for i in range(len(hidden_states))) + hidden_states = [sampler.up_g4(h) for h in hidden_states] + assert not all(torch.isnan(hidden_states[i]).any() for i in range(len(hidden_states))) + + # Output layer. + hidden_states = [F.relu(sampler.bn(h)) for h in hidden_states] + assert not all(torch.isnan(hidden_states[i]).any() for i in range(len(hidden_states))) + hidden_states = [sampler.conv_1x1(h) for h in hidden_states] + assert not all(torch.isnan(hidden_states[i]).any() for i in range(len(hidden_states))) + hidden_states = [sampler.depth2space(h) for h in hidden_states] + assert not all(torch.isnan(hidden_states[i]).any() for i in range(len(hidden_states))) + + +def test_generator(): + input_channels = 1 + conv_type = "standard" + context_channels = 384 + latent_channels = 768 + forecast_steps = 18 + output_shape = 256 + conditioning_stack = ContextConditioningStack( + input_channels=input_channels, + conv_type=conv_type, + output_channels=context_channels, + ) + latent_stack = LatentConditioningStack( + shape=(8 * input_channels, output_shape // 32, output_shape // 32), + output_channels=latent_channels, + ) + sampler = Sampler( + forecast_steps=forecast_steps, + latent_channels=latent_channels, + context_channels=context_channels, + ) + model = Generator( + conditioning_stack=conditioning_stack, latent_stack=latent_stack, sampler=sampler + ) + x = torch.rand((2, 4, 1, 256, 256)) + model.eval() + with torch.no_grad(): + out = model(x) + assert out.shape == (2, 18, 1, 256, 256) + assert not torch.isnan(out).any() def test_nowcasting_gan_creation(): - model = NowcastingGAN( - forecast_steps=24, + model = DGMR( + forecast_steps=18, input_channels=1, output_shape=128, latent_channels=768, - context_channels=768, + context_channels=384, num_samples=3, ) - x = torch.randn((2, 4, 1, 128, 128)) + x = torch.rand((2, 4, 1, 128, 128)) model.eval() with torch.no_grad(): out = model(x) - # MetNet creates predictions for the center 1/4th assert out.size() == ( 2, - 24, + 18, 1, 128, 128,