diff --git a/graph_weather/models/__init__.py b/graph_weather/models/__init__.py index 0083b16f..ace964db 100644 --- a/graph_weather/models/__init__.py +++ b/graph_weather/models/__init__.py @@ -1,6 +1,6 @@ """Models""" -from .fengwu_ghr.layers import ImageMetaModel, MetaModel +from .fengwu_ghr.layers import ImageMetaModel, MetaModel, WrapperImageModel, WrapperMetaModel from .layers.assimilator_decoder import AssimilatorDecoder from .layers.assimilator_encoder import AssimilatorEncoder from .layers.decoder import Decoder diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 06d15681..f5dbda57 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -76,7 +76,8 @@ def forward(self, x): x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) + q, k, v = map(lambda t: rearrange( + t, "b n (h d) -> b h n d", h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale @@ -95,7 +96,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim): for _ in range(depth): self.layers.append( nn.ModuleList( - [Attention(dim, heads=heads, dim_head=dim_head), FeedForward(dim, mlp_dim)] + [Attention(dim, heads=heads, dim_head=dim_head), + FeedForward(dim, mlp_dim)] ) ) @@ -107,20 +109,22 @@ def forward(self, x): class ImageMetaModel(nn.Module): - def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, dim_head=64): + def __init__(self, *, image_size, + patch_size, depth, heads, + mlp_dim, channels, dim_head): super().__init__() - image_height, image_width = pair(image_size) - patch_height, patch_width = pair(patch_size) + self.image_height, self.image_width = pair(image_size) + self.patch_height, self.patch_width = pair(patch_size) assert ( - image_height % patch_height == 0 and image_width % patch_width == 0 + self.image_height % self.patch_height == 0 and self.image_width % self.patch_width == 0 ), "Image dimensions must be divisible by the patch size." - patch_dim = channels * patch_height * patch_width + patch_dim = channels * self.patch_height * self.patch_width dim = patch_dim self.to_patch_embedding = nn.Sequential( Rearrange( - "b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", p_h=patch_height, p_w=patch_width + "b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", p_h=self.patch_height, p_w=self.patch_width ), nn.LayerNorm(patch_dim), # TODO Do we need this? nn.Linear(patch_dim, dim), # TODO Do we need this? @@ -128,8 +132,8 @@ def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, ) self.pos_embedding = posemb_sincos_2d( - h=image_height // patch_height, - w=image_width // patch_width, + h=self.image_height // self.patch_height, + w=self.image_width // self.patch_width, dim=dim, ) @@ -138,10 +142,10 @@ def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, self.reshaper = nn.Sequential( Rearrange( "b (h w) (p_h p_w c) -> b c (h p_h) (w p_w)", - h=image_height // patch_height, - w=image_width // patch_width, - p_h=patch_height, - p_w=patch_width, + h=self.image_height // self.patch_height, + w=self.image_width // self.patch_width, + p_h=self.patch_height, + p_w=self.patch_width, ) ) @@ -158,33 +162,53 @@ def forward(self, x): return x +class WrapperImageModel(nn.Module): + def __init__(self, image_meta_model: ImageMetaModel, + scale_factor): + super().__init__() + s_h, s_w = pair(scale_factor) + self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", + s_h=s_h, s_w=s_w) + self.image_meta_model = image_meta_model + self.debatcher = Rearrange(" (b s_h s_w) c h w -> b c (h s_h) (w s_w)", + s_h=s_h, s_w=s_w) + + def forward(self, x): + x = self.batcher(x) + x = self.image_meta_model(x) + x = self.debatcher(x) + return x + + class MetaModel(nn.Module): def __init__( self, lat_lons: list, *, + image_size, patch_size, depth, heads, mlp_dim, - image_size=(721, 1440), - channels=3, + channels, dim_head=64 ): super().__init__() - self.image_size = pair(image_size) + self.i_h, self.i_w = pair(image_size) self.pos_x = torch.tensor(lat_lons) self.pos_y = torch.cartesian_prod( ( - torch.arange(-self.image_size[0] / 2, self.image_size[0] / 2, 1) - / self.image_size[0] + torch.arange(-self.i_h / 2, + self.i_h / 2, 1) + / self.i_h * 180 ).to(torch.long), - (torch.arange(0, self.image_size[1], 1) / self.image_size[1] * 360).to(torch.long), + (torch.arange(0, self.i_w, 1) / + self.i_w * 360).to(torch.long), ) - self.image_model = ImageMetaModel( + self.image_meta_model = ImageMetaModel( image_size=image_size, patch_size=patch_size, depth=depth, @@ -200,11 +224,65 @@ def forward(self, x): x = rearrange(x, "b n c -> n (b c)") x = knn_interpolate(x, self.pos_x, self.pos_y) x = rearrange( - x, "(w h) (b c) -> b c w h", b=b, c=c, w=self.image_size[0], h=self.image_size[1] + x, "(h w) (b c) -> b c h w", b=b, c=c, + h=self.i_h, w=self.i_w ) - x = self.image_model(x) + x = self.image_meta_model(x) - x = rearrange(x, "b c w h -> (w h) (b c)") + x = rearrange(x, "b c h w -> (h w) (b c)") x = knn_interpolate(x, self.pos_y, self.pos_x) x = rearrange(x, "n (b c) -> b n c", b=b, c=c) return x + + +class WrapperMetaModel(nn.Module): + def __init__( + self, + lat_lons: list, + meta_model: MetaModel, + scale_factor + ): + super().__init__() + self.image_meta_model = meta_model.image_meta_model + + s_h, s_w = pair(scale_factor) + self.i_h, self.i_w = meta_model.i_h*s_h, meta_model.i_w*s_w + self.pos_x = torch.tensor(lat_lons) + self.pos_y = torch.cartesian_prod( + ( + torch.arange(-self.i_h / 2, + self.i_h / 2, 1) + / self.i_h + * 180 + ).to(torch.long), + (torch.arange(0, self.i_w, 1) / + self.i_w * 360).to(torch.long), + ) + + + self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", + s_h=s_h, s_w=s_w) + + self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", + s_h=s_h, s_w=s_w) + + def forward(self, x): + b, n, c = x.shape + + x = rearrange(x, "b n c -> n (b c)") + x = knn_interpolate(x, self.pos_x, self.pos_y) + x = rearrange( + x, "(h w) (b c) -> b c h w", b=b, c=c, + h=self.i_h, w=self.i_w + ) + + x = self.batcher(x) + x = self.image_meta_model(x) + x = self.debatcher(x) + + x = rearrange(x, "b c h w -> (h w) (b c)") + x = knn_interpolate(x, self.pos_y, self.pos_x) + x = rearrange(x, "n (b c) -> b n c", b=b, c=c) + + + return x diff --git a/tests/test_model.py b/tests/test_model.py index 5959349b..e19e831e 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -9,8 +9,10 @@ Decoder, Encoder, Processor, - MetaModel, ImageMetaModel, + MetaModel, + WrapperImageModel, + WrapperMetaModel ) from graph_weather.models.losses import NormalizedMSELoss from graph_weather.models.gencast.utils.noise import ( @@ -147,7 +149,8 @@ def test_assimilator_model(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): output_lat_lons.append((lat, lon)) - model = GraphWeatherAssimilator(output_lat_lons=output_lat_lons, analysis_dim=24) + model = GraphWeatherAssimilator( + output_lat_lons=output_lat_lons, analysis_dim=24) features = torch.randn((1, len(obs_lat_lons), 2)) lat_lon_heights = torch.tensor(obs_lat_lons) @@ -161,7 +164,8 @@ def test_forecaster_and_loss(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss( + lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) @@ -202,7 +206,8 @@ def test_forecaster_and_loss_grad_checkpoint(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss( + lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons, use_checkpointing=True) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) @@ -233,7 +238,8 @@ def test_normalized_loss(): assert not torch.isnan(loss) # Since feature_variance = out**2 and target = 0, we expect loss = weights - assert torch.isclose(loss, criterion.weights.expand_as(out.mean(-1)).mean()) + assert torch.isclose( + loss, criterion.weights.expand_as(out.mean(-1)).mean()) def test_image_meta_model(): @@ -243,13 +249,35 @@ def test_image_meta_model(): patch_size = 2 image = torch.randn((batch, channels, size, size)) model = ImageMetaModel( - image_size=size, patch_size=patch_size, channels=channels, depth=1, heads=1, mlp_dim=7 + image_size=size, patch_size=patch_size, + channels=channels, depth=1, heads=1, mlp_dim=7, + dim_head=64 ) out = model(image) assert not torch.isnan(out).any() assert not torch.isnan(out).any() - assert out.size() == (batch, channels, size, size) + assert out.size() == image.size() + + +def test_wrapper_image_meta_model(): + batch = 2 + channels = 3 + size = 4 + patch_size = 2 + model = ImageMetaModel( + image_size=size, patch_size=patch_size, + channels=channels, depth=1, heads=1, mlp_dim=7, + dim_head=64 + ) + scale_factor = 3 + big_image = torch.randn((batch, channels, + size*scale_factor, size*scale_factor)) + big_model = WrapperImageModel(model, scale_factor) + out = big_model(big_image) + assert not torch.isnan(out).any() + assert not torch.isnan(out).any() + assert out.size() == big_image.size() def test_meta_model(): @@ -270,10 +298,42 @@ def test_meta_model(): heads=1, mlp_dim=7, channels=channels, + dim_head=64 ) features = torch.randn((batch, len(lat_lons), channels)) out = model(features) assert not torch.isnan(out).any() assert not torch.isnan(out).any() - assert out.size() == (batch, len(lat_lons), channels) + assert out.size() == features.size() + + +def test_wrapper_meta_model(): + lat_lons = [] + for lat in range(-90, 90, 5): + for lon in range(0, 360, 5): + lat_lons.append((lat, lon)) + + batch = 2 + channels = 3 + image_size = 20 + patch_size = 4 + scale_factor=3 + model = MetaModel( + lat_lons, + image_size=image_size, + patch_size=patch_size, + depth=1, + heads=1, + mlp_dim=7, + channels=channels, + dim_head=64 + ) + + big_features = torch.randn((batch, len(lat_lons), channels)) + big_model = WrapperMetaModel(lat_lons, model, scale_factor) + out = big_model(big_features) + + assert not torch.isnan(out).any() + assert not torch.isnan(out).any() + assert out.size() == big_features.size()