Skip to content

Commit

Permalink
Update FiLM layer to be per feature map
Browse files Browse the repository at this point in the history
Instead of per block, each layer has its own features now

#major
  • Loading branch information
jacobbieker authored Jun 28, 2022
1 parent 349f2e4 commit ffbee79
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 24 deletions.
2 changes: 1 addition & 1 deletion metnet/layers/ConvLSTM.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(
batchnorm: Whether to use batch norm
"""
super(ConvLSTM, self).__init__()

self.output_channels = hidden_dim
# Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
Expand Down
2 changes: 2 additions & 0 deletions metnet/layers/DilatedCondConv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(
activation: nn.Module = nn.ReLU(),
):
super().__init__()
self.output_channels = output_channels
self.dilated_conv_one = nn.Conv2d(
in_channels=input_channels,
out_channels=output_channels,
Expand Down Expand Up @@ -64,6 +65,7 @@ def __init__(
activation: nn.Module = nn.ReLU(),
):
super().__init__()
self.output_channels = output_channels
self.dilated_conv_one = nn.ConvTranspose2d(
in_channels=input_channels,
out_channels=output_channels,
Expand Down
5 changes: 2 additions & 3 deletions metnet/layers/LeadTimeConditioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def forward(self, x: torch.Tensor, scale: torch.Tensor, bias: torch.Tensor) -> t
Returns:
Input tensor with the scale multiplied to it and bias added
"""
# TODO Make this a vector of scale and bias for each channel, rather than one for all
scale = scale.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(x)
bias = bias.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(x)
scale = scale.unsqueeze(2).unsqueeze(3).expand_as(x)
bias = bias.unsqueeze(2).unsqueeze(3).expand_as(x)
return (scale * x) + bias
48 changes: 28 additions & 20 deletions metnet/models/metnet2.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,6 @@ def __init__(
if upsample_method != "interp"
else total_number_of_conv_blocks
)
# TODO Only add this conditioning if adding to ConvLSTM
# total_number_of_conv_blocks += num_input_timesteps
self.ct = ConditionWithTimeMetNet2(
forecast_steps,
total_blocks_to_condition=total_number_of_conv_blocks,
hidden_dim=lead_time_features,
)

# ConvLSTM with 13 timesteps, 128 LSTM channels, 18 encoder blocks, 384 encoder channels,
self.conv_lstm = ConvLSTM(
Expand Down Expand Up @@ -199,6 +192,18 @@ def __init__(
]
)

self.time_conditioners = nn.ModuleList()
# Go through each set of blocks and add conditioner
# Context Stack
for layer in self.context_block_one:
self.time_conditioners.append(ConditionWithTimeMetNet2(forecast_steps=forecast_steps, hidden_dim=lead_time_features, num_feature_maps=layer.output_channels))
for layer in self.context_blocks:
self.time_conditioners.append(ConditionWithTimeMetNet2(forecast_steps=forecast_steps, hidden_dim=lead_time_features, num_feature_maps=layer.output_channels))
if self.upsample_method != "interp":
for layer in self.upsample:
self.time_conditioners.append(ConditionWithTimeMetNet2(forecast_steps=forecast_steps, hidden_dim=lead_time_features, num_feature_maps=layer.output_channels))
for layer in self.residual_block_three:
self.time_conditioners.append(ConditionWithTimeMetNet2(forecast_steps=forecast_steps, hidden_dim=lead_time_features, num_feature_maps=layer.output_channels))
# Last layers are a Conv 1x1 with 4096 channels then softmax
self.head = nn.Conv2d(upsampler_channels, output_channels, kernel_size=(1, 1))

Expand All @@ -216,7 +221,6 @@ def forward(self, x: torch.Tensor, lead_time: int = 0):
# Compute all timesteps, probably can be parallelized
x = self.image_encoder(x)
# Compute scale and bias
scale_and_bias = self.ct(x, lead_time)
block_num = 0

# ConvLSTM
Expand All @@ -226,11 +230,11 @@ def forward(self, x: torch.Tensor, lead_time: int = 0):

# Context Stack
for layer in self.context_block_one:
scale, bias = scale_and_bias[:, block_num]
scale, bias = self.time_conditioners[block_num](res, lead_time)
res = layer(res, scale, bias)
block_num += 1
for layer in self.context_blocks:
scale, bias = scale_and_bias[:, block_num]
scale, bias = self.time_conditioners[block_num](res, lead_time)
res = layer(res, scale, bias)
block_num += 1

Expand All @@ -243,13 +247,13 @@ def forward(self, x: torch.Tensor, lead_time: int = 0):
res = self.upsampler_changer(res)
else:
for layer in self.upsample:
scale, bias = scale_and_bias[:, block_num]
scale, bias = self.time_conditioners[block_num](res, lead_time)
res = layer(res, scale, bias)
block_num += 1

# Shallow network
for layer in self.residual_block_three:
scale, bias = scale_and_bias[:, block_num]
scale, bias = self.time_conditioners[block_num](res, lead_time)
res = layer(res, scale, bias)
block_num += 1

Expand All @@ -267,36 +271,40 @@ def __init__(
self,
forecast_steps: int,
hidden_dim: int,
total_blocks_to_condition: int,
num_feature_maps: int
):
"""
Compute the scale and bias factors for conditioning convolutional blocks on the forecast time
Args:
forecast_steps: Number of forecast steps
hidden_dim: Hidden dimension size
total_blocks_to_condition: Total number of scale, biases that are needed to be computed
num_feature_maps: Max number of channels in the blocks, to generate enough scale+bias values
This means extra values will be generated, but keeps implementation simpler
"""
super().__init__()
self.forecast_steps = forecast_steps
self.total_blocks = total_blocks_to_condition
self.num_feature_maps = num_feature_maps
self.lead_time_network = nn.ModuleList(
[
nn.Linear(in_features=forecast_steps, out_features=hidden_dim),
nn.Linear(in_features=hidden_dim, out_features=total_blocks_to_condition * 2),
nn.Linear(in_features=hidden_dim, out_features=2* num_feature_maps),
]
)

def forward(self, x: torch.Tensor, timestep: int) -> torch.Tensor:
def forward(self, x: torch.Tensor, timestep: int) -> [torch.Tensor, torch.Tensor]:
"""
Get the scale and bias for the conditioning layers
From the FiLM paper, each feature map (i.e. channel) has its own scale and bias layer, so needs
a scale and bias for each feature map to be generated
Args:
x: The Tensor that is used
timestep: Index of the timestep to use, between 0 and forecast_steps
Returns:
Tensor of shape (Batch, blocks, 2)
2 Tensors of shape (Batch, num_feature_maps)
"""
# One hot encode the timestep
timesteps = torch.zeros(x.size()[0], self.forecast_steps, dtype=x.dtype)
Expand All @@ -306,6 +314,6 @@ def forward(self, x: torch.Tensor, timestep: int) -> torch.Tensor:
timesteps = layer(timesteps)
scales_and_biases = timesteps
scales_and_biases = einops.rearrange(
scales_and_biases, "b (block sb) -> b block sb", block=self.total_blocks, sb=2
scales_and_biases, "b (block sb) -> b block sb", block=self.num_feature_maps, sb=2
)
return scales_and_biases
return scales_and_biases[:,:,0],scales_and_biases[:,:,1]

0 comments on commit ffbee79

Please sign in to comment.