Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
bugfix. #136
Browse files Browse the repository at this point in the history
  • Loading branch information
JackKelly committed Jun 16, 2022
1 parent 88183a4 commit 97d8ca5
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
1 change: 1 addition & 0 deletions experiments/026_dont_train_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def get_dataloader(
# and we want to save GPU RAM.
# But we do want hrvsatellite_time_utc to continue into the future by 2 hours because
# downstream code relies on hrvsatellite_time_utc.
# This must come last.
np_batch_processors.append(
DeleteForecastSatelliteImagery(num_hist_sat_images=NUM_HIST_SAT_IMAGES)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from power_perceiver.consts import BatchKey
from power_perceiver.load_prepared_batches.data_sources.prepared_data_source import NumpyBatch
from power_perceiver.utils import assert_num_dims


@dataclass
Expand All @@ -17,8 +18,9 @@ class DeleteForecastSatelliteImagery:
num_hist_sat_images: int

def __call__(self, np_batch: NumpyBatch) -> NumpyBatch:
# Shape: time, channels, y, x
# Shape: example, time, channels, y, x
assert_num_dims(np_batch[BatchKey.hrvsatellite], 5)
np_batch[BatchKey.hrvsatellite] = np_batch[BatchKey.hrvsatellite][
: self.num_hist_sat_images
:, : self.num_hist_sat_images
]
return np_batch

0 comments on commit 97d8ca5

Please sign in to comment.