Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fengwu ghr training #125

Merged
merged 54 commits into from
Sep 23, 2024
Merged
Changes from 1 commit
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
cb70551
fengwu_ghr: initial
rnwzd May 29, 2024
9eaf70d
fengwu_ghr: fixes
rnwzd May 29, 2024
4f3d4c1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 29, 2024
8c60fb7
Interpolate initial
rnwzd Jun 6, 2024
725421d
ImageMetaModel
rnwzd Jun 11, 2024
c57a27e
MetaModel initial
rnwzd Jun 11, 2024
3d2a17d
tested metamodel
rnwzd Jun 14, 2024
48e7d0a
Merge branch 'main' into fengwu_ghr
rnwzd Jun 17, 2024
87d1ffd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 17, 2024
21d84c7
wrapper meta model
rnwzd Jun 21, 2024
07a8d0f
RES
rnwzd Jul 1, 2024
cd84968
load RES state_dict
rnwzd Jul 2, 2024
e0f60ca
Merge branch 'main' of https://github.com/openclimatefix/graph_weathe…
rnwzd Jul 2, 2024
b15110f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 2, 2024
1146db9
bug fix
rnwzd Jul 2, 2024
499ef7d
Merge branch 'fengwu_ghr' of https://github.com/openclimatefix/graph_…
rnwzd Jul 2, 2024
fe82edc
Merge branch 'main' of https://github.com/openclimatefix/graph_weathe…
rnwzd Jul 2, 2024
325fd0e
bug fix
rnwzd Jul 2, 2024
cfa9c3f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 2, 2024
2fadf97
env yml fix
rnwzd Jul 29, 2024
0f46d7a
Merge branch 'fengwu_ghr' of https://github.com/openclimatefix/graph_…
rnwzd Jul 29, 2024
257b353
fengwu_ghr: initial
rnwzd May 29, 2024
04d4776
Merge branch 'fengwu_ghr' of https://github.com/openclimatefix/graph_…
rnwzd Jul 29, 2024
f72c610
test_wrapper_meta_model
rnwzd Jul 29, 2024
8a8ac64
tests fix
rnwzd Jul 30, 2024
6f0c61b
parent 743cf9704eea03353e02351cde52add7233437d6
rnwzd May 29, 2024
1ae62eb
Merge branch 'fengwu_ghr' of https://github.com/openclimatefix/graph_…
rnwzd Jul 30, 2024
a855c6a
fengwu_ghr: initial
rnwzd May 29, 2024
47d0d48
fengwu_ghr: initial
rnwzd May 29, 2024
7a1d562
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 29, 2024
e397941
Interpolate initial
rnwzd Jun 6, 2024
80a73ee
ImageMetaModel
rnwzd Jun 11, 2024
127d8ff
MetaModel initial
rnwzd Jun 11, 2024
b59e54d
tested metamodel
rnwzd Jun 14, 2024
19e73a2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 17, 2024
e54016f
wrapper meta model
rnwzd Jun 21, 2024
dae738e
RES
rnwzd Jul 1, 2024
92acbee
load RES state_dict
rnwzd Jul 2, 2024
5115dd4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 2, 2024
8a6a062
Merge branch 'fengwu_ghr' of https://github.com/openclimatefix/graph_…
rnwzd Jul 30, 2024
362acdc
added gcsfs to env yml
rnwzd Aug 1, 2024
4dc0dc5
__init__.py imports
rnwzd Aug 7, 2024
31ee1e9
MetaModel long coordinates
rnwzd Aug 15, 2024
2970757
knn_interpolate gpu patch
rnwzd Aug 15, 2024
9f84835
era5 training
rnwzd Sep 3, 2024
b9c1e30
era5 training bugfix
rnwzd Sep 9, 2024
eba1335
lora training
rnwzd Sep 18, 2024
0cfbf40
Merge branch 'main' of https://github.com/openclimatefix/graph_weathe…
rnwzd Sep 18, 2024
7915c9b
pkg does not exist
rnwzd Sep 18, 2024
085ae70
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 18, 2024
8944050
env.yml bugfix
rnwzd Sep 19, 2024
968b908
Update environment_cpu.yml
jacobbieker Sep 23, 2024
7eb323a
Update environment_cuda.yml
jacobbieker Sep 23, 2024
384d99e
Update environment_cuda.yml
jacobbieker Sep 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fengwu_ghr: initial
rnwzd committed May 29, 2024
commit cb7055190831d203aaa77cdfdef4a7066bbc4029
1 change: 1 addition & 0 deletions graph_weather/models/__init__.py
Original file line number Diff line number Diff line change
@@ -5,3 +5,4 @@
from .layers.decoder import Decoder
from .layers.encoder import Encoder
from .layers.processor import Processor
from .fengwu_ghr.layers import MetaModel
133 changes: 133 additions & 0 deletions graph_weather/models/fengwu_ghr/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import torch
from torch import nn

from einops import rearrange
from einops.layers.torch import Rearrange

# helpers


def pair(t):
return t if isinstance(t, tuple) else (t, t)


def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature ** omega)

y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pe.type(dtype)

# classes


class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)

def forward(self, x):
return self.net(x)


class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)

self.attend = nn.Softmax(dim=-1)

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)

def forward(self, x):
x = self.norm(x)

qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(
t, 'b n (h d) -> b h n d', h=self.heads), qkv)

dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

attn = self.attend(dots)

out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)


class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads=heads, dim_head=dim_head),
FeedForward(dim, mlp_dim)
]))

def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)


class MetaModel(nn.Module):
def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, dim_head=64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)

assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

patch_dim = channels * patch_height * patch_width
dim = patch_dim
self.to_patch_embedding = nn.Sequential(
Rearrange("b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)",
p_h=patch_height, p_w=patch_width),
nn.LayerNorm(patch_dim), # TODO Do we need this?
nn.Linear(patch_dim, dim), # TODO Do we need this?
nn.LayerNorm(dim), # TODO Do we need this?
)

self.pos_embedding = posemb_sincos_2d(
h=image_height // patch_height,
w=image_width // patch_width,
dim=dim,
)

self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

self.reshaper = nn.Sequential(
Rearrange("b (h w) (p_h p_w c) -> b c (h p_h) (w p_w)",
h=image_height // patch_height,
w=image_width // patch_width,
p_h=patch_height, p_w=patch_width)
)

def forward(self, img):
device = img.device

x = self.to_patch_embedding(img)
x += self.pos_embedding.to(device, dtype=x.dtype)

x = self.transformer(x)

print(x.shape)
x = self.reshaper(x)

return x
14 changes: 13 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
import torch

from graph_weather import GraphWeatherAssimilator, GraphWeatherForecaster
from graph_weather.models import AssimilatorDecoder, AssimilatorEncoder, Decoder, Encoder, Processor
from graph_weather.models import AssimilatorDecoder, AssimilatorEncoder, Decoder, Encoder, Processor, MetaModel
from graph_weather.models.losses import NormalizedMSELoss


@@ -222,3 +222,15 @@ def test_normalized_loss():
assert not torch.isnan(loss)
# Since feature_variance = out**2 and target = 0, we expect loss = weights
assert torch.isclose(loss, criterion.weights.expand_as(out.mean(-1)).mean())


def test_meta_model():
model = MetaModel(image_size=100,patch_size=10,
depth=1, heads=1, mlp_dim=7,
channels=3 )
features = torch.randn((1,3, 100,100) )

out = model(features)
assert not torch.isnan(out).any()
assert not torch.isnan(out).any()
assert out.size() == (1,3, 100,100)