Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fengwu ghr #114

Merged
merged 9 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
*.txt
# pixi environments
.pixi
.vscode/
2 changes: 1 addition & 1 deletion environment_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dependencies:
- pandas
- pip
- pyg
- python=3.12
- python
- pytorch
- cpuonly
- pytorch-cluster
Expand Down
1 change: 1 addition & 0 deletions graph_weather/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Models"""

from .fengwu_ghr.layers import MetaModel,ImageMetaModel
from .layers.assimilator_decoder import AssimilatorDecoder
from .layers.assimilator_encoder import AssimilatorEncoder
from .layers.decoder import Decoder
Expand Down
273 changes: 273 additions & 0 deletions graph_weather/models/fengwu_ghr/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
from scipy.interpolate import griddata
from torch_geometric.nn import knn
from torch_geometric.utils import scatter
import numpy as np
from scipy.interpolate import griddata, interpn
import torch
from einops import rearrange
from einops.layers.torch import Rearrange
from torch import nn


# helpers


def pair(t):
return t if isinstance(t, tuple) else (t, t)


def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor,
k: int = 3, num_workers: int = 1):
with torch.no_grad():
assign_index = knn(pos_x, pos_y, k,
num_workers=num_workers)
y_idx, x_idx = assign_index[0], assign_index[1]
diff = pos_x[x_idx] - pos_y[y_idx]
squared_distance = (diff * diff).sum(dim=-1, keepdim=True)
weights = 1.0 / torch.clamp(squared_distance, min=1e-16)

# print((x[x_idx]*weights).shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would remove commented out code like this when you finish with the changes, just to keep it a bit cleaner.

# print(weights.shape)
den = scatter(weights, y_idx, 0, pos_y.size(0), reduce='sum')
# print(den.shape)
y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce='sum')

y = y / den

return y


def grid_interpolate(lat_lons: list, z: torch.Tensor,
height, width,
method: str = "cubic"):
# TODO 1. CPU only
# 2. The mesh is a rectangle, not a sphere
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It being a rectangle is fine, it has to be for the MetaModel to take it as an image. So I think this is great!


xi = np.arange(0.5, width, 1)/width*360
yi = np.arange(0.5, height, 1)/height*180

xi, yi = np.meshgrid(xi, yi)
z = rearrange(z, "b n c -> n b c")
z = griddata(
lat_lons, z, (xi, yi),
fill_value=0, method=method)
z = rearrange(z, "h w b c -> b c h w") # hw ?
z = torch.tensor(z)
return z


def grid_extrapolate(lat_lons, z,
height, width,
method: str = "cubic"):
xi = np.arange(0.5, width, 1)/width*360
yi = np.arange(0.5, height, 1)/height*180
z = rearrange(z, "b c h w -> h w b c")
z = z.detach().numpy()
z = interpn((xi, yi), z, lat_lons,
bounds_error=False,
method=method)
z = rearrange(z, "n b c -> b n c")
z = torch.tensor(z)
return z


def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature**omega)

y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pe.type(dtype)


# classes


class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)

def forward(self, x):
return self.net(x)


class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head**-0.5
self.norm = nn.LayerNorm(dim)

self.attend = nn.Softmax(dim=-1)

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)

def forward(self, x):
x = self.norm(x)

qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(
t, "b n (h d) -> b h n d", h=self.heads), qkv)

dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

attn = self.attend(dots)

out = torch.matmul(attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)


class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[Attention(dim, heads=heads, dim_head=dim_head),
FeedForward(dim, mlp_dim)]
)
)

def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)


class ImageMetaModel(nn.Module):
def __init__(self, *,
image_size,
patch_size, depth,
heads, mlp_dim,
channels=3, dim_head=64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)

assert (
image_height % patch_height == 0 and image_width % patch_width == 0
), "Image dimensions must be divisible by the patch size."

patch_dim = channels * patch_height * patch_width
dim = patch_dim
self.to_patch_embedding = nn.Sequential(
Rearrange(
"b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)",
p_h=patch_height, p_w=patch_width
),
nn.LayerNorm(patch_dim), # TODO Do we need this?
nn.Linear(patch_dim, dim), # TODO Do we need this?
nn.LayerNorm(dim), # TODO Do we need this?
)

self.pos_embedding = posemb_sincos_2d(
h=image_height // patch_height,
w=image_width // patch_width,
dim=dim,
)

self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

self.reshaper = nn.Sequential(
Rearrange(
"b (h w) (p_h p_w c) -> b c (h p_h) (w p_w)",
h=image_height // patch_height,
w=image_width // patch_width,
p_h=patch_height,
p_w=patch_width,
)
)

def forward(self, x):
device = x.device
dtype = x.dtype

x = self.to_patch_embedding(x)
x += self.pos_embedding.to(device, dtype=dtype)

x = self.transformer(x)
x = self.reshaper(x)

return x


class MetaModel(nn.Module):
def __init__(self, lat_lons: list, *,
patch_size, depth,
heads, mlp_dim,
resolution=(721, 1440),
channels=3, dim_head=64,
interp_method='cubic'):
super().__init__()
self.resolution = pair(resolution)

self.pos_x = torch.tensor(lat_lons)
self.pos_y = torch.cartesian_prod(
torch.arange(0, self.resolution[0], 1),
torch.arange(0, self.resolution[1], 1)
)

self.image_model = ImageMetaModel(image_size=resolution,
patch_size=patch_size,
depth=depth,
heads=heads,
mlp_dim=mlp_dim,
channels=channels,
dim_head=dim_head)

def forward(self, x):
b, n, c = x.shape

x = rearrange(x, "b n c -> n (b c)")
x = knn_interpolate(x, self.pos_x, self.pos_y)
x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c,
w=self.resolution[0],
h=self.resolution[1])

x = self.image_model(x)

x = rearrange(x, "b c h w -> (h w) (b c)")
x = knn_interpolate(x, self.pos_y, self.pos_x)
x = rearrange(x, "n (b c) -> b n c", b=b, c=c)

return x


class MetaModel2(nn.Module):
def __init__(self, lat_lons: list, *,
patch_size, depth,
heads, mlp_dim,
resolution=(721, 1440),
channels=3, dim_head=64,
interp_method='cubic'):
super().__init__()
resolution = pair(resolution)
b = 3
n = len(lat_lons)
d = 7
x = torch.randn((b, n, d))
x = rearrange(x, "b n d -> n (b d)")

pos_x = torch.tensor(lat_lons)
pos_y = torch.cartesian_prod(
torch.arange(0, resolution[0], 1),
torch.arange(0, resolution[1], 1)
)
x = knn_interpolate(x, pos_x, pos_y)
x = rearrange(x, "m (b d) -> b m d", b=b, d=d)
print(x.shape)
57 changes: 52 additions & 5 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
import torch

from graph_weather import GraphWeatherAssimilator, GraphWeatherForecaster
from graph_weather.models import AssimilatorDecoder, AssimilatorEncoder, Decoder, Encoder, Processor
from graph_weather.models import (
AssimilatorDecoder,
AssimilatorEncoder,
Decoder,
Encoder,
Processor,
MetaModel,
ImageMetaModel
)
from graph_weather.models.losses import NormalizedMSELoss


Expand Down Expand Up @@ -135,7 +143,8 @@ def test_assimilator_model():
for lat in range(-90, 90, 5):
for lon in range(0, 360, 5):
output_lat_lons.append((lat, lon))
model = GraphWeatherAssimilator(output_lat_lons=output_lat_lons, analysis_dim=24)
model = GraphWeatherAssimilator(
output_lat_lons=output_lat_lons, analysis_dim=24)

features = torch.randn((1, len(obs_lat_lons), 2))
lat_lon_heights = torch.tensor(obs_lat_lons)
Expand All @@ -149,7 +158,8 @@ def test_forecaster_and_loss():
for lat in range(-90, 90, 5):
for lon in range(0, 360, 5):
lat_lons.append((lat, lon))
criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,)))
criterion = NormalizedMSELoss(
lat_lons=lat_lons, feature_variance=torch.randn((78,)))
model = GraphWeatherForecaster(lat_lons)
# Add in auxiliary features
features = torch.randn((2, len(lat_lons), 78 + 24))
Expand Down Expand Up @@ -190,7 +200,8 @@ def test_forecaster_and_loss_grad_checkpoint():
for lat in range(-90, 90, 5):
for lon in range(0, 360, 5):
lat_lons.append((lat, lon))
criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,)))
criterion = NormalizedMSELoss(
lat_lons=lat_lons, feature_variance=torch.randn((78,)))
model = GraphWeatherForecaster(lat_lons, use_checkpointing=True)
# Add in auxiliary features
features = torch.randn((2, len(lat_lons), 78 + 24))
Expand Down Expand Up @@ -221,4 +232,40 @@ def test_normalized_loss():

assert not torch.isnan(loss)
# Since feature_variance = out**2 and target = 0, we expect loss = weights
assert torch.isclose(loss, criterion.weights.expand_as(out.mean(-1)).mean())
assert torch.isclose(
loss, criterion.weights.expand_as(out.mean(-1)).mean())


def test_image_meta_model():
batch = 2
channels = 3
size = 900
image = torch.randn((batch, channels, size, size))
model = ImageMetaModel(image_size=size,
patch_size=10,
depth=1, heads=1, mlp_dim=7,
channels=channels)

out = model(image)
assert not torch.isnan(out).any()
assert not torch.isnan(out).any()
assert out.size() == (batch, channels,size,size)


def test_meta_model():
lat_lons = []
for lat in range(-90, 90, 5):
for lon in range(0, 360, 5):
lat_lons.append((lat, lon))

batch = 2
channels = 3
model = MetaModel(lat_lons,
resolution=4, patch_size=2,
depth=1, heads=1, mlp_dim=7, channels=channels)
features = torch.randn((batch, len(lat_lons), channels))

out = model(features)
# assert not torch.isnan(out).any()
# assert not torch.isnan(out).any()
assert out.size() == (batch, len(lat_lons), channels)
Loading