Skip to content

Commit

Permalink
Add position encoding for AxialAttention (#25)
Browse files Browse the repository at this point in the history
* Add position encoding

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
jacobbieker and pre-commit-ci[bot] authored Mar 16, 2022
1 parent a0f3ea2 commit 9b8a1e2
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions metnet/models/metnet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch.nn as nn
from axial_attention import AxialAttention
from axial_attention import AxialAttention, AxialPositionalEmbedding
from huggingface_hub import PyTorchModelHubMixin

from metnet.layers import ConditionTime, ConvGRU, DownSampler, MetNetPreprocessor, TimeDistributed
Expand Down Expand Up @@ -60,6 +60,9 @@ def __init__(
self.temporal_enc = TemporalEncoder(
image_encoder.output_channels, hidden_dim, ks=kernel_size, n_layers=num_layers
)
self.position_embedding = AxialPositionalEmbedding(
dim=self.temporal_enc.out_channels, shape=(input_size // 4, input_size // 4)
)
self.temporal_agg = nn.Sequential(
*[
AxialAttention(dim=hidden_dim, dim_index=1, heads=8, num_dimensions=2)
Expand All @@ -82,7 +85,7 @@ def encode_timestep(self, x, fstep=1):

# Temporal Encoder
_, state = self.temporal_enc(self.drop(x))
return self.temporal_agg(state)
return self.temporal_agg(self.position_embedding(state))

def forward(self, imgs: torch.Tensor, lead_time: int = 0) -> torch.Tensor:
"""It takes a rank 5 tensor
Expand All @@ -96,6 +99,7 @@ def forward(self, imgs: torch.Tensor, lead_time: int = 0) -> torch.Tensor:
class TemporalEncoder(nn.Module):
def __init__(self, in_channels, out_channels=384, ks=3, n_layers=1):
super().__init__()
self.out_channels = out_channels
self.rnn = ConvGRU(in_channels, out_channels, (ks, ks), n_layers, batch_first=True)

def forward(self, x):
Expand Down

0 comments on commit 9b8a1e2

Please sign in to comment.