From 88183a440db7c3f110572073f34bd9ca06478aa1 Mon Sep 17 00:00:00 2001 From: Jack Kelly Date: Thu, 16 Jun 2022 12:36:04 +0100 Subject: [PATCH] remove future satellite imagery to save GPU RAM. #136 --- experiments/026_dont_train_unet.py | 14 +++++++++- .../load_raw/data_sources/raw_data_source.py | 28 ++++++++++++------- power_perceiver/load_raw/raw_dataset.py | 7 +++-- .../delete_forecast_satellite_imagery.py | 24 ++++++++++++++++ .../np_batch_processor/sun_position.py | 6 ++-- 5 files changed, 63 insertions(+), 16 deletions(-) create mode 100644 power_perceiver/np_batch_processor/delete_forecast_satellite_imagery.py diff --git a/experiments/026_dont_train_unet.py b/experiments/026_dont_train_unet.py index f6e8d4c..c721a2b 100644 --- a/experiments/026_dont_train_unet.py +++ b/experiments/026_dont_train_unet.py @@ -42,6 +42,9 @@ from power_perceiver.load_raw.data_sources.raw_satellite_data_source import RawSatelliteDataSource from power_perceiver.load_raw.national_pv_dataset import NationalPVDataset from power_perceiver.load_raw.raw_dataset import RawDataset +from power_perceiver.np_batch_processor.delete_forecast_satellite_imagery import ( + DeleteForecastSatelliteImagery, +) from power_perceiver.np_batch_processor.encode_space_time import EncodeSpaceTime from power_perceiver.np_batch_processor.save_t0_time import SaveT0Time from power_perceiver.np_batch_processor.sun_position import SunPosition @@ -137,7 +140,8 @@ def get_dataloader( start_date=start_date, end_date=end_date, history_duration=datetime.timedelta(hours=1), - forecast_duration=datetime.timedelta(hours=0), + # We delete the future satellite imagery in the np_batch_processors. + forecast_duration=datetime.timedelta(hours=2), ) pv_data_source = RawPVDataSource( @@ -189,6 +193,14 @@ def get_dataloader( if USE_TOPOGRAPHY: np_batch_processors.append(Topography("/home/jack/europe_dem_2km_osgb.tif")) + # Delete imagery of the future, because we're not training the U-Net, + # and we want to save GPU RAM. + # But we do want hrvsatellite_time_utc to continue into the future by 2 hours because + # downstream code relies on hrvsatellite_time_utc. + np_batch_processors.append( + DeleteForecastSatelliteImagery(num_hist_sat_images=NUM_HIST_SAT_IMAGES) + ) + raw_dataset_kwargs = dict( n_examples_per_batch=20, # TODO: Increase to more like 32! n_batches_per_epoch=n_batches_per_epoch_per_worker, diff --git a/power_perceiver/load_raw/data_sources/raw_data_source.py b/power_perceiver/load_raw/data_sources/raw_data_source.py index afe633a..46c6622 100644 --- a/power_perceiver/load_raw/data_sources/raw_data_source.py +++ b/power_perceiver/load_raw/data_sources/raw_data_source.py @@ -76,9 +76,7 @@ def get_example( The returned Dataset must not include an `example` dimension. """ self._allow_nans = False - xr_data = self.data_in_ram - xr_data = self._get_time_slice(xr_data, t0_datetime_utc=t0_datetime_utc) - xr_data = self._get_spatial_slice(xr_data, center_osgb=center_osgb) + xr_data = self._get_slice(t0_datetime_utc=t0_datetime_utc, center_osgb=center_osgb) xr_data = self._post_process(xr_data) xr_data = self._transform(xr_data) try: @@ -89,6 +87,16 @@ def get_example( ) from e return xr_data + def _get_slice(self, t0_datetime_utc: datetime.datetime, center_osgb: Location) -> xr.DataArray: + """Can be overridden by child classes. + + The returned Dataset must not include an `example` dimension. + """ + xr_data = self.data_in_ram + xr_data = self._get_time_slice(xr_data, t0_datetime_utc=t0_datetime_utc) + xr_data = self._get_spatial_slice(xr_data, center_osgb=center_osgb) + return xr_data + def check_xarray_data(self, xr_data: xr.DataArray): # noqa: D102 if not self._allow_nans: assert np.isfinite(xr_data).all(), "Some xr_data is non-finite!" @@ -279,19 +287,19 @@ def _get_time_slice( The returned data does not include an `example` dimension. """ - start_dt_rounded = self._get_start_dt_ceil(t0_datetime_utc) - end_dt_rounded = self._get_end_dt_ceil(t0_datetime_utc) + start_dt_ceil = self._get_start_dt_ceil(t0_datetime_utc) + end_dt_ceil = self._get_end_dt_ceil(t0_datetime_utc) # Sanity check! assert ( - start_dt_rounded in xr_data.time_utc - ), f"{start_dt_rounded=} not in xr_data.time_utc! {t0_datetime_utc=}" + start_dt_ceil in xr_data.time_utc + ), f"{start_dt_ceil=} not in xr_data.time_utc! {t0_datetime_utc=}" assert ( - end_dt_rounded in xr_data.time_utc - ), f"{end_dt_rounded=} not in xr_data.time_utc! {t0_datetime_utc=}" + end_dt_ceil in xr_data.time_utc + ), f"{end_dt_ceil=} not in xr_data.time_utc! {t0_datetime_utc=}" # Get time slice: - time_slice = xr_data.sel({self._time_dim_name: slice(start_dt_rounded, end_dt_rounded)}) + time_slice = xr_data.sel({self._time_dim_name: slice(start_dt_ceil, end_dt_ceil)}) self._sanity_check_time_slice(time_slice, self._time_dim_name, t0_datetime_utc) return time_slice diff --git a/power_perceiver/load_raw/raw_dataset.py b/power_perceiver/load_raw/raw_dataset.py index e0f48a8..598eae2 100644 --- a/power_perceiver/load_raw/raw_dataset.py +++ b/power_perceiver/load_raw/raw_dataset.py @@ -264,9 +264,10 @@ def _get_specific_xr_example( else: try: xr_example[data_source.__class__] = data_source.empty_example - except AttributeError: - # This is probably a duplicate data_source. Ignore. - pass + except AttributeError as e: + raise AttributeError( + "If this is a duplicate data_source then we should ignore." + ) from e return xr_example diff --git a/power_perceiver/np_batch_processor/delete_forecast_satellite_imagery.py b/power_perceiver/np_batch_processor/delete_forecast_satellite_imagery.py new file mode 100644 index 0000000..001b033 --- /dev/null +++ b/power_perceiver/np_batch_processor/delete_forecast_satellite_imagery.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass + +from power_perceiver.consts import BatchKey +from power_perceiver.load_prepared_batches.data_sources.prepared_data_source import NumpyBatch + + +@dataclass +class DeleteForecastSatelliteImagery: + """Delete imagery of the future. + + Useful when not training the U-Net, and we want to save GPU RAM. + + But we do want hrvsatellite_time_utc to continue out to 2 hours because + downstream code relies on hrvsatellite_time_utc. + """ + + num_hist_sat_images: int + + def __call__(self, np_batch: NumpyBatch) -> NumpyBatch: + # Shape: time, channels, y, x + np_batch[BatchKey.hrvsatellite] = np_batch[BatchKey.hrvsatellite][ + : self.num_hist_sat_images + ] + return np_batch diff --git a/power_perceiver/np_batch_processor/sun_position.py b/power_perceiver/np_batch_processor/sun_position.py index 266bdfc..d0fc2d2 100644 --- a/power_perceiver/np_batch_processor/sun_position.py +++ b/power_perceiver/np_batch_processor/sun_position.py @@ -17,14 +17,16 @@ @dataclass class SunPosition: - """This is kind of a duplicate of the info in the Sun pre-prepared batch. + """Append the the Sun's azimuth and elevation. + + This is a duplicate of the info in the Sun pre-prepared batch. But we don't have access to those pre-prepared batches when training directly from the Zarr! Hence we need this when training directly from Zarr! """ def __call__(self, np_batch: NumpyBatch) -> NumpyBatch: - """Sets `BatchKey.solar_azimuth_at_t0` and `BatchKey.solar_elevation_at_t0`.""" + """Set `BatchKey.solar_azimuth` and `BatchKey.solar_elevation`.""" y_osgb = np_batch[BatchKey.hrvsatellite_y_osgb] # example, y, x x_osgb = np_batch[BatchKey.hrvsatellite_x_osgb] # example, y, x time_utc = np_batch[BatchKey.hrvsatellite_time_utc] # example, time