Skip to content
This repository has been archived by the owner on Nov 29, 2023. It is now read-only.

Commit

Permalink
Bug fixes for running MetNet
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Jun 29, 2021
1 parent 9e1f655 commit 425dded
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 7 deletions.
7 changes: 4 additions & 3 deletions satflow/configs/datamodule/metnet_datamodule.yaml
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
# @package _group_
_target_: satflow.data.datamodules.SatFlowDataModule

batch_size: 1
batch_size: 4
data_dir: ${data_dir} # data_dir is specified in config.yaml
shuffle: 0
sources:
train: "satflow-flow-144-tiled-{00001..00149..2}.tar"
val: "satflow-flow-144-tiled-{00002..00149..4}.tar"
test: "satflow-flow-144-tiled-{000014..00149..4}.tar"
num_workers: 1
test: "satflow-flow-144-tiled-{00004..00149..4}.tar"
num_workers: 4
pin_memory: True
config:
visualize: False
num_timesteps: 12
skip_timesteps: 1
forecast_times: 48
output_shape: 256
output_target: 64
target_type: "cloudmask"
num_crops: 10
use_topo: True
Expand Down
2 changes: 1 addition & 1 deletion satflow/configs/model/metnet_model.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# @package _group_
_target_: satflow.models.metnet.MetNet
input_channels: 17
hidden_dim: 64
hidden_dim: 128
forecast_steps: 48
learning_rate: 0.0078
kernel_size: 3
Expand Down
10 changes: 10 additions & 0 deletions satflow/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch.utils.data as thd
import webdataset as wds
from torch.utils.data.dataset import T_co
from satflow.data.utils.utils import crop_center

REGISTERED_DATASET_CLASSES = {}

Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(self, datasets, config, train=True):

# Defined output sizes, etc.
self.output_shape = config["output_shape"]
self.output_target = config.get("output_target", config["output_shape"])
self.target_type = config.get("target", "cloudmask")
# Should load the common data here
self.bands = config.get(
Expand Down Expand Up @@ -382,6 +384,14 @@ def __iter__(self) -> Iterator[T_co]:
for t in target_mask[1:]:
ts = np.concatenate([ts, t], axis=0)
target_mask = ts
if self.output_target != self.output_shape:
if self.use_image:
target_image = crop_center(
target_image, self.output_target, self.output_target
)
target_mask = crop_center(
target_mask, self.output_target, self.output_target
)
if self.vis:
self.visualize(image, target_image, target_mask)
if self.use_time and self.time_aux:
Expand Down
8 changes: 8 additions & 0 deletions satflow/data/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
from satpy import Scene


def crop_center(img, cropx, cropy):
"""Crops center of image through timestack, fails if all the images are concatenated as channels"""
t, c, y, x = img.shape
startx = x // 2 - (cropx // 2)
starty = y // 2 - (cropy // 2)
return img[:, :, starty : starty + cropy, startx : startx + cropx]


def eumetsat_filename_to_datetime(inner_tar_name):
"""Takes a file from the EUMETSAT API and returns
the date and time part of the filename"""
Expand Down
2 changes: 1 addition & 1 deletion satflow/models/layers/ConvGRU.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def reset_parameters(self):

def one_param(m):
"First parameter in `m`"
return m.parameters()[0]
return next(m.parameters())


def dropout_mask(x, sz, p):
Expand Down
9 changes: 7 additions & 2 deletions satflow/models/metnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
self.lr = learning_rate
self.drop = nn.Dropout(temporal_dropout)
if image_encoder is None:
image_encoder = DownSampler(input_channels)
image_encoder = DownSampler(input_channels + forecast_steps)
nf = 256 # from the simple image encoder
self.image_encoder = TimeDistributed(image_encoder)
self.ct = ConditionTime(forecast_steps)
Expand All @@ -62,6 +62,8 @@ def __init__(
)

self.head = head
self.head = nn.Conv2d(hidden_dim, 1, kernel_size=(1, 1)) # Reduces to mask
# self.head = nn.Sequential(nn.AdaptiveAvgPool2d(1), )

def encode_timestep(self, x, fstep=1):

Expand Down Expand Up @@ -102,6 +104,7 @@ def configure_optimizers(self):
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
y = torch.squeeze(y)

# if self.make_vis:
# if np.random.random() < 0.01:
Expand All @@ -115,13 +118,15 @@ def training_step(self, batch, batch_idx):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
y = torch.squeeze(y)
val_loss = F.mse_loss(y_hat, y)
self.log("val/loss", val_loss, on_step=True, on_epoch=True)
return val_loss

def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x, self.forecast_steps)
y_hat = self(x)
y = torch.squeeze(y)
loss = F.mse_loss(y_hat, y)
return loss

Expand Down

0 comments on commit 425dded

Please sign in to comment.