Skip to content

Commit

Permalink
Fixed windnet batch datapipe
Browse files Browse the repository at this point in the history
  • Loading branch information
confusedmatrix committed Feb 8, 2024
1 parent 45dc45e commit e432f23
Showing 1 changed file with 23 additions and 13 deletions.
36 changes: 23 additions & 13 deletions india_forecast_app/models/pvnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import torch
from ocf_datapipes.batch import stack_np_examples_into_batch
from ocf_datapipes.training.pvnet import construct_sliced_data_pipeline as pv_base_pipeline
from ocf_datapipes.training.windnet import construct_sliced_data_pipeline as wind_base_pipeline
from ocf_datapipes.training.windnet import construct_sliced_data_pipeline as wind_base_pipeline, DictDatasetIterDataPipe

Check failure on line 15 in india_forecast_app/models/pvnet/model.py

View workflow job for this annotation

GitHub Actions / lint_and_test / Lint the code and run the tests

Ruff (E501)

india_forecast_app/models/pvnet/model.py:15:101: E501 Line too long (120 > 100)
from ocf_datapipes.utils.utils import combine_to_single_dataset
from pvnet.models.base_model import BaseModel as PVNetBaseModel
from torch.utils.data import DataLoader
from torch.utils.data.datapipes.iter import IterableWrapper
Expand Down Expand Up @@ -113,19 +114,28 @@ def _create_dataloader(self):
t0_datapipe = t0_datapipe.sharding_filter()

# Batch datapipe
base_pipeline = wind_base_pipeline if self.asset_type == "wind" else pv_base_pipeline
batch_datapipe = (
# TODO wind return dict, whereas PV returns IterDataPipe - need to resolve this
# Perhaps see https://github.com/openclimatefix/ocf_datapipes/blob/main/ocf_datapipes/training/windnet.py#L328
base_pipeline(
config_filename=populated_data_config_filename,
location_pipe=location_pipe,
t0_datapipe=t0_datapipe,
production=False # TODO was True, but threw error as expecting GSP key to be defined
if self.asset_type == "wind":
base_datapipe_dict = (
wind_base_pipeline(
config_filename=populated_data_config_filename,
location_pipe=location_pipe,
t0_datapipe=t0_datapipe
)
)
.batch(BATCH_SIZE)
.map(stack_np_examples_into_batch)
)
base_datapipe = DictDatasetIterDataPipe(
{k: v for k, v in base_datapipe_dict.items() if k != "config"},
).map(combine_to_single_dataset)
else:
base_datapipe = (
pv_base_pipeline(
config_filename=populated_data_config_filename,
location_pipe=location_pipe,
t0_datapipe=t0_datapipe,
production=True
)
)

batch_datapipe = base_datapipe.batch(BATCH_SIZE).map(stack_np_examples_into_batch)

n_workers = os.cpu_count() - 1

Expand Down

0 comments on commit e432f23

Please sign in to comment.