diff --git a/.gitignore b/.gitignore index 9b8aab0b..d248bf98 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ *.txt # pixi environments .pixi +.vscode/ diff --git a/environment_cpu.yml b/environment_cpu.yml index db087fc2..e854dc84 100644 --- a/environment_cpu.yml +++ b/environment_cpu.yml @@ -9,7 +9,7 @@ dependencies: - pandas - pip - pyg - - python=3.12 + - python - pytorch - cpuonly - pytorch-cluster diff --git a/graph_weather/models/__init__.py b/graph_weather/models/__init__.py index a18cda87..72d222a8 100644 --- a/graph_weather/models/__init__.py +++ b/graph_weather/models/__init__.py @@ -1,5 +1,6 @@ """Models""" +from .fengwu_ghr.layers import MetaModel from .layers.assimilator_decoder import AssimilatorDecoder from .layers.assimilator_encoder import AssimilatorEncoder from .layers.decoder import Decoder diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py new file mode 100644 index 00000000..cd81218d --- /dev/null +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -0,0 +1,138 @@ +import torch +from einops import rearrange +from einops.layers.torch import Rearrange +from torch import nn + +# helpers + + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + + +def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") + assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" + omega = torch.arange(dim // 4) / (dim // 4 - 1) + omega = 1.0 / (temperature**omega) + + y = y.flatten()[:, None] * omega[None, :] + x = x.flatten()[:, None] * omega[None, :] + pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) + return pe.type(dtype) + + +# classes + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, dim), + ) + + def forward(self, x): + return self.net(x) + + +class Attention(nn.Module): + def __init__(self, dim, heads=8, dim_head=64): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.scale = dim_head**-0.5 + self.norm = nn.LayerNorm(dim) + + self.attend = nn.Softmax(dim=-1) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x): + x = self.norm(x) + + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + + out = torch.matmul(attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [Attention(dim, heads=heads, dim_head=dim_head), FeedForward(dim, mlp_dim)] + ) + ) + + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return self.norm(x) + + +class MetaModel(nn.Module): + def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, dim_head=64): + super().__init__() + image_height, image_width = pair(image_size) + patch_height, patch_width = pair(patch_size) + + assert ( + image_height % patch_height == 0 and image_width % patch_width == 0 + ), "Image dimensions must be divisible by the patch size." + + patch_dim = channels * patch_height * patch_width + dim = patch_dim + self.to_patch_embedding = nn.Sequential( + Rearrange( + "b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", p_h=patch_height, p_w=patch_width + ), + nn.LayerNorm(patch_dim), # TODO Do we need this? + nn.Linear(patch_dim, dim), # TODO Do we need this? + nn.LayerNorm(dim), # TODO Do we need this? + ) + + self.pos_embedding = posemb_sincos_2d( + h=image_height // patch_height, + w=image_width // patch_width, + dim=dim, + ) + + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) + + self.reshaper = nn.Sequential( + Rearrange( + "b (h w) (p_h p_w c) -> b c (h p_h) (w p_w)", + h=image_height // patch_height, + w=image_width // patch_width, + p_h=patch_height, + p_w=patch_width, + ) + ) + + def forward(self, img): + device = img.device + + x = self.to_patch_embedding(img) + x += self.pos_embedding.to(device, dtype=x.dtype) + + x = self.transformer(x) + + x = self.reshaper(x) + + return x diff --git a/tests/test_model.py b/tests/test_model.py index 58904292..050f3e28 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -3,7 +3,14 @@ import torch from graph_weather import GraphWeatherAssimilator, GraphWeatherForecaster -from graph_weather.models import AssimilatorDecoder, AssimilatorEncoder, Decoder, Encoder, Processor +from graph_weather.models import ( + AssimilatorDecoder, + AssimilatorEncoder, + Decoder, + Encoder, + Processor, + MetaModel, +) from graph_weather.models.losses import NormalizedMSELoss @@ -222,3 +229,13 @@ def test_normalized_loss(): assert not torch.isnan(loss) # Since feature_variance = out**2 and target = 0, we expect loss = weights assert torch.isclose(loss, criterion.weights.expand_as(out.mean(-1)).mean()) + + +def test_meta_model(): + model = MetaModel(image_size=100, patch_size=10, depth=1, heads=1, mlp_dim=7, channels=3) + features = torch.randn((1, 3, 100, 100)) + + out = model(features) + assert not torch.isnan(out).any() + assert not torch.isnan(out).any() + assert out.size() == (1, 3, 100, 100)