Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
remove future satellite imagery to save GPU RAM. #136
Browse files Browse the repository at this point in the history
  • Loading branch information
JackKelly committed Jun 16, 2022
1 parent 69c2659 commit 88183a4
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 16 deletions.
14 changes: 13 additions & 1 deletion experiments/026_dont_train_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
from power_perceiver.load_raw.data_sources.raw_satellite_data_source import RawSatelliteDataSource
from power_perceiver.load_raw.national_pv_dataset import NationalPVDataset
from power_perceiver.load_raw.raw_dataset import RawDataset
from power_perceiver.np_batch_processor.delete_forecast_satellite_imagery import (
DeleteForecastSatelliteImagery,
)
from power_perceiver.np_batch_processor.encode_space_time import EncodeSpaceTime
from power_perceiver.np_batch_processor.save_t0_time import SaveT0Time
from power_perceiver.np_batch_processor.sun_position import SunPosition
Expand Down Expand Up @@ -137,7 +140,8 @@ def get_dataloader(
start_date=start_date,
end_date=end_date,
history_duration=datetime.timedelta(hours=1),
forecast_duration=datetime.timedelta(hours=0),
# We delete the future satellite imagery in the np_batch_processors.
forecast_duration=datetime.timedelta(hours=2),
)

pv_data_source = RawPVDataSource(
Expand Down Expand Up @@ -189,6 +193,14 @@ def get_dataloader(
if USE_TOPOGRAPHY:
np_batch_processors.append(Topography("/home/jack/europe_dem_2km_osgb.tif"))

# Delete imagery of the future, because we're not training the U-Net,
# and we want to save GPU RAM.
# But we do want hrvsatellite_time_utc to continue into the future by 2 hours because
# downstream code relies on hrvsatellite_time_utc.
np_batch_processors.append(
DeleteForecastSatelliteImagery(num_hist_sat_images=NUM_HIST_SAT_IMAGES)
)

raw_dataset_kwargs = dict(
n_examples_per_batch=20, # TODO: Increase to more like 32!
n_batches_per_epoch=n_batches_per_epoch_per_worker,
Expand Down
28 changes: 18 additions & 10 deletions power_perceiver/load_raw/data_sources/raw_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def get_example(
The returned Dataset must not include an `example` dimension.
"""
self._allow_nans = False
xr_data = self.data_in_ram
xr_data = self._get_time_slice(xr_data, t0_datetime_utc=t0_datetime_utc)
xr_data = self._get_spatial_slice(xr_data, center_osgb=center_osgb)
xr_data = self._get_slice(t0_datetime_utc=t0_datetime_utc, center_osgb=center_osgb)
xr_data = self._post_process(xr_data)
xr_data = self._transform(xr_data)
try:
Expand All @@ -89,6 +87,16 @@ def get_example(
) from e
return xr_data

def _get_slice(self, t0_datetime_utc: datetime.datetime, center_osgb: Location) -> xr.DataArray:
"""Can be overridden by child classes.
The returned Dataset must not include an `example` dimension.
"""
xr_data = self.data_in_ram
xr_data = self._get_time_slice(xr_data, t0_datetime_utc=t0_datetime_utc)
xr_data = self._get_spatial_slice(xr_data, center_osgb=center_osgb)
return xr_data

def check_xarray_data(self, xr_data: xr.DataArray): # noqa: D102
if not self._allow_nans:
assert np.isfinite(xr_data).all(), "Some xr_data is non-finite!"
Expand Down Expand Up @@ -279,19 +287,19 @@ def _get_time_slice(
The returned data does not include an `example` dimension.
"""
start_dt_rounded = self._get_start_dt_ceil(t0_datetime_utc)
end_dt_rounded = self._get_end_dt_ceil(t0_datetime_utc)
start_dt_ceil = self._get_start_dt_ceil(t0_datetime_utc)
end_dt_ceil = self._get_end_dt_ceil(t0_datetime_utc)

# Sanity check!
assert (
start_dt_rounded in xr_data.time_utc
), f"{start_dt_rounded=} not in xr_data.time_utc! {t0_datetime_utc=}"
start_dt_ceil in xr_data.time_utc
), f"{start_dt_ceil=} not in xr_data.time_utc! {t0_datetime_utc=}"
assert (
end_dt_rounded in xr_data.time_utc
), f"{end_dt_rounded=} not in xr_data.time_utc! {t0_datetime_utc=}"
end_dt_ceil in xr_data.time_utc
), f"{end_dt_ceil=} not in xr_data.time_utc! {t0_datetime_utc=}"

# Get time slice:
time_slice = xr_data.sel({self._time_dim_name: slice(start_dt_rounded, end_dt_rounded)})
time_slice = xr_data.sel({self._time_dim_name: slice(start_dt_ceil, end_dt_ceil)})
self._sanity_check_time_slice(time_slice, self._time_dim_name, t0_datetime_utc)
return time_slice

Expand Down
7 changes: 4 additions & 3 deletions power_perceiver/load_raw/raw_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,10 @@ def _get_specific_xr_example(
else:
try:
xr_example[data_source.__class__] = data_source.empty_example
except AttributeError:
# This is probably a duplicate data_source. Ignore.
pass
except AttributeError as e:
raise AttributeError(
"If this is a duplicate data_source then we should ignore."
) from e

return xr_example

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from dataclasses import dataclass

from power_perceiver.consts import BatchKey
from power_perceiver.load_prepared_batches.data_sources.prepared_data_source import NumpyBatch


@dataclass
class DeleteForecastSatelliteImagery:
"""Delete imagery of the future.
Useful when not training the U-Net, and we want to save GPU RAM.
But we do want hrvsatellite_time_utc to continue out to 2 hours because
downstream code relies on hrvsatellite_time_utc.
"""

num_hist_sat_images: int

def __call__(self, np_batch: NumpyBatch) -> NumpyBatch:
# Shape: time, channels, y, x
np_batch[BatchKey.hrvsatellite] = np_batch[BatchKey.hrvsatellite][
: self.num_hist_sat_images
]
return np_batch
6 changes: 4 additions & 2 deletions power_perceiver/np_batch_processor/sun_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@

@dataclass
class SunPosition:
"""This is kind of a duplicate of the info in the Sun pre-prepared batch.
"""Append the the Sun's azimuth and elevation.
This is a duplicate of the info in the Sun pre-prepared batch.
But we don't have access to those pre-prepared batches when training directly
from the Zarr! Hence we need this when training directly from Zarr!
"""

def __call__(self, np_batch: NumpyBatch) -> NumpyBatch:
"""Sets `BatchKey.solar_azimuth_at_t0` and `BatchKey.solar_elevation_at_t0`."""
"""Set `BatchKey.solar_azimuth` and `BatchKey.solar_elevation`."""
y_osgb = np_batch[BatchKey.hrvsatellite_y_osgb] # example, y, x
x_osgb = np_batch[BatchKey.hrvsatellite_x_osgb] # example, y, x
time_utc = np_batch[BatchKey.hrvsatellite_time_utc] # example, time
Expand Down

0 comments on commit 88183a4

Please sign in to comment.