Skip to content

Commit

Permalink
Fengwu ghr: initial (#108)
Browse files Browse the repository at this point in the history
* fengwu_ghr: initial

* fengwu_ghr: fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
rnwzd and pre-commit-ci[bot] authored Jun 5, 2024
1 parent 5bb5f4f commit e140814
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
*.txt
# pixi environments
.pixi
.vscode/
2 changes: 1 addition & 1 deletion environment_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dependencies:
- pandas
- pip
- pyg
- python=3.12
- python
- pytorch
- cpuonly
- pytorch-cluster
Expand Down
1 change: 1 addition & 0 deletions graph_weather/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Models"""

from .fengwu_ghr.layers import MetaModel
from .layers.assimilator_decoder import AssimilatorDecoder
from .layers.assimilator_encoder import AssimilatorEncoder
from .layers.decoder import Decoder
Expand Down
138 changes: 138 additions & 0 deletions graph_weather/models/fengwu_ghr/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import torch
from einops import rearrange
from einops.layers.torch import Rearrange
from torch import nn

# helpers


def pair(t):
return t if isinstance(t, tuple) else (t, t)


def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature**omega)

y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pe.type(dtype)


# classes


class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)

def forward(self, x):
return self.net(x)


class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head**-0.5
self.norm = nn.LayerNorm(dim)

self.attend = nn.Softmax(dim=-1)

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)

def forward(self, x):
x = self.norm(x)

qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)

dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

attn = self.attend(dots)

out = torch.matmul(attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)


class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[Attention(dim, heads=heads, dim_head=dim_head), FeedForward(dim, mlp_dim)]
)
)

def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)


class MetaModel(nn.Module):
def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, dim_head=64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)

assert (
image_height % patch_height == 0 and image_width % patch_width == 0
), "Image dimensions must be divisible by the patch size."

patch_dim = channels * patch_height * patch_width
dim = patch_dim
self.to_patch_embedding = nn.Sequential(
Rearrange(
"b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", p_h=patch_height, p_w=patch_width
),
nn.LayerNorm(patch_dim), # TODO Do we need this?
nn.Linear(patch_dim, dim), # TODO Do we need this?
nn.LayerNorm(dim), # TODO Do we need this?
)

self.pos_embedding = posemb_sincos_2d(
h=image_height // patch_height,
w=image_width // patch_width,
dim=dim,
)

self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

self.reshaper = nn.Sequential(
Rearrange(
"b (h w) (p_h p_w c) -> b c (h p_h) (w p_w)",
h=image_height // patch_height,
w=image_width // patch_width,
p_h=patch_height,
p_w=patch_width,
)
)

def forward(self, img):
device = img.device

x = self.to_patch_embedding(img)
x += self.pos_embedding.to(device, dtype=x.dtype)

x = self.transformer(x)

x = self.reshaper(x)

return x
19 changes: 18 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
import torch

from graph_weather import GraphWeatherAssimilator, GraphWeatherForecaster
from graph_weather.models import AssimilatorDecoder, AssimilatorEncoder, Decoder, Encoder, Processor
from graph_weather.models import (
AssimilatorDecoder,
AssimilatorEncoder,
Decoder,
Encoder,
Processor,
MetaModel,
)
from graph_weather.models.losses import NormalizedMSELoss


Expand Down Expand Up @@ -222,3 +229,13 @@ def test_normalized_loss():
assert not torch.isnan(loss)
# Since feature_variance = out**2 and target = 0, we expect loss = weights
assert torch.isclose(loss, criterion.weights.expand_as(out.mean(-1)).mean())


def test_meta_model():
model = MetaModel(image_size=100, patch_size=10, depth=1, heads=1, mlp_dim=7, channels=3)
features = torch.randn((1, 3, 100, 100))

out = model(features)
assert not torch.isnan(out).any()
assert not torch.isnan(out).any()
assert out.size() == (1, 3, 100, 100)

0 comments on commit e140814

Please sign in to comment.