From cd84968524f41ff654a16bcd1587ab2a1c59ae38 Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Tue, 2 Jul 2024 11:15:32 +0200 Subject: [PATCH] load RES state_dict --- graph_weather/models/fengwu_ghr/layers.py | 45 ++++++++++++++++------- 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index ad82fcad..e42d7545 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -89,7 +89,7 @@ def forward(self, x): class Transformer(nn.Module): - def __init__(self, dim, depth, heads, dim_head, mlp_dim, res=False, h=None, w=None, scale_factor=None): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, res=False, image_size=None, scale_factor=None): super().__init__() self.depth = depth self.res = res @@ -104,7 +104,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, res=False, h=None, w=No ) ) if self.res: - assert h is not None and w is not None and scale_factor is not None, "If res=True, you must provide h, w and scale_factor" + assert image_size is not None and scale_factor is not None, "If res=True, you must provide h, w and scale_factor" + h, w = pair(image_size) s_h, s_w = pair(scale_factor) self.res_layers.append( nn.ModuleList( @@ -139,8 +140,20 @@ def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels, dim_head, res=False, - scale_factor=None): + scale_factor=None, + **kwargs): super().__init__() + #TODO this can probably be done better + self.image_size = image_size + self.patch_size = patch_size + self.depth = depth + self.heads = heads + self.mlp_dim = mlp_dim + self.channels = channels + self.dim_head = dim_head + self.res = res + self.scale_factor = scale_factor + self.image_height, self.image_width = pair(image_size) self.patch_height, self.patch_width = pair(patch_size) s_h, s_w = pair(scale_factor) @@ -170,10 +183,12 @@ def __init__(self, *, image_size, self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, res=res, - h=self.image_height // self.patch_height, - w=self.image_width // self.patch_width, - s_h=s_h, - s_w=s_w) + image_size=( + self.image_height // self.patch_height, + self.image_width // self.patch_width), + scale_factor=( + s_h, + s_w)) self.reshaper = nn.Sequential( Rearrange( @@ -205,12 +220,13 @@ def __init__(self, image_meta_model: ImageMetaModel, s_h, s_w = pair(scale_factor) self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w) - - imm_args = image_meta_model.vars().update( + + imm_args = vars(image_meta_model) + imm_args.update( {"res": True, "scale_factor": scale_factor}) self.image_meta_model = ImageMetaModel(**imm_args) - self.image_meta_model.load(image_meta_model, strict=False) - + self.image_meta_model.load_state_dict(image_meta_model.state_dict(), strict=False) + self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w) @@ -301,11 +317,12 @@ def __init__( self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w) - imm_args = meta_model.image_meta_model.vars().update( + imm_args = vars(meta_model.image_meta_model) + imm_args.update( {"res": True, "scale_factor": scale_factor}) self.image_meta_model = ImageMetaModel(**imm_args) - self.image_meta_model.load(meta_model.image_meta_model, strict=False) - + self.image_meta_model.load_state_dict(meta_model.image_meta_model.state_dict(), strict=False) + self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w)