From cb7055190831d203aaa77cdfdef4a7066bbc4029 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Wed, 29 May 2024 10:58:50 +0200 Subject: [PATCH 01/45] fengwu_ghr: initial --- graph_weather/models/__init__.py | 1 + graph_weather/models/fengwu_ghr/layers.py | 133 ++++++++++++++++++++++ tests/test_model.py | 14 ++- 3 files changed, 147 insertions(+), 1 deletion(-) create mode 100644 graph_weather/models/fengwu_ghr/layers.py diff --git a/graph_weather/models/__init__.py b/graph_weather/models/__init__.py index a18cda87..289d3724 100644 --- a/graph_weather/models/__init__.py +++ b/graph_weather/models/__init__.py @@ -5,3 +5,4 @@ from .layers.decoder import Decoder from .layers.encoder import Encoder from .layers.processor import Processor +from .fengwu_ghr.layers import MetaModel diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py new file mode 100644 index 00000000..62425651 --- /dev/null +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -0,0 +1,133 @@ +import torch +from torch import nn + +from einops import rearrange +from einops.layers.torch import Rearrange + +# helpers + + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + + +def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") + assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" + omega = torch.arange(dim // 4) / (dim // 4 - 1) + omega = 1.0 / (temperature ** omega) + + y = y.flatten()[:, None] * omega[None, :] + x = x.flatten()[:, None] * omega[None, :] + pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) + return pe.type(dtype) + +# classes + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, dim), + ) + + def forward(self, x): + return self.net(x) + + +class Attention(nn.Module): + def __init__(self, dim, heads=8, dim_head=64): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.scale = dim_head ** -0.5 + self.norm = nn.LayerNorm(dim) + + self.attend = nn.Softmax(dim=-1) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x): + x = self.norm(x) + + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange( + t, 'b n (h d) -> b h n d', h=self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + Attention(dim, heads=heads, dim_head=dim_head), + FeedForward(dim, mlp_dim) + ])) + + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return self.norm(x) + + +class MetaModel(nn.Module): + def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, dim_head=64): + super().__init__() + image_height, image_width = pair(image_size) + patch_height, patch_width = pair(patch_size) + + assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' + + patch_dim = channels * patch_height * patch_width + dim = patch_dim + self.to_patch_embedding = nn.Sequential( + Rearrange("b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", + p_h=patch_height, p_w=patch_width), + nn.LayerNorm(patch_dim), # TODO Do we need this? + nn.Linear(patch_dim, dim), # TODO Do we need this? + nn.LayerNorm(dim), # TODO Do we need this? + ) + + self.pos_embedding = posemb_sincos_2d( + h=image_height // patch_height, + w=image_width // patch_width, + dim=dim, + ) + + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) + + self.reshaper = nn.Sequential( + Rearrange("b (h w) (p_h p_w c) -> b c (h p_h) (w p_w)", + h=image_height // patch_height, + w=image_width // patch_width, + p_h=patch_height, p_w=patch_width) + ) + + def forward(self, img): + device = img.device + + x = self.to_patch_embedding(img) + x += self.pos_embedding.to(device, dtype=x.dtype) + + x = self.transformer(x) + + print(x.shape) + x = self.reshaper(x) + + return x diff --git a/tests/test_model.py b/tests/test_model.py index 58904292..2ef43cc8 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -3,7 +3,7 @@ import torch from graph_weather import GraphWeatherAssimilator, GraphWeatherForecaster -from graph_weather.models import AssimilatorDecoder, AssimilatorEncoder, Decoder, Encoder, Processor +from graph_weather.models import AssimilatorDecoder, AssimilatorEncoder, Decoder, Encoder, Processor, MetaModel from graph_weather.models.losses import NormalizedMSELoss @@ -222,3 +222,15 @@ def test_normalized_loss(): assert not torch.isnan(loss) # Since feature_variance = out**2 and target = 0, we expect loss = weights assert torch.isclose(loss, criterion.weights.expand_as(out.mean(-1)).mean()) + + +def test_meta_model(): + model = MetaModel(image_size=100,patch_size=10, + depth=1, heads=1, mlp_dim=7, + channels=3 ) + features = torch.randn((1,3, 100,100) ) + + out = model(features) + assert not torch.isnan(out).any() + assert not torch.isnan(out).any() + assert out.size() == (1,3, 100,100) From 9eaf70d5dad163afa5311c0819ce9f363abd82cb Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Wed, 29 May 2024 11:00:25 +0200 Subject: [PATCH 02/45] fengwu_ghr: fixes --- .gitignore | 1 + environment_cpu.yml | 2 +- graph_weather/models/fengwu_ghr/layers.py | 3 +-- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 9b8aab0b..d248bf98 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ *.txt # pixi environments .pixi +.vscode/ diff --git a/environment_cpu.yml b/environment_cpu.yml index db087fc2..e854dc84 100644 --- a/environment_cpu.yml +++ b/environment_cpu.yml @@ -9,7 +9,7 @@ dependencies: - pandas - pip - pyg - - python=3.12 + - python - pytorch - cpuonly - pytorch-cluster diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 62425651..9aeaf508 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -126,8 +126,7 @@ def forward(self, img): x += self.pos_embedding.to(device, dtype=x.dtype) x = self.transformer(x) - - print(x.shape) + x = self.reshaper(x) return x From 4f3d4c1774fded109433219dd56b7de2ef2c27c0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 May 2024 09:11:13 +0000 Subject: [PATCH 03/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- graph_weather/models/__init__.py | 2 +- graph_weather/models/fengwu_ghr/layers.py | 50 +++++++++++++---------- tests/test_model.py | 19 +++++---- 3 files changed, 41 insertions(+), 30 deletions(-) diff --git a/graph_weather/models/__init__.py b/graph_weather/models/__init__.py index 289d3724..72d222a8 100644 --- a/graph_weather/models/__init__.py +++ b/graph_weather/models/__init__.py @@ -1,8 +1,8 @@ """Models""" +from .fengwu_ghr.layers import MetaModel from .layers.assimilator_decoder import AssimilatorDecoder from .layers.assimilator_encoder import AssimilatorEncoder from .layers.decoder import Decoder from .layers.encoder import Encoder from .layers.processor import Processor -from .fengwu_ghr.layers import MetaModel diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 9aeaf508..cd81218d 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -1,8 +1,7 @@ import torch -from torch import nn - from einops import rearrange from einops.layers.torch import Rearrange +from torch import nn # helpers @@ -15,13 +14,14 @@ def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32): y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" omega = torch.arange(dim // 4) / (dim // 4 - 1) - omega = 1.0 / (temperature ** omega) + omega = 1.0 / (temperature**omega) y = y.flatten()[:, None] * omega[None, :] x = x.flatten()[:, None] * omega[None, :] pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) return pe.type(dtype) + # classes @@ -44,7 +44,7 @@ def __init__(self, dim, heads=8, dim_head=64): super().__init__() inner_dim = dim_head * heads self.heads = heads - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 self.norm = nn.LayerNorm(dim) self.attend = nn.Softmax(dim=-1) @@ -56,15 +56,14 @@ def forward(self, x): x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange( - t, 'b n (h d) -> b h n d', h=self.heads), qkv) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = self.attend(dots) out = torch.matmul(attn, v) - out = rearrange(out, 'b h n d -> b n (h d)') + out = rearrange(out, "b h n d -> b n (h d)") return self.to_out(out) @@ -74,10 +73,11 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim): self.norm = nn.LayerNorm(dim) self.layers = nn.ModuleList([]) for _ in range(depth): - self.layers.append(nn.ModuleList([ - Attention(dim, heads=heads, dim_head=dim_head), - FeedForward(dim, mlp_dim) - ])) + self.layers.append( + nn.ModuleList( + [Attention(dim, heads=heads, dim_head=dim_head), FeedForward(dim, mlp_dim)] + ) + ) def forward(self, x): for attn, ff in self.layers: @@ -92,16 +92,19 @@ def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, image_height, image_width = pair(image_size) patch_height, patch_width = pair(patch_size) - assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' + assert ( + image_height % patch_height == 0 and image_width % patch_width == 0 + ), "Image dimensions must be divisible by the patch size." patch_dim = channels * patch_height * patch_width dim = patch_dim self.to_patch_embedding = nn.Sequential( - Rearrange("b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", - p_h=patch_height, p_w=patch_width), - nn.LayerNorm(patch_dim), # TODO Do we need this? - nn.Linear(patch_dim, dim), # TODO Do we need this? - nn.LayerNorm(dim), # TODO Do we need this? + Rearrange( + "b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", p_h=patch_height, p_w=patch_width + ), + nn.LayerNorm(patch_dim), # TODO Do we need this? + nn.Linear(patch_dim, dim), # TODO Do we need this? + nn.LayerNorm(dim), # TODO Do we need this? ) self.pos_embedding = posemb_sincos_2d( @@ -113,10 +116,13 @@ def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) self.reshaper = nn.Sequential( - Rearrange("b (h w) (p_h p_w c) -> b c (h p_h) (w p_w)", - h=image_height // patch_height, - w=image_width // patch_width, - p_h=patch_height, p_w=patch_width) + Rearrange( + "b (h w) (p_h p_w c) -> b c (h p_h) (w p_w)", + h=image_height // patch_height, + w=image_width // patch_width, + p_h=patch_height, + p_w=patch_width, + ) ) def forward(self, img): @@ -126,7 +132,7 @@ def forward(self, img): x += self.pos_embedding.to(device, dtype=x.dtype) x = self.transformer(x) - + x = self.reshaper(x) return x diff --git a/tests/test_model.py b/tests/test_model.py index 2ef43cc8..050f3e28 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -3,7 +3,14 @@ import torch from graph_weather import GraphWeatherAssimilator, GraphWeatherForecaster -from graph_weather.models import AssimilatorDecoder, AssimilatorEncoder, Decoder, Encoder, Processor, MetaModel +from graph_weather.models import ( + AssimilatorDecoder, + AssimilatorEncoder, + Decoder, + Encoder, + Processor, + MetaModel, +) from graph_weather.models.losses import NormalizedMSELoss @@ -225,12 +232,10 @@ def test_normalized_loss(): def test_meta_model(): - model = MetaModel(image_size=100,patch_size=10, - depth=1, heads=1, mlp_dim=7, - channels=3 ) - features = torch.randn((1,3, 100,100) ) - + model = MetaModel(image_size=100, patch_size=10, depth=1, heads=1, mlp_dim=7, channels=3) + features = torch.randn((1, 3, 100, 100)) + out = model(features) assert not torch.isnan(out).any() assert not torch.isnan(out).any() - assert out.size() == (1,3, 100,100) + assert out.size() == (1, 3, 100, 100) From 8c60fb71e4de98560d601b0a295b2b957fbe9e3b Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Thu, 6 Jun 2024 15:07:10 +0200 Subject: [PATCH 04/45] Interpolate initial --- graph_weather/models/fengwu_ghr/layers.py | 79 ++++++++++++++++++++--- tests/test_model.py | 31 ++++++--- 2 files changed, 92 insertions(+), 18 deletions(-) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index cd81218d..036d9048 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -1,3 +1,5 @@ +import numpy as np +from scipy.interpolate import griddata, interpn import torch from einops import rearrange from einops.layers.torch import Rearrange @@ -10,6 +12,39 @@ def pair(t): return t if isinstance(t, tuple) else (t, t) +def grid_interpolate(lat_lons: list, z: torch.Tensor, + height, width, + method: str = "cubic"): + # TODO 1. CPU only + # 2. The mesh is a rectangle, not a sphere + + xi = np.arange(0.5, width, 1)/width*360 + yi = np.arange(0.5, height, 1)/height*180 + + xi, yi = np.meshgrid(xi, yi) + z = rearrange(z, "b n c -> n b c") + z = griddata( + lat_lons, z, (xi, yi), + fill_value=0, method=method) + z = rearrange(z, "h w b c -> b c h w") # hw ? + z = torch.tensor(z) + return z + +def grid_extrapolate(lat_lons, z, + height, width, + method: str = "cubic"): + xi = np.arange(0.5, width, 1)/width*360 + yi = np.arange(0.5, height, 1)/height*180 + z = rearrange(z, "b c h w -> h w b c") + z = z.detach().numpy() + z= interpn((xi,yi),z, lat_lons, + bounds_error=False, + method=method) + z = rearrange(z, "n b c -> b n c") + z = torch.tensor(z) + return z + + def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32): y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" @@ -56,7 +91,8 @@ def forward(self, x): x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) + q, k, v = map(lambda t: rearrange( + t, "b n (h d) -> b h n d", h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale @@ -75,7 +111,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim): for _ in range(depth): self.layers.append( nn.ModuleList( - [Attention(dim, heads=heads, dim_head=dim_head), FeedForward(dim, mlp_dim)] + [Attention(dim, heads=heads, dim_head=dim_head), + FeedForward(dim, mlp_dim)] ) ) @@ -87,20 +124,31 @@ def forward(self, x): class MetaModel(nn.Module): - def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, dim_head=64): + def __init__(self, lat_lons: list, *, + patch_size, depth, + heads, mlp_dim, + resolution=(721, 1440), + channels=3, dim_head=64, + interp_method='cubic'): super().__init__() - image_height, image_width = pair(image_size) + image_height, image_width = pair(resolution) patch_height, patch_width = pair(patch_size) assert ( image_height % patch_height == 0 and image_width % patch_width == 0 ), "Image dimensions must be divisible by the patch size." + # interpolate + self.interpolate = lambda z: grid_interpolate( + lat_lons, z, image_height, image_width, + method=interp_method) + patch_dim = channels * patch_height * patch_width dim = patch_dim self.to_patch_embedding = nn.Sequential( Rearrange( - "b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", p_h=patch_height, p_w=patch_width + "b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", + p_h=patch_height, p_w=patch_width ), nn.LayerNorm(patch_dim), # TODO Do we need this? nn.Linear(patch_dim, dim), # TODO Do we need this? @@ -125,14 +173,27 @@ def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, ) ) - def forward(self, img): - device = img.device + # extrapolate + self.extrapolate = lambda z: grid_extrapolate( + lat_lons, z, image_height, image_width, + method=interp_method) + - x = self.to_patch_embedding(img) - x += self.pos_embedding.to(device, dtype=x.dtype) + def forward(self, x): + device = x.device + dtype = x.dtype + + x = self.interpolate(x.to("cpu")) + x = x.to(device, dtype=dtype) + + x = self.to_patch_embedding(x) + x += self.pos_embedding.to(device, dtype=dtype) x = self.transformer(x) x = self.reshaper(x) + x = self.extrapolate(x.to("cpu")) + x = x.to(device, dtype=dtype) + return x diff --git a/tests/test_model.py b/tests/test_model.py index 050f3e28..474a3f89 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -142,7 +142,8 @@ def test_assimilator_model(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): output_lat_lons.append((lat, lon)) - model = GraphWeatherAssimilator(output_lat_lons=output_lat_lons, analysis_dim=24) + model = GraphWeatherAssimilator( + output_lat_lons=output_lat_lons, analysis_dim=24) features = torch.randn((1, len(obs_lat_lons), 2)) lat_lon_heights = torch.tensor(obs_lat_lons) @@ -156,7 +157,8 @@ def test_forecaster_and_loss(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss( + lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) @@ -197,7 +199,8 @@ def test_forecaster_and_loss_grad_checkpoint(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss( + lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons, use_checkpointing=True) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) @@ -228,14 +231,24 @@ def test_normalized_loss(): assert not torch.isnan(loss) # Since feature_variance = out**2 and target = 0, we expect loss = weights - assert torch.isclose(loss, criterion.weights.expand_as(out.mean(-1)).mean()) + assert torch.isclose( + loss, criterion.weights.expand_as(out.mean(-1)).mean()) def test_meta_model(): - model = MetaModel(image_size=100, patch_size=10, depth=1, heads=1, mlp_dim=7, channels=3) - features = torch.randn((1, 3, 100, 100)) + lat_lons = [] + for lat in range(-90, 90, 5): + for lon in range(0, 360, 5): + lat_lons.append((lat, lon)) + + batch =2 + channels = 3 + model = MetaModel(lat_lons, + resolution=4, patch_size=2, + depth=1, heads=1, mlp_dim=7, channels=channels) + features = torch.randn((batch,len(lat_lons), channels)) out = model(features) - assert not torch.isnan(out).any() - assert not torch.isnan(out).any() - assert out.size() == (1, 3, 100, 100) + #assert not torch.isnan(out).any() + #assert not torch.isnan(out).any() + assert out.size() == (batch,len(lat_lons), channels) From 725421df408eafcd4bcea62a420b7a8e704ad299 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Tue, 11 Jun 2024 15:26:54 +0200 Subject: [PATCH 05/45] ImageMetaModel --- graph_weather/models/__init__.py | 2 +- graph_weather/models/fengwu_ghr/layers.py | 86 ++++++++++++++++------- tests/test_model.py | 27 +++++-- 3 files changed, 82 insertions(+), 33 deletions(-) diff --git a/graph_weather/models/__init__.py b/graph_weather/models/__init__.py index 72d222a8..fadc1d52 100644 --- a/graph_weather/models/__init__.py +++ b/graph_weather/models/__init__.py @@ -1,6 +1,6 @@ """Models""" -from .fengwu_ghr.layers import MetaModel +from .fengwu_ghr.layers import MetaModel,ImageMetaModel from .layers.assimilator_decoder import AssimilatorDecoder from .layers.assimilator_encoder import AssimilatorEncoder from .layers.decoder import Decoder diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 036d9048..e3da7e31 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -5,12 +5,38 @@ from einops.layers.torch import Rearrange from torch import nn + # helpers def pair(t): return t if isinstance(t, tuple) else (t, t) +from torch_geometric.nn import knn +from torch_geometric.utils import scatter + + +def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, + k: int = 3, num_workers: int = 1): + with torch.no_grad(): + assign_index = knn(pos_x, pos_y, k, + num_workers=num_workers) + y_idx, x_idx = assign_index[0], assign_index[1] + diff = pos_x[x_idx] - pos_y[y_idx] + squared_distance = (diff * diff).sum(dim=-1, keepdim=True) + weights = 1.0 / torch.clamp(squared_distance, min=1e-16) + + + # print((x[x_idx]*weights).shape) + # print(weights.shape) + den = scatter(weights, y_idx, 0, pos_y.size(0), reduce='sum') + # print(den.shape) + y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce='sum') + + + y = y / den + + return y def grid_interpolate(lat_lons: list, z: torch.Tensor, height, width, @@ -30,6 +56,7 @@ def grid_interpolate(lat_lons: list, z: torch.Tensor, z = torch.tensor(z) return z + def grid_extrapolate(lat_lons, z, height, width, method: str = "cubic"): @@ -37,8 +64,8 @@ def grid_extrapolate(lat_lons, z, yi = np.arange(0.5, height, 1)/height*180 z = rearrange(z, "b c h w -> h w b c") z = z.detach().numpy() - z= interpn((xi,yi),z, lat_lons, - bounds_error=False, + z = interpn((xi, yi), z, lat_lons, + bounds_error=False, method=method) z = rearrange(z, "n b c -> b n c") z = torch.tensor(z) @@ -122,27 +149,20 @@ def forward(self, x): x = ff(x) + x return self.norm(x) - -class MetaModel(nn.Module): - def __init__(self, lat_lons: list, *, +class ImageMetaModel(nn.Module): + def __init__(self, *, + image_size, patch_size, depth, heads, mlp_dim, - resolution=(721, 1440), - channels=3, dim_head=64, - interp_method='cubic'): + channels=3, dim_head=64): super().__init__() - image_height, image_width = pair(resolution) + image_height, image_width = pair(image_size) patch_height, patch_width = pair(patch_size) assert ( image_height % patch_height == 0 and image_width % patch_width == 0 ), "Image dimensions must be divisible by the patch size." - # interpolate - self.interpolate = lambda z: grid_interpolate( - lat_lons, z, image_height, image_width, - method=interp_method) - patch_dim = channels * patch_height * patch_width dim = patch_dim self.to_patch_embedding = nn.Sequential( @@ -173,27 +193,39 @@ def __init__(self, lat_lons: list, *, ) ) - # extrapolate - self.extrapolate = lambda z: grid_extrapolate( - lat_lons, z, image_height, image_width, - method=interp_method) - - def forward(self, x): device = x.device dtype = x.dtype - x = self.interpolate(x.to("cpu")) - x = x.to(device, dtype=dtype) - x = self.to_patch_embedding(x) x += self.pos_embedding.to(device, dtype=dtype) x = self.transformer(x) - x = self.reshaper(x) - x = self.extrapolate(x.to("cpu")) - x = x.to(device, dtype=dtype) - return x + +class MetaModel(nn.Module): + def __init__(self, lat_lons: list, *, + patch_size, depth, + heads, mlp_dim, + resolution=(721, 1440), + channels=3, dim_head=64, + interp_method='cubic'): + super().__init__() + resolution = pair(resolution) + b=3 + n=len(lat_lons) + d=7 + x=torch.randn((b,n,d)) + x=rearrange(x,"b n d -> n (b d)") + + pos_x= torch.tensor(lat_lons) + pos_y = torch.cartesian_prod( + torch.arange(0.5,resolution[0],1), + torch.arange(0.5,resolution[1],1) + ) + x = knn_interpolate(x,pos_x,pos_y) + x = rearrange(x,"m (b d) -> b m d", b=b,d=d) + print(x.shape) + diff --git a/tests/test_model.py b/tests/test_model.py index 474a3f89..95ec75bf 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -10,6 +10,7 @@ Encoder, Processor, MetaModel, + ImageMetaModel ) from graph_weather.models.losses import NormalizedMSELoss @@ -235,20 +236,36 @@ def test_normalized_loss(): loss, criterion.weights.expand_as(out.mean(-1)).mean()) +def test_image_meta_model(): + batch = 2 + channels = 3 + size = 900 + image = torch.randn((batch, channels, size, size)) + model = ImageMetaModel(image_size=size, + patch_size=10, + depth=1, heads=1, mlp_dim=7, + channels=channels) + + out = model(image) + assert not torch.isnan(out).any() + assert not torch.isnan(out).any() + assert out.size() == (batch, channels,size,size) + + def test_meta_model(): lat_lons = [] for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - batch =2 + batch = 2 channels = 3 model = MetaModel(lat_lons, resolution=4, patch_size=2, depth=1, heads=1, mlp_dim=7, channels=channels) - features = torch.randn((batch,len(lat_lons), channels)) + features = torch.randn((batch, len(lat_lons), channels)) out = model(features) - #assert not torch.isnan(out).any() - #assert not torch.isnan(out).any() - assert out.size() == (batch,len(lat_lons), channels) + # assert not torch.isnan(out).any() + # assert not torch.isnan(out).any() + assert out.size() == (batch, len(lat_lons), channels) From c57a27ec2358d5cd2ed0883d1f406efa52806a7b Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Tue, 11 Jun 2024 15:59:59 +0200 Subject: [PATCH 06/45] MetaModel initial --- graph_weather/models/fengwu_ghr/layers.py | 78 +++++++++++++++++------ 1 file changed, 60 insertions(+), 18 deletions(-) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index e3da7e31..0a4b7c3c 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -1,3 +1,6 @@ +from scipy.interpolate import griddata +from torch_geometric.nn import knn +from torch_geometric.utils import scatter import numpy as np from scipy.interpolate import griddata, interpn import torch @@ -12,9 +15,6 @@ def pair(t): return t if isinstance(t, tuple) else (t, t) -from torch_geometric.nn import knn -from torch_geometric.utils import scatter - def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, k: int = 3, num_workers: int = 1): @@ -26,18 +26,17 @@ def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, squared_distance = (diff * diff).sum(dim=-1, keepdim=True) weights = 1.0 / torch.clamp(squared_distance, min=1e-16) - # print((x[x_idx]*weights).shape) # print(weights.shape) den = scatter(weights, y_idx, 0, pos_y.size(0), reduce='sum') # print(den.shape) y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce='sum') - - + y = y / den return y + def grid_interpolate(lat_lons: list, z: torch.Tensor, height, width, method: str = "cubic"): @@ -149,6 +148,7 @@ def forward(self, x): x = ff(x) + x return self.norm(x) + class ImageMetaModel(nn.Module): def __init__(self, *, image_size, @@ -205,7 +205,50 @@ def forward(self, x): return x + class MetaModel(nn.Module): + def __init__(self, lat_lons: list, *, + patch_size, depth, + heads, mlp_dim, + resolution=(721, 1440), + channels=3, dim_head=64, + interp_method='cubic'): + super().__init__() + self.resolution = pair(resolution) + + self.pos_x = torch.tensor(lat_lons) + self.pos_y = torch.cartesian_prod( + torch.arange(0, self.resolution[0], 1), + torch.arange(0, self.resolution[1], 1) + ) + + self.image_model = ImageMetaModel(image_size=resolution, + patch_size=patch_size, + depth=depth, + heads=heads, + mlp_dim=mlp_dim, + channels=channels, + dim_head=dim_head) + + def forward(self, x): + b, n, c = x.shape + + x = rearrange(x, "b n c -> n (b c)") + x = knn_interpolate(x, self.pos_x, self.pos_y) + x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, + w=self.resolution[0], + h=self.resolution[1]) + + x = self.image_model(x) + + x = rearrange(x, "b c h w -> (h w) (b c)") + x = knn_interpolate(x, self.pos_y, self.pos_x) + x = rearrange(x, "n (b c) -> b n c", b=b, c=c) + + return x + + +class MetaModel2(nn.Module): def __init__(self, lat_lons: list, *, patch_size, depth, heads, mlp_dim, @@ -214,18 +257,17 @@ def __init__(self, lat_lons: list, *, interp_method='cubic'): super().__init__() resolution = pair(resolution) - b=3 - n=len(lat_lons) - d=7 - x=torch.randn((b,n,d)) - x=rearrange(x,"b n d -> n (b d)") - - pos_x= torch.tensor(lat_lons) + b = 3 + n = len(lat_lons) + d = 7 + x = torch.randn((b, n, d)) + x = rearrange(x, "b n d -> n (b d)") + + pos_x = torch.tensor(lat_lons) pos_y = torch.cartesian_prod( - torch.arange(0.5,resolution[0],1), - torch.arange(0.5,resolution[1],1) + torch.arange(0, resolution[0], 1), + torch.arange(0, resolution[1], 1) ) - x = knn_interpolate(x,pos_x,pos_y) - x = rearrange(x,"m (b d) -> b m d", b=b,d=d) + x = knn_interpolate(x, pos_x, pos_y) + x = rearrange(x, "m (b d) -> b m d", b=b, d=d) print(x.shape) - From 3d2a17d62c15d93f29a22fc54f8ebc7ad4052d9e Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Fri, 14 Jun 2024 16:44:08 +0200 Subject: [PATCH 07/45] tested metamodel --- graph_weather/models/fengwu_ghr/layers.py | 90 +++-------------------- tests/test_model.py | 19 +++-- 2 files changed, 23 insertions(+), 86 deletions(-) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 0a4b7c3c..dc530eac 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -1,8 +1,6 @@ from scipy.interpolate import griddata from torch_geometric.nn import knn from torch_geometric.utils import scatter -import numpy as np -from scipy.interpolate import griddata, interpn import torch from einops import rearrange from einops.layers.torch import Rearrange @@ -17,7 +15,7 @@ def pair(t): def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, - k: int = 3, num_workers: int = 1): + k: int = 4, num_workers: int = 1): with torch.no_grad(): assign_index = knn(pos_x, pos_y, k, num_workers=num_workers) @@ -26,10 +24,7 @@ def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, squared_distance = (diff * diff).sum(dim=-1, keepdim=True) weights = 1.0 / torch.clamp(squared_distance, min=1e-16) - # print((x[x_idx]*weights).shape) - # print(weights.shape) den = scatter(weights, y_idx, 0, pos_y.size(0), reduce='sum') - # print(den.shape) y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce='sum') y = y / den @@ -37,40 +32,6 @@ def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, return y -def grid_interpolate(lat_lons: list, z: torch.Tensor, - height, width, - method: str = "cubic"): - # TODO 1. CPU only - # 2. The mesh is a rectangle, not a sphere - - xi = np.arange(0.5, width, 1)/width*360 - yi = np.arange(0.5, height, 1)/height*180 - - xi, yi = np.meshgrid(xi, yi) - z = rearrange(z, "b n c -> n b c") - z = griddata( - lat_lons, z, (xi, yi), - fill_value=0, method=method) - z = rearrange(z, "h w b c -> b c h w") # hw ? - z = torch.tensor(z) - return z - - -def grid_extrapolate(lat_lons, z, - height, width, - method: str = "cubic"): - xi = np.arange(0.5, width, 1)/width*360 - yi = np.arange(0.5, height, 1)/height*180 - z = rearrange(z, "b c h w -> h w b c") - z = z.detach().numpy() - z = interpn((xi, yi), z, lat_lons, - bounds_error=False, - method=method) - z = rearrange(z, "n b c -> b n c") - z = torch.tensor(z) - return z - - def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32): y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" @@ -210,19 +171,19 @@ class MetaModel(nn.Module): def __init__(self, lat_lons: list, *, patch_size, depth, heads, mlp_dim, - resolution=(721, 1440), - channels=3, dim_head=64, - interp_method='cubic'): + image_size=(721, 1440), + channels=3, dim_head=64): super().__init__() - self.resolution = pair(resolution) + self.image_size = pair(image_size) self.pos_x = torch.tensor(lat_lons) self.pos_y = torch.cartesian_prod( - torch.arange(0, self.resolution[0], 1), - torch.arange(0, self.resolution[1], 1) + (torch.arange(-self.image_size[0]/2, + self.image_size[0]/2, 1)/self.image_size[0]*180).to(torch.long), + (torch.arange(0, self.image_size[1], 1)/self.image_size[1]*360).to(torch.long) ) - self.image_model = ImageMetaModel(image_size=resolution, + self.image_model = ImageMetaModel(image_size=image_size, patch_size=patch_size, depth=depth, heads=heads, @@ -235,39 +196,12 @@ def forward(self, x): x = rearrange(x, "b n c -> n (b c)") x = knn_interpolate(x, self.pos_x, self.pos_y) - x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, - w=self.resolution[0], - h=self.resolution[1]) - + x = rearrange(x, "(w h) (b c) -> b c w h", b=b, c=c, + w=self.image_size[0], + h=self.image_size[1]) x = self.image_model(x) - x = rearrange(x, "b c h w -> (h w) (b c)") + x = rearrange(x, "b c w h -> (w h) (b c)") x = knn_interpolate(x, self.pos_y, self.pos_x) x = rearrange(x, "n (b c) -> b n c", b=b, c=c) - return x - - -class MetaModel2(nn.Module): - def __init__(self, lat_lons: list, *, - patch_size, depth, - heads, mlp_dim, - resolution=(721, 1440), - channels=3, dim_head=64, - interp_method='cubic'): - super().__init__() - resolution = pair(resolution) - b = 3 - n = len(lat_lons) - d = 7 - x = torch.randn((b, n, d)) - x = rearrange(x, "b n d -> n (b d)") - - pos_x = torch.tensor(lat_lons) - pos_y = torch.cartesian_prod( - torch.arange(0, resolution[0], 1), - torch.arange(0, resolution[1], 1) - ) - x = knn_interpolate(x, pos_x, pos_y) - x = rearrange(x, "m (b d) -> b m d", b=b, d=d) - print(x.shape) diff --git a/tests/test_model.py b/tests/test_model.py index 95ec75bf..c290b118 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -239,19 +239,20 @@ def test_normalized_loss(): def test_image_meta_model(): batch = 2 channels = 3 - size = 900 + size = 4 + patch_size = 2 image = torch.randn((batch, channels, size, size)) model = ImageMetaModel(image_size=size, - patch_size=10, - depth=1, heads=1, mlp_dim=7, - channels=channels) + patch_size=patch_size, + channels=channels, + depth=1, heads=1, mlp_dim=7 + ) out = model(image) assert not torch.isnan(out).any() assert not torch.isnan(out).any() assert out.size() == (batch, channels,size,size) - def test_meta_model(): lat_lons = [] for lat in range(-90, 90, 5): @@ -260,12 +261,14 @@ def test_meta_model(): batch = 2 channels = 3 + image_size=20 + patch_size=4 model = MetaModel(lat_lons, - resolution=4, patch_size=2, + image_size=image_size, patch_size=patch_size, depth=1, heads=1, mlp_dim=7, channels=channels) features = torch.randn((batch, len(lat_lons), channels)) out = model(features) - # assert not torch.isnan(out).any() - # assert not torch.isnan(out).any() + assert not torch.isnan(out).any() + assert not torch.isnan(out).any() assert out.size() == (batch, len(lat_lons), channels) From 87d1ffd22e34960bca767d71ca0325e368e59f19 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 17 Jun 2024 17:16:17 +0000 Subject: [PATCH 08/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- graph_weather/models/__init__.py | 2 +- graph_weather/models/fengwu_ghr/layers.py | 80 ++++++++++++----------- tests/test_model.py | 43 ++++++------ 3 files changed, 65 insertions(+), 60 deletions(-) diff --git a/graph_weather/models/__init__.py b/graph_weather/models/__init__.py index fadc1d52..0083b16f 100644 --- a/graph_weather/models/__init__.py +++ b/graph_weather/models/__init__.py @@ -1,6 +1,6 @@ """Models""" -from .fengwu_ghr.layers import MetaModel,ImageMetaModel +from .fengwu_ghr.layers import ImageMetaModel, MetaModel from .layers.assimilator_decoder import AssimilatorDecoder from .layers.assimilator_encoder import AssimilatorEncoder from .layers.decoder import Decoder diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 348f997b..06d15681 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -1,10 +1,9 @@ -from torch_geometric.nn import knn -from torch_geometric.utils import scatter import torch from einops import rearrange from einops.layers.torch import Rearrange from torch import nn - +from torch_geometric.nn import knn +from torch_geometric.utils import scatter # helpers @@ -13,18 +12,18 @@ def pair(t): return t if isinstance(t, tuple) else (t, t) -def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, - k: int = 4, num_workers: int = 1): +def knn_interpolate( + x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, k: int = 4, num_workers: int = 1 +): with torch.no_grad(): - assign_index = knn(pos_x, pos_y, k, - num_workers=num_workers) + assign_index = knn(pos_x, pos_y, k, num_workers=num_workers) y_idx, x_idx = assign_index[0], assign_index[1] diff = pos_x[x_idx] - pos_y[y_idx] squared_distance = (diff * diff).sum(dim=-1, keepdim=True) weights = 1.0 / torch.clamp(squared_distance, min=1e-16) - den = scatter(weights, y_idx, 0, pos_y.size(0), reduce='sum') - y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce='sum') + den = scatter(weights, y_idx, 0, pos_y.size(0), reduce="sum") + y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce="sum") y = y / den @@ -77,8 +76,7 @@ def forward(self, x): x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange( - t, "b n (h d) -> b h n d", h=self.heads), qkv) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale @@ -97,8 +95,7 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim): for _ in range(depth): self.layers.append( nn.ModuleList( - [Attention(dim, heads=heads, dim_head=dim_head), - FeedForward(dim, mlp_dim)] + [Attention(dim, heads=heads, dim_head=dim_head), FeedForward(dim, mlp_dim)] ) ) @@ -110,11 +107,7 @@ def forward(self, x): class ImageMetaModel(nn.Module): - def __init__(self, *, - image_size, - patch_size, depth, - heads, mlp_dim, - channels=3, dim_head=64): + def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, dim_head=64): super().__init__() image_height, image_width = pair(image_size) patch_height, patch_width = pair(patch_size) @@ -127,8 +120,7 @@ def __init__(self, *, dim = patch_dim self.to_patch_embedding = nn.Sequential( Rearrange( - "b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", - p_h=patch_height, p_w=patch_width + "b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", p_h=patch_height, p_w=patch_width ), nn.LayerNorm(patch_dim), # TODO Do we need this? nn.Linear(patch_dim, dim), # TODO Do we need this? @@ -167,37 +159,49 @@ def forward(self, x): class MetaModel(nn.Module): - def __init__(self, lat_lons: list, *, - patch_size, depth, - heads, mlp_dim, - image_size=(721, 1440), - channels=3, dim_head=64): + def __init__( + self, + lat_lons: list, + *, + patch_size, + depth, + heads, + mlp_dim, + image_size=(721, 1440), + channels=3, + dim_head=64 + ): super().__init__() self.image_size = pair(image_size) self.pos_x = torch.tensor(lat_lons) self.pos_y = torch.cartesian_prod( - (torch.arange(-self.image_size[0]/2, - self.image_size[0]/2, 1)/self.image_size[0]*180).to(torch.long), - (torch.arange(0, self.image_size[1], 1)/self.image_size[1]*360).to(torch.long) + ( + torch.arange(-self.image_size[0] / 2, self.image_size[0] / 2, 1) + / self.image_size[0] + * 180 + ).to(torch.long), + (torch.arange(0, self.image_size[1], 1) / self.image_size[1] * 360).to(torch.long), ) - self.image_model = ImageMetaModel(image_size=image_size, - patch_size=patch_size, - depth=depth, - heads=heads, - mlp_dim=mlp_dim, - channels=channels, - dim_head=dim_head) + self.image_model = ImageMetaModel( + image_size=image_size, + patch_size=patch_size, + depth=depth, + heads=heads, + mlp_dim=mlp_dim, + channels=channels, + dim_head=dim_head, + ) def forward(self, x): b, n, c = x.shape x = rearrange(x, "b n c -> n (b c)") x = knn_interpolate(x, self.pos_x, self.pos_y) - x = rearrange(x, "(w h) (b c) -> b c w h", b=b, c=c, - w=self.image_size[0], - h=self.image_size[1]) + x = rearrange( + x, "(w h) (b c) -> b c w h", b=b, c=c, w=self.image_size[0], h=self.image_size[1] + ) x = self.image_model(x) x = rearrange(x, "b c w h -> (w h) (b c)") diff --git a/tests/test_model.py b/tests/test_model.py index f7600943..5959349b 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -10,7 +10,7 @@ Encoder, Processor, MetaModel, - ImageMetaModel + ImageMetaModel, ) from graph_weather.models.losses import NormalizedMSELoss from graph_weather.models.gencast.utils.noise import ( @@ -147,8 +147,7 @@ def test_assimilator_model(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): output_lat_lons.append((lat, lon)) - model = GraphWeatherAssimilator( - output_lat_lons=output_lat_lons, analysis_dim=24) + model = GraphWeatherAssimilator(output_lat_lons=output_lat_lons, analysis_dim=24) features = torch.randn((1, len(obs_lat_lons), 2)) lat_lon_heights = torch.tensor(obs_lat_lons) @@ -162,8 +161,7 @@ def test_forecaster_and_loss(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss( - lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) @@ -204,8 +202,7 @@ def test_forecaster_and_loss_grad_checkpoint(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss( - lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons, use_checkpointing=True) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) @@ -236,8 +233,7 @@ def test_normalized_loss(): assert not torch.isnan(loss) # Since feature_variance = out**2 and target = 0, we expect loss = weights - assert torch.isclose( - loss, criterion.weights.expand_as(out.mean(-1)).mean()) + assert torch.isclose(loss, criterion.weights.expand_as(out.mean(-1)).mean()) def test_image_meta_model(): @@ -246,16 +242,15 @@ def test_image_meta_model(): size = 4 patch_size = 2 image = torch.randn((batch, channels, size, size)) - model = ImageMetaModel(image_size=size, - patch_size=patch_size, - channels=channels, - depth=1, heads=1, mlp_dim=7 - ) + model = ImageMetaModel( + image_size=size, patch_size=patch_size, channels=channels, depth=1, heads=1, mlp_dim=7 + ) out = model(image) assert not torch.isnan(out).any() assert not torch.isnan(out).any() - assert out.size() == (batch, channels,size,size) + assert out.size() == (batch, channels, size, size) + def test_meta_model(): lat_lons = [] @@ -265,14 +260,20 @@ def test_meta_model(): batch = 2 channels = 3 - image_size=20 - patch_size=4 - model = MetaModel(lat_lons, - image_size=image_size, patch_size=patch_size, - depth=1, heads=1, mlp_dim=7, channels=channels) + image_size = 20 + patch_size = 4 + model = MetaModel( + lat_lons, + image_size=image_size, + patch_size=patch_size, + depth=1, + heads=1, + mlp_dim=7, + channels=channels, + ) features = torch.randn((batch, len(lat_lons), channels)) out = model(features) assert not torch.isnan(out).any() assert not torch.isnan(out).any() - assert out.size() == (batch, len(lat_lons), channels) + assert out.size() == (batch, len(lat_lons), channels) From 21d84c785ece635a6b753d2da1dc5d140c3bfca4 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Fri, 21 Jun 2024 17:58:13 +0200 Subject: [PATCH 09/45] wrapper meta model --- graph_weather/models/__init__.py | 2 +- graph_weather/models/fengwu_ghr/layers.py | 126 +++++++++++++++++----- tests/test_model.py | 76 +++++++++++-- 3 files changed, 171 insertions(+), 33 deletions(-) diff --git a/graph_weather/models/__init__.py b/graph_weather/models/__init__.py index 0083b16f..ace964db 100644 --- a/graph_weather/models/__init__.py +++ b/graph_weather/models/__init__.py @@ -1,6 +1,6 @@ """Models""" -from .fengwu_ghr.layers import ImageMetaModel, MetaModel +from .fengwu_ghr.layers import ImageMetaModel, MetaModel, WrapperImageModel, WrapperMetaModel from .layers.assimilator_decoder import AssimilatorDecoder from .layers.assimilator_encoder import AssimilatorEncoder from .layers.decoder import Decoder diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 06d15681..f5dbda57 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -76,7 +76,8 @@ def forward(self, x): x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) + q, k, v = map(lambda t: rearrange( + t, "b n (h d) -> b h n d", h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale @@ -95,7 +96,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim): for _ in range(depth): self.layers.append( nn.ModuleList( - [Attention(dim, heads=heads, dim_head=dim_head), FeedForward(dim, mlp_dim)] + [Attention(dim, heads=heads, dim_head=dim_head), + FeedForward(dim, mlp_dim)] ) ) @@ -107,20 +109,22 @@ def forward(self, x): class ImageMetaModel(nn.Module): - def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, dim_head=64): + def __init__(self, *, image_size, + patch_size, depth, heads, + mlp_dim, channels, dim_head): super().__init__() - image_height, image_width = pair(image_size) - patch_height, patch_width = pair(patch_size) + self.image_height, self.image_width = pair(image_size) + self.patch_height, self.patch_width = pair(patch_size) assert ( - image_height % patch_height == 0 and image_width % patch_width == 0 + self.image_height % self.patch_height == 0 and self.image_width % self.patch_width == 0 ), "Image dimensions must be divisible by the patch size." - patch_dim = channels * patch_height * patch_width + patch_dim = channels * self.patch_height * self.patch_width dim = patch_dim self.to_patch_embedding = nn.Sequential( Rearrange( - "b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", p_h=patch_height, p_w=patch_width + "b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", p_h=self.patch_height, p_w=self.patch_width ), nn.LayerNorm(patch_dim), # TODO Do we need this? nn.Linear(patch_dim, dim), # TODO Do we need this? @@ -128,8 +132,8 @@ def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, ) self.pos_embedding = posemb_sincos_2d( - h=image_height // patch_height, - w=image_width // patch_width, + h=self.image_height // self.patch_height, + w=self.image_width // self.patch_width, dim=dim, ) @@ -138,10 +142,10 @@ def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, self.reshaper = nn.Sequential( Rearrange( "b (h w) (p_h p_w c) -> b c (h p_h) (w p_w)", - h=image_height // patch_height, - w=image_width // patch_width, - p_h=patch_height, - p_w=patch_width, + h=self.image_height // self.patch_height, + w=self.image_width // self.patch_width, + p_h=self.patch_height, + p_w=self.patch_width, ) ) @@ -158,33 +162,53 @@ def forward(self, x): return x +class WrapperImageModel(nn.Module): + def __init__(self, image_meta_model: ImageMetaModel, + scale_factor): + super().__init__() + s_h, s_w = pair(scale_factor) + self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", + s_h=s_h, s_w=s_w) + self.image_meta_model = image_meta_model + self.debatcher = Rearrange(" (b s_h s_w) c h w -> b c (h s_h) (w s_w)", + s_h=s_h, s_w=s_w) + + def forward(self, x): + x = self.batcher(x) + x = self.image_meta_model(x) + x = self.debatcher(x) + return x + + class MetaModel(nn.Module): def __init__( self, lat_lons: list, *, + image_size, patch_size, depth, heads, mlp_dim, - image_size=(721, 1440), - channels=3, + channels, dim_head=64 ): super().__init__() - self.image_size = pair(image_size) + self.i_h, self.i_w = pair(image_size) self.pos_x = torch.tensor(lat_lons) self.pos_y = torch.cartesian_prod( ( - torch.arange(-self.image_size[0] / 2, self.image_size[0] / 2, 1) - / self.image_size[0] + torch.arange(-self.i_h / 2, + self.i_h / 2, 1) + / self.i_h * 180 ).to(torch.long), - (torch.arange(0, self.image_size[1], 1) / self.image_size[1] * 360).to(torch.long), + (torch.arange(0, self.i_w, 1) / + self.i_w * 360).to(torch.long), ) - self.image_model = ImageMetaModel( + self.image_meta_model = ImageMetaModel( image_size=image_size, patch_size=patch_size, depth=depth, @@ -200,11 +224,65 @@ def forward(self, x): x = rearrange(x, "b n c -> n (b c)") x = knn_interpolate(x, self.pos_x, self.pos_y) x = rearrange( - x, "(w h) (b c) -> b c w h", b=b, c=c, w=self.image_size[0], h=self.image_size[1] + x, "(h w) (b c) -> b c h w", b=b, c=c, + h=self.i_h, w=self.i_w ) - x = self.image_model(x) + x = self.image_meta_model(x) - x = rearrange(x, "b c w h -> (w h) (b c)") + x = rearrange(x, "b c h w -> (h w) (b c)") x = knn_interpolate(x, self.pos_y, self.pos_x) x = rearrange(x, "n (b c) -> b n c", b=b, c=c) return x + + +class WrapperMetaModel(nn.Module): + def __init__( + self, + lat_lons: list, + meta_model: MetaModel, + scale_factor + ): + super().__init__() + self.image_meta_model = meta_model.image_meta_model + + s_h, s_w = pair(scale_factor) + self.i_h, self.i_w = meta_model.i_h*s_h, meta_model.i_w*s_w + self.pos_x = torch.tensor(lat_lons) + self.pos_y = torch.cartesian_prod( + ( + torch.arange(-self.i_h / 2, + self.i_h / 2, 1) + / self.i_h + * 180 + ).to(torch.long), + (torch.arange(0, self.i_w, 1) / + self.i_w * 360).to(torch.long), + ) + + + self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", + s_h=s_h, s_w=s_w) + + self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", + s_h=s_h, s_w=s_w) + + def forward(self, x): + b, n, c = x.shape + + x = rearrange(x, "b n c -> n (b c)") + x = knn_interpolate(x, self.pos_x, self.pos_y) + x = rearrange( + x, "(h w) (b c) -> b c h w", b=b, c=c, + h=self.i_h, w=self.i_w + ) + + x = self.batcher(x) + x = self.image_meta_model(x) + x = self.debatcher(x) + + x = rearrange(x, "b c h w -> (h w) (b c)") + x = knn_interpolate(x, self.pos_y, self.pos_x) + x = rearrange(x, "n (b c) -> b n c", b=b, c=c) + + + return x diff --git a/tests/test_model.py b/tests/test_model.py index 5959349b..e19e831e 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -9,8 +9,10 @@ Decoder, Encoder, Processor, - MetaModel, ImageMetaModel, + MetaModel, + WrapperImageModel, + WrapperMetaModel ) from graph_weather.models.losses import NormalizedMSELoss from graph_weather.models.gencast.utils.noise import ( @@ -147,7 +149,8 @@ def test_assimilator_model(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): output_lat_lons.append((lat, lon)) - model = GraphWeatherAssimilator(output_lat_lons=output_lat_lons, analysis_dim=24) + model = GraphWeatherAssimilator( + output_lat_lons=output_lat_lons, analysis_dim=24) features = torch.randn((1, len(obs_lat_lons), 2)) lat_lon_heights = torch.tensor(obs_lat_lons) @@ -161,7 +164,8 @@ def test_forecaster_and_loss(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss( + lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) @@ -202,7 +206,8 @@ def test_forecaster_and_loss_grad_checkpoint(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss( + lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons, use_checkpointing=True) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) @@ -233,7 +238,8 @@ def test_normalized_loss(): assert not torch.isnan(loss) # Since feature_variance = out**2 and target = 0, we expect loss = weights - assert torch.isclose(loss, criterion.weights.expand_as(out.mean(-1)).mean()) + assert torch.isclose( + loss, criterion.weights.expand_as(out.mean(-1)).mean()) def test_image_meta_model(): @@ -243,13 +249,35 @@ def test_image_meta_model(): patch_size = 2 image = torch.randn((batch, channels, size, size)) model = ImageMetaModel( - image_size=size, patch_size=patch_size, channels=channels, depth=1, heads=1, mlp_dim=7 + image_size=size, patch_size=patch_size, + channels=channels, depth=1, heads=1, mlp_dim=7, + dim_head=64 ) out = model(image) assert not torch.isnan(out).any() assert not torch.isnan(out).any() - assert out.size() == (batch, channels, size, size) + assert out.size() == image.size() + + +def test_wrapper_image_meta_model(): + batch = 2 + channels = 3 + size = 4 + patch_size = 2 + model = ImageMetaModel( + image_size=size, patch_size=patch_size, + channels=channels, depth=1, heads=1, mlp_dim=7, + dim_head=64 + ) + scale_factor = 3 + big_image = torch.randn((batch, channels, + size*scale_factor, size*scale_factor)) + big_model = WrapperImageModel(model, scale_factor) + out = big_model(big_image) + assert not torch.isnan(out).any() + assert not torch.isnan(out).any() + assert out.size() == big_image.size() def test_meta_model(): @@ -270,10 +298,42 @@ def test_meta_model(): heads=1, mlp_dim=7, channels=channels, + dim_head=64 ) features = torch.randn((batch, len(lat_lons), channels)) out = model(features) assert not torch.isnan(out).any() assert not torch.isnan(out).any() - assert out.size() == (batch, len(lat_lons), channels) + assert out.size() == features.size() + + +def test_wrapper_meta_model(): + lat_lons = [] + for lat in range(-90, 90, 5): + for lon in range(0, 360, 5): + lat_lons.append((lat, lon)) + + batch = 2 + channels = 3 + image_size = 20 + patch_size = 4 + scale_factor=3 + model = MetaModel( + lat_lons, + image_size=image_size, + patch_size=patch_size, + depth=1, + heads=1, + mlp_dim=7, + channels=channels, + dim_head=64 + ) + + big_features = torch.randn((batch, len(lat_lons), channels)) + big_model = WrapperMetaModel(lat_lons, model, scale_factor) + out = big_model(big_features) + + assert not torch.isnan(out).any() + assert not torch.isnan(out).any() + assert out.size() == big_features.size() From 07a8d0f6078355cd47f19b0fa95ecdb6b2ae2be5 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Mon, 1 Jul 2024 16:30:29 +0200 Subject: [PATCH 10/45] RES --- graph_weather/models/fengwu_ghr/layers.py | 72 ++++++++++++++++++----- 1 file changed, 57 insertions(+), 15 deletions(-) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index f5dbda57..ad82fcad 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -89,33 +89,64 @@ def forward(self, x): class Transformer(nn.Module): - def __init__(self, dim, depth, heads, dim_head, mlp_dim): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, res=False, h=None, w=None, scale_factor=None): super().__init__() + self.depth = depth + self.res = res self.norm = nn.LayerNorm(dim) self.layers = nn.ModuleList([]) - for _ in range(depth): + self.res_layers = nn.ModuleList([]) + for _ in range(self.depth): self.layers.append( nn.ModuleList( [Attention(dim, heads=heads, dim_head=dim_head), FeedForward(dim, mlp_dim)] ) ) + if self.res: + assert h is not None and w is not None and scale_factor is not None, "If res=True, you must provide h, w and scale_factor" + s_h, s_w = pair(scale_factor) + self.res_layers.append( + nn.ModuleList( + [ # reshape to original shape window partition operation + # (b s_h s_w) (h w) d -> b (s_h h) (s_w w) d -> (b h w) (s_h s_w) d + Rearrange("(b s_h s_w) (h w) d -> (b h w) (s_h s_w) d", + h=h, w=w, s_h=s_h, s_w=s_w + ), + # TODO ????? + Attention(dim, heads=heads, dim_head=dim_head), + # restore shape + Rearrange("(b h w) (s_h s_w) d -> (b s_h s_w) (h w) d", + h=h, w=w, s_h=s_h, s_w=s_w + ), + ])) def forward(self, x): - for attn, ff in self.layers: + for i in range(self.depth): + attn, ff = self.layers[i] x = attn(x) + x x = ff(x) + x + if self.res: + reshape, loc_attn, restore = self.res_layers[i] + x = reshape(x) + x = loc_attn(x) + x + x = restore(x) return self.norm(x) class ImageMetaModel(nn.Module): def __init__(self, *, image_size, patch_size, depth, heads, - mlp_dim, channels, dim_head): + mlp_dim, channels, dim_head, + res=False, + scale_factor=None): super().__init__() self.image_height, self.image_width = pair(image_size) self.patch_height, self.patch_width = pair(patch_size) + s_h, s_w = pair(scale_factor) + if res: + assert scale_factor is not None, "If res=True, you must provide scale_factor" assert ( self.image_height % self.patch_height == 0 and self.image_width % self.patch_width == 0 ), "Image dimensions must be divisible by the patch size." @@ -137,7 +168,12 @@ def __init__(self, *, image_size, dim=dim, ) - self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, + res=res, + h=self.image_height // self.patch_height, + w=self.image_width // self.patch_width, + s_h=s_h, + s_w=s_w) self.reshaper = nn.Sequential( Rearrange( @@ -169,8 +205,13 @@ def __init__(self, image_meta_model: ImageMetaModel, s_h, s_w = pair(scale_factor) self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w) - self.image_meta_model = image_meta_model - self.debatcher = Rearrange(" (b s_h s_w) c h w -> b c (h s_h) (w s_w)", + + imm_args = image_meta_model.vars().update( + {"res": True, "scale_factor": scale_factor}) + self.image_meta_model = ImageMetaModel(**imm_args) + self.image_meta_model.load(image_meta_model, strict=False) + + self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w) def forward(self, x): @@ -224,7 +265,7 @@ def forward(self, x): x = rearrange(x, "b n c -> n (b c)") x = knn_interpolate(x, self.pos_x, self.pos_y) x = rearrange( - x, "(h w) (b c) -> b c h w", b=b, c=c, + x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w ) x = self.image_meta_model(x) @@ -243,8 +284,6 @@ def __init__( scale_factor ): super().__init__() - self.image_meta_model = meta_model.image_meta_model - s_h, s_w = pair(scale_factor) self.i_h, self.i_w = meta_model.i_h*s_h, meta_model.i_w*s_w self.pos_x = torch.tensor(lat_lons) @@ -259,23 +298,27 @@ def __init__( self.i_w * 360).to(torch.long), ) - self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w) + + imm_args = meta_model.image_meta_model.vars().update( + {"res": True, "scale_factor": scale_factor}) + self.image_meta_model = ImageMetaModel(**imm_args) + self.image_meta_model.load(meta_model.image_meta_model, strict=False) self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w) def forward(self, x): b, n, c = x.shape - + x = rearrange(x, "b n c -> n (b c)") x = knn_interpolate(x, self.pos_x, self.pos_y) x = rearrange( - x, "(h w) (b c) -> b c h w", b=b, c=c, + x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w ) - + x = self.batcher(x) x = self.image_meta_model(x) x = self.debatcher(x) @@ -283,6 +326,5 @@ def forward(self, x): x = rearrange(x, "b c h w -> (h w) (b c)") x = knn_interpolate(x, self.pos_y, self.pos_x) x = rearrange(x, "n (b c) -> b n c", b=b, c=c) - return x From cd84968524f41ff654a16bcd1587ab2a1c59ae38 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Tue, 2 Jul 2024 11:15:32 +0200 Subject: [PATCH 11/45] load RES state_dict --- graph_weather/models/fengwu_ghr/layers.py | 45 ++++++++++++++++------- 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index ad82fcad..e42d7545 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -89,7 +89,7 @@ def forward(self, x): class Transformer(nn.Module): - def __init__(self, dim, depth, heads, dim_head, mlp_dim, res=False, h=None, w=None, scale_factor=None): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, res=False, image_size=None, scale_factor=None): super().__init__() self.depth = depth self.res = res @@ -104,7 +104,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, res=False, h=None, w=No ) ) if self.res: - assert h is not None and w is not None and scale_factor is not None, "If res=True, you must provide h, w and scale_factor" + assert image_size is not None and scale_factor is not None, "If res=True, you must provide h, w and scale_factor" + h, w = pair(image_size) s_h, s_w = pair(scale_factor) self.res_layers.append( nn.ModuleList( @@ -139,8 +140,20 @@ def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels, dim_head, res=False, - scale_factor=None): + scale_factor=None, + **kwargs): super().__init__() + #TODO this can probably be done better + self.image_size = image_size + self.patch_size = patch_size + self.depth = depth + self.heads = heads + self.mlp_dim = mlp_dim + self.channels = channels + self.dim_head = dim_head + self.res = res + self.scale_factor = scale_factor + self.image_height, self.image_width = pair(image_size) self.patch_height, self.patch_width = pair(patch_size) s_h, s_w = pair(scale_factor) @@ -170,10 +183,12 @@ def __init__(self, *, image_size, self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, res=res, - h=self.image_height // self.patch_height, - w=self.image_width // self.patch_width, - s_h=s_h, - s_w=s_w) + image_size=( + self.image_height // self.patch_height, + self.image_width // self.patch_width), + scale_factor=( + s_h, + s_w)) self.reshaper = nn.Sequential( Rearrange( @@ -205,12 +220,13 @@ def __init__(self, image_meta_model: ImageMetaModel, s_h, s_w = pair(scale_factor) self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w) - - imm_args = image_meta_model.vars().update( + + imm_args = vars(image_meta_model) + imm_args.update( {"res": True, "scale_factor": scale_factor}) self.image_meta_model = ImageMetaModel(**imm_args) - self.image_meta_model.load(image_meta_model, strict=False) - + self.image_meta_model.load_state_dict(image_meta_model.state_dict(), strict=False) + self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w) @@ -301,11 +317,12 @@ def __init__( self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w) - imm_args = meta_model.image_meta_model.vars().update( + imm_args = vars(meta_model.image_meta_model) + imm_args.update( {"res": True, "scale_factor": scale_factor}) self.image_meta_model = ImageMetaModel(**imm_args) - self.image_meta_model.load(meta_model.image_meta_model, strict=False) - + self.image_meta_model.load_state_dict(meta_model.image_meta_model.state_dict(), strict=False) + self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w) From b15110f19b253e22a27abdaa7b8db066d2ff0703 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Jul 2024 09:25:20 +0000 Subject: [PATCH 12/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- graph_weather/models/fengwu_ghr/layers.py | 150 +++++++++++----------- tests/test_model.py | 43 ++++--- 2 files changed, 98 insertions(+), 95 deletions(-) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index f03235e3..bbe19a85 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -76,8 +76,7 @@ def forward(self, x): x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange( - t, "b n (h d) -> b h n d", h=self.heads), qkv) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale @@ -89,7 +88,9 @@ def forward(self, x): class Transformer(nn.Module): - def __init__(self, dim, depth, heads, dim_head, mlp_dim, res=False, image_size=None, scale_factor=None): + def __init__( + self, dim, depth, heads, dim_head, mlp_dim, res=False, image_size=None, scale_factor=None + ): super().__init__() self.depth = depth self.res = res @@ -99,28 +100,39 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, res=False, image_size=N for _ in range(self.depth): self.layers.append( nn.ModuleList( - [Attention(dim, heads=heads, dim_head=dim_head), - FeedForward(dim, mlp_dim)] + [Attention(dim, heads=heads, dim_head=dim_head), FeedForward(dim, mlp_dim)] ) ) if self.res: - assert image_size is not None and scale_factor is not None, "If res=True, you must provide h, w and scale_factor" + assert ( + image_size is not None and scale_factor is not None + ), "If res=True, you must provide h, w and scale_factor" h, w = pair(image_size) s_h, s_w = pair(scale_factor) self.res_layers.append( nn.ModuleList( [ # reshape to original shape window partition operation # (b s_h s_w) (h w) d -> b (s_h h) (s_w w) d -> (b h w) (s_h s_w) d - Rearrange("(b s_h s_w) (h w) d -> (b h w) (s_h s_w) d", - h=h, w=w, s_h=s_h, s_w=s_w - ), + Rearrange( + "(b s_h s_w) (h w) d -> (b h w) (s_h s_w) d", + h=h, + w=w, + s_h=s_h, + s_w=s_w, + ), # TODO ????? Attention(dim, heads=heads, dim_head=dim_head), # restore shape - Rearrange("(b h w) (s_h s_w) d -> (b s_h s_w) (h w) d", - h=h, w=w, s_h=s_h, s_w=s_w - ), - ])) + Rearrange( + "(b h w) (s_h s_w) d -> (b s_h s_w) (h w) d", + h=h, + w=w, + s_h=s_h, + s_w=s_w, + ), + ] + ) + ) def forward(self, x): for i in range(self.depth): @@ -136,14 +148,22 @@ def forward(self, x): class ImageMetaModel(nn.Module): - def __init__(self, *, image_size, - patch_size, depth, heads, - mlp_dim, channels, dim_head, - res=False, - scale_factor=None, - **kwargs): + def __init__( + self, + *, + image_size, + patch_size, + depth, + heads, + mlp_dim, + channels, + dim_head, + res=False, + scale_factor=None, + **kwargs + ): super().__init__() - #TODO this can probably be done better + # TODO this can probably be done better self.image_size = image_size self.patch_size = patch_size self.depth = depth @@ -168,7 +188,9 @@ def __init__(self, *, image_size, dim = patch_dim self.to_patch_embedding = nn.Sequential( Rearrange( - "b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", p_h=self.patch_height, p_w=self.patch_width + "b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", + p_h=self.patch_height, + p_w=self.patch_width, ), nn.LayerNorm(patch_dim), # TODO Do we need this? nn.Linear(patch_dim, dim), # TODO Do we need this? @@ -181,14 +203,19 @@ def __init__(self, *, image_size, dim=dim, ) - self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, - res=res, - image_size=( - self.image_height // self.patch_height, - self.image_width // self.patch_width), - scale_factor=( - s_h, - s_w)) + self.transformer = Transformer( + dim, + depth, + heads, + dim_head, + mlp_dim, + res=res, + image_size=( + self.image_height // self.patch_height, + self.image_width // self.patch_width, + ), + scale_factor=(s_h, s_w), + ) self.reshaper = nn.Sequential( Rearrange( @@ -203,6 +230,7 @@ def __init__(self, *, image_size, def forward(self, x): device = x.device dtype = x.dtype + def forward(self, x): device = x.device dtype = x.dtype @@ -219,21 +247,17 @@ def forward(self, x): class WrapperImageModel(nn.Module): - def __init__(self, image_meta_model: ImageMetaModel, - scale_factor): + def __init__(self, image_meta_model: ImageMetaModel, scale_factor): super().__init__() s_h, s_w = pair(scale_factor) - self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", - s_h=s_h, s_w=s_w) + self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w) imm_args = vars(image_meta_model) - imm_args.update( - {"res": True, "scale_factor": scale_factor}) + imm_args.update({"res": True, "scale_factor": scale_factor}) self.image_meta_model = ImageMetaModel(**imm_args) self.image_meta_model.load_state_dict(image_meta_model.state_dict(), strict=False) - self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", - s_h=s_h, s_w=s_w) + self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w) def forward(self, x): x = self.batcher(x) @@ -260,14 +284,8 @@ def __init__( self.pos_x = torch.tensor(lat_lons) self.pos_y = torch.cartesian_prod( - ( - torch.arange(-self.i_h / 2, - self.i_h / 2, 1) - / self.i_h - * 180 - ).to(torch.long), - (torch.arange(0, self.i_w, 1) / - self.i_w * 360).to(torch.long), + (torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long), + (torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long), ) self.image_meta_model = ImageMetaModel( @@ -285,10 +303,7 @@ def forward(self, x): x = rearrange(x, "b n c -> n (b c)") x = knn_interpolate(x, self.pos_x, self.pos_y) - x = rearrange( - x, "(h w) (b c) -> b c h w", b=b, c=c, - h=self.i_h, w=self.i_w - ) + x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w) x = self.image_meta_model(x) x = rearrange(x, "b c h w -> (h w) (b c)") @@ -298,48 +313,33 @@ def forward(self, x): class WrapperMetaModel(nn.Module): - def __init__( - self, - lat_lons: list, - meta_model: MetaModel, - scale_factor - ): + def __init__(self, lat_lons: list, meta_model: MetaModel, scale_factor): super().__init__() s_h, s_w = pair(scale_factor) - self.i_h, self.i_w = meta_model.i_h*s_h, meta_model.i_w*s_w + self.i_h, self.i_w = meta_model.i_h * s_h, meta_model.i_w * s_w self.pos_x = torch.tensor(lat_lons) self.pos_y = torch.cartesian_prod( - ( - torch.arange(-self.i_h / 2, - self.i_h / 2, 1) - / self.i_h - * 180 - ).to(torch.long), - (torch.arange(0, self.i_w, 1) / - self.i_w * 360).to(torch.long), + (torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long), + (torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long), ) - self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", - s_h=s_h, s_w=s_w) + self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w) imm_args = vars(meta_model.image_meta_model) - imm_args.update( - {"res": True, "scale_factor": scale_factor}) + imm_args.update({"res": True, "scale_factor": scale_factor}) self.image_meta_model = ImageMetaModel(**imm_args) - self.image_meta_model.load_state_dict(meta_model.image_meta_model.state_dict(), strict=False) + self.image_meta_model.load_state_dict( + meta_model.image_meta_model.state_dict(), strict=False + ) - self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", - s_h=s_h, s_w=s_w) + self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w) def forward(self, x): b, n, c = x.shape x = rearrange(x, "b n c -> n (b c)") x = knn_interpolate(x, self.pos_x, self.pos_y) - x = rearrange( - x, "(h w) (b c) -> b c h w", b=b, c=c, - h=self.i_h, w=self.i_w - ) + x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w) x = self.batcher(x) x = self.image_meta_model(x) diff --git a/tests/test_model.py b/tests/test_model.py index ac260b48..8f52116b 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -12,7 +12,7 @@ ImageMetaModel, MetaModel, WrapperImageModel, - WrapperMetaModel + WrapperMetaModel, ) from graph_weather.models.losses import NormalizedMSELoss @@ -151,8 +151,7 @@ def test_assimilator_model(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): output_lat_lons.append((lat, lon)) - model = GraphWeatherAssimilator( - output_lat_lons=output_lat_lons, analysis_dim=24) + model = GraphWeatherAssimilator(output_lat_lons=output_lat_lons, analysis_dim=24) features = torch.randn((1, len(obs_lat_lons), 2)) lat_lon_heights = torch.tensor(obs_lat_lons) @@ -166,8 +165,7 @@ def test_forecaster_and_loss(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss( - lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) @@ -208,8 +206,7 @@ def test_forecaster_and_loss_grad_checkpoint(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss( - lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons, use_checkpointing=True) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) @@ -240,8 +237,7 @@ def test_normalized_loss(): assert not torch.isnan(loss) # Since feature_variance = out**2 and target = 0, we expect loss = weights - assert torch.isclose( - loss, criterion.weights.expand_as(out.mean(-1)).mean()) + assert torch.isclose(loss, criterion.weights.expand_as(out.mean(-1)).mean()) def test_image_meta_model(): @@ -251,9 +247,13 @@ def test_image_meta_model(): patch_size = 2 image = torch.randn((batch, channels, size, size)) model = ImageMetaModel( - image_size=size, patch_size=patch_size, - channels=channels, depth=1, heads=1, mlp_dim=7, - dim_head=64 + image_size=size, + patch_size=patch_size, + channels=channels, + depth=1, + heads=1, + mlp_dim=7, + dim_head=64, ) out = model(image) @@ -268,13 +268,16 @@ def test_wrapper_image_meta_model(): size = 4 patch_size = 2 model = ImageMetaModel( - image_size=size, patch_size=patch_size, - channels=channels, depth=1, heads=1, mlp_dim=7, - dim_head=64 + image_size=size, + patch_size=patch_size, + channels=channels, + depth=1, + heads=1, + mlp_dim=7, + dim_head=64, ) scale_factor = 3 - big_image = torch.randn((batch, channels, - size*scale_factor, size*scale_factor)) + big_image = torch.randn((batch, channels, size * scale_factor, size * scale_factor)) big_model = WrapperImageModel(model, scale_factor) out = big_model(big_image) assert not torch.isnan(out).any() @@ -300,7 +303,7 @@ def test_meta_model(): heads=1, mlp_dim=7, channels=channels, - dim_head=64 + dim_head=64, ) features = torch.randn((batch, len(lat_lons), channels)) @@ -320,7 +323,7 @@ def test_wrapper_meta_model(): channels = 3 image_size = 20 patch_size = 4 - scale_factor=3 + scale_factor = 3 model = MetaModel( lat_lons, image_size=image_size, @@ -329,7 +332,7 @@ def test_wrapper_meta_model(): heads=1, mlp_dim=7, channels=channels, - dim_head=64 + dim_head=64, ) big_features = torch.randn((batch, len(lat_lons), channels)) From 1146db9de02985298ba0539f3a7c661e759fe31e Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Tue, 2 Jul 2024 14:14:11 +0200 Subject: [PATCH 13/45] bug fix --- graph_weather/models/fengwu_ghr/layers.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index f03235e3..e42d7545 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -200,15 +200,10 @@ def __init__(self, *, image_size, ) ) - def forward(self, x): - device = x.device - dtype = x.dtype def forward(self, x): device = x.device dtype = x.dtype - x = self.to_patch_embedding(x) - x += self.pos_embedding.to(device, dtype=dtype) x = self.to_patch_embedding(x) x += self.pos_embedding.to(device, dtype=dtype) From 325fd0e9b57f03a7d225e37d6bf7bd0de79b4038 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Tue, 2 Jul 2024 15:46:45 +0200 Subject: [PATCH 14/45] bug fix --- tests/test_model.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index c7b24bee..e5604fce 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -358,4 +358,36 @@ def test_gencast_loss(): preds = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim)) noise_levels = torch.rand((batch_size, 1)) targets = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim)) - assert loss.forward(preds, targets, noise_levels) is not None + assert loss.forward(preds, noise_levels, targets) is not None + + +def test_gencast_denoiser(): + grid_lat = np.arange(-90, 90, 1) + grid_lon = np.arange(0, 360, 1) + input_features_dim = 10 + output_features_dim = 5 + batch_size = 3 + + denoiser = Denoiser( + grid_lon=grid_lon, + grid_lat=grid_lat, + input_features_dim=input_features_dim, + output_features_dim=output_features_dim, + hidden_dims=[16, 32], + num_blocks=3, + num_heads=4, + splits=0, + num_hops=1, + device=torch.device("cpu"), + ).eval() + + corrupted_targets = torch.randn((batch_size, len(grid_lon), len(grid_lat), output_features_dim)) + prev_inputs = torch.randn((batch_size, len(grid_lon), len(grid_lat), 2 * input_features_dim)) + noise_levels = torch.randn((batch_size, 1)) + + with torch.no_grad(): + preds = denoiser( + corrupted_targets=corrupted_targets, prev_inputs=prev_inputs, noise_levels=noise_levels + ) + + assert not torch.isnan(preds).any() \ No newline at end of file From cfa9c3f7e073e07e796dc6de6cc1a26670af2d27 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Jul 2024 13:48:05 +0000 Subject: [PATCH 15/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index e5604fce..9b43cb16 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -390,4 +390,4 @@ def test_gencast_denoiser(): corrupted_targets=corrupted_targets, prev_inputs=prev_inputs, noise_levels=noise_levels ) - assert not torch.isnan(preds).any() \ No newline at end of file + assert not torch.isnan(preds).any() From 2fadf974fe5087d974ed4caa1fd5c1d3bab9afe3 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Mon, 29 Jul 2024 12:28:26 +0200 Subject: [PATCH 16/45] env yml fix --- environment_cuda.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/environment_cuda.yml b/environment_cuda.yml index 9f76251a..3d06cbb9 100644 --- a/environment_cuda.yml +++ b/environment_cuda.yml @@ -6,12 +6,12 @@ channels: - conda-forge - defaults dependencies: - - pytorch-cuda=12.1 + - pytorch-cuda - numcodecs - pandas - pip - pyg - - python=3.12 + - python - pytorch - pytorch-cluster - pytorch-scatter From 257b353fa30906629a215b49f6f3d38bc760febe Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Wed, 29 May 2024 10:58:50 +0200 Subject: [PATCH 17/45] fengwu_ghr: initial fengwu_ghr: fixes [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Interpolate initial ImageMetaModel MetaModel initial tested metamodel [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci wrapper meta model RES load RES state_dict bug fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci env yml fix --- environment_cuda.yml | 4 +- graph_weather/models/fengwu_ghr/layers.py | 109 ++++++++++++++++++++++ 2 files changed, 111 insertions(+), 2 deletions(-) diff --git a/environment_cuda.yml b/environment_cuda.yml index 9f76251a..3d06cbb9 100644 --- a/environment_cuda.yml +++ b/environment_cuda.yml @@ -6,12 +6,12 @@ channels: - conda-forge - defaults dependencies: - - pytorch-cuda=12.1 + - pytorch-cuda - numcodecs - pandas - pip - pyg - - python=3.12 + - python - pytorch - pytorch-cluster - pytorch-scatter diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index d129d2dd..2d032ab8 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -1,3 +1,6 @@ +from scipy.interpolate import griddata +from torch_geometric.nn import knn +from torch_geometric.utils import scatter import torch from einops import rearrange from einops.layers.torch import Rearrange @@ -12,6 +15,22 @@ def pair(t): return t if isinstance(t, tuple) else (t, t) +def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, + k: int = 4, num_workers: int = 1): + with torch.no_grad(): + assign_index = knn(pos_x, pos_y, k, num_workers=num_workers) + y_idx, x_idx = assign_index[0], assign_index[1] + diff = pos_x[x_idx] - pos_y[y_idx] + squared_distance = (diff * diff).sum(dim=-1, keepdim=True) + weights = 1.0 / torch.clamp(squared_distance, min=1e-16) + + den = scatter(weights, y_idx, 0, pos_y.size(0), reduce="sum") + y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce="sum") + + y = y / den + + return y + def knn_interpolate( x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, k: int = 4, num_workers: int = 1 ): @@ -344,3 +363,93 @@ def forward(self, x): x = rearrange(x, "n (b c) -> b n c", b=b, c=c) return x + +class MetaModel(nn.Module): + def __init__( + self, + lat_lons: list, + *, + image_size, + patch_size, + depth, + heads, + mlp_dim, + channels, + dim_head=64 + ): + super().__init__() + self.i_h, self.i_w = pair(image_size) + + self.pos_x = torch.tensor(lat_lons) + self.pos_y = torch.cartesian_prod( + (torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long), + (torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long), + ) + + self.image_meta_model = ImageMetaModel( + image_size=image_size, + patch_size=patch_size, + depth=depth, + heads=heads, + mlp_dim=mlp_dim, + channels=channels, + dim_head=dim_head, + ) + + def forward(self, x): + b, n, c = x.shape + + x = rearrange(x, "b n c -> n (b c)") + x = knn_interpolate(x, self.pos_x, self.pos_y) + x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w) + x = self.image_meta_model(x) + + x = rearrange(x, "b c h w -> (h w) (b c)") + x = knn_interpolate(x, self.pos_y, self.pos_x) + x = rearrange(x, "n (b c) -> b n c", b=b, c=c) + return x + + +class WrapperMetaModel(nn.Module): + def __init__( + self, + lat_lons: list, + meta_model: MetaModel, + scale_factor + ): + super().__init__() + s_h, s_w = pair(scale_factor) + self.i_h, self.i_w = meta_model.i_h * s_h, meta_model.i_w * s_w + self.pos_x = torch.tensor(lat_lons) + self.pos_y = torch.cartesian_prod( + (torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long), + (torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long), + ) + + self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w) + + imm_args = vars(meta_model.image_meta_model) + imm_args.update({"res": True, "scale_factor": scale_factor}) + self.image_meta_model = ImageMetaModel(**imm_args) + self.image_meta_model.load_state_dict( + meta_model.image_meta_model.state_dict(), strict=False + ) + + self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w) + + def forward(self, x): + b, n, c = x.shape + + x = rearrange(x, "b n c -> n (b c)") + x = knn_interpolate(x, self.pos_x, self.pos_y) + x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w) + + x = self.batcher(x) + x = self.image_meta_model(x) + x = self.debatcher(x) + + x = rearrange(x, "b c h w -> (h w) (b c)") + x = knn_interpolate(x, self.pos_y, self.pos_x) + x = rearrange(x, "n (b c) -> b n c", b=b, c=c) + + return x From f72c6103a7c98434767dea1a6c2d747539c535e5 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Mon, 29 Jul 2024 16:09:52 +0200 Subject: [PATCH 18/45] test_wrapper_meta_model --- tests/test_model.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/test_model.py b/tests/test_model.py index 90da16b7..857f5c05 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -315,6 +315,35 @@ def test_meta_model(): assert out.size() == features.size() +def test_wrapper_meta_model(): + lat_lons = [] + for lat in range(-90, 90, 5): + for lon in range(0, 360, 5): + lat_lons.append((lat, lon)) + + batch = 2 + channels = 3 + image_size = 20 + patch_size = 4 + scale_factor=3 + model = MetaModel( + lat_lons, + image_size=image_size, + patch_size=patch_size, + depth=1, + heads=1, + mlp_dim=7, + channels=channels, + dim_head=64 + ) + + big_features = torch.randn((batch, len(lat_lons), channels)) + big_model = WrapperMetaModel(lat_lons, model, scale_factor) + out = big_model(big_features) + + assert not torch.isnan(out).any() + + def test_gencast_noise(): num_lon = 360 num_lat = 180 From 8a8ac64b84234cf9889f817c574c3c6f1c8b14ea Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Tue, 30 Jul 2024 14:22:53 +0200 Subject: [PATCH 19/45] tests fix --- tests/test_model.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index 857f5c05..bfbc1503 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -260,7 +260,6 @@ def test_image_meta_model(): out = model(image) assert not torch.isnan(out).any() - assert not torch.isnan(out).any() assert out.size() == image.size() @@ -283,7 +282,6 @@ def test_wrapper_image_meta_model(): big_model = WrapperImageModel(model, scale_factor) out = big_model(big_image) assert not torch.isnan(out).any() - assert not torch.isnan(out).any() assert out.size() == big_image.size() @@ -311,7 +309,6 @@ def test_meta_model(): out = model(features) assert not torch.isnan(out).any() - assert not torch.isnan(out).any() assert out.size() == features.size() @@ -342,7 +339,7 @@ def test_wrapper_meta_model(): out = big_model(big_features) assert not torch.isnan(out).any() - + assert out.size() == big_features.size() def test_gencast_noise(): num_lon = 360 From 6f0c61b6a20ca1e8723aa5ab184ec08d114cad8c Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Wed, 29 May 2024 10:58:50 +0200 Subject: [PATCH 20/45] parent 743cf9704eea03353e02351cde52add7233437d6 author Lorenzo Breschi 1716973130 +0200 committer Lorenzo Breschi 1722343516 +0200 fengwu_ghr: initial fengwu_ghr: fixes [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Interpolate initial ImageMetaModel MetaModel initial tested metamodel [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci wrapper meta model RES load RES state_dict bug fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci env yml fix fengwu_ghr: initial [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Interpolate initial ImageMetaModel MetaModel initial tested metamodel [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci wrapper meta model RES load RES state_dict bug fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci env yml fix test_wrapper_meta_model tests fix --- environment_cuda.yml | 4 +- graph_weather/models/fengwu_ghr/layers.py | 222 ++++++++++++++++++++++ tests/test_model.py | 32 +++- 3 files changed, 253 insertions(+), 5 deletions(-) diff --git a/environment_cuda.yml b/environment_cuda.yml index 9f76251a..3d06cbb9 100644 --- a/environment_cuda.yml +++ b/environment_cuda.yml @@ -6,12 +6,12 @@ channels: - conda-forge - defaults dependencies: - - pytorch-cuda=12.1 + - pytorch-cuda - numcodecs - pandas - pip - pyg - - python=3.12 + - python - pytorch - pytorch-cluster - pytorch-scatter diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index d129d2dd..97a6684f 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -1,3 +1,6 @@ +from scipy.interpolate import griddata +from torch_geometric.nn import knn +from torch_geometric.utils import scatter import torch from einops import rearrange from einops.layers.torch import Rearrange @@ -12,6 +15,39 @@ def pair(t): return t if isinstance(t, tuple) else (t, t) +def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, + k: int = 4, num_workers: int = 1): + with torch.no_grad(): + assign_index = knn(pos_x, pos_y, k, + num_workers=num_workers) + y_idx, x_idx = assign_index[0], assign_index[1] + diff = pos_x[x_idx] - pos_y[y_idx] + squared_distance = (diff * diff).sum(dim=-1, keepdim=True) + weights = 1.0 / torch.clamp(squared_distance, min=1e-16) + + den = scatter(weights, y_idx, 0, pos_y.size(0), reduce="sum") + y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce="sum") + + y = y / den + + return y + +def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, + k: int = 4, num_workers: int = 1): + with torch.no_grad(): + assign_index = knn(pos_x, pos_y, k, num_workers=num_workers) + y_idx, x_idx = assign_index[0], assign_index[1] + diff = pos_x[x_idx] - pos_y[y_idx] + squared_distance = (diff * diff).sum(dim=-1, keepdim=True) + weights = 1.0 / torch.clamp(squared_distance, min=1e-16) + + den = scatter(weights, y_idx, 0, pos_y.size(0), reduce="sum") + y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce="sum") + + y = y / den + + return y + def knn_interpolate( x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, k: int = 4, num_workers: int = 1 ): @@ -344,3 +380,189 @@ def forward(self, x): x = rearrange(x, "n (b c) -> b n c", b=b, c=c) return x + +class MetaModel(nn.Module): + def __init__( + self, + lat_lons: list, + *, + image_size, + patch_size, + depth, + heads, + mlp_dim, + channels, + dim_head=64 + ): + super().__init__() + self.i_h, self.i_w = pair(image_size) + + self.pos_x = torch.tensor(lat_lons) + self.pos_y = torch.cartesian_prod( + (torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long), + (torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long), + ) + + self.image_meta_model = ImageMetaModel( + image_size=image_size, + patch_size=patch_size, + depth=depth, + heads=heads, + mlp_dim=mlp_dim, + channels=channels, + dim_head=dim_head, + ) + + def forward(self, x): + b, n, c = x.shape + + x = rearrange(x, "b n c -> n (b c)") + x = knn_interpolate(x, self.pos_x, self.pos_y) + x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w) + x = self.image_meta_model(x) + + x = rearrange(x, "b c h w -> (h w) (b c)") + x = knn_interpolate(x, self.pos_y, self.pos_x) + x = rearrange(x, "n (b c) -> b n c", b=b, c=c) + return x + + +class WrapperMetaModel(nn.Module): + def __init__( + self, + lat_lons: list, + meta_model: MetaModel, + scale_factor + ): + super().__init__() + s_h, s_w = pair(scale_factor) + self.i_h, self.i_w = meta_model.i_h * s_h, meta_model.i_w * s_w + self.pos_x = torch.tensor(lat_lons) + self.pos_y = torch.cartesian_prod( + (torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long), + (torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long), + ) + + self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w) + + imm_args = vars(meta_model.image_meta_model) + imm_args.update({"res": True, "scale_factor": scale_factor}) + self.image_meta_model = ImageMetaModel(**imm_args) + self.image_meta_model.load_state_dict( + meta_model.image_meta_model.state_dict(), strict=False + ) + + self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w) + + def forward(self, x): + b, n, c = x.shape + + x = rearrange(x, "b n c -> n (b c)") + x = knn_interpolate(x, self.pos_x, self.pos_y) + x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w) + + x = self.batcher(x) + x = self.image_meta_model(x) + x = self.debatcher(x) + + x = rearrange(x, "b c h w -> (h w) (b c)") + x = knn_interpolate(x, self.pos_y, self.pos_x) + x = rearrange(x, "n (b c) -> b n c", b=b, c=c) + + return x + +class MetaModel(nn.Module): + def __init__( + self, + lat_lons: list, + *, + image_size, + patch_size, + depth, + heads, + mlp_dim, + channels, + dim_head=64 + ): + super().__init__() + self.i_h, self.i_w = pair(image_size) + + self.pos_x = torch.tensor(lat_lons) + self.pos_y = torch.cartesian_prod( + (torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long), + (torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long), + ) + + self.image_meta_model = ImageMetaModel( + image_size=image_size, + patch_size=patch_size, + depth=depth, + heads=heads, + mlp_dim=mlp_dim, + channels=channels, + dim_head=dim_head, + ) + + def forward(self, x): + b, n, c = x.shape + + x = rearrange(x, "b n c -> n (b c)") + x = knn_interpolate(x, self.pos_x, self.pos_y) + x = rearrange( + x, "(h w) (b c) -> b c h w", b=b, c=c, + h=self.i_h, w=self.i_w + ) + x = self.image_meta_model(x) + + x = rearrange(x, "b c h w -> (h w) (b c)") + x = knn_interpolate(x, self.pos_y, self.pos_x) + x = rearrange(x, "n (b c) -> b n c", b=b, c=c) + return x + + +class WrapperMetaModel(nn.Module): + def __init__( + self, + lat_lons: list, + meta_model: MetaModel, + scale_factor + ): + super().__init__() + s_h, s_w = pair(scale_factor) + self.i_h, self.i_w = meta_model.i_h * s_h, meta_model.i_w * s_w + self.pos_x = torch.tensor(lat_lons) + self.pos_y = torch.cartesian_prod( + (torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long), + (torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long), + ) + + self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w) + + imm_args = vars(meta_model.image_meta_model) + imm_args.update({"res": True, "scale_factor": scale_factor}) + self.image_meta_model = ImageMetaModel(**imm_args) + self.image_meta_model.load_state_dict( + meta_model.image_meta_model.state_dict(), strict=False + ) + + self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w) + + def forward(self, x): + b, n, c = x.shape + + x = rearrange(x, "b n c -> n (b c)") + x = knn_interpolate(x, self.pos_x, self.pos_y) + x = rearrange( + x, "(h w) (b c) -> b c h w", b=b, c=c, + h=self.i_h, w=self.i_w + ) + + x = self.batcher(x) + x = self.image_meta_model(x) + x = self.debatcher(x) + + x = rearrange(x, "b c h w -> (h w) (b c)") + x = knn_interpolate(x, self.pos_y, self.pos_x) + x = rearrange(x, "n (b c) -> b n c", b=b, c=c) + + return x diff --git a/tests/test_model.py b/tests/test_model.py index 90da16b7..bfbc1503 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -260,7 +260,6 @@ def test_image_meta_model(): out = model(image) assert not torch.isnan(out).any() - assert not torch.isnan(out).any() assert out.size() == image.size() @@ -283,7 +282,6 @@ def test_wrapper_image_meta_model(): big_model = WrapperImageModel(model, scale_factor) out = big_model(big_image) assert not torch.isnan(out).any() - assert not torch.isnan(out).any() assert out.size() == big_image.size() @@ -311,10 +309,38 @@ def test_meta_model(): out = model(features) assert not torch.isnan(out).any() - assert not torch.isnan(out).any() assert out.size() == features.size() +def test_wrapper_meta_model(): + lat_lons = [] + for lat in range(-90, 90, 5): + for lon in range(0, 360, 5): + lat_lons.append((lat, lon)) + + batch = 2 + channels = 3 + image_size = 20 + patch_size = 4 + scale_factor=3 + model = MetaModel( + lat_lons, + image_size=image_size, + patch_size=patch_size, + depth=1, + heads=1, + mlp_dim=7, + channels=channels, + dim_head=64 + ) + + big_features = torch.randn((batch, len(lat_lons), channels)) + big_model = WrapperMetaModel(lat_lons, model, scale_factor) + out = big_model(big_features) + + assert not torch.isnan(out).any() + assert out.size() == big_features.size() + def test_gencast_noise(): num_lon = 360 num_lat = 180 From a855c6aebf5823414bcd28562f99d6d54f5ea216 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Wed, 29 May 2024 10:58:50 +0200 Subject: [PATCH 21/45] fengwu_ghr: initial fengwu_ghr: fixes [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Interpolate initial ImageMetaModel MetaModel initial tested metamodel [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci wrapper meta model RES load RES state_dict bug fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci env yml fix --- graph_weather/models/fengwu_ghr/layers.py | 116 +--------------------- 1 file changed, 3 insertions(+), 113 deletions(-) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 97a6684f..6b5264f1 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -1,6 +1,9 @@ from scipy.interpolate import griddata from torch_geometric.nn import knn from torch_geometric.utils import scatter +from scipy.interpolate import griddata +from torch_geometric.nn import knn +from torch_geometric.utils import scatter import torch from einops import rearrange from einops.layers.torch import Rearrange @@ -15,23 +18,6 @@ def pair(t): return t if isinstance(t, tuple) else (t, t) -def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, - k: int = 4, num_workers: int = 1): - with torch.no_grad(): - assign_index = knn(pos_x, pos_y, k, - num_workers=num_workers) - y_idx, x_idx = assign_index[0], assign_index[1] - diff = pos_x[x_idx] - pos_y[y_idx] - squared_distance = (diff * diff).sum(dim=-1, keepdim=True) - weights = 1.0 / torch.clamp(squared_distance, min=1e-16) - - den = scatter(weights, y_idx, 0, pos_y.size(0), reduce="sum") - y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce="sum") - - y = y / den - - return y - def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, k: int = 4, num_workers: int = 1): with torch.no_grad(): @@ -470,99 +456,3 @@ def forward(self, x): x = rearrange(x, "n (b c) -> b n c", b=b, c=c) return x - -class MetaModel(nn.Module): - def __init__( - self, - lat_lons: list, - *, - image_size, - patch_size, - depth, - heads, - mlp_dim, - channels, - dim_head=64 - ): - super().__init__() - self.i_h, self.i_w = pair(image_size) - - self.pos_x = torch.tensor(lat_lons) - self.pos_y = torch.cartesian_prod( - (torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long), - (torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long), - ) - - self.image_meta_model = ImageMetaModel( - image_size=image_size, - patch_size=patch_size, - depth=depth, - heads=heads, - mlp_dim=mlp_dim, - channels=channels, - dim_head=dim_head, - ) - - def forward(self, x): - b, n, c = x.shape - - x = rearrange(x, "b n c -> n (b c)") - x = knn_interpolate(x, self.pos_x, self.pos_y) - x = rearrange( - x, "(h w) (b c) -> b c h w", b=b, c=c, - h=self.i_h, w=self.i_w - ) - x = self.image_meta_model(x) - - x = rearrange(x, "b c h w -> (h w) (b c)") - x = knn_interpolate(x, self.pos_y, self.pos_x) - x = rearrange(x, "n (b c) -> b n c", b=b, c=c) - return x - - -class WrapperMetaModel(nn.Module): - def __init__( - self, - lat_lons: list, - meta_model: MetaModel, - scale_factor - ): - super().__init__() - s_h, s_w = pair(scale_factor) - self.i_h, self.i_w = meta_model.i_h * s_h, meta_model.i_w * s_w - self.pos_x = torch.tensor(lat_lons) - self.pos_y = torch.cartesian_prod( - (torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long), - (torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long), - ) - - self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w) - - imm_args = vars(meta_model.image_meta_model) - imm_args.update({"res": True, "scale_factor": scale_factor}) - self.image_meta_model = ImageMetaModel(**imm_args) - self.image_meta_model.load_state_dict( - meta_model.image_meta_model.state_dict(), strict=False - ) - - self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w) - - def forward(self, x): - b, n, c = x.shape - - x = rearrange(x, "b n c -> n (b c)") - x = knn_interpolate(x, self.pos_x, self.pos_y) - x = rearrange( - x, "(h w) (b c) -> b c h w", b=b, c=c, - h=self.i_h, w=self.i_w - ) - - x = self.batcher(x) - x = self.image_meta_model(x) - x = self.debatcher(x) - - x = rearrange(x, "b c h w -> (h w) (b c)") - x = knn_interpolate(x, self.pos_y, self.pos_x) - x = rearrange(x, "n (b c) -> b n c", b=b, c=c) - - return x From 47d0d481788767583315615d68117a314b30c0e1 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Wed, 29 May 2024 10:58:50 +0200 Subject: [PATCH 22/45] fengwu_ghr: initial --- graph_weather/models/__init__.py | 1 + graph_weather/models/fengwu_ghr/layers.py | 107 ---------------------- 2 files changed, 1 insertion(+), 107 deletions(-) diff --git a/graph_weather/models/__init__.py b/graph_weather/models/__init__.py index ace964db..ceacc486 100644 --- a/graph_weather/models/__init__.py +++ b/graph_weather/models/__init__.py @@ -6,3 +6,4 @@ from .layers.decoder import Decoder from .layers.encoder import Encoder from .layers.processor import Processor +from .fengwu_ghr.layers import MetaModel diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 6b5264f1..8ba70d90 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -1,9 +1,3 @@ -from scipy.interpolate import griddata -from torch_geometric.nn import knn -from torch_geometric.utils import scatter -from scipy.interpolate import griddata -from torch_geometric.nn import knn -from torch_geometric.utils import scatter import torch from einops import rearrange from einops.layers.torch import Rearrange @@ -18,22 +12,6 @@ def pair(t): return t if isinstance(t, tuple) else (t, t) -def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, - k: int = 4, num_workers: int = 1): - with torch.no_grad(): - assign_index = knn(pos_x, pos_y, k, num_workers=num_workers) - y_idx, x_idx = assign_index[0], assign_index[1] - diff = pos_x[x_idx] - pos_y[y_idx] - squared_distance = (diff * diff).sum(dim=-1, keepdim=True) - weights = 1.0 / torch.clamp(squared_distance, min=1e-16) - - den = scatter(weights, y_idx, 0, pos_y.size(0), reduce="sum") - y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce="sum") - - y = y / den - - return y - def knn_interpolate( x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, k: int = 4, num_workers: int = 1 ): @@ -328,91 +306,6 @@ def forward(self, x): return x -class WrapperMetaModel(nn.Module): - def __init__(self, lat_lons: list, meta_model: MetaModel, scale_factor): - super().__init__() - s_h, s_w = pair(scale_factor) - self.i_h, self.i_w = meta_model.i_h * s_h, meta_model.i_w * s_w - self.pos_x = torch.tensor(lat_lons) - self.pos_y = torch.cartesian_prod( - (torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long), - (torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long), - ) - - self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w) - - imm_args = vars(meta_model.image_meta_model) - imm_args.update({"res": True, "scale_factor": scale_factor}) - self.image_meta_model = ImageMetaModel(**imm_args) - self.image_meta_model.load_state_dict( - meta_model.image_meta_model.state_dict(), strict=False - ) - - self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w) - - def forward(self, x): - b, n, c = x.shape - - x = rearrange(x, "b n c -> n (b c)") - x = knn_interpolate(x, self.pos_x, self.pos_y) - x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w) - - x = self.batcher(x) - x = self.image_meta_model(x) - x = self.debatcher(x) - - x = rearrange(x, "b c h w -> (h w) (b c)") - x = knn_interpolate(x, self.pos_y, self.pos_x) - x = rearrange(x, "n (b c) -> b n c", b=b, c=c) - - return x - -class MetaModel(nn.Module): - def __init__( - self, - lat_lons: list, - *, - image_size, - patch_size, - depth, - heads, - mlp_dim, - channels, - dim_head=64 - ): - super().__init__() - self.i_h, self.i_w = pair(image_size) - - self.pos_x = torch.tensor(lat_lons) - self.pos_y = torch.cartesian_prod( - (torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long), - (torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long), - ) - - self.image_meta_model = ImageMetaModel( - image_size=image_size, - patch_size=patch_size, - depth=depth, - heads=heads, - mlp_dim=mlp_dim, - channels=channels, - dim_head=dim_head, - ) - - def forward(self, x): - b, n, c = x.shape - - x = rearrange(x, "b n c -> n (b c)") - x = knn_interpolate(x, self.pos_x, self.pos_y) - x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w) - x = self.image_meta_model(x) - - x = rearrange(x, "b c h w -> (h w) (b c)") - x = knn_interpolate(x, self.pos_y, self.pos_x) - x = rearrange(x, "n (b c) -> b n c", b=b, c=c) - return x - - class WrapperMetaModel(nn.Module): def __init__( self, From 7a1d562ea29fbe2b8d1b2caf46f74f3536d94169 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 May 2024 09:11:13 +0000 Subject: [PATCH 23/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- graph_weather/models/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/graph_weather/models/__init__.py b/graph_weather/models/__init__.py index ceacc486..ace964db 100644 --- a/graph_weather/models/__init__.py +++ b/graph_weather/models/__init__.py @@ -6,4 +6,3 @@ from .layers.decoder import Decoder from .layers.encoder import Encoder from .layers.processor import Processor -from .fengwu_ghr.layers import MetaModel From e397941bb51912130e1ee75d0df8a77265102940 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Thu, 6 Jun 2024 15:07:10 +0200 Subject: [PATCH 24/45] Interpolate initial --- graph_weather/models/fengwu_ghr/layers.py | 5 ++++- tests/test_model.py | 9 ++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 8ba70d90..29eb606f 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -1,3 +1,5 @@ +import numpy as np +from scipy.interpolate import griddata, interpn import torch from einops import rearrange from einops.layers.torch import Rearrange @@ -76,7 +78,8 @@ def forward(self, x): x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) + q, k, v = map(lambda t: rearrange( + t, "b n (h d) -> b h n d", h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale diff --git a/tests/test_model.py b/tests/test_model.py index bfbc1503..38795d36 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -153,7 +153,8 @@ def test_assimilator_model(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): output_lat_lons.append((lat, lon)) - model = GraphWeatherAssimilator(output_lat_lons=output_lat_lons, analysis_dim=24) + model = GraphWeatherAssimilator( + output_lat_lons=output_lat_lons, analysis_dim=24) features = torch.randn((1, len(obs_lat_lons), 2)) lat_lon_heights = torch.tensor(obs_lat_lons) @@ -167,7 +168,8 @@ def test_forecaster_and_loss(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss( + lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) @@ -208,7 +210,8 @@ def test_forecaster_and_loss_grad_checkpoint(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss( + lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons, use_checkpointing=True) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) From 80a73eed5960597e5972f041818d0132fcef9530 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Tue, 11 Jun 2024 15:26:54 +0200 Subject: [PATCH 25/45] ImageMetaModel --- graph_weather/models/fengwu_ghr/layers.py | 50 +++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 29eb606f..879dc008 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -13,6 +13,31 @@ def pair(t): return t if isinstance(t, tuple) else (t, t) +from torch_geometric.nn import knn +from torch_geometric.utils import scatter + + +def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, + k: int = 3, num_workers: int = 1): + with torch.no_grad(): + assign_index = knn(pos_x, pos_y, k, + num_workers=num_workers) + y_idx, x_idx = assign_index[0], assign_index[1] + diff = pos_x[x_idx] - pos_y[y_idx] + squared_distance = (diff * diff).sum(dim=-1, keepdim=True) + weights = 1.0 / torch.clamp(squared_distance, min=1e-16) + + + # print((x[x_idx]*weights).shape) + # print(weights.shape) + den = scatter(weights, y_idx, 0, pos_y.size(0), reduce='sum') + # print(den.shape) + y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce='sum') + + + y = y / den + + return y def knn_interpolate( x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, k: int = 4, num_workers: int = 1 @@ -352,3 +377,28 @@ def forward(self, x): x = rearrange(x, "n (b c) -> b n c", b=b, c=c) return x + +class MetaModel(nn.Module): + def __init__(self, lat_lons: list, *, + patch_size, depth, + heads, mlp_dim, + resolution=(721, 1440), + channels=3, dim_head=64, + interp_method='cubic'): + super().__init__() + resolution = pair(resolution) + b=3 + n=len(lat_lons) + d=7 + x=torch.randn((b,n,d)) + x=rearrange(x,"b n d -> n (b d)") + + pos_x= torch.tensor(lat_lons) + pos_y = torch.cartesian_prod( + torch.arange(0.5,resolution[0],1), + torch.arange(0.5,resolution[1],1) + ) + x = knn_interpolate(x,pos_x,pos_y) + x = rearrange(x,"m (b d) -> b m d", b=b,d=d) + print(x.shape) + From 127d8ff77fd855072b3f47e8efe61c321341fe26 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Tue, 11 Jun 2024 15:59:59 +0200 Subject: [PATCH 26/45] MetaModel initial --- graph_weather/models/fengwu_ghr/layers.py | 91 ++++++++++++++--------- 1 file changed, 56 insertions(+), 35 deletions(-) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 879dc008..887ff6e3 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -1,3 +1,6 @@ +from scipy.interpolate import griddata +from torch_geometric.nn import knn +from torch_geometric.utils import scatter import numpy as np from scipy.interpolate import griddata, interpn import torch @@ -13,9 +16,6 @@ def pair(t): return t if isinstance(t, tuple) else (t, t) -from torch_geometric.nn import knn -from torch_geometric.utils import scatter - def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, k: int = 3, num_workers: int = 1): @@ -27,36 +27,16 @@ def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, squared_distance = (diff * diff).sum(dim=-1, keepdim=True) weights = 1.0 / torch.clamp(squared_distance, min=1e-16) - # print((x[x_idx]*weights).shape) # print(weights.shape) den = scatter(weights, y_idx, 0, pos_y.size(0), reduce='sum') # print(den.shape) y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce='sum') - - - y = y / den - - return y - -def knn_interpolate( - x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, k: int = 4, num_workers: int = 1 -): - with torch.no_grad(): - assign_index = knn(pos_x, pos_y, k, num_workers=num_workers) - y_idx, x_idx = assign_index[0], assign_index[1] - diff = pos_x[x_idx] - pos_y[y_idx] - squared_distance = (diff * diff).sum(dim=-1, keepdim=True) - weights = 1.0 / torch.clamp(squared_distance, min=1e-16) - - den = scatter(weights, y_idx, 0, pos_y.size(0), reduce="sum") - y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce="sum") y = y / den return y - def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32): y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" @@ -379,6 +359,48 @@ def forward(self, x): return x class MetaModel(nn.Module): + def __init__(self, lat_lons: list, *, + patch_size, depth, + heads, mlp_dim, + resolution=(721, 1440), + channels=3, dim_head=64, + interp_method='cubic'): + super().__init__() + self.resolution = pair(resolution) + + self.pos_x = torch.tensor(lat_lons) + self.pos_y = torch.cartesian_prod( + torch.arange(0, self.resolution[0], 1), + torch.arange(0, self.resolution[1], 1) + ) + + self.image_model = ImageMetaModel(image_size=resolution, + patch_size=patch_size, + depth=depth, + heads=heads, + mlp_dim=mlp_dim, + channels=channels, + dim_head=dim_head) + + def forward(self, x): + b, n, c = x.shape + + x = rearrange(x, "b n c -> n (b c)") + x = knn_interpolate(x, self.pos_x, self.pos_y) + x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, + w=self.resolution[0], + h=self.resolution[1]) + + x = self.image_model(x) + + x = rearrange(x, "b c h w -> (h w) (b c)") + x = knn_interpolate(x, self.pos_y, self.pos_x) + x = rearrange(x, "n (b c) -> b n c", b=b, c=c) + + return x + + +class MetaModel2(nn.Module): def __init__(self, lat_lons: list, *, patch_size, depth, heads, mlp_dim, @@ -387,18 +409,17 @@ def __init__(self, lat_lons: list, *, interp_method='cubic'): super().__init__() resolution = pair(resolution) - b=3 - n=len(lat_lons) - d=7 - x=torch.randn((b,n,d)) - x=rearrange(x,"b n d -> n (b d)") - - pos_x= torch.tensor(lat_lons) + b = 3 + n = len(lat_lons) + d = 7 + x = torch.randn((b, n, d)) + x = rearrange(x, "b n d -> n (b d)") + + pos_x = torch.tensor(lat_lons) pos_y = torch.cartesian_prod( - torch.arange(0.5,resolution[0],1), - torch.arange(0.5,resolution[1],1) + torch.arange(0, resolution[0], 1), + torch.arange(0, resolution[1], 1) ) - x = knn_interpolate(x,pos_x,pos_y) - x = rearrange(x,"m (b d) -> b m d", b=b,d=d) + x = knn_interpolate(x, pos_x, pos_y) + x = rearrange(x, "m (b d) -> b m d", b=b, d=d) print(x.shape) - From b59e54d2d271fe2f2c0d9657e5c684b4820546f4 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Fri, 14 Jun 2024 16:44:08 +0200 Subject: [PATCH 27/45] tested metamodel --- graph_weather/models/fengwu_ghr/layers.py | 56 +++++------------------ 1 file changed, 12 insertions(+), 44 deletions(-) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 887ff6e3..a789eb84 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -1,8 +1,6 @@ from scipy.interpolate import griddata from torch_geometric.nn import knn from torch_geometric.utils import scatter -import numpy as np -from scipy.interpolate import griddata, interpn import torch from einops import rearrange from einops.layers.torch import Rearrange @@ -18,7 +16,7 @@ def pair(t): def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, - k: int = 3, num_workers: int = 1): + k: int = 4, num_workers: int = 1): with torch.no_grad(): assign_index = knn(pos_x, pos_y, k, num_workers=num_workers) @@ -27,10 +25,7 @@ def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, squared_distance = (diff * diff).sum(dim=-1, keepdim=True) weights = 1.0 / torch.clamp(squared_distance, min=1e-16) - # print((x[x_idx]*weights).shape) - # print(weights.shape) den = scatter(weights, y_idx, 0, pos_y.size(0), reduce='sum') - # print(den.shape) y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce='sum') y = y / den @@ -362,19 +357,19 @@ class MetaModel(nn.Module): def __init__(self, lat_lons: list, *, patch_size, depth, heads, mlp_dim, - resolution=(721, 1440), - channels=3, dim_head=64, - interp_method='cubic'): + image_size=(721, 1440), + channels=3, dim_head=64): super().__init__() - self.resolution = pair(resolution) + self.image_size = pair(image_size) self.pos_x = torch.tensor(lat_lons) self.pos_y = torch.cartesian_prod( - torch.arange(0, self.resolution[0], 1), - torch.arange(0, self.resolution[1], 1) + (torch.arange(-self.image_size[0]/2, + self.image_size[0]/2, 1)/self.image_size[0]*180).to(torch.long), + (torch.arange(0, self.image_size[1], 1)/self.image_size[1]*360).to(torch.long) ) - self.image_model = ImageMetaModel(image_size=resolution, + self.image_model = ImageMetaModel(image_size=image_size, patch_size=patch_size, depth=depth, heads=heads, @@ -387,39 +382,12 @@ def forward(self, x): x = rearrange(x, "b n c -> n (b c)") x = knn_interpolate(x, self.pos_x, self.pos_y) - x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, - w=self.resolution[0], - h=self.resolution[1]) - + x = rearrange(x, "(w h) (b c) -> b c w h", b=b, c=c, + w=self.image_size[0], + h=self.image_size[1]) x = self.image_model(x) - x = rearrange(x, "b c h w -> (h w) (b c)") + x = rearrange(x, "b c w h -> (w h) (b c)") x = knn_interpolate(x, self.pos_y, self.pos_x) x = rearrange(x, "n (b c) -> b n c", b=b, c=c) - return x - - -class MetaModel2(nn.Module): - def __init__(self, lat_lons: list, *, - patch_size, depth, - heads, mlp_dim, - resolution=(721, 1440), - channels=3, dim_head=64, - interp_method='cubic'): - super().__init__() - resolution = pair(resolution) - b = 3 - n = len(lat_lons) - d = 7 - x = torch.randn((b, n, d)) - x = rearrange(x, "b n d -> n (b d)") - - pos_x = torch.tensor(lat_lons) - pos_y = torch.cartesian_prod( - torch.arange(0, resolution[0], 1), - torch.arange(0, resolution[1], 1) - ) - x = knn_interpolate(x, pos_x, pos_y) - x = rearrange(x, "m (b d) -> b m d", b=b, d=d) - print(x.shape) From 19e73a2527f8deff092cd78a679b451d8929cfdd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 17 Jun 2024 17:16:17 +0000 Subject: [PATCH 28/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- graph_weather/models/fengwu_ghr/layers.py | 66 +++++++++++++---------- tests/test_model.py | 9 ++-- 2 files changed, 40 insertions(+), 35 deletions(-) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index a789eb84..45fdfa6a 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -1,6 +1,3 @@ -from scipy.interpolate import griddata -from torch_geometric.nn import knn -from torch_geometric.utils import scatter import torch from einops import rearrange from einops.layers.torch import Rearrange @@ -15,18 +12,18 @@ def pair(t): return t if isinstance(t, tuple) else (t, t) -def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, - k: int = 4, num_workers: int = 1): +def knn_interpolate( + x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, k: int = 4, num_workers: int = 1 +): with torch.no_grad(): - assign_index = knn(pos_x, pos_y, k, - num_workers=num_workers) + assign_index = knn(pos_x, pos_y, k, num_workers=num_workers) y_idx, x_idx = assign_index[0], assign_index[1] diff = pos_x[x_idx] - pos_y[y_idx] squared_distance = (diff * diff).sum(dim=-1, keepdim=True) weights = 1.0 / torch.clamp(squared_distance, min=1e-16) - den = scatter(weights, y_idx, 0, pos_y.size(0), reduce='sum') - y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce='sum') + den = scatter(weights, y_idx, 0, pos_y.size(0), reduce="sum") + y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce="sum") y = y / den @@ -78,8 +75,7 @@ def forward(self, x): x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange( - t, "b n (h d) -> b h n d", h=self.heads), qkv) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale @@ -354,37 +350,49 @@ def forward(self, x): return x class MetaModel(nn.Module): - def __init__(self, lat_lons: list, *, - patch_size, depth, - heads, mlp_dim, - image_size=(721, 1440), - channels=3, dim_head=64): + def __init__( + self, + lat_lons: list, + *, + patch_size, + depth, + heads, + mlp_dim, + image_size=(721, 1440), + channels=3, + dim_head=64 + ): super().__init__() self.image_size = pair(image_size) self.pos_x = torch.tensor(lat_lons) self.pos_y = torch.cartesian_prod( - (torch.arange(-self.image_size[0]/2, - self.image_size[0]/2, 1)/self.image_size[0]*180).to(torch.long), - (torch.arange(0, self.image_size[1], 1)/self.image_size[1]*360).to(torch.long) + ( + torch.arange(-self.image_size[0] / 2, self.image_size[0] / 2, 1) + / self.image_size[0] + * 180 + ).to(torch.long), + (torch.arange(0, self.image_size[1], 1) / self.image_size[1] * 360).to(torch.long), ) - self.image_model = ImageMetaModel(image_size=image_size, - patch_size=patch_size, - depth=depth, - heads=heads, - mlp_dim=mlp_dim, - channels=channels, - dim_head=dim_head) + self.image_model = ImageMetaModel( + image_size=image_size, + patch_size=patch_size, + depth=depth, + heads=heads, + mlp_dim=mlp_dim, + channels=channels, + dim_head=dim_head, + ) def forward(self, x): b, n, c = x.shape x = rearrange(x, "b n c -> n (b c)") x = knn_interpolate(x, self.pos_x, self.pos_y) - x = rearrange(x, "(w h) (b c) -> b c w h", b=b, c=c, - w=self.image_size[0], - h=self.image_size[1]) + x = rearrange( + x, "(w h) (b c) -> b c w h", b=b, c=c, w=self.image_size[0], h=self.image_size[1] + ) x = self.image_model(x) x = rearrange(x, "b c w h -> (w h) (b c)") diff --git a/tests/test_model.py b/tests/test_model.py index 38795d36..bfbc1503 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -153,8 +153,7 @@ def test_assimilator_model(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): output_lat_lons.append((lat, lon)) - model = GraphWeatherAssimilator( - output_lat_lons=output_lat_lons, analysis_dim=24) + model = GraphWeatherAssimilator(output_lat_lons=output_lat_lons, analysis_dim=24) features = torch.randn((1, len(obs_lat_lons), 2)) lat_lon_heights = torch.tensor(obs_lat_lons) @@ -168,8 +167,7 @@ def test_forecaster_and_loss(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss( - lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) @@ -210,8 +208,7 @@ def test_forecaster_and_loss_grad_checkpoint(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss( - lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons, use_checkpointing=True) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) From e54016f98bf3e2eaa5451d10edbce7d843914559 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Fri, 21 Jun 2024 17:58:13 +0200 Subject: [PATCH 29/45] wrapper meta model --- graph_weather/models/fengwu_ghr/layers.py | 79 +++++++++++++++++++---- tests/test_model.py | 12 ++-- 2 files changed, 76 insertions(+), 15 deletions(-) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 45fdfa6a..ae9f1e87 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -75,7 +75,8 @@ def forward(self, x): x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) + q, k, v = map(lambda t: rearrange( + t, "b n (h d) -> b h n d", h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale @@ -354,28 +355,30 @@ def __init__( self, lat_lons: list, *, + image_size, patch_size, depth, heads, mlp_dim, - image_size=(721, 1440), - channels=3, + channels, dim_head=64 ): super().__init__() - self.image_size = pair(image_size) + self.i_h, self.i_w = pair(image_size) self.pos_x = torch.tensor(lat_lons) self.pos_y = torch.cartesian_prod( ( - torch.arange(-self.image_size[0] / 2, self.image_size[0] / 2, 1) - / self.image_size[0] + torch.arange(-self.i_h / 2, + self.i_h / 2, 1) + / self.i_h * 180 ).to(torch.long), - (torch.arange(0, self.image_size[1], 1) / self.image_size[1] * 360).to(torch.long), + (torch.arange(0, self.i_w, 1) / + self.i_w * 360).to(torch.long), ) - self.image_model = ImageMetaModel( + self.image_meta_model = ImageMetaModel( image_size=image_size, patch_size=patch_size, depth=depth, @@ -391,11 +394,65 @@ def forward(self, x): x = rearrange(x, "b n c -> n (b c)") x = knn_interpolate(x, self.pos_x, self.pos_y) x = rearrange( - x, "(w h) (b c) -> b c w h", b=b, c=c, w=self.image_size[0], h=self.image_size[1] + x, "(h w) (b c) -> b c h w", b=b, c=c, + h=self.i_h, w=self.i_w ) - x = self.image_model(x) + x = self.image_meta_model(x) - x = rearrange(x, "b c w h -> (w h) (b c)") + x = rearrange(x, "b c h w -> (h w) (b c)") x = knn_interpolate(x, self.pos_y, self.pos_x) x = rearrange(x, "n (b c) -> b n c", b=b, c=c) return x + + +class WrapperMetaModel(nn.Module): + def __init__( + self, + lat_lons: list, + meta_model: MetaModel, + scale_factor + ): + super().__init__() + self.image_meta_model = meta_model.image_meta_model + + s_h, s_w = pair(scale_factor) + self.i_h, self.i_w = meta_model.i_h*s_h, meta_model.i_w*s_w + self.pos_x = torch.tensor(lat_lons) + self.pos_y = torch.cartesian_prod( + ( + torch.arange(-self.i_h / 2, + self.i_h / 2, 1) + / self.i_h + * 180 + ).to(torch.long), + (torch.arange(0, self.i_w, 1) / + self.i_w * 360).to(torch.long), + ) + + + self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", + s_h=s_h, s_w=s_w) + + self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", + s_h=s_h, s_w=s_w) + + def forward(self, x): + b, n, c = x.shape + + x = rearrange(x, "b n c -> n (b c)") + x = knn_interpolate(x, self.pos_x, self.pos_y) + x = rearrange( + x, "(h w) (b c) -> b c h w", b=b, c=c, + h=self.i_h, w=self.i_w + ) + + x = self.batcher(x) + x = self.image_meta_model(x) + x = self.debatcher(x) + + x = rearrange(x, "b c h w -> (h w) (b c)") + x = knn_interpolate(x, self.pos_y, self.pos_x) + x = rearrange(x, "n (b c) -> b n c", b=b, c=c) + + + return x diff --git a/tests/test_model.py b/tests/test_model.py index bfbc1503..18f7d071 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -153,7 +153,8 @@ def test_assimilator_model(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): output_lat_lons.append((lat, lon)) - model = GraphWeatherAssimilator(output_lat_lons=output_lat_lons, analysis_dim=24) + model = GraphWeatherAssimilator( + output_lat_lons=output_lat_lons, analysis_dim=24) features = torch.randn((1, len(obs_lat_lons), 2)) lat_lon_heights = torch.tensor(obs_lat_lons) @@ -167,7 +168,8 @@ def test_forecaster_and_loss(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss( + lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) @@ -208,7 +210,8 @@ def test_forecaster_and_loss_grad_checkpoint(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss( + lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons, use_checkpointing=True) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) @@ -239,7 +242,8 @@ def test_normalized_loss(): assert not torch.isnan(loss) # Since feature_variance = out**2 and target = 0, we expect loss = weights - assert torch.isclose(loss, criterion.weights.expand_as(out.mean(-1)).mean()) + assert torch.isclose( + loss, criterion.weights.expand_as(out.mean(-1)).mean()) def test_image_meta_model(): From dae738e9a3e8fa7a6dd99236f01220d17104bde1 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Mon, 1 Jul 2024 16:30:29 +0200 Subject: [PATCH 30/45] RES --- graph_weather/models/fengwu_ghr/layers.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index ae9f1e87..4a7e1eb2 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -394,7 +394,7 @@ def forward(self, x): x = rearrange(x, "b n c -> n (b c)") x = knn_interpolate(x, self.pos_x, self.pos_y) x = rearrange( - x, "(h w) (b c) -> b c h w", b=b, c=c, + x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w ) x = self.image_meta_model(x) @@ -413,8 +413,6 @@ def __init__( scale_factor ): super().__init__() - self.image_meta_model = meta_model.image_meta_model - s_h, s_w = pair(scale_factor) self.i_h, self.i_w = meta_model.i_h*s_h, meta_model.i_w*s_w self.pos_x = torch.tensor(lat_lons) @@ -429,23 +427,27 @@ def __init__( self.i_w * 360).to(torch.long), ) - self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w) + + imm_args = meta_model.image_meta_model.vars().update( + {"res": True, "scale_factor": scale_factor}) + self.image_meta_model = ImageMetaModel(**imm_args) + self.image_meta_model.load(meta_model.image_meta_model, strict=False) self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w) def forward(self, x): b, n, c = x.shape - + x = rearrange(x, "b n c -> n (b c)") x = knn_interpolate(x, self.pos_x, self.pos_y) x = rearrange( - x, "(h w) (b c) -> b c h w", b=b, c=c, + x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w ) - + x = self.batcher(x) x = self.image_meta_model(x) x = self.debatcher(x) @@ -453,6 +455,5 @@ def forward(self, x): x = rearrange(x, "b c h w -> (h w) (b c)") x = knn_interpolate(x, self.pos_y, self.pos_x) x = rearrange(x, "n (b c) -> b n c", b=b, c=c) - return x From 92acbee2ba63121b18f4b16a76500f42e385d65d Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Tue, 2 Jul 2024 11:15:32 +0200 Subject: [PATCH 31/45] load RES state_dict --- graph_weather/models/fengwu_ghr/layers.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 4a7e1eb2..aa770252 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -104,9 +104,7 @@ def __init__( ) ) if self.res: - assert ( - image_size is not None and scale_factor is not None - ), "If res=True, you must provide h, w and scale_factor" + assert image_size is not None and scale_factor is not None, "If res=True, you must provide h, w and scale_factor" h, w = pair(image_size) s_h, s_w = pair(scale_factor) self.res_layers.append( @@ -430,11 +428,12 @@ def __init__( self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w) - imm_args = meta_model.image_meta_model.vars().update( + imm_args = vars(meta_model.image_meta_model) + imm_args.update( {"res": True, "scale_factor": scale_factor}) self.image_meta_model = ImageMetaModel(**imm_args) - self.image_meta_model.load(meta_model.image_meta_model, strict=False) - + self.image_meta_model.load_state_dict(meta_model.image_meta_model.state_dict(), strict=False) + self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w) From 5115dd47851d90324a6f9389567d7f08919584e1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Jul 2024 09:25:20 +0000 Subject: [PATCH 32/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- graph_weather/models/fengwu_ghr/layers.py | 59 +++++++---------------- tests/test_model.py | 16 +++--- 2 files changed, 24 insertions(+), 51 deletions(-) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index aa770252..3954cea4 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -75,8 +75,7 @@ def forward(self, x): x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange( - t, "b n (h d) -> b h n d", h=self.heads), qkv) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale @@ -104,7 +103,9 @@ def __init__( ) ) if self.res: - assert image_size is not None and scale_factor is not None, "If res=True, you must provide h, w and scale_factor" + assert ( + image_size is not None and scale_factor is not None + ), "If res=True, you must provide h, w and scale_factor" h, w = pair(image_size) s_h, s_w = pair(scale_factor) self.res_layers.append( @@ -366,14 +367,8 @@ def __init__( self.pos_x = torch.tensor(lat_lons) self.pos_y = torch.cartesian_prod( - ( - torch.arange(-self.i_h / 2, - self.i_h / 2, 1) - / self.i_h - * 180 - ).to(torch.long), - (torch.arange(0, self.i_w, 1) / - self.i_w * 360).to(torch.long), + (torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long), + (torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long), ) self.image_meta_model = ImageMetaModel( @@ -391,10 +386,7 @@ def forward(self, x): x = rearrange(x, "b n c -> n (b c)") x = knn_interpolate(x, self.pos_x, self.pos_y) - x = rearrange( - x, "(h w) (b c) -> b c h w", b=b, c=c, - h=self.i_h, w=self.i_w - ) + x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w) x = self.image_meta_model(x) x = rearrange(x, "b c h w -> (h w) (b c)") @@ -404,48 +396,33 @@ def forward(self, x): class WrapperMetaModel(nn.Module): - def __init__( - self, - lat_lons: list, - meta_model: MetaModel, - scale_factor - ): + def __init__(self, lat_lons: list, meta_model: MetaModel, scale_factor): super().__init__() s_h, s_w = pair(scale_factor) - self.i_h, self.i_w = meta_model.i_h*s_h, meta_model.i_w*s_w + self.i_h, self.i_w = meta_model.i_h * s_h, meta_model.i_w * s_w self.pos_x = torch.tensor(lat_lons) self.pos_y = torch.cartesian_prod( - ( - torch.arange(-self.i_h / 2, - self.i_h / 2, 1) - / self.i_h - * 180 - ).to(torch.long), - (torch.arange(0, self.i_w, 1) / - self.i_w * 360).to(torch.long), + (torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long), + (torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long), ) - self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", - s_h=s_h, s_w=s_w) + self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w) imm_args = vars(meta_model.image_meta_model) - imm_args.update( - {"res": True, "scale_factor": scale_factor}) + imm_args.update({"res": True, "scale_factor": scale_factor}) self.image_meta_model = ImageMetaModel(**imm_args) - self.image_meta_model.load_state_dict(meta_model.image_meta_model.state_dict(), strict=False) + self.image_meta_model.load_state_dict( + meta_model.image_meta_model.state_dict(), strict=False + ) - self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", - s_h=s_h, s_w=s_w) + self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w) def forward(self, x): b, n, c = x.shape x = rearrange(x, "b n c -> n (b c)") x = knn_interpolate(x, self.pos_x, self.pos_y) - x = rearrange( - x, "(h w) (b c) -> b c h w", b=b, c=c, - h=self.i_h, w=self.i_w - ) + x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w) x = self.batcher(x) x = self.image_meta_model(x) diff --git a/tests/test_model.py b/tests/test_model.py index 18f7d071..9e510081 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -153,8 +153,7 @@ def test_assimilator_model(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): output_lat_lons.append((lat, lon)) - model = GraphWeatherAssimilator( - output_lat_lons=output_lat_lons, analysis_dim=24) + model = GraphWeatherAssimilator(output_lat_lons=output_lat_lons, analysis_dim=24) features = torch.randn((1, len(obs_lat_lons), 2)) lat_lon_heights = torch.tensor(obs_lat_lons) @@ -168,8 +167,7 @@ def test_forecaster_and_loss(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss( - lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) @@ -210,8 +208,7 @@ def test_forecaster_and_loss_grad_checkpoint(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss( - lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons, use_checkpointing=True) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) @@ -242,8 +239,7 @@ def test_normalized_loss(): assert not torch.isnan(loss) # Since feature_variance = out**2 and target = 0, we expect loss = weights - assert torch.isclose( - loss, criterion.weights.expand_as(out.mean(-1)).mean()) + assert torch.isclose(loss, criterion.weights.expand_as(out.mean(-1)).mean()) def test_image_meta_model(): @@ -326,7 +322,7 @@ def test_wrapper_meta_model(): channels = 3 image_size = 20 patch_size = 4 - scale_factor=3 + scale_factor = 3 model = MetaModel( lat_lons, image_size=image_size, @@ -335,7 +331,7 @@ def test_wrapper_meta_model(): heads=1, mlp_dim=7, channels=channels, - dim_head=64 + dim_head=64, ) big_features = torch.randn((batch, len(lat_lons), channels)) From 362acdcb9039a53062c3303dd9caa94882ef0ba5 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Thu, 1 Aug 2024 10:41:16 +0200 Subject: [PATCH 33/45] added gcsfs to env yml --- environment_cpu.yml | 2 ++ environment_cuda.yml | 2 ++ 2 files changed, 4 insertions(+) diff --git a/environment_cpu.yml b/environment_cpu.yml index 5784bbb4..67ca0f60 100644 --- a/environment_cpu.yml +++ b/environment_cpu.yml @@ -25,6 +25,8 @@ dependencies: - h3-py - numpy - pyshtools + - gcsfs + - pytest - pip: - datasets - einops diff --git a/environment_cuda.yml b/environment_cuda.yml index 3d06cbb9..d77d63e9 100644 --- a/environment_cuda.yml +++ b/environment_cuda.yml @@ -26,6 +26,8 @@ dependencies: - h3-py - numpy - pyshtools + - gcsfs + - pytest - pip: - datasets - einops From 4dc0dc59cfefe6ba63ba9f4a5d8c1e012ad0a6c0 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Wed, 7 Aug 2024 11:36:48 +0200 Subject: [PATCH 34/45] __init__.py imports --- graph_weather/models/fengwu_ghr/__init__.py | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 graph_weather/models/fengwu_ghr/__init__.py diff --git a/graph_weather/models/fengwu_ghr/__init__.py b/graph_weather/models/fengwu_ghr/__init__.py new file mode 100644 index 00000000..4347fb69 --- /dev/null +++ b/graph_weather/models/fengwu_ghr/__init__.py @@ -0,0 +1,3 @@ +"""Main import for FengWu-GHR""" + +from .layers import ImageMetaModel, MetaModel, WrapperImageModel, WrapperMetaModel From 31ee1e9c3b7aa4196baa0dff1b83877bfdfe4c56 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Thu, 15 Aug 2024 16:14:20 +0200 Subject: [PATCH 35/45] MetaModel long coordinates --- graph_weather/models/fengwu_ghr/layers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 8c73194d..f63919d3 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -227,6 +227,7 @@ def __init__( ) def forward(self, x): + assert x.shape[1] == self.channels, "Wrong number of channels" device = x.device dtype = x.dtype @@ -274,12 +275,12 @@ def __init__( super().__init__() self.i_h, self.i_w = pair(image_size) - self.pos_x = torch.tensor(lat_lons) + self.pos_x = torch.tensor(lat_lons).to(torch.long) self.pos_y = torch.cartesian_prod( (torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long), (torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long), ) - + self.image_meta_model = ImageMetaModel( image_size=image_size, patch_size=patch_size, From 297075747be16702e2f0a612ec24860d95a72040 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Thu, 15 Aug 2024 16:35:00 +0200 Subject: [PATCH 36/45] knn_interpolate gpu patch --- graph_weather/models/fengwu_ghr/layers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index f63919d3..31a148c7 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -22,6 +22,9 @@ def knn_interpolate( squared_distance = (diff * diff).sum(dim=-1, keepdim=True) weights = 1.0 / torch.clamp(squared_distance, min=1e-16) + y_idx, x_idx = y_idx.to(x.device), x_idx.to(x.device) + weights = weights.to(x.device) + den = scatter(weights, y_idx, 0, pos_y.size(0), reduce="sum") y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce="sum") From 9f84835e4dd553ba32692195061e48fb490a330d Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Tue, 3 Sep 2024 11:55:05 +0000 Subject: [PATCH 37/45] era5 training --- .gitignore | 2 + train/era5.py | 185 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 187 insertions(+) create mode 100644 train/era5.py diff --git a/.gitignore b/.gitignore index d248bf98..e450bd8f 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,5 @@ # pixi environments .pixi .vscode/ +checkpoints/ +lightning_logs/ diff --git a/train/era5.py b/train/era5.py new file mode 100644 index 00000000..343d5975 --- /dev/null +++ b/train/era5.py @@ -0,0 +1,185 @@ +import click +import xarray +import numpy as np +import pandas as pd +import pytorch_lightning as pl +import torch +from pytorch_lightning.callbacks import ModelCheckpoint +from torch.utils.data import DataLoader, Dataset + +from graph_weather.models import MetaModel +from graph_weather.models.losses import NormalizedMSELoss + +from einops import rearrange + + +class LitGraphForecaster(pl.LightningModule): + """ + LightningModule for graph-based weather forecasting. + + Attributes: + model (GraphWeatherForecaster): Graph weather forecaster model. + criterion (NormalizedMSELoss): Loss criterion for training. + lr : Learning rate for optimizer. + + Methods: + __init__: Initialize the LitGraphForecaster object. + forward: Forward pass of the model. + training_step: Training step. + configure_optimizers: Configure the optimizer for training. + """ + + def __init__( + self, + lat_lons: list, + *, + channels: int, + image_size, + patch_size=4, + depth=5, + heads=4, + mlp_dim=5, + feature_dim: int = 605, + lr: float = 3e-4, + + ): + """ + Initialize the LitGraphForecaster object with the required args. + + Args: + lat_lons : List of latitude and longitude values. + feature_dim : Dimensionality of the input features. + aux_dim : Dimensionality of auxiliary features. + hidden_dim : Dimensionality of hidden layers in the model. + num_blocks : Number of graph convolutional blocks in the model. + lr (float): Learning rate for optimizer. + """ + super().__init__() + self.model = MetaModel( + lat_lons, + image_size=image_size, + patch_size=patch_size, + depth=depth, + heads=heads, + mlp_dim=mlp_dim, + channels=channels + ) + self.criterion = NormalizedMSELoss( + lat_lons=lat_lons, feature_variance=np.ones((feature_dim,)) + ) + self.lr = lr + self.save_hyperparameters() + + def forward(self, x): + """ + Forward pass . + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + return self.model(x) + + def training_step(self, batch, batch_idx): + """ + Training step. + + Args: + batch (array): Batch of data containing input and output tensors. + batch_idx (int): Index of the current batch. + + Returns: + torch.Tensor: Loss tensor. + """ + x, y = batch[:, 0], batch[:, 1] + if torch.isnan(x).any() or torch.isnan(y).any(): + return None + y_hat = self.forward(x) + loss = self.criterion(y_hat, y) + self.log('loss', loss, prog_bar=True) + return loss + + def configure_optimizers(self): + """ + Configure the optimizer. + + Returns: + torch.optim.Optimizer: Optimizer instance. + """ + return torch.optim.AdamW(self.parameters(), lr=self.lr) + + +class Era5Dataset(Dataset): + """Era5 dataset.""" + + def __init__(self, xarr, transform=None): + """ + Arguments: + #TODO + """ + ds = np.asarray(xarr.to_array()) + ds = torch.from_numpy(ds) + ds -= ds.min(0, keepdim=True)[0] + ds /= ds.max(0, keepdim=True)[0] + ds = rearrange(ds, "C T H W -> T (H W) C") + self.ds = ds + + def __len__(self): + return len(self.ds) - 1 + + def __getitem__(self, index): + return self.ds[index:index+2] + + +if __name__ == "__main__": + + patch_size = 4 + grid_step = 20 + + reanalysis = xarray.open_zarr( + 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3', + storage_options=dict(token='anon'), + + ) + reanalysis = reanalysis.isel(time=slice(100, 400), longitude=slice( + 0, 1440, grid_step), latitude=slice(0, 721, grid_step)) + print(f'size: {reanalysis.nbytes / (1024 ** 3)} GiB') + + lat_lons = np.array( + np.meshgrid( + np.asarray(reanalysis["latitude"]).flatten(), + np.asarray(reanalysis["longitude"]).flatten(), + ) + ).T.reshape((-1, 2)) + + checkpoint_callback = ModelCheckpoint( + dirpath="./checkpoints", save_top_k=1, monitor="loss") + reanalysis = reanalysis[["2m_temperature", + "surface_pressure", + "10m_u_component_of_wind", + "10m_v_component_of_wind"]] + + shape = np.asarray(reanalysis.to_array()).shape + channels = shape[0] + + dset = DataLoader(Era5Dataset(reanalysis), batch_size=10, num_workers=8) + model = LitGraphForecaster(lat_lons=lat_lons, + channels=channels, + image_size=(721//grid_step, 1440//grid_step), + patch_size=patch_size, + depth=5, + heads=4, + mlp_dim=5) + trainer = pl.Trainer( + accelerator="gpu", + devices=-1, + max_epochs=1000, + precision="16-mixed", + callbacks=[checkpoint_callback], + log_every_n_steps=3 + + ) + + trainer.fit(model, dset) From b9c1e300d17c5d4510fe1afe1c54f75cbc079b7a Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Mon, 9 Sep 2024 11:42:34 +0000 Subject: [PATCH 38/45] era5 training bugfix --- train/era5.py | 50 +++++++++++++++++++++++++++++--------------------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/train/era5.py b/train/era5.py index 343d5975..a83be4af 100644 --- a/train/era5.py +++ b/train/era5.py @@ -12,8 +12,10 @@ from einops import rearrange +from pathlib import Path -class LitGraphForecaster(pl.LightningModule): + +class LitFengWuGHR(pl.LightningModule): """ LightningModule for graph-based weather forecasting. @@ -23,7 +25,7 @@ class LitGraphForecaster(pl.LightningModule): lr : Learning rate for optimizer. Methods: - __init__: Initialize the LitGraphForecaster object. + __init__: Initialize the LitFengWuGHR object. forward: Forward pass of the model. training_step: Training step. configure_optimizers: Configure the optimizer for training. @@ -44,7 +46,7 @@ def __init__( ): """ - Initialize the LitGraphForecaster object with the required args. + Initialize the LitFengWuGHR object with the required args. Args: lat_lons : List of latitude and longitude values. @@ -135,16 +137,27 @@ def __getitem__(self, index): if __name__ == "__main__": + ckpt_path = Path("./checkpoints") patch_size = 4 grid_step = 20 + variables = ["2m_temperature", + "surface_pressure", + "10m_u_component_of_wind", + "10m_v_component_of_wind"] + + channels = len(variables) + ckpt_path.mkdir(parents=True, exist_ok=True) reanalysis = xarray.open_zarr( 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3', storage_options=dict(token='anon'), - ) - reanalysis = reanalysis.isel(time=slice(100, 400), longitude=slice( + + reanalysis = reanalysis.sel(time=slice('2020-01-01', '2021-01-01')) + reanalysis = reanalysis.isel(time=slice(100,107), longitude=slice( 0, 1440, grid_step), latitude=slice(0, 721, grid_step)) + + reanalysis = reanalysis[variables] print(f'size: {reanalysis.nbytes / (1024 ** 3)} GiB') lat_lons = np.array( @@ -155,27 +168,20 @@ def __getitem__(self, index): ).T.reshape((-1, 2)) checkpoint_callback = ModelCheckpoint( - dirpath="./checkpoints", save_top_k=1, monitor="loss") - reanalysis = reanalysis[["2m_temperature", - "surface_pressure", - "10m_u_component_of_wind", - "10m_v_component_of_wind"]] - - shape = np.asarray(reanalysis.to_array()).shape - channels = shape[0] + dirpath=ckpt_path, save_top_k=1, monitor="loss") dset = DataLoader(Era5Dataset(reanalysis), batch_size=10, num_workers=8) - model = LitGraphForecaster(lat_lons=lat_lons, - channels=channels, - image_size=(721//grid_step, 1440//grid_step), - patch_size=patch_size, - depth=5, - heads=4, - mlp_dim=5) + model = LitFengWuGHR(lat_lons=lat_lons, + channels=channels, + image_size=(721//grid_step, 1440//grid_step), + patch_size=patch_size, + depth=5, + heads=4, + mlp_dim=5) trainer = pl.Trainer( accelerator="gpu", devices=-1, - max_epochs=1000, + max_epochs=100, precision="16-mixed", callbacks=[checkpoint_callback], log_every_n_steps=3 @@ -183,3 +189,5 @@ def __getitem__(self, index): ) trainer.fit(model, dset) + + torch.save(model.state_dict(), ckpt_path / "best.pt") From eba1335b985b7578328320a9ff6a9feac8a8ae81 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi <58804597+rnwzd@users.noreply.github.com> Date: Wed, 18 Sep 2024 08:26:43 +0000 Subject: [PATCH 39/45] lora training --- graph_weather/models/__init__.py | 2 +- graph_weather/models/fengwu_ghr/__init__.py | 2 +- graph_weather/models/fengwu_ghr/layers.py | 42 +++++ train/era5.py | 4 +- train/lora.py | 163 ++++++++++++++++++++ 5 files changed, 209 insertions(+), 4 deletions(-) create mode 100644 train/lora.py diff --git a/graph_weather/models/__init__.py b/graph_weather/models/__init__.py index ace964db..1fdd58ff 100644 --- a/graph_weather/models/__init__.py +++ b/graph_weather/models/__init__.py @@ -1,6 +1,6 @@ """Models""" -from .fengwu_ghr.layers import ImageMetaModel, MetaModel, WrapperImageModel, WrapperMetaModel +from .fengwu_ghr.layers import ImageMetaModel, MetaModel, WrapperImageModel, WrapperMetaModel, LoRAModule from .layers.assimilator_decoder import AssimilatorDecoder from .layers.assimilator_encoder import AssimilatorEncoder from .layers.decoder import Decoder diff --git a/graph_weather/models/fengwu_ghr/__init__.py b/graph_weather/models/fengwu_ghr/__init__.py index 4347fb69..1ea08d95 100644 --- a/graph_weather/models/fengwu_ghr/__init__.py +++ b/graph_weather/models/fengwu_ghr/__init__.py @@ -1,3 +1,3 @@ """Main import for FengWu-GHR""" -from .layers import ImageMetaModel, MetaModel, WrapperImageModel, WrapperMetaModel +from .layers import ImageMetaModel, MetaModel, WrapperImageModel, WrapperMetaModel, LoRAModule diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 31a148c7..6efa9572 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -357,3 +357,45 @@ def forward(self, x): x = rearrange(x, "n (b c) -> b n c", b=b, c=c) return x + + +class LoRALayer(nn.Module): + def __init__(self, linear_layer: nn.Module, r: int): + """ + Initialize LoRALayer. + + Args: + linear_layer (nn.Module): Linear layer to be transformed. + r (int): rank of the low-rank matrix. + """ + super().__init__() + out_features, in_features = linear_layer.weight.shape + + self.A = nn.Parameter(torch.randn(r, in_features)) + self.B = nn.Parameter(torch.zeros(out_features, r)) + self.linear_layer = linear_layer + + def forward(self, x): + out = self.linear_layer(x) + self.B @ self.A @ x + return out + + +class LoRAModule(nn.Module): + def __init__(self, model, r=4): + """ + Initialize LoRAModule. + + Args: + model (nn.Module): Model to be modified with LoRA layers. + r (int, optional): Rank of LoRA layers. Defaults to 4. + """ + super().__init__() + for name, layer in model.named_modules(): + layer.eval() + if isinstance(layer, nn.Linear): + lora_layer = LoRALayer(layer, r) + setattr(model, name, lora_layer) + self.model = model + + def forward(self, x): + return self.model(x) \ No newline at end of file diff --git a/train/era5.py b/train/era5.py index a83be4af..a5b69332 100644 --- a/train/era5.py +++ b/train/era5.py @@ -41,7 +41,7 @@ def __init__( depth=5, heads=4, mlp_dim=5, - feature_dim: int = 605, + feature_dim: int = 605, # TODO where does this come from? lr: float = 3e-4, ): @@ -190,4 +190,4 @@ def __getitem__(self, index): trainer.fit(model, dset) - torch.save(model.state_dict(), ckpt_path / "best.pt") + torch.save(model.model.state_dict(), ckpt_path / "best.pt") diff --git a/train/lora.py b/train/lora.py new file mode 100644 index 00000000..cc23a42d --- /dev/null +++ b/train/lora.py @@ -0,0 +1,163 @@ +import torch.nn as nn +import click +import xarray +import numpy as np +import pandas as pd +import pytorch_lightning as pl +import torch +from pytorch_lightning.callbacks import ModelCheckpoint +from torch.utils.data import DataLoader, Dataset + +from graph_weather.models import MetaModel, LoRAModule +from graph_weather.models.losses import NormalizedMSELoss + +from einops import rearrange + +from pathlib import Path + + +class LitLoRAFengWuGHR(pl.LightningModule): + def __init__( + self, + lat_lons: list, + single_step_model_state_dict: dict, + *, + time_step: int, + rank: int, + channels: int, + image_size, + patch_size=4, + depth=5, + heads=4, + mlp_dim=5, + feature_dim: int = 605, # TODO where does this come from? + lr: float = 3e-4, + ): + super().__init__() + assert time_step > 1, "Time step must be greater than 1. Remember that 1 is the simple model time step." + ssmodel = MetaModel( + lat_lons, + image_size=image_size, + patch_size=patch_size, + depth=depth, + heads=heads, + mlp_dim=mlp_dim, + channels=channels + ) + ssmodel.load_state_dict(single_step_model_state_dict) + self.models = nn.ModuleList([ssmodel] + + [LoRAModule(ssmodel, r=rank) for _ in range(2, time_step+1)]) + self.criterion = NormalizedMSELoss( + lat_lons=lat_lons, feature_variance=np.ones((feature_dim,)) + + ) + self.lr = lr + self.save_hyperparameters() + + def forward(self, x): + ys = [] + for t, model in enumerate(self.models): + x = model(x) + ys.append(x) + return torch.stack(ys, dim=1) + + def training_step(self, batch, batch_idx): + if torch.isnan(batch).any(): + return None + x, ys = batch[:, 0, ...], batch[:, 1:, ...] + + y_hat = self.forward(x) + loss = self.criterion(y_hat, ys) + self.log('loss', loss, prog_bar=True) + return loss + + def configure_optimizers(self): + return torch.optim.AdamW(self.parameters(), lr=self.lr) + + +class Era5Dataset(Dataset): + + def __init__(self, xarr, time_step=1, transform=None): + assert time_step > 0, "Time step must be greater than 0." + ds = np.asarray(xarr.to_array()) + ds = torch.from_numpy(ds) + ds -= ds.min(0, keepdim=True)[0] + ds /= ds.max(0, keepdim=True)[0] + ds = rearrange(ds, "C T H W -> T (H W) C") + self.ds = ds + self.time_step = time_step + + def __len__(self): + return len(self.ds) - self.time_step + + def __getitem__(self, index): + return self.ds[index:index+time_step+1] + + +if __name__ == "__main__": + + ckpt_path = Path("./checkpoints") + ckpt_name = "best.pt" + patch_size = 4 + grid_step = 20 + time_step = 2 + rank = 4 + variables = ["2m_temperature", + "surface_pressure", + "10m_u_component_of_wind", + "10m_v_component_of_wind"] + + ############################################################### + + channels = len(variables) + ckpt_path.mkdir(parents=True, exist_ok=True) + + reanalysis = xarray.open_zarr( + 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3', + storage_options=dict(token='anon'), + ) + + reanalysis = reanalysis.sel(time=slice('2020-01-01', '2021-01-01')) + reanalysis = reanalysis.isel(time=slice(100, 111), longitude=slice( + 0, 1440, grid_step), latitude=slice(0, 721, grid_step)) + + reanalysis = reanalysis[variables] + print(f'size: {reanalysis.nbytes / (1024 ** 3)} GiB') + + lat_lons = np.array( + np.meshgrid( + np.asarray(reanalysis["latitude"]).flatten(), + np.asarray(reanalysis["longitude"]).flatten(), + ) + ).T.reshape((-1, 2)) + + checkpoint_callback = ModelCheckpoint( + dirpath=ckpt_path, save_top_k=1, monitor="loss") + + dset = DataLoader(Era5Dataset( + reanalysis, time_step=time_step), batch_size=10, num_workers=8) + + single_step_model_state_dict = torch.load(ckpt_path / ckpt_name) + + model = LitLoRAFengWuGHR(lat_lons=lat_lons, + single_step_model_state_dict=single_step_model_state_dict, + time_step=time_step, + rank=rank, + ########## + channels=channels, + image_size=(721//grid_step, 1440//grid_step), + patch_size=patch_size, + depth=5, + heads=4, + mlp_dim=5) + trainer = pl.Trainer( + accelerator="gpu", + devices=-1, + max_epochs=100, + precision="16-mixed", + callbacks=[checkpoint_callback], + log_every_n_steps=3, + strategy='ddp_find_unused_parameters_true' + ) + + trainer.fit(model, dset) From 7915c9b80b42a2cf64daf48bcc91c7d77555aa6a Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi <58804597+rnwzd@users.noreply.github.com> Date: Wed, 18 Sep 2024 13:06:17 +0000 Subject: [PATCH 40/45] pkg does not exist --- environment_cpu.yml | 2 +- environment_cuda.yml | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/environment_cpu.yml b/environment_cpu.yml index cfdc958d..c5e781d0 100644 --- a/environment_cpu.yml +++ b/environment_cpu.yml @@ -25,7 +25,6 @@ dependencies: - h3-py - numpy - pyshtools - - torch_harmonics - gcsfs - pytest - pip: @@ -39,3 +38,4 @@ dependencies: - click - trimesh - rtree + - torch-harmonics diff --git a/environment_cuda.yml b/environment_cuda.yml index a3c5a651..af1d7d3e 100644 --- a/environment_cuda.yml +++ b/environment_cuda.yml @@ -12,7 +12,6 @@ dependencies: - pip - pyg - python - - pytorch - pytorch-cluster - pytorch-scatter - pytorch-sparse @@ -26,7 +25,6 @@ dependencies: - h3-py - numpy - pyshtools - - torch_harmonics - gcsfs - pytest - pip: @@ -40,3 +38,4 @@ dependencies: - click - trimesh - rtree + - torch-harmonics From 085ae707294fb0d3b8031a7f9eedfa0637e314d1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Sep 2024 13:10:30 +0000 Subject: [PATCH 41/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- graph_weather/models/__init__.py | 8 +- graph_weather/models/fengwu_ghr/__init__.py | 2 +- graph_weather/models/fengwu_ghr/layers.py | 23 ++--- tests/test_model.py | 1 - train/era5.py | 75 ++++++++-------- train/lora.py | 94 +++++++++++---------- 6 files changed, 100 insertions(+), 103 deletions(-) diff --git a/graph_weather/models/__init__.py b/graph_weather/models/__init__.py index 1fdd58ff..91909420 100644 --- a/graph_weather/models/__init__.py +++ b/graph_weather/models/__init__.py @@ -1,6 +1,12 @@ """Models""" -from .fengwu_ghr.layers import ImageMetaModel, MetaModel, WrapperImageModel, WrapperMetaModel, LoRAModule +from .fengwu_ghr.layers import ( + ImageMetaModel, + LoRAModule, + MetaModel, + WrapperImageModel, + WrapperMetaModel, +) from .layers.assimilator_decoder import AssimilatorDecoder from .layers.assimilator_encoder import AssimilatorEncoder from .layers.decoder import Decoder diff --git a/graph_weather/models/fengwu_ghr/__init__.py b/graph_weather/models/fengwu_ghr/__init__.py index 1ea08d95..39b921fd 100644 --- a/graph_weather/models/fengwu_ghr/__init__.py +++ b/graph_weather/models/fengwu_ghr/__init__.py @@ -1,3 +1,3 @@ """Main import for FengWu-GHR""" -from .layers import ImageMetaModel, MetaModel, WrapperImageModel, WrapperMetaModel, LoRAModule +from .layers import ImageMetaModel, LoRAModule, MetaModel, WrapperImageModel, WrapperMetaModel diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 6efa9572..38cf43ab 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -32,6 +32,7 @@ def knn_interpolate( return y + def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32): y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" @@ -262,6 +263,7 @@ def forward(self, x): x = self.debatcher(x) return x + class MetaModel(nn.Module): def __init__( self, @@ -283,7 +285,7 @@ def __init__( (torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long), (torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long), ) - + self.image_meta_model = ImageMetaModel( image_size=image_size, patch_size=patch_size, @@ -299,10 +301,7 @@ def forward(self, x): x = rearrange(x, "b n c -> n (b c)") x = knn_interpolate(x, self.pos_x, self.pos_y) - x = rearrange( - x, "(h w) (b c) -> b c h w", b=b, c=c, - h=self.i_h, w=self.i_w - ) + x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w) x = self.image_meta_model(x) x = rearrange(x, "b c h w -> (h w) (b c)") @@ -312,12 +311,7 @@ def forward(self, x): class WrapperMetaModel(nn.Module): - def __init__( - self, - lat_lons: list, - meta_model: MetaModel, - scale_factor - ): + def __init__(self, lat_lons: list, meta_model: MetaModel, scale_factor): super().__init__() s_h, s_w = pair(scale_factor) self.i_h, self.i_w = meta_model.i_h * s_h, meta_model.i_w * s_w @@ -343,10 +337,7 @@ def forward(self, x): x = rearrange(x, "b n c -> n (b c)") x = knn_interpolate(x, self.pos_x, self.pos_y) - x = rearrange( - x, "(h w) (b c) -> b c h w", b=b, c=c, - h=self.i_h, w=self.i_w - ) + x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w) x = self.batcher(x) x = self.image_meta_model(x) @@ -398,4 +389,4 @@ def __init__(self, model, r=4): self.model = model def forward(self, x): - return self.model(x) \ No newline at end of file + return self.model(x) diff --git a/tests/test_model.py b/tests/test_model.py index 0ff5a81d..9fae0a1d 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -332,4 +332,3 @@ def test_wrapper_meta_model(): assert not torch.isnan(out).any() assert out.size() == big_features.size() - diff --git a/train/era5.py b/train/era5.py index a5b69332..ab5eee2d 100644 --- a/train/era5.py +++ b/train/era5.py @@ -1,19 +1,16 @@ -import click -import xarray +from pathlib import Path + import numpy as np -import pandas as pd import pytorch_lightning as pl import torch +import xarray +from einops import rearrange from pytorch_lightning.callbacks import ModelCheckpoint from torch.utils.data import DataLoader, Dataset from graph_weather.models import MetaModel from graph_weather.models.losses import NormalizedMSELoss -from einops import rearrange - -from pathlib import Path - class LitFengWuGHR(pl.LightningModule): """ @@ -37,13 +34,12 @@ def __init__( *, channels: int, image_size, - patch_size=4, - depth=5, - heads=4, - mlp_dim=5, - feature_dim: int = 605, # TODO where does this come from? + patch_size=4, + depth=5, + heads=4, + mlp_dim=5, + feature_dim: int = 605, # TODO where does this come from? lr: float = 3e-4, - ): """ Initialize the LitFengWuGHR object with the required args. @@ -64,7 +60,7 @@ def __init__( depth=depth, heads=heads, mlp_dim=mlp_dim, - channels=channels + channels=channels, ) self.criterion = NormalizedMSELoss( lat_lons=lat_lons, feature_variance=np.ones((feature_dim,)) @@ -100,7 +96,7 @@ def training_step(self, batch, batch_idx): return None y_hat = self.forward(x) loss = self.criterion(y_hat, y) - self.log('loss', loss, prog_bar=True) + self.log("loss", loss, prog_bar=True) return loss def configure_optimizers(self): @@ -132,7 +128,7 @@ def __len__(self): return len(self.ds) - 1 def __getitem__(self, index): - return self.ds[index:index+2] + return self.ds[index : index + 2] if __name__ == "__main__": @@ -140,25 +136,28 @@ def __getitem__(self, index): ckpt_path = Path("./checkpoints") patch_size = 4 grid_step = 20 - variables = ["2m_temperature", - "surface_pressure", - "10m_u_component_of_wind", - "10m_v_component_of_wind"] + variables = [ + "2m_temperature", + "surface_pressure", + "10m_u_component_of_wind", + "10m_v_component_of_wind", + ] channels = len(variables) ckpt_path.mkdir(parents=True, exist_ok=True) reanalysis = xarray.open_zarr( - 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3', - storage_options=dict(token='anon'), + "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3", + storage_options=dict(token="anon"), ) - reanalysis = reanalysis.sel(time=slice('2020-01-01', '2021-01-01')) - reanalysis = reanalysis.isel(time=slice(100,107), longitude=slice( - 0, 1440, grid_step), latitude=slice(0, 721, grid_step)) + reanalysis = reanalysis.sel(time=slice("2020-01-01", "2021-01-01")) + reanalysis = reanalysis.isel( + time=slice(100, 107), longitude=slice(0, 1440, grid_step), latitude=slice(0, 721, grid_step) + ) reanalysis = reanalysis[variables] - print(f'size: {reanalysis.nbytes / (1024 ** 3)} GiB') + print(f"size: {reanalysis.nbytes / (1024 ** 3)} GiB") lat_lons = np.array( np.meshgrid( @@ -167,27 +166,27 @@ def __getitem__(self, index): ) ).T.reshape((-1, 2)) - checkpoint_callback = ModelCheckpoint( - dirpath=ckpt_path, save_top_k=1, monitor="loss") + checkpoint_callback = ModelCheckpoint(dirpath=ckpt_path, save_top_k=1, monitor="loss") dset = DataLoader(Era5Dataset(reanalysis), batch_size=10, num_workers=8) - model = LitFengWuGHR(lat_lons=lat_lons, - channels=channels, - image_size=(721//grid_step, 1440//grid_step), - patch_size=patch_size, - depth=5, - heads=4, - mlp_dim=5) + model = LitFengWuGHR( + lat_lons=lat_lons, + channels=channels, + image_size=(721 // grid_step, 1440 // grid_step), + patch_size=patch_size, + depth=5, + heads=4, + mlp_dim=5, + ) trainer = pl.Trainer( accelerator="gpu", devices=-1, max_epochs=100, precision="16-mixed", callbacks=[checkpoint_callback], - log_every_n_steps=3 - + log_every_n_steps=3, ) trainer.fit(model, dset) - + torch.save(model.model.state_dict(), ckpt_path / "best.pt") diff --git a/train/lora.py b/train/lora.py index cc23a42d..f829c797 100644 --- a/train/lora.py +++ b/train/lora.py @@ -1,20 +1,17 @@ -import torch.nn as nn -import click -import xarray +from pathlib import Path + import numpy as np -import pandas as pd import pytorch_lightning as pl import torch +import torch.nn as nn +import xarray +from einops import rearrange from pytorch_lightning.callbacks import ModelCheckpoint from torch.utils.data import DataLoader, Dataset -from graph_weather.models import MetaModel, LoRAModule +from graph_weather.models import LoRAModule, MetaModel from graph_weather.models.losses import NormalizedMSELoss -from einops import rearrange - -from pathlib import Path - class LitLoRAFengWuGHR(pl.LightningModule): def __init__( @@ -26,15 +23,17 @@ def __init__( rank: int, channels: int, image_size, - patch_size=4, - depth=5, - heads=4, - mlp_dim=5, + patch_size=4, + depth=5, + heads=4, + mlp_dim=5, feature_dim: int = 605, # TODO where does this come from? lr: float = 3e-4, ): super().__init__() - assert time_step > 1, "Time step must be greater than 1. Remember that 1 is the simple model time step." + assert ( + time_step > 1 + ), "Time step must be greater than 1. Remember that 1 is the simple model time step." ssmodel = MetaModel( lat_lons, image_size=image_size, @@ -42,14 +41,14 @@ def __init__( depth=depth, heads=heads, mlp_dim=mlp_dim, - channels=channels + channels=channels, ) ssmodel.load_state_dict(single_step_model_state_dict) - self.models = nn.ModuleList([ssmodel] + - [LoRAModule(ssmodel, r=rank) for _ in range(2, time_step+1)]) + self.models = nn.ModuleList( + [ssmodel] + [LoRAModule(ssmodel, r=rank) for _ in range(2, time_step + 1)] + ) self.criterion = NormalizedMSELoss( lat_lons=lat_lons, feature_variance=np.ones((feature_dim,)) - ) self.lr = lr self.save_hyperparameters() @@ -68,7 +67,7 @@ def training_step(self, batch, batch_idx): y_hat = self.forward(x) loss = self.criterion(y_hat, ys) - self.log('loss', loss, prog_bar=True) + self.log("loss", loss, prog_bar=True) return loss def configure_optimizers(self): @@ -91,7 +90,7 @@ def __len__(self): return len(self.ds) - self.time_step def __getitem__(self, index): - return self.ds[index:index+time_step+1] + return self.ds[index : index + time_step + 1] if __name__ == "__main__": @@ -102,10 +101,12 @@ def __getitem__(self, index): grid_step = 20 time_step = 2 rank = 4 - variables = ["2m_temperature", - "surface_pressure", - "10m_u_component_of_wind", - "10m_v_component_of_wind"] + variables = [ + "2m_temperature", + "surface_pressure", + "10m_u_component_of_wind", + "10m_v_component_of_wind", + ] ############################################################### @@ -113,16 +114,17 @@ def __getitem__(self, index): ckpt_path.mkdir(parents=True, exist_ok=True) reanalysis = xarray.open_zarr( - 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3', - storage_options=dict(token='anon'), + "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3", + storage_options=dict(token="anon"), ) - reanalysis = reanalysis.sel(time=slice('2020-01-01', '2021-01-01')) - reanalysis = reanalysis.isel(time=slice(100, 111), longitude=slice( - 0, 1440, grid_step), latitude=slice(0, 721, grid_step)) + reanalysis = reanalysis.sel(time=slice("2020-01-01", "2021-01-01")) + reanalysis = reanalysis.isel( + time=slice(100, 111), longitude=slice(0, 1440, grid_step), latitude=slice(0, 721, grid_step) + ) reanalysis = reanalysis[variables] - print(f'size: {reanalysis.nbytes / (1024 ** 3)} GiB') + print(f"size: {reanalysis.nbytes / (1024 ** 3)} GiB") lat_lons = np.array( np.meshgrid( @@ -131,25 +133,25 @@ def __getitem__(self, index): ) ).T.reshape((-1, 2)) - checkpoint_callback = ModelCheckpoint( - dirpath=ckpt_path, save_top_k=1, monitor="loss") + checkpoint_callback = ModelCheckpoint(dirpath=ckpt_path, save_top_k=1, monitor="loss") - dset = DataLoader(Era5Dataset( - reanalysis, time_step=time_step), batch_size=10, num_workers=8) + dset = DataLoader(Era5Dataset(reanalysis, time_step=time_step), batch_size=10, num_workers=8) single_step_model_state_dict = torch.load(ckpt_path / ckpt_name) - model = LitLoRAFengWuGHR(lat_lons=lat_lons, - single_step_model_state_dict=single_step_model_state_dict, - time_step=time_step, - rank=rank, - ########## - channels=channels, - image_size=(721//grid_step, 1440//grid_step), - patch_size=patch_size, - depth=5, - heads=4, - mlp_dim=5) + model = LitLoRAFengWuGHR( + lat_lons=lat_lons, + single_step_model_state_dict=single_step_model_state_dict, + time_step=time_step, + rank=rank, + ########## + channels=channels, + image_size=(721 // grid_step, 1440 // grid_step), + patch_size=patch_size, + depth=5, + heads=4, + mlp_dim=5, + ) trainer = pl.Trainer( accelerator="gpu", devices=-1, @@ -157,7 +159,7 @@ def __getitem__(self, index): precision="16-mixed", callbacks=[checkpoint_callback], log_every_n_steps=3, - strategy='ddp_find_unused_parameters_true' + strategy="ddp_find_unused_parameters_true", ) trainer.fit(model, dset) From 894405089210a59b6d9298ec7c7718c96d6abdfd Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi <58804597+rnwzd@users.noreply.github.com> Date: Thu, 19 Sep 2024 12:54:52 +0000 Subject: [PATCH 42/45] env.yml bugfix --- environment_cuda.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/environment_cuda.yml b/environment_cuda.yml index af1d7d3e..1ecf55da 100644 --- a/environment_cuda.yml +++ b/environment_cuda.yml @@ -12,6 +12,7 @@ dependencies: - pip - pyg - python + - pytorch - pytorch-cluster - pytorch-scatter - pytorch-sparse From 968b9088e581a2c3dea61c1c583294ea77f01a96 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Mon, 23 Sep 2024 19:13:14 +0100 Subject: [PATCH 43/45] Update environment_cpu.yml --- environment_cpu.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/environment_cpu.yml b/environment_cpu.yml index c5e781d0..383d4a1e 100644 --- a/environment_cpu.yml +++ b/environment_cpu.yml @@ -28,6 +28,7 @@ dependencies: - gcsfs - pytest - pip: + - setuptools - datasets - einops - fsspec From 7eb323af33e6c6406b2b0889502f2215a8439a0e Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Mon, 23 Sep 2024 19:13:32 +0100 Subject: [PATCH 44/45] Update environment_cuda.yml --- environment_cuda.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/environment_cuda.yml b/environment_cuda.yml index 1ecf55da..ca69ecba 100644 --- a/environment_cuda.yml +++ b/environment_cuda.yml @@ -29,6 +29,7 @@ dependencies: - gcsfs - pytest - pip: + - setuptools - datasets - einops - fsspec From 384d99e6a07c1e1aee5f49a4e797f784f09b06f0 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Mon, 23 Sep 2024 19:33:14 +0100 Subject: [PATCH 45/45] Update environment_cuda.yml --- environment_cuda.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/environment_cuda.yml b/environment_cuda.yml index ca69ecba..e1eb18d4 100644 --- a/environment_cuda.yml +++ b/environment_cuda.yml @@ -11,7 +11,7 @@ dependencies: - pandas - pip - pyg - - python + - python=3.12 - pytorch - pytorch-cluster - pytorch-scatter