diff --git a/experiments/026_dont_train_unet.py b/experiments/026_dont_train_unet.py index c721a2b..e4c7d68 100644 --- a/experiments/026_dont_train_unet.py +++ b/experiments/026_dont_train_unet.py @@ -197,6 +197,7 @@ def get_dataloader( # and we want to save GPU RAM. # But we do want hrvsatellite_time_utc to continue into the future by 2 hours because # downstream code relies on hrvsatellite_time_utc. + # This must come last. np_batch_processors.append( DeleteForecastSatelliteImagery(num_hist_sat_images=NUM_HIST_SAT_IMAGES) ) diff --git a/power_perceiver/np_batch_processor/delete_forecast_satellite_imagery.py b/power_perceiver/np_batch_processor/delete_forecast_satellite_imagery.py index 001b033..997eeb6 100644 --- a/power_perceiver/np_batch_processor/delete_forecast_satellite_imagery.py +++ b/power_perceiver/np_batch_processor/delete_forecast_satellite_imagery.py @@ -2,6 +2,7 @@ from power_perceiver.consts import BatchKey from power_perceiver.load_prepared_batches.data_sources.prepared_data_source import NumpyBatch +from power_perceiver.utils import assert_num_dims @dataclass @@ -17,8 +18,9 @@ class DeleteForecastSatelliteImagery: num_hist_sat_images: int def __call__(self, np_batch: NumpyBatch) -> NumpyBatch: - # Shape: time, channels, y, x + # Shape: example, time, channels, y, x + assert_num_dims(np_batch[BatchKey.hrvsatellite], 5) np_batch[BatchKey.hrvsatellite] = np_batch[BatchKey.hrvsatellite][ - : self.num_hist_sat_images + :, : self.num_hist_sat_images ] return np_batch