diff --git a/satflow/baseline/README.md b/satflow/baseline/README.md index 669773c7..0e023e8d 100644 --- a/satflow/baseline/README.md +++ b/satflow/baseline/README.md @@ -20,3 +20,6 @@ but the optical flow usually ended up not actually changing anything. Instead, w MSG HSV satellite channel to compute the optical flow. This was chosen as that is the highest resolution satellite channel available, and it resulted in optical flow actually computing some movement. This flow field was then applied to the cloud masks directly to obtain the flow results. + +Avg Total Loss: 0.15720261434381796 Avg Baseline Loss: 0.1598897848692671 +Overall Loss: 0.15720261434381738 Baseline: 0.1598897848692671 diff --git a/satflow/data/datamodules.py b/satflow/data/datamodules.py index e1216d84..4e8f4947 100644 --- a/satflow/data/datamodules.py +++ b/satflow/data/datamodules.py @@ -33,10 +33,8 @@ def prepare_data(self): def setup(self, stage: Optional[str] = None): # Assign train/val datasets for use in dataloaders if stage == "fit" or stage is None: - train_dset = wds.WebDataset( - os.path.join(self.data_dir, self.sources["train"]) - ).decode() - val_dset = wds.WebDataset(os.path.join(self.data_dir, self.sources["val"])).decode() + train_dset = wds.WebDataset(os.path.join(self.data_dir, self.sources["train"])) + val_dset = wds.WebDataset(os.path.join(self.data_dir, self.sources["val"])) if self.shuffle > 0: # Add shuffling, each sample is still quite large, so too many examples ends up running out of ram train_dset = train_dset.shuffle(self.shuffle) @@ -45,7 +43,7 @@ def setup(self, stage: Optional[str] = None): # Assign test dataset for use in dataloader(s) if stage == "test" or stage is None: - test_dset = wds.WebDataset(os.path.join(self.data_dir, self.sources["test"])).decode() + test_dset = wds.WebDataset(os.path.join(self.data_dir, self.sources["test"])) self.test_dataset = SatFlowDataset([test_dset], config=self.config, train=False) def train_dataloader(self): @@ -100,10 +98,8 @@ def prepare_data(self): def setup(self, stage: Optional[str] = None): # Assign train/val datasets for use in dataloaders if stage == "fit" or stage is None: - train_dset = wds.WebDataset( - os.path.join(self.data_dir, self.sources["train"]) - ).decode() - val_dset = wds.WebDataset(os.path.join(self.data_dir, self.sources["val"])).decode() + train_dset = wds.WebDataset(os.path.join(self.data_dir, self.sources["train"])) + val_dset = wds.WebDataset(os.path.join(self.data_dir, self.sources["val"])) if self.shuffle > 0: # Add shuffling, each sample is still quite large, so too many examples ends up running out of ram train_dset = train_dset.shuffle(self.shuffle) @@ -112,7 +108,7 @@ def setup(self, stage: Optional[str] = None): # Assign test dataset for use in dataloader(s) if stage == "test" or stage is None: - test_dset = wds.WebDataset(os.path.join(self.data_dir, self.sources["test"])).decode() + test_dset = wds.WebDataset(os.path.join(self.data_dir, self.sources["test"])) self.test_dataset = CloudFlowDataset([test_dset], config=self.config, train=False) def train_dataloader(self): diff --git a/satflow/data/datasets.py b/satflow/data/datasets.py index 69c96463..dba3246d 100644 --- a/satflow/data/datasets.py +++ b/satflow/data/datasets.py @@ -384,7 +384,7 @@ def __iter__(self) -> Iterator[T_co]: # Now in a Time x W x H x Channel order target_image, target_mask = self.get_timestep( sample, - target_timestep, + idx + 1, return_target=True, return_image=self.use_image, ) @@ -398,7 +398,7 @@ def __iter__(self) -> Iterator[T_co]: if np.isclose(np.min(target_mask), np.max(target_mask)): continue # Ignore if target timestep has no clouds, or only clouds # Now create stack here - for i in range(idx + 1, target_timestep): + for i in range(idx + 2, target_timestep + 1): t_image, t_mask = self.get_timestep( sample, i, @@ -407,12 +407,12 @@ def __iter__(self) -> Iterator[T_co]: ) t_mask = self.aug.replay(replay, image=t_mask)["image"] target_mask = np.concatenate( - [np.expand_dims(t_mask, axis=0), target_mask] + [target_mask, np.expand_dims(t_mask, axis=0)] ) if self.use_image: t_image = self.aug.replay(replay, image=t_image)["image"] target_image = np.concatenate( - [np.expand_dims(t_image, axis=0), target_image] + [target_image, np.expand_dims(t_image, axis=0)] ) # Ensure last target mask is also different than previous ones -> only want ones where things change if np.allclose(target_mask[0], target_mask[-1]): @@ -566,7 +566,7 @@ def __iter__(self) -> Iterator[T_co]: # Now in a Time x W x H x Channel order _, target_mask = self.get_timestep( sample, - target_timestep, + idx + 1, return_target=True, return_image=False, ) @@ -576,7 +576,7 @@ def __iter__(self) -> Iterator[T_co]: if np.isclose(np.min(target_mask), np.max(target_mask)): continue # Ignore if target timestep has no clouds, or only clouds # Now create stack here - for i in range(idx + 1, target_timestep): + for i in range(idx + 2, target_timestep + 1): _, t_mask = self.get_timestep( sample, i, @@ -585,7 +585,7 @@ def __iter__(self) -> Iterator[T_co]: ) t_mask = self.aug.replay(replay, image=t_mask)["image"] target_mask = np.concatenate( - [np.expand_dims(t_mask, axis=0), target_mask] + [target_mask, np.expand_dims(t_mask, axis=0)] ) # Ensure last target mask is also different than previous ones -> only want ones where things change if np.allclose(target_mask[0], target_mask[-1]): @@ -663,6 +663,10 @@ def __iter__(self) -> Iterator[T_co]: for idx in idxs: for _ in range(self.num_crops): # Do random crops as well for training logger.debug(f"IDX: {idx}") + print( + f"Timesteps: Current: {timesteps[idx]} Prev: {timesteps[idx - 1]} Next: {timesteps[idx + 1]} Final: {timesteps[idx + self.forecast_times - 1]} " + f"Timedelta: Next - Curr: {timesteps[idx + 1] - timesteps[idx] } End - Curr: {timesteps[idx + self.forecast_times - 1] - timesteps[idx]}" + ) image, mask = self.get_timestep( sample, idx, @@ -684,7 +688,7 @@ def __iter__(self) -> Iterator[T_co]: # Now in a Time x W x H x Channel order _, target_mask = self.get_timestep( sample, - idx + self.forecast_times, + idx + 1, return_target=True, return_image=False, ) @@ -694,7 +698,7 @@ def __iter__(self) -> Iterator[T_co]: if np.isclose(np.min(target_mask), np.max(target_mask)): continue # Ignore if target timestep has no clouds, or only clouds # Now create stack here - for i in range(idx + 1, idx + self.forecast_times): + for i in range(idx + 2, idx + self.forecast_times + 1): _, t_mask = self.get_timestep( sample, i, @@ -703,7 +707,7 @@ def __iter__(self) -> Iterator[T_co]: ) t_mask = self.aug.replay(replay, image=t_mask)["image"] target_mask = np.concatenate( - [np.expand_dims(t_mask, axis=0), target_mask] + [target_mask, np.expand_dims(t_mask, axis=0)] ) target_mask = np.round(target_mask).astype(np.int8) # Convert to float/half-precision diff --git a/satflow/models/conv_lstm.py b/satflow/models/conv_lstm.py index 1d422d37..76439135 100644 --- a/satflow/models/conv_lstm.py +++ b/satflow/models/conv_lstm.py @@ -161,13 +161,26 @@ def training_step(self, batch, batch_idx): self.visualize(x, y, y_hat, batch_idx) loss = self.criterion(y_hat, y) self.log("train/loss", loss, on_step=True) + y_hat = torch.moveaxis(y_hat, 2, 1) + frame_loss_dict = {} + for f in range(self.forecast_steps): + frame_loss = self.criterion(y_hat[:, f, :, :, :], y[:, f, :, :, :]).item() + frame_loss_dict[f"train/frame_{f}_loss"] = frame_loss + self.log_dict(frame_loss_dict, on_step=False, on_epoch=True) return loss def validation_step(self, batch, batch_idx): x, y = batch y_hat = self(x, self.forecast_steps) val_loss = self.criterion(y_hat, y) + # Save out loss per frame as well + frame_loss_dict = {} + y_hat = torch.moveaxis(y_hat, 2, 1) + for f in range(self.forecast_steps): + frame_loss = self.criterion(y_hat[:, f, :, :, :], y[:, f, :, :, :]).item() + frame_loss_dict[f"val/frame_{f}_loss"] = frame_loss self.log("val/loss", val_loss, on_step=True, on_epoch=True) + self.log_dict(frame_loss_dict, on_step=False, on_epoch=True) return val_loss def test_step(self, batch, batch_idx): diff --git a/satflow/models/metnet.py b/satflow/models/metnet.py index e1a0704f..ceed7244 100644 --- a/satflow/models/metnet.py +++ b/satflow/models/metnet.py @@ -48,6 +48,7 @@ def __init__( self.horizon = forecast_steps self.lr = lr + self.criterion = F.mse_loss self.drop = nn.Dropout(temporal_dropout) if image_encoder in ["downsampler", "default"]: image_encoder = DownSampler(input_channels + forecast_steps) @@ -114,16 +115,27 @@ def training_step(self, batch, batch_idx): # self.visualize(x, y, y_hat, batch_idx) # Generally only care about the center x crop, so the model can take into account the clouds in the area without # being penalized for that, but for now, just do general MSE loss, also only care about first 12 channels - loss = F.mse_loss(y_hat, y) - self.log("train/loss", loss, on_step=True) + loss = self.criterion(y_hat, y) + self.log("train/loss", loss) + frame_loss_dict = {} + for f in range(self.forecast_steps): + frame_loss = self.criterion(y_hat[f, :, :], y[f, :, :]).item() + frame_loss_dict[f"train/frame_{f}_loss"] = frame_loss + self.log_dict(frame_loss_dict) return loss def validation_step(self, batch, batch_idx): x, y = batch y_hat = self(x) y = torch.squeeze(y) - val_loss = F.mse_loss(y_hat, y) - self.log("val/loss", val_loss, on_step=True, on_epoch=True) + val_loss = self.criterion(y_hat, y) + self.log("val/loss", val_loss) + # Save out loss per frame as well + frame_loss_dict = {} + for f in range(self.forecast_steps): + frame_loss = self.criterion(y_hat[:, f, :, :, :], y[:, f, :, :, :]).item() + frame_loss_dict[f"val/frame_{f}_loss"] = frame_loss + self.log_dict(frame_loss_dict) return val_loss def test_step(self, batch, batch_idx): diff --git a/satflow/models/unet.py b/satflow/models/unet.py index 5873de7e..a31b36e4 100644 --- a/satflow/models/unet.py +++ b/satflow/models/unet.py @@ -67,18 +67,29 @@ def training_step(self, batch, batch_idx): # being penalized for that, but for now, just do general MSE loss, also only care about first 12 channels loss = self.criterion(y_hat, y) self.log("train/loss", loss, on_step=True) + frame_loss_dict = {} + for f in range(self.forecast_steps): + frame_loss = self.criterion(y_hat[:, f, :, :], y[:, f, :, :]).item() + frame_loss_dict[f"train/frame_{f}_loss"] = frame_loss + self.log_dict(frame_loss_dict) return loss def validation_step(self, batch, batch_idx): x, y = batch y_hat = self(x) val_loss = self.criterion(y_hat, y) - self.log("val/loss", val_loss, on_step=True, on_epoch=True) + self.log("val/loss", val_loss) + # Save out loss per frame as well + frame_loss_dict = {} + for f in range(self.forecast_steps): + frame_loss = self.criterion(y_hat[:, f, :, :], y[:, f, :, :]).item() + frame_loss_dict[f"val/frame_{f}_loss"] = frame_loss + self.log_dict(frame_loss_dict) return val_loss def test_step(self, batch, batch_idx): x, y = batch - y_hat = self(x, self.forecast_steps) + y_hat = self(x) loss = self.criterion(y_hat, y) return loss