diff --git a/metnet/models/metnet.py b/metnet/models/metnet.py index df0757f..0fa5093 100644 --- a/metnet/models/metnet.py +++ b/metnet/models/metnet.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from axial_attention import AxialAttention +from axial_attention import AxialAttention, AxialPositionalEmbedding from huggingface_hub import PyTorchModelHubMixin from metnet.layers import ConditionTime, ConvGRU, DownSampler, MetNetPreprocessor, TimeDistributed @@ -60,6 +60,9 @@ def __init__( self.temporal_enc = TemporalEncoder( image_encoder.output_channels, hidden_dim, ks=kernel_size, n_layers=num_layers ) + self.position_embedding = AxialPositionalEmbedding( + dim=self.temporal_enc.out_channels, shape=(input_size // 4, input_size // 4) + ) self.temporal_agg = nn.Sequential( *[ AxialAttention(dim=hidden_dim, dim_index=1, heads=8, num_dimensions=2) @@ -82,7 +85,7 @@ def encode_timestep(self, x, fstep=1): # Temporal Encoder _, state = self.temporal_enc(self.drop(x)) - return self.temporal_agg(state) + return self.temporal_agg(self.position_embedding(state)) def forward(self, imgs: torch.Tensor, lead_time: int = 0) -> torch.Tensor: """It takes a rank 5 tensor @@ -96,6 +99,7 @@ def forward(self, imgs: torch.Tensor, lead_time: int = 0) -> torch.Tensor: class TemporalEncoder(nn.Module): def __init__(self, in_channels, out_channels=384, ks=3, n_layers=1): super().__init__() + self.out_channels = out_channels self.rnn = ConvGRU(in_channels, out_channels, (ks, ks), n_layers, batch_first=True) def forward(self, x):