Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 2, 2024
1 parent e0f60ca commit b15110f
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 95 deletions.
150 changes: 75 additions & 75 deletions graph_weather/models/fengwu_ghr/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ def forward(self, x):
x = self.norm(x)

qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(
t, "b n (h d) -> b h n d", h=self.heads), qkv)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)

dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

Expand All @@ -89,7 +88,9 @@ def forward(self, x):


class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, res=False, image_size=None, scale_factor=None):
def __init__(
self, dim, depth, heads, dim_head, mlp_dim, res=False, image_size=None, scale_factor=None
):
super().__init__()
self.depth = depth
self.res = res
Expand All @@ -99,28 +100,39 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, res=False, image_size=N
for _ in range(self.depth):
self.layers.append(
nn.ModuleList(
[Attention(dim, heads=heads, dim_head=dim_head),
FeedForward(dim, mlp_dim)]
[Attention(dim, heads=heads, dim_head=dim_head), FeedForward(dim, mlp_dim)]
)
)
if self.res:
assert image_size is not None and scale_factor is not None, "If res=True, you must provide h, w and scale_factor"
assert (
image_size is not None and scale_factor is not None
), "If res=True, you must provide h, w and scale_factor"
h, w = pair(image_size)
s_h, s_w = pair(scale_factor)
self.res_layers.append(
nn.ModuleList(
[ # reshape to original shape window partition operation
# (b s_h s_w) (h w) d -> b (s_h h) (s_w w) d -> (b h w) (s_h s_w) d
Rearrange("(b s_h s_w) (h w) d -> (b h w) (s_h s_w) d",
h=h, w=w, s_h=s_h, s_w=s_w
),
Rearrange(
"(b s_h s_w) (h w) d -> (b h w) (s_h s_w) d",
h=h,
w=w,
s_h=s_h,
s_w=s_w,
),
# TODO ?????
Attention(dim, heads=heads, dim_head=dim_head),
# restore shape
Rearrange("(b h w) (s_h s_w) d -> (b s_h s_w) (h w) d",
h=h, w=w, s_h=s_h, s_w=s_w
),
]))
Rearrange(
"(b h w) (s_h s_w) d -> (b s_h s_w) (h w) d",
h=h,
w=w,
s_h=s_h,
s_w=s_w,
),
]
)
)

def forward(self, x):
for i in range(self.depth):
Expand All @@ -136,14 +148,22 @@ def forward(self, x):


class ImageMetaModel(nn.Module):
def __init__(self, *, image_size,
patch_size, depth, heads,
mlp_dim, channels, dim_head,
res=False,
scale_factor=None,
**kwargs):
def __init__(
self,
*,
image_size,
patch_size,
depth,
heads,
mlp_dim,
channels,
dim_head,
res=False,
scale_factor=None,
**kwargs
):
super().__init__()
#TODO this can probably be done better
# TODO this can probably be done better
self.image_size = image_size
self.patch_size = patch_size
self.depth = depth
Expand All @@ -168,7 +188,9 @@ def __init__(self, *, image_size,
dim = patch_dim
self.to_patch_embedding = nn.Sequential(
Rearrange(
"b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", p_h=self.patch_height, p_w=self.patch_width
"b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)",
p_h=self.patch_height,
p_w=self.patch_width,
),
nn.LayerNorm(patch_dim), # TODO Do we need this?
nn.Linear(patch_dim, dim), # TODO Do we need this?
Expand All @@ -181,14 +203,19 @@ def __init__(self, *, image_size,
dim=dim,
)

self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim,
res=res,
image_size=(
self.image_height // self.patch_height,
self.image_width // self.patch_width),
scale_factor=(
s_h,
s_w))
self.transformer = Transformer(
dim,
depth,
heads,
dim_head,
mlp_dim,
res=res,
image_size=(
self.image_height // self.patch_height,
self.image_width // self.patch_width,
),
scale_factor=(s_h, s_w),
)

self.reshaper = nn.Sequential(
Rearrange(
Expand All @@ -203,6 +230,7 @@ def __init__(self, *, image_size,
def forward(self, x):
device = x.device
dtype = x.dtype

def forward(self, x):
device = x.device
dtype = x.dtype
Expand All @@ -219,21 +247,17 @@ def forward(self, x):


class WrapperImageModel(nn.Module):
def __init__(self, image_meta_model: ImageMetaModel,
scale_factor):
def __init__(self, image_meta_model: ImageMetaModel, scale_factor):
super().__init__()
s_h, s_w = pair(scale_factor)
self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w",
s_h=s_h, s_w=s_w)
self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w)

imm_args = vars(image_meta_model)
imm_args.update(
{"res": True, "scale_factor": scale_factor})
imm_args.update({"res": True, "scale_factor": scale_factor})
self.image_meta_model = ImageMetaModel(**imm_args)
self.image_meta_model.load_state_dict(image_meta_model.state_dict(), strict=False)

self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)",
s_h=s_h, s_w=s_w)
self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w)

def forward(self, x):
x = self.batcher(x)
Expand All @@ -260,14 +284,8 @@ def __init__(

self.pos_x = torch.tensor(lat_lons)
self.pos_y = torch.cartesian_prod(
(
torch.arange(-self.i_h / 2,
self.i_h / 2, 1)
/ self.i_h
* 180
).to(torch.long),
(torch.arange(0, self.i_w, 1) /
self.i_w * 360).to(torch.long),
(torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long),
(torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long),
)

self.image_meta_model = ImageMetaModel(
Expand All @@ -285,10 +303,7 @@ def forward(self, x):

x = rearrange(x, "b n c -> n (b c)")
x = knn_interpolate(x, self.pos_x, self.pos_y)
x = rearrange(
x, "(h w) (b c) -> b c h w", b=b, c=c,
h=self.i_h, w=self.i_w
)
x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w)
x = self.image_meta_model(x)

x = rearrange(x, "b c h w -> (h w) (b c)")
Expand All @@ -298,48 +313,33 @@ def forward(self, x):


class WrapperMetaModel(nn.Module):
def __init__(
self,
lat_lons: list,
meta_model: MetaModel,
scale_factor
):
def __init__(self, lat_lons: list, meta_model: MetaModel, scale_factor):
super().__init__()
s_h, s_w = pair(scale_factor)
self.i_h, self.i_w = meta_model.i_h*s_h, meta_model.i_w*s_w
self.i_h, self.i_w = meta_model.i_h * s_h, meta_model.i_w * s_w
self.pos_x = torch.tensor(lat_lons)
self.pos_y = torch.cartesian_prod(
(
torch.arange(-self.i_h / 2,
self.i_h / 2, 1)
/ self.i_h
* 180
).to(torch.long),
(torch.arange(0, self.i_w, 1) /
self.i_w * 360).to(torch.long),
(torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long),
(torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long),
)

self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w",
s_h=s_h, s_w=s_w)
self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w)

imm_args = vars(meta_model.image_meta_model)
imm_args.update(
{"res": True, "scale_factor": scale_factor})
imm_args.update({"res": True, "scale_factor": scale_factor})
self.image_meta_model = ImageMetaModel(**imm_args)
self.image_meta_model.load_state_dict(meta_model.image_meta_model.state_dict(), strict=False)
self.image_meta_model.load_state_dict(
meta_model.image_meta_model.state_dict(), strict=False
)

self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)",
s_h=s_h, s_w=s_w)
self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w)

def forward(self, x):
b, n, c = x.shape

x = rearrange(x, "b n c -> n (b c)")
x = knn_interpolate(x, self.pos_x, self.pos_y)
x = rearrange(
x, "(h w) (b c) -> b c h w", b=b, c=c,
h=self.i_h, w=self.i_w
)
x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w)

x = self.batcher(x)
x = self.image_meta_model(x)
Expand Down
43 changes: 23 additions & 20 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
ImageMetaModel,
MetaModel,
WrapperImageModel,
WrapperMetaModel
WrapperMetaModel,
)
from graph_weather.models.losses import NormalizedMSELoss

Expand Down Expand Up @@ -151,8 +151,7 @@ def test_assimilator_model():
for lat in range(-90, 90, 5):
for lon in range(0, 360, 5):
output_lat_lons.append((lat, lon))
model = GraphWeatherAssimilator(
output_lat_lons=output_lat_lons, analysis_dim=24)
model = GraphWeatherAssimilator(output_lat_lons=output_lat_lons, analysis_dim=24)

features = torch.randn((1, len(obs_lat_lons), 2))
lat_lon_heights = torch.tensor(obs_lat_lons)
Expand All @@ -166,8 +165,7 @@ def test_forecaster_and_loss():
for lat in range(-90, 90, 5):
for lon in range(0, 360, 5):
lat_lons.append((lat, lon))
criterion = NormalizedMSELoss(
lat_lons=lat_lons, feature_variance=torch.randn((78,)))
criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,)))
model = GraphWeatherForecaster(lat_lons)
# Add in auxiliary features
features = torch.randn((2, len(lat_lons), 78 + 24))
Expand Down Expand Up @@ -208,8 +206,7 @@ def test_forecaster_and_loss_grad_checkpoint():
for lat in range(-90, 90, 5):
for lon in range(0, 360, 5):
lat_lons.append((lat, lon))
criterion = NormalizedMSELoss(
lat_lons=lat_lons, feature_variance=torch.randn((78,)))
criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,)))
model = GraphWeatherForecaster(lat_lons, use_checkpointing=True)
# Add in auxiliary features
features = torch.randn((2, len(lat_lons), 78 + 24))
Expand Down Expand Up @@ -240,8 +237,7 @@ def test_normalized_loss():

assert not torch.isnan(loss)
# Since feature_variance = out**2 and target = 0, we expect loss = weights
assert torch.isclose(
loss, criterion.weights.expand_as(out.mean(-1)).mean())
assert torch.isclose(loss, criterion.weights.expand_as(out.mean(-1)).mean())


def test_image_meta_model():
Expand All @@ -251,9 +247,13 @@ def test_image_meta_model():
patch_size = 2
image = torch.randn((batch, channels, size, size))
model = ImageMetaModel(
image_size=size, patch_size=patch_size,
channels=channels, depth=1, heads=1, mlp_dim=7,
dim_head=64
image_size=size,
patch_size=patch_size,
channels=channels,
depth=1,
heads=1,
mlp_dim=7,
dim_head=64,
)

out = model(image)
Expand All @@ -268,13 +268,16 @@ def test_wrapper_image_meta_model():
size = 4
patch_size = 2
model = ImageMetaModel(
image_size=size, patch_size=patch_size,
channels=channels, depth=1, heads=1, mlp_dim=7,
dim_head=64
image_size=size,
patch_size=patch_size,
channels=channels,
depth=1,
heads=1,
mlp_dim=7,
dim_head=64,
)
scale_factor = 3
big_image = torch.randn((batch, channels,
size*scale_factor, size*scale_factor))
big_image = torch.randn((batch, channels, size * scale_factor, size * scale_factor))
big_model = WrapperImageModel(model, scale_factor)
out = big_model(big_image)
assert not torch.isnan(out).any()
Expand All @@ -300,7 +303,7 @@ def test_meta_model():
heads=1,
mlp_dim=7,
channels=channels,
dim_head=64
dim_head=64,
)
features = torch.randn((batch, len(lat_lons), channels))

Expand All @@ -320,7 +323,7 @@ def test_wrapper_meta_model():
channels = 3
image_size = 20
patch_size = 4
scale_factor=3
scale_factor = 3
model = MetaModel(
lat_lons,
image_size=image_size,
Expand All @@ -329,7 +332,7 @@ def test_wrapper_meta_model():
heads=1,
mlp_dim=7,
channels=channels,
dim_head=64
dim_head=64,
)

big_features = torch.randn((batch, len(lat_lons), channels))
Expand Down

0 comments on commit b15110f

Please sign in to comment.