Skip to content
This repository has been archived by the owner on Nov 29, 2023. It is now read-only.

Add Loss per Frame #53

Merged
merged 5 commits into from
Jul 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions satflow/baseline/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,6 @@ but the optical flow usually ended up not actually changing anything. Instead, w
MSG HSV satellite channel to compute the optical flow. This was chosen as that is the highest
resolution satellite channel available, and it resulted in optical flow actually computing some movement.
This flow field was then applied to the cloud masks directly to obtain the flow results.

Avg Total Loss: 0.15720261434381796 Avg Baseline Loss: 0.1598897848692671
Overall Loss: 0.15720261434381738 Baseline: 0.1598897848692671
16 changes: 6 additions & 10 deletions satflow/data/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,8 @@ def prepare_data(self):
def setup(self, stage: Optional[str] = None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
train_dset = wds.WebDataset(
os.path.join(self.data_dir, self.sources["train"])
).decode()
val_dset = wds.WebDataset(os.path.join(self.data_dir, self.sources["val"])).decode()
train_dset = wds.WebDataset(os.path.join(self.data_dir, self.sources["train"]))
val_dset = wds.WebDataset(os.path.join(self.data_dir, self.sources["val"]))
if self.shuffle > 0:
# Add shuffling, each sample is still quite large, so too many examples ends up running out of ram
train_dset = train_dset.shuffle(self.shuffle)
Expand All @@ -45,7 +43,7 @@ def setup(self, stage: Optional[str] = None):

# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
test_dset = wds.WebDataset(os.path.join(self.data_dir, self.sources["test"])).decode()
test_dset = wds.WebDataset(os.path.join(self.data_dir, self.sources["test"]))
self.test_dataset = SatFlowDataset([test_dset], config=self.config, train=False)

def train_dataloader(self):
Expand Down Expand Up @@ -100,10 +98,8 @@ def prepare_data(self):
def setup(self, stage: Optional[str] = None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
train_dset = wds.WebDataset(
os.path.join(self.data_dir, self.sources["train"])
).decode()
val_dset = wds.WebDataset(os.path.join(self.data_dir, self.sources["val"])).decode()
train_dset = wds.WebDataset(os.path.join(self.data_dir, self.sources["train"]))
val_dset = wds.WebDataset(os.path.join(self.data_dir, self.sources["val"]))
if self.shuffle > 0:
# Add shuffling, each sample is still quite large, so too many examples ends up running out of ram
train_dset = train_dset.shuffle(self.shuffle)
Expand All @@ -112,7 +108,7 @@ def setup(self, stage: Optional[str] = None):

# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
test_dset = wds.WebDataset(os.path.join(self.data_dir, self.sources["test"])).decode()
test_dset = wds.WebDataset(os.path.join(self.data_dir, self.sources["test"]))
self.test_dataset = CloudFlowDataset([test_dset], config=self.config, train=False)

def train_dataloader(self):
Expand Down
24 changes: 14 additions & 10 deletions satflow/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def __iter__(self) -> Iterator[T_co]:
# Now in a Time x W x H x Channel order
target_image, target_mask = self.get_timestep(
sample,
target_timestep,
idx + 1,
return_target=True,
return_image=self.use_image,
)
Expand All @@ -398,7 +398,7 @@ def __iter__(self) -> Iterator[T_co]:
if np.isclose(np.min(target_mask), np.max(target_mask)):
continue # Ignore if target timestep has no clouds, or only clouds
# Now create stack here
for i in range(idx + 1, target_timestep):
for i in range(idx + 2, target_timestep + 1):
t_image, t_mask = self.get_timestep(
sample,
i,
Expand All @@ -407,12 +407,12 @@ def __iter__(self) -> Iterator[T_co]:
)
t_mask = self.aug.replay(replay, image=t_mask)["image"]
target_mask = np.concatenate(
[np.expand_dims(t_mask, axis=0), target_mask]
[target_mask, np.expand_dims(t_mask, axis=0)]
)
if self.use_image:
t_image = self.aug.replay(replay, image=t_image)["image"]
target_image = np.concatenate(
[np.expand_dims(t_image, axis=0), target_image]
[target_image, np.expand_dims(t_image, axis=0)]
)
# Ensure last target mask is also different than previous ones -> only want ones where things change
if np.allclose(target_mask[0], target_mask[-1]):
Expand Down Expand Up @@ -566,7 +566,7 @@ def __iter__(self) -> Iterator[T_co]:
# Now in a Time x W x H x Channel order
_, target_mask = self.get_timestep(
sample,
target_timestep,
idx + 1,
return_target=True,
return_image=False,
)
Expand All @@ -576,7 +576,7 @@ def __iter__(self) -> Iterator[T_co]:
if np.isclose(np.min(target_mask), np.max(target_mask)):
continue # Ignore if target timestep has no clouds, or only clouds
# Now create stack here
for i in range(idx + 1, target_timestep):
for i in range(idx + 2, target_timestep + 1):
_, t_mask = self.get_timestep(
sample,
i,
Expand All @@ -585,7 +585,7 @@ def __iter__(self) -> Iterator[T_co]:
)
t_mask = self.aug.replay(replay, image=t_mask)["image"]
target_mask = np.concatenate(
[np.expand_dims(t_mask, axis=0), target_mask]
[target_mask, np.expand_dims(t_mask, axis=0)]
)
# Ensure last target mask is also different than previous ones -> only want ones where things change
if np.allclose(target_mask[0], target_mask[-1]):
Expand Down Expand Up @@ -663,6 +663,10 @@ def __iter__(self) -> Iterator[T_co]:
for idx in idxs:
for _ in range(self.num_crops): # Do random crops as well for training
logger.debug(f"IDX: {idx}")
print(
f"Timesteps: Current: {timesteps[idx]} Prev: {timesteps[idx - 1]} Next: {timesteps[idx + 1]} Final: {timesteps[idx + self.forecast_times - 1]} "
f"Timedelta: Next - Curr: {timesteps[idx + 1] - timesteps[idx] } End - Curr: {timesteps[idx + self.forecast_times - 1] - timesteps[idx]}"
)
image, mask = self.get_timestep(
sample,
idx,
Expand All @@ -684,7 +688,7 @@ def __iter__(self) -> Iterator[T_co]:
# Now in a Time x W x H x Channel order
_, target_mask = self.get_timestep(
sample,
idx + self.forecast_times,
idx + 1,
return_target=True,
return_image=False,
)
Expand All @@ -694,7 +698,7 @@ def __iter__(self) -> Iterator[T_co]:
if np.isclose(np.min(target_mask), np.max(target_mask)):
continue # Ignore if target timestep has no clouds, or only clouds
# Now create stack here
for i in range(idx + 1, idx + self.forecast_times):
for i in range(idx + 2, idx + self.forecast_times + 1):
_, t_mask = self.get_timestep(
sample,
i,
Expand All @@ -703,7 +707,7 @@ def __iter__(self) -> Iterator[T_co]:
)
t_mask = self.aug.replay(replay, image=t_mask)["image"]
target_mask = np.concatenate(
[np.expand_dims(t_mask, axis=0), target_mask]
[target_mask, np.expand_dims(t_mask, axis=0)]
)
target_mask = np.round(target_mask).astype(np.int8)
# Convert to float/half-precision
Expand Down
13 changes: 13 additions & 0 deletions satflow/models/conv_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,26 @@ def training_step(self, batch, batch_idx):
self.visualize(x, y, y_hat, batch_idx)
loss = self.criterion(y_hat, y)
self.log("train/loss", loss, on_step=True)
y_hat = torch.moveaxis(y_hat, 2, 1)
frame_loss_dict = {}
for f in range(self.forecast_steps):
frame_loss = self.criterion(y_hat[:, f, :, :, :], y[:, f, :, :, :]).item()
frame_loss_dict[f"train/frame_{f}_loss"] = frame_loss
self.log_dict(frame_loss_dict, on_step=False, on_epoch=True)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x, self.forecast_steps)
val_loss = self.criterion(y_hat, y)
# Save out loss per frame as well
frame_loss_dict = {}
y_hat = torch.moveaxis(y_hat, 2, 1)
for f in range(self.forecast_steps):
frame_loss = self.criterion(y_hat[:, f, :, :, :], y[:, f, :, :, :]).item()
frame_loss_dict[f"val/frame_{f}_loss"] = frame_loss
self.log("val/loss", val_loss, on_step=True, on_epoch=True)
self.log_dict(frame_loss_dict, on_step=False, on_epoch=True)
return val_loss

def test_step(self, batch, batch_idx):
Expand Down
20 changes: 16 additions & 4 deletions satflow/models/metnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(

self.horizon = forecast_steps
self.lr = lr
self.criterion = F.mse_loss
self.drop = nn.Dropout(temporal_dropout)
if image_encoder in ["downsampler", "default"]:
image_encoder = DownSampler(input_channels + forecast_steps)
Expand Down Expand Up @@ -114,16 +115,27 @@ def training_step(self, batch, batch_idx):
# self.visualize(x, y, y_hat, batch_idx)
# Generally only care about the center x crop, so the model can take into account the clouds in the area without
# being penalized for that, but for now, just do general MSE loss, also only care about first 12 channels
loss = F.mse_loss(y_hat, y)
self.log("train/loss", loss, on_step=True)
loss = self.criterion(y_hat, y)
self.log("train/loss", loss)
frame_loss_dict = {}
for f in range(self.forecast_steps):
frame_loss = self.criterion(y_hat[f, :, :], y[f, :, :]).item()
frame_loss_dict[f"train/frame_{f}_loss"] = frame_loss
self.log_dict(frame_loss_dict)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
y = torch.squeeze(y)
val_loss = F.mse_loss(y_hat, y)
self.log("val/loss", val_loss, on_step=True, on_epoch=True)
val_loss = self.criterion(y_hat, y)
self.log("val/loss", val_loss)
# Save out loss per frame as well
frame_loss_dict = {}
for f in range(self.forecast_steps):
frame_loss = self.criterion(y_hat[:, f, :, :, :], y[:, f, :, :, :]).item()
frame_loss_dict[f"val/frame_{f}_loss"] = frame_loss
self.log_dict(frame_loss_dict)
return val_loss

def test_step(self, batch, batch_idx):
Expand Down
15 changes: 13 additions & 2 deletions satflow/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,29 @@ def training_step(self, batch, batch_idx):
# being penalized for that, but for now, just do general MSE loss, also only care about first 12 channels
loss = self.criterion(y_hat, y)
self.log("train/loss", loss, on_step=True)
frame_loss_dict = {}
for f in range(self.forecast_steps):
frame_loss = self.criterion(y_hat[:, f, :, :], y[:, f, :, :]).item()
frame_loss_dict[f"train/frame_{f}_loss"] = frame_loss
self.log_dict(frame_loss_dict)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
val_loss = self.criterion(y_hat, y)
self.log("val/loss", val_loss, on_step=True, on_epoch=True)
self.log("val/loss", val_loss)
# Save out loss per frame as well
frame_loss_dict = {}
for f in range(self.forecast_steps):
frame_loss = self.criterion(y_hat[:, f, :, :], y[:, f, :, :]).item()
frame_loss_dict[f"val/frame_{f}_loss"] = frame_loss
self.log_dict(frame_loss_dict)
return val_loss

def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x, self.forecast_steps)
y_hat = self(x)
loss = self.criterion(y_hat, y)
return loss

Expand Down