diff --git a/metnet/layers/ConvLSTM.py b/metnet/layers/ConvLSTM.py index f78b305..a823041 100644 --- a/metnet/layers/ConvLSTM.py +++ b/metnet/layers/ConvLSTM.py @@ -131,7 +131,7 @@ def __init__( batchnorm: Whether to use batch norm """ super(ConvLSTM, self).__init__() - + self.output_channels = hidden_dim # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers kernel_size = self._extend_for_multilayer(kernel_size, num_layers) hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers) diff --git a/metnet/layers/DilatedCondConv.py b/metnet/layers/DilatedCondConv.py index adf29d0..704bb0d 100644 --- a/metnet/layers/DilatedCondConv.py +++ b/metnet/layers/DilatedCondConv.py @@ -16,6 +16,7 @@ def __init__( activation: nn.Module = nn.ReLU(), ): super().__init__() + self.output_channels = output_channels self.dilated_conv_one = nn.Conv2d( in_channels=input_channels, out_channels=output_channels, @@ -64,6 +65,7 @@ def __init__( activation: nn.Module = nn.ReLU(), ): super().__init__() + self.output_channels = output_channels self.dilated_conv_one = nn.ConvTranspose2d( in_channels=input_channels, out_channels=output_channels, diff --git a/metnet/layers/LeadTimeConditioner.py b/metnet/layers/LeadTimeConditioner.py index d0df177..4355b90 100644 --- a/metnet/layers/LeadTimeConditioner.py +++ b/metnet/layers/LeadTimeConditioner.py @@ -23,7 +23,6 @@ def forward(self, x: torch.Tensor, scale: torch.Tensor, bias: torch.Tensor) -> t Returns: Input tensor with the scale multiplied to it and bias added """ - # TODO Make this a vector of scale and bias for each channel, rather than one for all - scale = scale.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(x) - bias = bias.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(x) + scale = scale.unsqueeze(2).unsqueeze(3).expand_as(x) + bias = bias.unsqueeze(2).unsqueeze(3).expand_as(x) return (scale * x) + bias diff --git a/metnet/models/metnet2.py b/metnet/models/metnet2.py index a3a1dcb..f76a12b 100644 --- a/metnet/models/metnet2.py +++ b/metnet/models/metnet2.py @@ -109,13 +109,6 @@ def __init__( if upsample_method != "interp" else total_number_of_conv_blocks ) - # TODO Only add this conditioning if adding to ConvLSTM - # total_number_of_conv_blocks += num_input_timesteps - self.ct = ConditionWithTimeMetNet2( - forecast_steps, - total_blocks_to_condition=total_number_of_conv_blocks, - hidden_dim=lead_time_features, - ) # ConvLSTM with 13 timesteps, 128 LSTM channels, 18 encoder blocks, 384 encoder channels, self.conv_lstm = ConvLSTM( @@ -199,6 +192,18 @@ def __init__( ] ) + self.time_conditioners = nn.ModuleList() + # Go through each set of blocks and add conditioner + # Context Stack + for layer in self.context_block_one: + self.time_conditioners.append(ConditionWithTimeMetNet2(forecast_steps=forecast_steps, hidden_dim=lead_time_features, num_feature_maps=layer.output_channels)) + for layer in self.context_blocks: + self.time_conditioners.append(ConditionWithTimeMetNet2(forecast_steps=forecast_steps, hidden_dim=lead_time_features, num_feature_maps=layer.output_channels)) + if self.upsample_method != "interp": + for layer in self.upsample: + self.time_conditioners.append(ConditionWithTimeMetNet2(forecast_steps=forecast_steps, hidden_dim=lead_time_features, num_feature_maps=layer.output_channels)) + for layer in self.residual_block_three: + self.time_conditioners.append(ConditionWithTimeMetNet2(forecast_steps=forecast_steps, hidden_dim=lead_time_features, num_feature_maps=layer.output_channels)) # Last layers are a Conv 1x1 with 4096 channels then softmax self.head = nn.Conv2d(upsampler_channels, output_channels, kernel_size=(1, 1)) @@ -216,7 +221,6 @@ def forward(self, x: torch.Tensor, lead_time: int = 0): # Compute all timesteps, probably can be parallelized x = self.image_encoder(x) # Compute scale and bias - scale_and_bias = self.ct(x, lead_time) block_num = 0 # ConvLSTM @@ -226,11 +230,11 @@ def forward(self, x: torch.Tensor, lead_time: int = 0): # Context Stack for layer in self.context_block_one: - scale, bias = scale_and_bias[:, block_num] + scale, bias = self.time_conditioners[block_num](res, lead_time) res = layer(res, scale, bias) block_num += 1 for layer in self.context_blocks: - scale, bias = scale_and_bias[:, block_num] + scale, bias = self.time_conditioners[block_num](res, lead_time) res = layer(res, scale, bias) block_num += 1 @@ -243,13 +247,13 @@ def forward(self, x: torch.Tensor, lead_time: int = 0): res = self.upsampler_changer(res) else: for layer in self.upsample: - scale, bias = scale_and_bias[:, block_num] + scale, bias = self.time_conditioners[block_num](res, lead_time) res = layer(res, scale, bias) block_num += 1 # Shallow network for layer in self.residual_block_three: - scale, bias = scale_and_bias[:, block_num] + scale, bias = self.time_conditioners[block_num](res, lead_time) res = layer(res, scale, bias) block_num += 1 @@ -267,7 +271,7 @@ def __init__( self, forecast_steps: int, hidden_dim: int, - total_blocks_to_condition: int, + num_feature_maps: int ): """ Compute the scale and bias factors for conditioning convolutional blocks on the forecast time @@ -275,28 +279,32 @@ def __init__( Args: forecast_steps: Number of forecast steps hidden_dim: Hidden dimension size - total_blocks_to_condition: Total number of scale, biases that are needed to be computed + num_feature_maps: Max number of channels in the blocks, to generate enough scale+bias values + This means extra values will be generated, but keeps implementation simpler """ super().__init__() self.forecast_steps = forecast_steps - self.total_blocks = total_blocks_to_condition + self.num_feature_maps = num_feature_maps self.lead_time_network = nn.ModuleList( [ nn.Linear(in_features=forecast_steps, out_features=hidden_dim), - nn.Linear(in_features=hidden_dim, out_features=total_blocks_to_condition * 2), + nn.Linear(in_features=hidden_dim, out_features=2* num_feature_maps), ] ) - def forward(self, x: torch.Tensor, timestep: int) -> torch.Tensor: + def forward(self, x: torch.Tensor, timestep: int) -> [torch.Tensor, torch.Tensor]: """ Get the scale and bias for the conditioning layers + From the FiLM paper, each feature map (i.e. channel) has its own scale and bias layer, so needs + a scale and bias for each feature map to be generated + Args: x: The Tensor that is used timestep: Index of the timestep to use, between 0 and forecast_steps Returns: - Tensor of shape (Batch, blocks, 2) + 2 Tensors of shape (Batch, num_feature_maps) """ # One hot encode the timestep timesteps = torch.zeros(x.size()[0], self.forecast_steps, dtype=x.dtype) @@ -306,6 +314,6 @@ def forward(self, x: torch.Tensor, timestep: int) -> torch.Tensor: timesteps = layer(timesteps) scales_and_biases = timesteps scales_and_biases = einops.rearrange( - scales_and_biases, "b (block sb) -> b block sb", block=self.total_blocks, sb=2 + scales_and_biases, "b (block sb) -> b block sb", block=self.num_feature_maps, sb=2 ) - return scales_and_biases + return scales_and_biases[:,:,0],scales_and_biases[:,:,1]