diff --git a/india_forecast_app/models/pvnet/model.py b/india_forecast_app/models/pvnet/model.py index 365b775..95c4aa7 100644 --- a/india_forecast_app/models/pvnet/model.py +++ b/india_forecast_app/models/pvnet/model.py @@ -12,7 +12,8 @@ import torch from ocf_datapipes.batch import stack_np_examples_into_batch from ocf_datapipes.training.pvnet import construct_sliced_data_pipeline as pv_base_pipeline -from ocf_datapipes.training.windnet import construct_sliced_data_pipeline as wind_base_pipeline +from ocf_datapipes.training.windnet import construct_sliced_data_pipeline as wind_base_pipeline, DictDatasetIterDataPipe +from ocf_datapipes.utils.utils import combine_to_single_dataset from pvnet.models.base_model import BaseModel as PVNetBaseModel from torch.utils.data import DataLoader from torch.utils.data.datapipes.iter import IterableWrapper @@ -113,19 +114,28 @@ def _create_dataloader(self): t0_datapipe = t0_datapipe.sharding_filter() # Batch datapipe - base_pipeline = wind_base_pipeline if self.asset_type == "wind" else pv_base_pipeline - batch_datapipe = ( - # TODO wind return dict, whereas PV returns IterDataPipe - need to resolve this - # Perhaps see https://github.com/openclimatefix/ocf_datapipes/blob/main/ocf_datapipes/training/windnet.py#L328 - base_pipeline( - config_filename=populated_data_config_filename, - location_pipe=location_pipe, - t0_datapipe=t0_datapipe, - production=False # TODO was True, but threw error as expecting GSP key to be defined + if self.asset_type == "wind": + base_datapipe_dict = ( + wind_base_pipeline( + config_filename=populated_data_config_filename, + location_pipe=location_pipe, + t0_datapipe=t0_datapipe + ) ) - .batch(BATCH_SIZE) - .map(stack_np_examples_into_batch) - ) + base_datapipe = DictDatasetIterDataPipe( + {k: v for k, v in base_datapipe_dict.items() if k != "config"}, + ).map(combine_to_single_dataset) + else: + base_datapipe = ( + pv_base_pipeline( + config_filename=populated_data_config_filename, + location_pipe=location_pipe, + t0_datapipe=t0_datapipe, + production=True + ) + ) + + batch_datapipe = base_datapipe.batch(BATCH_SIZE).map(stack_np_examples_into_batch) n_workers = os.cpu_count() - 1