Skip to content

Commit

Permalink
[pre-commit.ci] pre-commit autoupdate (#9)
Browse files Browse the repository at this point in the history
* [pre-commit.ci] pre-commit autoupdate

updates:
- [github.com/pre-commit/pre-commit-hooks: v4.1.0 → v4.3.0](pre-commit/pre-commit-hooks@v4.1.0...v4.3.0)
- [github.com/PyCQA/flake8: 4.0.1 → 5.0.4](PyCQA/flake8@4.0.1...5.0.4)
- [github.com/psf/black: 21.12b0 → 22.6.0](psf/black@21.12b0...22.6.0)
- [github.com/pre-commit/mirrors-prettier: v2.5.1 → v3.0.0-alpha.0](pre-commit/mirrors-prettier@v2.5.1...v3.0.0-alpha.0)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
pre-commit-ci[bot] authored Aug 30, 2022
1 parent 89e4274 commit ba961cd
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 16 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ default_language_version:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0
rev: v4.3.0
hooks:
# list of supported hooks: https://pre-commit.com/hooks.html
- id: trailing-whitespace
Expand All @@ -24,7 +24,7 @@ repos:
"metnet",
]
- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
rev: 5.0.4
hooks:
- id: flake8
args:
Expand All @@ -42,14 +42,14 @@ repos:
- id: isort
args: [--profile, black, --line-length, "100", "metnet"]
- repo: https://github.com/psf/black
rev: 21.12b0
rev: 22.6.0
hooks:
- id: black
args: [--line-length, "100"]

# yaml formatting
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v2.5.1
rev: v3.0.0-alpha.0
hooks:
- id: prettier
types: [yaml]
43 changes: 31 additions & 12 deletions metnet/models/metnet2.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,38 @@ def __init__(
# Go through each set of blocks and add conditioner
# Context Stack
for layer in self.context_block_one:
self.time_conditioners.append(ConditionWithTimeMetNet2(forecast_steps=forecast_steps, hidden_dim=lead_time_features, num_feature_maps=layer.output_channels))
self.time_conditioners.append(
ConditionWithTimeMetNet2(
forecast_steps=forecast_steps,
hidden_dim=lead_time_features,
num_feature_maps=layer.output_channels,
)
)
for layer in self.context_blocks:
self.time_conditioners.append(ConditionWithTimeMetNet2(forecast_steps=forecast_steps, hidden_dim=lead_time_features, num_feature_maps=layer.output_channels))
self.time_conditioners.append(
ConditionWithTimeMetNet2(
forecast_steps=forecast_steps,
hidden_dim=lead_time_features,
num_feature_maps=layer.output_channels,
)
)
if self.upsample_method != "interp":
for layer in self.upsample:
self.time_conditioners.append(ConditionWithTimeMetNet2(forecast_steps=forecast_steps, hidden_dim=lead_time_features, num_feature_maps=layer.output_channels))
self.time_conditioners.append(
ConditionWithTimeMetNet2(
forecast_steps=forecast_steps,
hidden_dim=lead_time_features,
num_feature_maps=layer.output_channels,
)
)
for layer in self.residual_block_three:
self.time_conditioners.append(ConditionWithTimeMetNet2(forecast_steps=forecast_steps, hidden_dim=lead_time_features, num_feature_maps=layer.output_channels))
self.time_conditioners.append(
ConditionWithTimeMetNet2(
forecast_steps=forecast_steps,
hidden_dim=lead_time_features,
num_feature_maps=layer.output_channels,
)
)
# Last layers are a Conv 1x1 with 4096 channels then softmax
self.head = nn.Conv2d(upsampler_channels, output_channels, kernel_size=(1, 1))

Expand Down Expand Up @@ -267,12 +291,7 @@ def forward(self, x: torch.Tensor, lead_time: int = 0):
class ConditionWithTimeMetNet2(nn.Module):
"""Compute Scale and bias for conditioning on time"""

def __init__(
self,
forecast_steps: int,
hidden_dim: int,
num_feature_maps: int
):
def __init__(self, forecast_steps: int, hidden_dim: int, num_feature_maps: int):
"""
Compute the scale and bias factors for conditioning convolutional blocks on the forecast time
Expand All @@ -288,7 +307,7 @@ def __init__(
self.lead_time_network = nn.ModuleList(
[
nn.Linear(in_features=forecast_steps, out_features=hidden_dim),
nn.Linear(in_features=hidden_dim, out_features=2* num_feature_maps),
nn.Linear(in_features=hidden_dim, out_features=2 * num_feature_maps),
]
)

Expand Down Expand Up @@ -316,4 +335,4 @@ def forward(self, x: torch.Tensor, timestep: int) -> [torch.Tensor, torch.Tensor
scales_and_biases = einops.rearrange(
scales_and_biases, "b (block sb) -> b block sb", block=self.num_feature_maps, sb=2
)
return scales_and_biases[:,:,0],scales_and_biases[:,:,1]
return scales_and_biases[:, :, 0], scales_and_biases[:, :, 1]

0 comments on commit ba961cd

Please sign in to comment.