Skip to content

Commit

Permalink
wrapper meta model
Browse files Browse the repository at this point in the history
  • Loading branch information
rnwzd committed Jun 21, 2024
1 parent 87d1ffd commit 21d84c7
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 33 deletions.
2 changes: 1 addition & 1 deletion graph_weather/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Models"""

from .fengwu_ghr.layers import ImageMetaModel, MetaModel
from .fengwu_ghr.layers import ImageMetaModel, MetaModel, WrapperImageModel, WrapperMetaModel
from .layers.assimilator_decoder import AssimilatorDecoder
from .layers.assimilator_encoder import AssimilatorEncoder
from .layers.decoder import Decoder
Expand Down
126 changes: 102 additions & 24 deletions graph_weather/models/fengwu_ghr/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def forward(self, x):
x = self.norm(x)

qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
q, k, v = map(lambda t: rearrange(
t, "b n (h d) -> b h n d", h=self.heads), qkv)

dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

Expand All @@ -95,7 +96,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim):
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[Attention(dim, heads=heads, dim_head=dim_head), FeedForward(dim, mlp_dim)]
[Attention(dim, heads=heads, dim_head=dim_head),
FeedForward(dim, mlp_dim)]
)
)

Expand All @@ -107,29 +109,31 @@ def forward(self, x):


class ImageMetaModel(nn.Module):
def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, dim_head=64):
def __init__(self, *, image_size,
patch_size, depth, heads,
mlp_dim, channels, dim_head):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
self.image_height, self.image_width = pair(image_size)
self.patch_height, self.patch_width = pair(patch_size)

assert (
image_height % patch_height == 0 and image_width % patch_width == 0
self.image_height % self.patch_height == 0 and self.image_width % self.patch_width == 0
), "Image dimensions must be divisible by the patch size."

patch_dim = channels * patch_height * patch_width
patch_dim = channels * self.patch_height * self.patch_width
dim = patch_dim
self.to_patch_embedding = nn.Sequential(
Rearrange(
"b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", p_h=patch_height, p_w=patch_width
"b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", p_h=self.patch_height, p_w=self.patch_width
),
nn.LayerNorm(patch_dim), # TODO Do we need this?
nn.Linear(patch_dim, dim), # TODO Do we need this?
nn.LayerNorm(dim), # TODO Do we need this?
)

self.pos_embedding = posemb_sincos_2d(
h=image_height // patch_height,
w=image_width // patch_width,
h=self.image_height // self.patch_height,
w=self.image_width // self.patch_width,
dim=dim,
)

Expand All @@ -138,10 +142,10 @@ def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3,
self.reshaper = nn.Sequential(
Rearrange(
"b (h w) (p_h p_w c) -> b c (h p_h) (w p_w)",
h=image_height // patch_height,
w=image_width // patch_width,
p_h=patch_height,
p_w=patch_width,
h=self.image_height // self.patch_height,
w=self.image_width // self.patch_width,
p_h=self.patch_height,
p_w=self.patch_width,
)
)

Expand All @@ -158,33 +162,53 @@ def forward(self, x):
return x


class WrapperImageModel(nn.Module):
def __init__(self, image_meta_model: ImageMetaModel,
scale_factor):
super().__init__()
s_h, s_w = pair(scale_factor)
self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w",
s_h=s_h, s_w=s_w)
self.image_meta_model = image_meta_model
self.debatcher = Rearrange(" (b s_h s_w) c h w -> b c (h s_h) (w s_w)",
s_h=s_h, s_w=s_w)

def forward(self, x):
x = self.batcher(x)
x = self.image_meta_model(x)
x = self.debatcher(x)
return x


class MetaModel(nn.Module):
def __init__(
self,
lat_lons: list,
*,
image_size,
patch_size,
depth,
heads,
mlp_dim,
image_size=(721, 1440),
channels=3,
channels,
dim_head=64
):
super().__init__()
self.image_size = pair(image_size)
self.i_h, self.i_w = pair(image_size)

self.pos_x = torch.tensor(lat_lons)
self.pos_y = torch.cartesian_prod(
(
torch.arange(-self.image_size[0] / 2, self.image_size[0] / 2, 1)
/ self.image_size[0]
torch.arange(-self.i_h / 2,
self.i_h / 2, 1)
/ self.i_h
* 180
).to(torch.long),
(torch.arange(0, self.image_size[1], 1) / self.image_size[1] * 360).to(torch.long),
(torch.arange(0, self.i_w, 1) /
self.i_w * 360).to(torch.long),
)

self.image_model = ImageMetaModel(
self.image_meta_model = ImageMetaModel(
image_size=image_size,
patch_size=patch_size,
depth=depth,
Expand All @@ -200,11 +224,65 @@ def forward(self, x):
x = rearrange(x, "b n c -> n (b c)")
x = knn_interpolate(x, self.pos_x, self.pos_y)
x = rearrange(
x, "(w h) (b c) -> b c w h", b=b, c=c, w=self.image_size[0], h=self.image_size[1]
x, "(h w) (b c) -> b c h w", b=b, c=c,
h=self.i_h, w=self.i_w
)
x = self.image_model(x)
x = self.image_meta_model(x)

x = rearrange(x, "b c w h -> (w h) (b c)")
x = rearrange(x, "b c h w -> (h w) (b c)")
x = knn_interpolate(x, self.pos_y, self.pos_x)
x = rearrange(x, "n (b c) -> b n c", b=b, c=c)
return x


class WrapperMetaModel(nn.Module):
def __init__(
self,
lat_lons: list,
meta_model: MetaModel,
scale_factor
):
super().__init__()
self.image_meta_model = meta_model.image_meta_model

s_h, s_w = pair(scale_factor)
self.i_h, self.i_w = meta_model.i_h*s_h, meta_model.i_w*s_w
self.pos_x = torch.tensor(lat_lons)
self.pos_y = torch.cartesian_prod(
(
torch.arange(-self.i_h / 2,
self.i_h / 2, 1)
/ self.i_h
* 180
).to(torch.long),
(torch.arange(0, self.i_w, 1) /
self.i_w * 360).to(torch.long),
)


self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w",
s_h=s_h, s_w=s_w)

self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)",
s_h=s_h, s_w=s_w)

def forward(self, x):
b, n, c = x.shape

x = rearrange(x, "b n c -> n (b c)")
x = knn_interpolate(x, self.pos_x, self.pos_y)
x = rearrange(
x, "(h w) (b c) -> b c h w", b=b, c=c,
h=self.i_h, w=self.i_w
)

x = self.batcher(x)
x = self.image_meta_model(x)
x = self.debatcher(x)

x = rearrange(x, "b c h w -> (h w) (b c)")
x = knn_interpolate(x, self.pos_y, self.pos_x)
x = rearrange(x, "n (b c) -> b n c", b=b, c=c)


return x
76 changes: 68 additions & 8 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
Decoder,
Encoder,
Processor,
MetaModel,
ImageMetaModel,
MetaModel,
WrapperImageModel,
WrapperMetaModel
)
from graph_weather.models.losses import NormalizedMSELoss
from graph_weather.models.gencast.utils.noise import (
Expand Down Expand Up @@ -147,7 +149,8 @@ def test_assimilator_model():
for lat in range(-90, 90, 5):
for lon in range(0, 360, 5):
output_lat_lons.append((lat, lon))
model = GraphWeatherAssimilator(output_lat_lons=output_lat_lons, analysis_dim=24)
model = GraphWeatherAssimilator(
output_lat_lons=output_lat_lons, analysis_dim=24)

features = torch.randn((1, len(obs_lat_lons), 2))
lat_lon_heights = torch.tensor(obs_lat_lons)
Expand All @@ -161,7 +164,8 @@ def test_forecaster_and_loss():
for lat in range(-90, 90, 5):
for lon in range(0, 360, 5):
lat_lons.append((lat, lon))
criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,)))
criterion = NormalizedMSELoss(
lat_lons=lat_lons, feature_variance=torch.randn((78,)))
model = GraphWeatherForecaster(lat_lons)
# Add in auxiliary features
features = torch.randn((2, len(lat_lons), 78 + 24))
Expand Down Expand Up @@ -202,7 +206,8 @@ def test_forecaster_and_loss_grad_checkpoint():
for lat in range(-90, 90, 5):
for lon in range(0, 360, 5):
lat_lons.append((lat, lon))
criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,)))
criterion = NormalizedMSELoss(
lat_lons=lat_lons, feature_variance=torch.randn((78,)))
model = GraphWeatherForecaster(lat_lons, use_checkpointing=True)
# Add in auxiliary features
features = torch.randn((2, len(lat_lons), 78 + 24))
Expand Down Expand Up @@ -233,7 +238,8 @@ def test_normalized_loss():

assert not torch.isnan(loss)
# Since feature_variance = out**2 and target = 0, we expect loss = weights
assert torch.isclose(loss, criterion.weights.expand_as(out.mean(-1)).mean())
assert torch.isclose(
loss, criterion.weights.expand_as(out.mean(-1)).mean())


def test_image_meta_model():
Expand All @@ -243,13 +249,35 @@ def test_image_meta_model():
patch_size = 2
image = torch.randn((batch, channels, size, size))
model = ImageMetaModel(
image_size=size, patch_size=patch_size, channels=channels, depth=1, heads=1, mlp_dim=7
image_size=size, patch_size=patch_size,
channels=channels, depth=1, heads=1, mlp_dim=7,
dim_head=64
)

out = model(image)
assert not torch.isnan(out).any()
assert not torch.isnan(out).any()
assert out.size() == (batch, channels, size, size)
assert out.size() == image.size()


def test_wrapper_image_meta_model():
batch = 2
channels = 3
size = 4
patch_size = 2
model = ImageMetaModel(
image_size=size, patch_size=patch_size,
channels=channels, depth=1, heads=1, mlp_dim=7,
dim_head=64
)
scale_factor = 3
big_image = torch.randn((batch, channels,
size*scale_factor, size*scale_factor))
big_model = WrapperImageModel(model, scale_factor)
out = big_model(big_image)
assert not torch.isnan(out).any()
assert not torch.isnan(out).any()
assert out.size() == big_image.size()


def test_meta_model():
Expand All @@ -270,10 +298,42 @@ def test_meta_model():
heads=1,
mlp_dim=7,
channels=channels,
dim_head=64
)
features = torch.randn((batch, len(lat_lons), channels))

out = model(features)
assert not torch.isnan(out).any()
assert not torch.isnan(out).any()
assert out.size() == (batch, len(lat_lons), channels)
assert out.size() == features.size()


def test_wrapper_meta_model():
lat_lons = []
for lat in range(-90, 90, 5):
for lon in range(0, 360, 5):
lat_lons.append((lat, lon))

batch = 2
channels = 3
image_size = 20
patch_size = 4
scale_factor=3
model = MetaModel(
lat_lons,
image_size=image_size,
patch_size=patch_size,
depth=1,
heads=1,
mlp_dim=7,
channels=channels,
dim_head=64
)

big_features = torch.randn((batch, len(lat_lons), channels))
big_model = WrapperMetaModel(lat_lons, model, scale_factor)
out = big_model(big_features)

assert not torch.isnan(out).any()
assert not torch.isnan(out).any()
assert out.size() == big_features.size()

0 comments on commit 21d84c7

Please sign in to comment.