From cb7055190831d203aaa77cdfdef4a7066bbc4029 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Wed, 29 May 2024 10:58:50 +0200 Subject: [PATCH 1/3] fengwu_ghr: initial --- graph_weather/models/__init__.py | 1 + graph_weather/models/fengwu_ghr/layers.py | 133 ++++++++++++++++++++++ tests/test_model.py | 14 ++- 3 files changed, 147 insertions(+), 1 deletion(-) create mode 100644 graph_weather/models/fengwu_ghr/layers.py diff --git a/graph_weather/models/__init__.py b/graph_weather/models/__init__.py index a18cda87..289d3724 100644 --- a/graph_weather/models/__init__.py +++ b/graph_weather/models/__init__.py @@ -5,3 +5,4 @@ from .layers.decoder import Decoder from .layers.encoder import Encoder from .layers.processor import Processor +from .fengwu_ghr.layers import MetaModel diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py new file mode 100644 index 00000000..62425651 --- /dev/null +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -0,0 +1,133 @@ +import torch +from torch import nn + +from einops import rearrange +from einops.layers.torch import Rearrange + +# helpers + + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + + +def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") + assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" + omega = torch.arange(dim // 4) / (dim // 4 - 1) + omega = 1.0 / (temperature ** omega) + + y = y.flatten()[:, None] * omega[None, :] + x = x.flatten()[:, None] * omega[None, :] + pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) + return pe.type(dtype) + +# classes + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, dim), + ) + + def forward(self, x): + return self.net(x) + + +class Attention(nn.Module): + def __init__(self, dim, heads=8, dim_head=64): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.scale = dim_head ** -0.5 + self.norm = nn.LayerNorm(dim) + + self.attend = nn.Softmax(dim=-1) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x): + x = self.norm(x) + + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange( + t, 'b n (h d) -> b h n d', h=self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + Attention(dim, heads=heads, dim_head=dim_head), + FeedForward(dim, mlp_dim) + ])) + + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return self.norm(x) + + +class MetaModel(nn.Module): + def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, dim_head=64): + super().__init__() + image_height, image_width = pair(image_size) + patch_height, patch_width = pair(patch_size) + + assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' + + patch_dim = channels * patch_height * patch_width + dim = patch_dim + self.to_patch_embedding = nn.Sequential( + Rearrange("b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", + p_h=patch_height, p_w=patch_width), + nn.LayerNorm(patch_dim), # TODO Do we need this? + nn.Linear(patch_dim, dim), # TODO Do we need this? + nn.LayerNorm(dim), # TODO Do we need this? + ) + + self.pos_embedding = posemb_sincos_2d( + h=image_height // patch_height, + w=image_width // patch_width, + dim=dim, + ) + + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) + + self.reshaper = nn.Sequential( + Rearrange("b (h w) (p_h p_w c) -> b c (h p_h) (w p_w)", + h=image_height // patch_height, + w=image_width // patch_width, + p_h=patch_height, p_w=patch_width) + ) + + def forward(self, img): + device = img.device + + x = self.to_patch_embedding(img) + x += self.pos_embedding.to(device, dtype=x.dtype) + + x = self.transformer(x) + + print(x.shape) + x = self.reshaper(x) + + return x diff --git a/tests/test_model.py b/tests/test_model.py index 58904292..2ef43cc8 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -3,7 +3,7 @@ import torch from graph_weather import GraphWeatherAssimilator, GraphWeatherForecaster -from graph_weather.models import AssimilatorDecoder, AssimilatorEncoder, Decoder, Encoder, Processor +from graph_weather.models import AssimilatorDecoder, AssimilatorEncoder, Decoder, Encoder, Processor, MetaModel from graph_weather.models.losses import NormalizedMSELoss @@ -222,3 +222,15 @@ def test_normalized_loss(): assert not torch.isnan(loss) # Since feature_variance = out**2 and target = 0, we expect loss = weights assert torch.isclose(loss, criterion.weights.expand_as(out.mean(-1)).mean()) + + +def test_meta_model(): + model = MetaModel(image_size=100,patch_size=10, + depth=1, heads=1, mlp_dim=7, + channels=3 ) + features = torch.randn((1,3, 100,100) ) + + out = model(features) + assert not torch.isnan(out).any() + assert not torch.isnan(out).any() + assert out.size() == (1,3, 100,100) From 9eaf70d5dad163afa5311c0819ce9f363abd82cb Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Wed, 29 May 2024 11:00:25 +0200 Subject: [PATCH 2/3] fengwu_ghr: fixes --- .gitignore | 1 + environment_cpu.yml | 2 +- graph_weather/models/fengwu_ghr/layers.py | 3 +-- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 9b8aab0b..d248bf98 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ *.txt # pixi environments .pixi +.vscode/ diff --git a/environment_cpu.yml b/environment_cpu.yml index db087fc2..e854dc84 100644 --- a/environment_cpu.yml +++ b/environment_cpu.yml @@ -9,7 +9,7 @@ dependencies: - pandas - pip - pyg - - python=3.12 + - python - pytorch - cpuonly - pytorch-cluster diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 62425651..9aeaf508 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -126,8 +126,7 @@ def forward(self, img): x += self.pos_embedding.to(device, dtype=x.dtype) x = self.transformer(x) - - print(x.shape) + x = self.reshaper(x) return x From 4f3d4c1774fded109433219dd56b7de2ef2c27c0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 May 2024 09:11:13 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- graph_weather/models/__init__.py | 2 +- graph_weather/models/fengwu_ghr/layers.py | 50 +++++++++++++---------- tests/test_model.py | 19 +++++---- 3 files changed, 41 insertions(+), 30 deletions(-) diff --git a/graph_weather/models/__init__.py b/graph_weather/models/__init__.py index 289d3724..72d222a8 100644 --- a/graph_weather/models/__init__.py +++ b/graph_weather/models/__init__.py @@ -1,8 +1,8 @@ """Models""" +from .fengwu_ghr.layers import MetaModel from .layers.assimilator_decoder import AssimilatorDecoder from .layers.assimilator_encoder import AssimilatorEncoder from .layers.decoder import Decoder from .layers.encoder import Encoder from .layers.processor import Processor -from .fengwu_ghr.layers import MetaModel diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 9aeaf508..cd81218d 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -1,8 +1,7 @@ import torch -from torch import nn - from einops import rearrange from einops.layers.torch import Rearrange +from torch import nn # helpers @@ -15,13 +14,14 @@ def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32): y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" omega = torch.arange(dim // 4) / (dim // 4 - 1) - omega = 1.0 / (temperature ** omega) + omega = 1.0 / (temperature**omega) y = y.flatten()[:, None] * omega[None, :] x = x.flatten()[:, None] * omega[None, :] pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) return pe.type(dtype) + # classes @@ -44,7 +44,7 @@ def __init__(self, dim, heads=8, dim_head=64): super().__init__() inner_dim = dim_head * heads self.heads = heads - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 self.norm = nn.LayerNorm(dim) self.attend = nn.Softmax(dim=-1) @@ -56,15 +56,14 @@ def forward(self, x): x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange( - t, 'b n (h d) -> b h n d', h=self.heads), qkv) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = self.attend(dots) out = torch.matmul(attn, v) - out = rearrange(out, 'b h n d -> b n (h d)') + out = rearrange(out, "b h n d -> b n (h d)") return self.to_out(out) @@ -74,10 +73,11 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim): self.norm = nn.LayerNorm(dim) self.layers = nn.ModuleList([]) for _ in range(depth): - self.layers.append(nn.ModuleList([ - Attention(dim, heads=heads, dim_head=dim_head), - FeedForward(dim, mlp_dim) - ])) + self.layers.append( + nn.ModuleList( + [Attention(dim, heads=heads, dim_head=dim_head), FeedForward(dim, mlp_dim)] + ) + ) def forward(self, x): for attn, ff in self.layers: @@ -92,16 +92,19 @@ def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, image_height, image_width = pair(image_size) patch_height, patch_width = pair(patch_size) - assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' + assert ( + image_height % patch_height == 0 and image_width % patch_width == 0 + ), "Image dimensions must be divisible by the patch size." patch_dim = channels * patch_height * patch_width dim = patch_dim self.to_patch_embedding = nn.Sequential( - Rearrange("b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", - p_h=patch_height, p_w=patch_width), - nn.LayerNorm(patch_dim), # TODO Do we need this? - nn.Linear(patch_dim, dim), # TODO Do we need this? - nn.LayerNorm(dim), # TODO Do we need this? + Rearrange( + "b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", p_h=patch_height, p_w=patch_width + ), + nn.LayerNorm(patch_dim), # TODO Do we need this? + nn.Linear(patch_dim, dim), # TODO Do we need this? + nn.LayerNorm(dim), # TODO Do we need this? ) self.pos_embedding = posemb_sincos_2d( @@ -113,10 +116,13 @@ def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) self.reshaper = nn.Sequential( - Rearrange("b (h w) (p_h p_w c) -> b c (h p_h) (w p_w)", - h=image_height // patch_height, - w=image_width // patch_width, - p_h=patch_height, p_w=patch_width) + Rearrange( + "b (h w) (p_h p_w c) -> b c (h p_h) (w p_w)", + h=image_height // patch_height, + w=image_width // patch_width, + p_h=patch_height, + p_w=patch_width, + ) ) def forward(self, img): @@ -126,7 +132,7 @@ def forward(self, img): x += self.pos_embedding.to(device, dtype=x.dtype) x = self.transformer(x) - + x = self.reshaper(x) return x diff --git a/tests/test_model.py b/tests/test_model.py index 2ef43cc8..050f3e28 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -3,7 +3,14 @@ import torch from graph_weather import GraphWeatherAssimilator, GraphWeatherForecaster -from graph_weather.models import AssimilatorDecoder, AssimilatorEncoder, Decoder, Encoder, Processor, MetaModel +from graph_weather.models import ( + AssimilatorDecoder, + AssimilatorEncoder, + Decoder, + Encoder, + Processor, + MetaModel, +) from graph_weather.models.losses import NormalizedMSELoss @@ -225,12 +232,10 @@ def test_normalized_loss(): def test_meta_model(): - model = MetaModel(image_size=100,patch_size=10, - depth=1, heads=1, mlp_dim=7, - channels=3 ) - features = torch.randn((1,3, 100,100) ) - + model = MetaModel(image_size=100, patch_size=10, depth=1, heads=1, mlp_dim=7, channels=3) + features = torch.randn((1, 3, 100, 100)) + out = model(features) assert not torch.isnan(out).any() assert not torch.isnan(out).any() - assert out.size() == (1,3, 100,100) + assert out.size() == (1, 3, 100, 100)