From 5f62f1f492d45d7b9177a1c23a9c97d13a69a000 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Mon, 16 Sep 2024 15:14:14 +0000 Subject: [PATCH] refactor dataloader --- pvnet_app/app.py | 123 +++------------------------------------- pvnet_app/dataloader.py | 123 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 115 deletions(-) create mode 100644 pvnet_app/dataloader.py diff --git a/pvnet_app/app.py b/pvnet_app/app.py index 4e5148e..169bb62 100644 --- a/pvnet_app/app.py +++ b/pvnet_app/app.py @@ -15,7 +15,6 @@ import tempfile import warnings from datetime import timedelta -from pathlib import Path import dask import pandas as pd @@ -26,20 +25,13 @@ from nowcasting_datamodel.models.base import Base_Forecast from nowcasting_datamodel.read.read_gsp import get_latest_gsp_capacities from nowcasting_datamodel.save.save import save as save_sql_forecasts -from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset -from ocf_datapipes.batch import stack_np_examples_into_batch, batch_to_tensor, copy_batch_to_device +from ocf_datapipes.batch import batch_to_tensor, copy_batch_to_device from pvnet.models.base_model import BaseModel as PVNetBaseModel -from pvnet.utils import GSPLocationLookup -from torch.utils.data import DataLoader import sentry_sdk - import pvnet_app -from pvnet_app.data.nwp import ( - download_all_nwp_data, - preprocess_nwp_data, -) +from pvnet_app.data.nwp import download_all_nwp_data, preprocess_nwp_data from pvnet_app.data.satellite import ( download_all_sat_data, preprocess_sat_data, @@ -53,11 +45,8 @@ save_yaml_config, ) -# Legacy imports -from ocf_datapipes.load import OpenGSPFromDatabase -from torch.utils.data.datapipes.iter import IterableWrapper -from ocf_datapipes.training.pvnet import construct_sliced_data_pipeline -from ocf_datapipes.batch import BatchKey +from pvnet_app.dataloader import get_legacy_dataloader, get_dataloader + # sentry sentry_sdk.init( @@ -224,104 +213,6 @@ def filter(self, record): # APP MAIN -def get_dataloader(config_filename: str, t0: pd.Timestamp, gsp_ids: list[int], num_workers: int): - - # Populate the data config with production data paths - populated_data_config_filename = Path(config_filename).parent / "data_config.yaml" - - populate_data_config_sources(config_filename, populated_data_config_filename) - - dataset = PVNetUKRegionalDataset( - config_filename=populated_data_config_filename, - start_time=t0, - end_time=t0, - gsp_ids=gsp_ids, - ) - - # Set up dataloader for parallel loading - dataloader_kwargs = dict( - shuffle=False, - batch_size=batch_size, - sampler=None, - batch_sampler=None, - num_workers=num_workers, - collate_fn=stack_np_examples_into_batch, - pin_memory=False, - drop_last=False, - timeout=0, - prefetch_factor=None if num_workers == 0 else 2, - persistent_workers=False, - ) - - return DataLoader(dataset, **dataloader_kwargs) - - -def legacy_squeeze(batch): - batch[BatchKey.gsp_id] = batch[BatchKey.gsp_id].squeeze(1) - return batch - - -def get_legacy_dataloader( - config_filename: str, - t0: pd.Timestamp, - gsp_ids: list[int], - num_workers: int, -): - - # Populate the data config with production data paths - populated_data_config_filename = Path(config_filename).parent / "data_config.yaml" - - populate_data_config_sources( - config_filename, - populated_data_config_filename, - gsp_path=os.environ["DB_URL"], - - ) - - # Set up ID location query object - ds_gsp = next(iter(OpenGSPFromDatabase())) - gsp_id_to_loc = GSPLocationLookup(ds_gsp.x_osgb, ds_gsp.y_osgb) - - # Location and time datapipes - location_pipe = IterableWrapper([gsp_id_to_loc(gsp_id) for gsp_id in gsp_ids]) - t0_datapipe = IterableWrapper([t0]).repeat(len(location_pipe)) - - location_pipe = location_pipe.sharding_filter() - t0_datapipe = t0_datapipe.sharding_filter() - - # Batch datapipe - batch_datapipe = ( - construct_sliced_data_pipeline( - config_filename=populated_data_config_filename, - location_pipe=location_pipe, - t0_datapipe=t0_datapipe, - production=True, - ) - .batch(batch_size) - .map(stack_np_examples_into_batch) - .map(legacy_squeeze) - ) - - # Set up dataloader for parallel loading - dataloader_kwargs = dict( - shuffle=False, - batch_size=None, # batched in datapipe step - sampler=None, - batch_sampler=None, - num_workers=num_workers, - collate_fn=None, - pin_memory=False, - drop_last=False, - timeout=0, - worker_init_fn=None, - prefetch_factor=None if num_workers == 0 else 2, - persistent_workers=False, - ) - - return DataLoader(batch_datapipe, **dataloader_kwargs) - - - def app( t0=None, gsp_ids: list[int] = all_gsp_ids, @@ -497,7 +388,8 @@ def app( config_filename=common_config_path, t0=t0, gsp_ids=gsp_ids, - num_workers=num_workers + batch_size=batch_size, + num_workers=num_workers, ) else: @@ -505,7 +397,8 @@ def app( config_filename=common_config_path, t0=t0, gsp_ids=gsp_ids, - num_workers=num_workers + batch_size=batch_size, + num_workers=num_workers, ) diff --git a/pvnet_app/dataloader.py b/pvnet_app/dataloader.py new file mode 100644 index 0000000..da4615e --- /dev/null +++ b/pvnet_app/dataloader.py @@ -0,0 +1,123 @@ +from pathlib import Path +import os + +import pandas as pd + +from torch.utils.data import DataLoader +from ocf_datapipes.batch import stack_np_examples_into_batch +from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset + +from pvnet_app.utils import populate_data_config_sources + +# Legacy imports +from ocf_datapipes.load import OpenGSPFromDatabase +from torch.utils.data.datapipes.iter import IterableWrapper +from ocf_datapipes.training.pvnet import construct_sliced_data_pipeline +from ocf_datapipes.batch import BatchKey +from pvnet.utils import GSPLocationLookup + + + +def get_dataloader( + config_filename: str, + t0: pd.Timestamp, + gsp_ids: list[int], + batch_size: int, + num_workers: int, +): + + # Populate the data config with production data paths + populated_data_config_filename = Path(config_filename).parent / "data_config.yaml" + + populate_data_config_sources(config_filename, populated_data_config_filename) + + dataset = PVNetUKRegionalDataset( + config_filename=populated_data_config_filename, + start_time=t0, + end_time=t0, + gsp_ids=gsp_ids, + ) + + # Set up dataloader for parallel loading + dataloader_kwargs = dict( + shuffle=False, + batch_size=batch_size, + sampler=None, + batch_sampler=None, + num_workers=num_workers, + collate_fn=stack_np_examples_into_batch, + pin_memory=False, + drop_last=False, + timeout=0, + prefetch_factor=None if num_workers == 0 else 2, + persistent_workers=False, + ) + + return DataLoader(dataset, **dataloader_kwargs) + + +def legacy_squeeze(batch): + batch[BatchKey.gsp_id] = batch[BatchKey.gsp_id].squeeze(1) + return batch + + +def get_legacy_dataloader( + config_filename: str, + t0: pd.Timestamp, + gsp_ids: list[int], + batch_size: int, + num_workers: int, +): + + # Populate the data config with production data paths + populated_data_config_filename = Path(config_filename).parent / "data_config.yaml" + + populate_data_config_sources( + config_filename, + populated_data_config_filename, + gsp_path=os.environ["DB_URL"], + + ) + + # Set up ID location query object + ds_gsp = next(iter(OpenGSPFromDatabase())) + gsp_id_to_loc = GSPLocationLookup(ds_gsp.x_osgb, ds_gsp.y_osgb) + + # Location and time datapipes + location_pipe = IterableWrapper([gsp_id_to_loc(gsp_id) for gsp_id in gsp_ids]) + t0_datapipe = IterableWrapper([t0]).repeat(len(location_pipe)) + + location_pipe = location_pipe.sharding_filter() + t0_datapipe = t0_datapipe.sharding_filter() + + # Batch datapipe + batch_datapipe = ( + construct_sliced_data_pipeline( + config_filename=populated_data_config_filename, + location_pipe=location_pipe, + t0_datapipe=t0_datapipe, + production=True, + ) + .batch(batch_size) + .map(stack_np_examples_into_batch) + .map(legacy_squeeze) + ) + + # Set up dataloader for parallel loading + dataloader_kwargs = dict( + shuffle=False, + batch_size=None, # batched in datapipe step + sampler=None, + batch_sampler=None, + num_workers=num_workers, + collate_fn=None, + pin_memory=False, + drop_last=False, + timeout=0, + worker_init_fn=None, + prefetch_factor=None if num_workers == 0 else 2, + persistent_workers=False, + ) + + return DataLoader(batch_datapipe, **dataloader_kwargs) +