Skip to content

Commit

Permalink
load RES state_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
rnwzd committed Jul 2, 2024
1 parent 07a8d0f commit cd84968
Showing 1 changed file with 31 additions and 14 deletions.
45 changes: 31 additions & 14 deletions graph_weather/models/fengwu_ghr/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def forward(self, x):


class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, res=False, h=None, w=None, scale_factor=None):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, res=False, image_size=None, scale_factor=None):
super().__init__()
self.depth = depth
self.res = res
Expand All @@ -104,7 +104,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, res=False, h=None, w=No
)
)
if self.res:
assert h is not None and w is not None and scale_factor is not None, "If res=True, you must provide h, w and scale_factor"
assert image_size is not None and scale_factor is not None, "If res=True, you must provide h, w and scale_factor"
h, w = pair(image_size)
s_h, s_w = pair(scale_factor)
self.res_layers.append(
nn.ModuleList(
Expand Down Expand Up @@ -139,8 +140,20 @@ def __init__(self, *, image_size,
patch_size, depth, heads,
mlp_dim, channels, dim_head,
res=False,
scale_factor=None):
scale_factor=None,
**kwargs):
super().__init__()
#TODO this can probably be done better
self.image_size = image_size
self.patch_size = patch_size
self.depth = depth
self.heads = heads
self.mlp_dim = mlp_dim
self.channels = channels
self.dim_head = dim_head
self.res = res
self.scale_factor = scale_factor

self.image_height, self.image_width = pair(image_size)
self.patch_height, self.patch_width = pair(patch_size)
s_h, s_w = pair(scale_factor)
Expand Down Expand Up @@ -170,10 +183,12 @@ def __init__(self, *, image_size,

self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim,
res=res,
h=self.image_height // self.patch_height,
w=self.image_width // self.patch_width,
s_h=s_h,
s_w=s_w)
image_size=(
self.image_height // self.patch_height,
self.image_width // self.patch_width),
scale_factor=(
s_h,
s_w))

self.reshaper = nn.Sequential(
Rearrange(
Expand Down Expand Up @@ -205,12 +220,13 @@ def __init__(self, image_meta_model: ImageMetaModel,
s_h, s_w = pair(scale_factor)
self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w",
s_h=s_h, s_w=s_w)

imm_args = image_meta_model.vars().update(

imm_args = vars(image_meta_model)
imm_args.update(
{"res": True, "scale_factor": scale_factor})
self.image_meta_model = ImageMetaModel(**imm_args)
self.image_meta_model.load(image_meta_model, strict=False)
self.image_meta_model.load_state_dict(image_meta_model.state_dict(), strict=False)

self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)",
s_h=s_h, s_w=s_w)

Expand Down Expand Up @@ -301,11 +317,12 @@ def __init__(
self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w",
s_h=s_h, s_w=s_w)

imm_args = meta_model.image_meta_model.vars().update(
imm_args = vars(meta_model.image_meta_model)
imm_args.update(
{"res": True, "scale_factor": scale_factor})
self.image_meta_model = ImageMetaModel(**imm_args)
self.image_meta_model.load(meta_model.image_meta_model, strict=False)
self.image_meta_model.load_state_dict(meta_model.image_meta_model.state_dict(), strict=False)

self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)",
s_h=s_h, s_w=s_w)

Expand Down

0 comments on commit cd84968

Please sign in to comment.