diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index f03235e3..bbe19a85 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -76,8 +76,7 @@ def forward(self, x): x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange( - t, "b n (h d) -> b h n d", h=self.heads), qkv) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale @@ -89,7 +88,9 @@ def forward(self, x): class Transformer(nn.Module): - def __init__(self, dim, depth, heads, dim_head, mlp_dim, res=False, image_size=None, scale_factor=None): + def __init__( + self, dim, depth, heads, dim_head, mlp_dim, res=False, image_size=None, scale_factor=None + ): super().__init__() self.depth = depth self.res = res @@ -99,28 +100,39 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, res=False, image_size=N for _ in range(self.depth): self.layers.append( nn.ModuleList( - [Attention(dim, heads=heads, dim_head=dim_head), - FeedForward(dim, mlp_dim)] + [Attention(dim, heads=heads, dim_head=dim_head), FeedForward(dim, mlp_dim)] ) ) if self.res: - assert image_size is not None and scale_factor is not None, "If res=True, you must provide h, w and scale_factor" + assert ( + image_size is not None and scale_factor is not None + ), "If res=True, you must provide h, w and scale_factor" h, w = pair(image_size) s_h, s_w = pair(scale_factor) self.res_layers.append( nn.ModuleList( [ # reshape to original shape window partition operation # (b s_h s_w) (h w) d -> b (s_h h) (s_w w) d -> (b h w) (s_h s_w) d - Rearrange("(b s_h s_w) (h w) d -> (b h w) (s_h s_w) d", - h=h, w=w, s_h=s_h, s_w=s_w - ), + Rearrange( + "(b s_h s_w) (h w) d -> (b h w) (s_h s_w) d", + h=h, + w=w, + s_h=s_h, + s_w=s_w, + ), # TODO ????? Attention(dim, heads=heads, dim_head=dim_head), # restore shape - Rearrange("(b h w) (s_h s_w) d -> (b s_h s_w) (h w) d", - h=h, w=w, s_h=s_h, s_w=s_w - ), - ])) + Rearrange( + "(b h w) (s_h s_w) d -> (b s_h s_w) (h w) d", + h=h, + w=w, + s_h=s_h, + s_w=s_w, + ), + ] + ) + ) def forward(self, x): for i in range(self.depth): @@ -136,14 +148,22 @@ def forward(self, x): class ImageMetaModel(nn.Module): - def __init__(self, *, image_size, - patch_size, depth, heads, - mlp_dim, channels, dim_head, - res=False, - scale_factor=None, - **kwargs): + def __init__( + self, + *, + image_size, + patch_size, + depth, + heads, + mlp_dim, + channels, + dim_head, + res=False, + scale_factor=None, + **kwargs + ): super().__init__() - #TODO this can probably be done better + # TODO this can probably be done better self.image_size = image_size self.patch_size = patch_size self.depth = depth @@ -168,7 +188,9 @@ def __init__(self, *, image_size, dim = patch_dim self.to_patch_embedding = nn.Sequential( Rearrange( - "b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", p_h=self.patch_height, p_w=self.patch_width + "b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", + p_h=self.patch_height, + p_w=self.patch_width, ), nn.LayerNorm(patch_dim), # TODO Do we need this? nn.Linear(patch_dim, dim), # TODO Do we need this? @@ -181,14 +203,19 @@ def __init__(self, *, image_size, dim=dim, ) - self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, - res=res, - image_size=( - self.image_height // self.patch_height, - self.image_width // self.patch_width), - scale_factor=( - s_h, - s_w)) + self.transformer = Transformer( + dim, + depth, + heads, + dim_head, + mlp_dim, + res=res, + image_size=( + self.image_height // self.patch_height, + self.image_width // self.patch_width, + ), + scale_factor=(s_h, s_w), + ) self.reshaper = nn.Sequential( Rearrange( @@ -203,6 +230,7 @@ def __init__(self, *, image_size, def forward(self, x): device = x.device dtype = x.dtype + def forward(self, x): device = x.device dtype = x.dtype @@ -219,21 +247,17 @@ def forward(self, x): class WrapperImageModel(nn.Module): - def __init__(self, image_meta_model: ImageMetaModel, - scale_factor): + def __init__(self, image_meta_model: ImageMetaModel, scale_factor): super().__init__() s_h, s_w = pair(scale_factor) - self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", - s_h=s_h, s_w=s_w) + self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w) imm_args = vars(image_meta_model) - imm_args.update( - {"res": True, "scale_factor": scale_factor}) + imm_args.update({"res": True, "scale_factor": scale_factor}) self.image_meta_model = ImageMetaModel(**imm_args) self.image_meta_model.load_state_dict(image_meta_model.state_dict(), strict=False) - self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", - s_h=s_h, s_w=s_w) + self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w) def forward(self, x): x = self.batcher(x) @@ -260,14 +284,8 @@ def __init__( self.pos_x = torch.tensor(lat_lons) self.pos_y = torch.cartesian_prod( - ( - torch.arange(-self.i_h / 2, - self.i_h / 2, 1) - / self.i_h - * 180 - ).to(torch.long), - (torch.arange(0, self.i_w, 1) / - self.i_w * 360).to(torch.long), + (torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long), + (torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long), ) self.image_meta_model = ImageMetaModel( @@ -285,10 +303,7 @@ def forward(self, x): x = rearrange(x, "b n c -> n (b c)") x = knn_interpolate(x, self.pos_x, self.pos_y) - x = rearrange( - x, "(h w) (b c) -> b c h w", b=b, c=c, - h=self.i_h, w=self.i_w - ) + x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w) x = self.image_meta_model(x) x = rearrange(x, "b c h w -> (h w) (b c)") @@ -298,48 +313,33 @@ def forward(self, x): class WrapperMetaModel(nn.Module): - def __init__( - self, - lat_lons: list, - meta_model: MetaModel, - scale_factor - ): + def __init__(self, lat_lons: list, meta_model: MetaModel, scale_factor): super().__init__() s_h, s_w = pair(scale_factor) - self.i_h, self.i_w = meta_model.i_h*s_h, meta_model.i_w*s_w + self.i_h, self.i_w = meta_model.i_h * s_h, meta_model.i_w * s_w self.pos_x = torch.tensor(lat_lons) self.pos_y = torch.cartesian_prod( - ( - torch.arange(-self.i_h / 2, - self.i_h / 2, 1) - / self.i_h - * 180 - ).to(torch.long), - (torch.arange(0, self.i_w, 1) / - self.i_w * 360).to(torch.long), + (torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long), + (torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long), ) - self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", - s_h=s_h, s_w=s_w) + self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w) imm_args = vars(meta_model.image_meta_model) - imm_args.update( - {"res": True, "scale_factor": scale_factor}) + imm_args.update({"res": True, "scale_factor": scale_factor}) self.image_meta_model = ImageMetaModel(**imm_args) - self.image_meta_model.load_state_dict(meta_model.image_meta_model.state_dict(), strict=False) + self.image_meta_model.load_state_dict( + meta_model.image_meta_model.state_dict(), strict=False + ) - self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", - s_h=s_h, s_w=s_w) + self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w) def forward(self, x): b, n, c = x.shape x = rearrange(x, "b n c -> n (b c)") x = knn_interpolate(x, self.pos_x, self.pos_y) - x = rearrange( - x, "(h w) (b c) -> b c h w", b=b, c=c, - h=self.i_h, w=self.i_w - ) + x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w) x = self.batcher(x) x = self.image_meta_model(x) diff --git a/tests/test_model.py b/tests/test_model.py index ac260b48..8f52116b 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -12,7 +12,7 @@ ImageMetaModel, MetaModel, WrapperImageModel, - WrapperMetaModel + WrapperMetaModel, ) from graph_weather.models.losses import NormalizedMSELoss @@ -151,8 +151,7 @@ def test_assimilator_model(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): output_lat_lons.append((lat, lon)) - model = GraphWeatherAssimilator( - output_lat_lons=output_lat_lons, analysis_dim=24) + model = GraphWeatherAssimilator(output_lat_lons=output_lat_lons, analysis_dim=24) features = torch.randn((1, len(obs_lat_lons), 2)) lat_lon_heights = torch.tensor(obs_lat_lons) @@ -166,8 +165,7 @@ def test_forecaster_and_loss(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss( - lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) @@ -208,8 +206,7 @@ def test_forecaster_and_loss_grad_checkpoint(): for lat in range(-90, 90, 5): for lon in range(0, 360, 5): lat_lons.append((lat, lon)) - criterion = NormalizedMSELoss( - lat_lons=lat_lons, feature_variance=torch.randn((78,))) + criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) model = GraphWeatherForecaster(lat_lons, use_checkpointing=True) # Add in auxiliary features features = torch.randn((2, len(lat_lons), 78 + 24)) @@ -240,8 +237,7 @@ def test_normalized_loss(): assert not torch.isnan(loss) # Since feature_variance = out**2 and target = 0, we expect loss = weights - assert torch.isclose( - loss, criterion.weights.expand_as(out.mean(-1)).mean()) + assert torch.isclose(loss, criterion.weights.expand_as(out.mean(-1)).mean()) def test_image_meta_model(): @@ -251,9 +247,13 @@ def test_image_meta_model(): patch_size = 2 image = torch.randn((batch, channels, size, size)) model = ImageMetaModel( - image_size=size, patch_size=patch_size, - channels=channels, depth=1, heads=1, mlp_dim=7, - dim_head=64 + image_size=size, + patch_size=patch_size, + channels=channels, + depth=1, + heads=1, + mlp_dim=7, + dim_head=64, ) out = model(image) @@ -268,13 +268,16 @@ def test_wrapper_image_meta_model(): size = 4 patch_size = 2 model = ImageMetaModel( - image_size=size, patch_size=patch_size, - channels=channels, depth=1, heads=1, mlp_dim=7, - dim_head=64 + image_size=size, + patch_size=patch_size, + channels=channels, + depth=1, + heads=1, + mlp_dim=7, + dim_head=64, ) scale_factor = 3 - big_image = torch.randn((batch, channels, - size*scale_factor, size*scale_factor)) + big_image = torch.randn((batch, channels, size * scale_factor, size * scale_factor)) big_model = WrapperImageModel(model, scale_factor) out = big_model(big_image) assert not torch.isnan(out).any() @@ -300,7 +303,7 @@ def test_meta_model(): heads=1, mlp_dim=7, channels=channels, - dim_head=64 + dim_head=64, ) features = torch.randn((batch, len(lat_lons), channels)) @@ -320,7 +323,7 @@ def test_wrapper_meta_model(): channels = 3 image_size = 20 patch_size = 4 - scale_factor=3 + scale_factor = 3 model = MetaModel( lat_lons, image_size=image_size, @@ -329,7 +332,7 @@ def test_wrapper_meta_model(): heads=1, mlp_dim=7, channels=channels, - dim_head=64 + dim_head=64, ) big_features = torch.randn((batch, len(lat_lons), channels))