From ba961cd047b6da78d44e29ad08bc6886e555bcd0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Aug 2022 10:59:20 +0100 Subject: [PATCH] [pre-commit.ci] pre-commit autoupdate (#9) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [pre-commit.ci] pre-commit autoupdate updates: - [github.com/pre-commit/pre-commit-hooks: v4.1.0 → v4.3.0](https://github.com/pre-commit/pre-commit-hooks/compare/v4.1.0...v4.3.0) - [github.com/PyCQA/flake8: 4.0.1 → 5.0.4](https://github.com/PyCQA/flake8/compare/4.0.1...5.0.4) - [github.com/psf/black: 21.12b0 → 22.6.0](https://github.com/psf/black/compare/21.12b0...22.6.0) - [github.com/pre-commit/mirrors-prettier: v2.5.1 → v3.0.0-alpha.0](https://github.com/pre-commit/mirrors-prettier/compare/v2.5.1...v3.0.0-alpha.0) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 8 ++++---- metnet/models/metnet2.py | 43 +++++++++++++++++++++++++++++----------- 2 files changed, 35 insertions(+), 16 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1a7ab5c..48c9026 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ default_language_version: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.1.0 + rev: v4.3.0 hooks: # list of supported hooks: https://pre-commit.com/hooks.html - id: trailing-whitespace @@ -24,7 +24,7 @@ repos: "metnet", ] - repo: https://github.com/PyCQA/flake8 - rev: 4.0.1 + rev: 5.0.4 hooks: - id: flake8 args: @@ -42,14 +42,14 @@ repos: - id: isort args: [--profile, black, --line-length, "100", "metnet"] - repo: https://github.com/psf/black - rev: 21.12b0 + rev: 22.6.0 hooks: - id: black args: [--line-length, "100"] # yaml formatting - repo: https://github.com/pre-commit/mirrors-prettier - rev: v2.5.1 + rev: v3.0.0-alpha.0 hooks: - id: prettier types: [yaml] diff --git a/metnet/models/metnet2.py b/metnet/models/metnet2.py index f76a12b..4fff887 100644 --- a/metnet/models/metnet2.py +++ b/metnet/models/metnet2.py @@ -196,14 +196,38 @@ def __init__( # Go through each set of blocks and add conditioner # Context Stack for layer in self.context_block_one: - self.time_conditioners.append(ConditionWithTimeMetNet2(forecast_steps=forecast_steps, hidden_dim=lead_time_features, num_feature_maps=layer.output_channels)) + self.time_conditioners.append( + ConditionWithTimeMetNet2( + forecast_steps=forecast_steps, + hidden_dim=lead_time_features, + num_feature_maps=layer.output_channels, + ) + ) for layer in self.context_blocks: - self.time_conditioners.append(ConditionWithTimeMetNet2(forecast_steps=forecast_steps, hidden_dim=lead_time_features, num_feature_maps=layer.output_channels)) + self.time_conditioners.append( + ConditionWithTimeMetNet2( + forecast_steps=forecast_steps, + hidden_dim=lead_time_features, + num_feature_maps=layer.output_channels, + ) + ) if self.upsample_method != "interp": for layer in self.upsample: - self.time_conditioners.append(ConditionWithTimeMetNet2(forecast_steps=forecast_steps, hidden_dim=lead_time_features, num_feature_maps=layer.output_channels)) + self.time_conditioners.append( + ConditionWithTimeMetNet2( + forecast_steps=forecast_steps, + hidden_dim=lead_time_features, + num_feature_maps=layer.output_channels, + ) + ) for layer in self.residual_block_three: - self.time_conditioners.append(ConditionWithTimeMetNet2(forecast_steps=forecast_steps, hidden_dim=lead_time_features, num_feature_maps=layer.output_channels)) + self.time_conditioners.append( + ConditionWithTimeMetNet2( + forecast_steps=forecast_steps, + hidden_dim=lead_time_features, + num_feature_maps=layer.output_channels, + ) + ) # Last layers are a Conv 1x1 with 4096 channels then softmax self.head = nn.Conv2d(upsampler_channels, output_channels, kernel_size=(1, 1)) @@ -267,12 +291,7 @@ def forward(self, x: torch.Tensor, lead_time: int = 0): class ConditionWithTimeMetNet2(nn.Module): """Compute Scale and bias for conditioning on time""" - def __init__( - self, - forecast_steps: int, - hidden_dim: int, - num_feature_maps: int - ): + def __init__(self, forecast_steps: int, hidden_dim: int, num_feature_maps: int): """ Compute the scale and bias factors for conditioning convolutional blocks on the forecast time @@ -288,7 +307,7 @@ def __init__( self.lead_time_network = nn.ModuleList( [ nn.Linear(in_features=forecast_steps, out_features=hidden_dim), - nn.Linear(in_features=hidden_dim, out_features=2* num_feature_maps), + nn.Linear(in_features=hidden_dim, out_features=2 * num_feature_maps), ] ) @@ -316,4 +335,4 @@ def forward(self, x: torch.Tensor, timestep: int) -> [torch.Tensor, torch.Tensor scales_and_biases = einops.rearrange( scales_and_biases, "b (block sb) -> b block sb", block=self.num_feature_maps, sb=2 ) - return scales_and_biases[:,:,0],scales_and_biases[:,:,1] + return scales_and_biases[:, :, 0], scales_and_biases[:, :, 1]