diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index 2d032ab..a9ec895 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -363,93 +363,3 @@ def forward(self, x): x = rearrange(x, "n (b c) -> b n c", b=b, c=c) return x - -class MetaModel(nn.Module): - def __init__( - self, - lat_lons: list, - *, - image_size, - patch_size, - depth, - heads, - mlp_dim, - channels, - dim_head=64 - ): - super().__init__() - self.i_h, self.i_w = pair(image_size) - - self.pos_x = torch.tensor(lat_lons) - self.pos_y = torch.cartesian_prod( - (torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long), - (torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long), - ) - - self.image_meta_model = ImageMetaModel( - image_size=image_size, - patch_size=patch_size, - depth=depth, - heads=heads, - mlp_dim=mlp_dim, - channels=channels, - dim_head=dim_head, - ) - - def forward(self, x): - b, n, c = x.shape - - x = rearrange(x, "b n c -> n (b c)") - x = knn_interpolate(x, self.pos_x, self.pos_y) - x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w) - x = self.image_meta_model(x) - - x = rearrange(x, "b c h w -> (h w) (b c)") - x = knn_interpolate(x, self.pos_y, self.pos_x) - x = rearrange(x, "n (b c) -> b n c", b=b, c=c) - return x - - -class WrapperMetaModel(nn.Module): - def __init__( - self, - lat_lons: list, - meta_model: MetaModel, - scale_factor - ): - super().__init__() - s_h, s_w = pair(scale_factor) - self.i_h, self.i_w = meta_model.i_h * s_h, meta_model.i_w * s_w - self.pos_x = torch.tensor(lat_lons) - self.pos_y = torch.cartesian_prod( - (torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long), - (torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long), - ) - - self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w) - - imm_args = vars(meta_model.image_meta_model) - imm_args.update({"res": True, "scale_factor": scale_factor}) - self.image_meta_model = ImageMetaModel(**imm_args) - self.image_meta_model.load_state_dict( - meta_model.image_meta_model.state_dict(), strict=False - ) - - self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w) - - def forward(self, x): - b, n, c = x.shape - - x = rearrange(x, "b n c -> n (b c)") - x = knn_interpolate(x, self.pos_x, self.pos_y) - x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w) - - x = self.batcher(x) - x = self.image_meta_model(x) - x = self.debatcher(x) - - x = rearrange(x, "b c h w -> (h w) (b c)") - x = knn_interpolate(x, self.pos_y, self.pos_x) - x = rearrange(x, "n (b c) -> b n c", b=b, c=c) - - return x