Skip to content

Commit

Permalink
refactor dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Sep 16, 2024
1 parent 73bee7e commit 5f62f1f
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 115 deletions.
123 changes: 8 additions & 115 deletions pvnet_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import tempfile
import warnings
from datetime import timedelta
from pathlib import Path

import dask
import pandas as pd
Expand All @@ -26,20 +25,13 @@
from nowcasting_datamodel.models.base import Base_Forecast
from nowcasting_datamodel.read.read_gsp import get_latest_gsp_capacities
from nowcasting_datamodel.save.save import save as save_sql_forecasts
from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset
from ocf_datapipes.batch import stack_np_examples_into_batch, batch_to_tensor, copy_batch_to_device
from ocf_datapipes.batch import batch_to_tensor, copy_batch_to_device
from pvnet.models.base_model import BaseModel as PVNetBaseModel
from pvnet.utils import GSPLocationLookup
from torch.utils.data import DataLoader
import sentry_sdk



import pvnet_app
from pvnet_app.data.nwp import (
download_all_nwp_data,
preprocess_nwp_data,
)
from pvnet_app.data.nwp import download_all_nwp_data, preprocess_nwp_data
from pvnet_app.data.satellite import (
download_all_sat_data,
preprocess_sat_data,
Expand All @@ -53,11 +45,8 @@
save_yaml_config,
)

# Legacy imports
from ocf_datapipes.load import OpenGSPFromDatabase
from torch.utils.data.datapipes.iter import IterableWrapper
from ocf_datapipes.training.pvnet import construct_sliced_data_pipeline
from ocf_datapipes.batch import BatchKey
from pvnet_app.dataloader import get_legacy_dataloader, get_dataloader


# sentry
sentry_sdk.init(
Expand Down Expand Up @@ -224,104 +213,6 @@ def filter(self, record):
# APP MAIN


def get_dataloader(config_filename: str, t0: pd.Timestamp, gsp_ids: list[int], num_workers: int):

# Populate the data config with production data paths
populated_data_config_filename = Path(config_filename).parent / "data_config.yaml"

populate_data_config_sources(config_filename, populated_data_config_filename)

dataset = PVNetUKRegionalDataset(
config_filename=populated_data_config_filename,
start_time=t0,
end_time=t0,
gsp_ids=gsp_ids,
)

# Set up dataloader for parallel loading
dataloader_kwargs = dict(
shuffle=False,
batch_size=batch_size,
sampler=None,
batch_sampler=None,
num_workers=num_workers,
collate_fn=stack_np_examples_into_batch,
pin_memory=False,
drop_last=False,
timeout=0,
prefetch_factor=None if num_workers == 0 else 2,
persistent_workers=False,
)

return DataLoader(dataset, **dataloader_kwargs)


def legacy_squeeze(batch):
batch[BatchKey.gsp_id] = batch[BatchKey.gsp_id].squeeze(1)
return batch


def get_legacy_dataloader(
config_filename: str,
t0: pd.Timestamp,
gsp_ids: list[int],
num_workers: int,
):

# Populate the data config with production data paths
populated_data_config_filename = Path(config_filename).parent / "data_config.yaml"

populate_data_config_sources(
config_filename,
populated_data_config_filename,
gsp_path=os.environ["DB_URL"],

)

# Set up ID location query object
ds_gsp = next(iter(OpenGSPFromDatabase()))
gsp_id_to_loc = GSPLocationLookup(ds_gsp.x_osgb, ds_gsp.y_osgb)

# Location and time datapipes
location_pipe = IterableWrapper([gsp_id_to_loc(gsp_id) for gsp_id in gsp_ids])
t0_datapipe = IterableWrapper([t0]).repeat(len(location_pipe))

location_pipe = location_pipe.sharding_filter()
t0_datapipe = t0_datapipe.sharding_filter()

# Batch datapipe
batch_datapipe = (
construct_sliced_data_pipeline(
config_filename=populated_data_config_filename,
location_pipe=location_pipe,
t0_datapipe=t0_datapipe,
production=True,
)
.batch(batch_size)
.map(stack_np_examples_into_batch)
.map(legacy_squeeze)
)

# Set up dataloader for parallel loading
dataloader_kwargs = dict(
shuffle=False,
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
num_workers=num_workers,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
prefetch_factor=None if num_workers == 0 else 2,
persistent_workers=False,
)

return DataLoader(batch_datapipe, **dataloader_kwargs)



def app(
t0=None,
gsp_ids: list[int] = all_gsp_ids,
Expand Down Expand Up @@ -497,15 +388,17 @@ def app(
config_filename=common_config_path,
t0=t0,
gsp_ids=gsp_ids,
num_workers=num_workers
batch_size=batch_size,
num_workers=num_workers,
)

else:
dataloader = get_dataloader(
config_filename=common_config_path,
t0=t0,
gsp_ids=gsp_ids,
num_workers=num_workers
batch_size=batch_size,
num_workers=num_workers,
)


Expand Down
123 changes: 123 additions & 0 deletions pvnet_app/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from pathlib import Path
import os

import pandas as pd

from torch.utils.data import DataLoader
from ocf_datapipes.batch import stack_np_examples_into_batch
from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset

from pvnet_app.utils import populate_data_config_sources

# Legacy imports
from ocf_datapipes.load import OpenGSPFromDatabase
from torch.utils.data.datapipes.iter import IterableWrapper
from ocf_datapipes.training.pvnet import construct_sliced_data_pipeline
from ocf_datapipes.batch import BatchKey
from pvnet.utils import GSPLocationLookup



def get_dataloader(
config_filename: str,
t0: pd.Timestamp,
gsp_ids: list[int],
batch_size: int,
num_workers: int,
):

# Populate the data config with production data paths
populated_data_config_filename = Path(config_filename).parent / "data_config.yaml"

populate_data_config_sources(config_filename, populated_data_config_filename)

dataset = PVNetUKRegionalDataset(
config_filename=populated_data_config_filename,
start_time=t0,
end_time=t0,
gsp_ids=gsp_ids,
)

# Set up dataloader for parallel loading
dataloader_kwargs = dict(
shuffle=False,
batch_size=batch_size,
sampler=None,
batch_sampler=None,
num_workers=num_workers,
collate_fn=stack_np_examples_into_batch,
pin_memory=False,
drop_last=False,
timeout=0,
prefetch_factor=None if num_workers == 0 else 2,
persistent_workers=False,
)

return DataLoader(dataset, **dataloader_kwargs)


def legacy_squeeze(batch):
batch[BatchKey.gsp_id] = batch[BatchKey.gsp_id].squeeze(1)
return batch


def get_legacy_dataloader(
config_filename: str,
t0: pd.Timestamp,
gsp_ids: list[int],
batch_size: int,
num_workers: int,
):

# Populate the data config with production data paths
populated_data_config_filename = Path(config_filename).parent / "data_config.yaml"

populate_data_config_sources(
config_filename,
populated_data_config_filename,
gsp_path=os.environ["DB_URL"],

)

# Set up ID location query object
ds_gsp = next(iter(OpenGSPFromDatabase()))
gsp_id_to_loc = GSPLocationLookup(ds_gsp.x_osgb, ds_gsp.y_osgb)

# Location and time datapipes
location_pipe = IterableWrapper([gsp_id_to_loc(gsp_id) for gsp_id in gsp_ids])
t0_datapipe = IterableWrapper([t0]).repeat(len(location_pipe))

location_pipe = location_pipe.sharding_filter()
t0_datapipe = t0_datapipe.sharding_filter()

# Batch datapipe
batch_datapipe = (
construct_sliced_data_pipeline(
config_filename=populated_data_config_filename,
location_pipe=location_pipe,
t0_datapipe=t0_datapipe,
production=True,
)
.batch(batch_size)
.map(stack_np_examples_into_batch)
.map(legacy_squeeze)
)

# Set up dataloader for parallel loading
dataloader_kwargs = dict(
shuffle=False,
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
num_workers=num_workers,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
prefetch_factor=None if num_workers == 0 else 2,
persistent_workers=False,
)

return DataLoader(batch_datapipe, **dataloader_kwargs)

0 comments on commit 5f62f1f

Please sign in to comment.