From c93996cdf26d892b7186c6a17ae5493f1a5bc59c Mon Sep 17 00:00:00 2001 From: James Fulton <41546094+dfulu@users.noreply.github.com> Date: Mon, 7 Oct 2024 11:33:34 +0100 Subject: [PATCH] Use concurrent batch pipeline for ~30x speed up (#236) * bug fix * use concurrent datapipe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * clean up * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update comment * update comment * save as tensor --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- scripts/backtest_uk_gsp.py | 73 ++++++++---------------------- scripts/save_concurrent_batches.py | 71 ++++------------------------- 2 files changed, 28 insertions(+), 116 deletions(-) diff --git a/scripts/backtest_uk_gsp.py b/scripts/backtest_uk_gsp.py index 33d7eb62..e002af3f 100644 --- a/scripts/backtest_uk_gsp.py +++ b/scripts/backtest_uk_gsp.py @@ -38,15 +38,11 @@ NumpyBatch, batch_to_tensor, copy_batch_to_device, - stack_np_examples_into_batch, ) from ocf_datapipes.config.load import load_yaml_configuration from ocf_datapipes.load import OpenGSP -from ocf_datapipes.training.common import create_t0_and_loc_datapipes -from ocf_datapipes.training.pvnet import ( - _get_datapipes_dict, - construct_sliced_data_pipeline, -) +from ocf_datapipes.training.common import _get_datapipes_dict +from ocf_datapipes.training.pvnet_all_gsp import construct_sliced_data_pipeline, create_t0_datapipe from ocf_datapipes.utils.consts import ELEVATION_MEAN, ELEVATION_STD from omegaconf import DictConfig @@ -58,20 +54,19 @@ from tqdm import tqdm from pvnet.load_model import get_model_from_checkpoints -from pvnet.utils import GSPLocationLookup # ------------------------------------------------------------------ # USER CONFIGURED VARIABLES -output_dir = "/mnt/disks/backtest/test_backtest" +output_dir = "/mnt/disks/extra_batches/test_backtest" # Local directory to load the PVNet checkpoint from. By default this should pull the best performing # checkpoint on the val set -model_chckpoint_dir = "/home/jamesfulton/repos/PVNet/checkpoints/kqaknmuc" +model_chckpoint_dir = "/home/jamesfulton/repos/PVNet/checkpoints/q911tei5" # Local directory to load the summation model checkpoint from. By default this should pull the best # performing checkpoint on the val set. If set to None a simple sum is used instead summation_chckpoint_dir = ( - "/home/jamesfulton/repos/PVNet_summation/checkpoints/pvnet_summation/nw673nw2" + "/home/jamesfulton/repos/PVNet_summation/checkpoints/pvnet_summation/73oa4w9t" ) # Forecasts will be made for all available init times between these @@ -144,7 +139,7 @@ def get_available_t0_times(start_datetime, end_datetime, config_path): # Pop out the config file config = datapipes_dict.pop("config") - # We are going to abuse the `create_t0_and_loc_datapipes()` function to find the init-times in + # We are going to abuse the `create_t0_datapipe()` function to find the init-times in # potential_init_times which we have input data for. To do this, we will feed in some fake GSP # data which has the potential_init_times as timestamps. This is a bit hacky but works for now @@ -172,18 +167,15 @@ def get_available_t0_times(start_datetime, end_datetime, config_path): # Overwrite the GSP data which is already in the datapipes dict datapipes_dict["gsp"] = IterableWrapper([ds_fake_gsp]) - # Use create_t0_and_loc_datapipes to get datapipe of init-times - location_pipe, t0_datapipe = create_t0_and_loc_datapipes( + # Use create_t0_datapipe to get datapipe of init-times + t0_datapipe = create_t0_datapipe( datapipes_dict, configuration=config, - key_for_t0="gsp", shuffle=False, ) - # Create a full list of available init-times. Note that we need to loop over the t0s AND - # locations to avoid the torch datapipes buffer overflow but we don't actually use the location - available_init_times = [t0 for _, t0 in zip(location_pipe, t0_datapipe)] - available_init_times = pd.to_datetime(available_init_times) + # Create a full list of available init-times + available_init_times = pd.to_datetime([t0 for t0 in t0_datapipe]) logger.info( f"{len(available_init_times)} out of {len(potential_init_times)} " @@ -193,22 +185,16 @@ def get_available_t0_times(start_datetime, end_datetime, config_path): return available_init_times -def get_loctimes_datapipes(config_path): - """Create location and init-time datapipes +def get_times_datapipe(config_path): + """Create init-time datapipe Args: config_path: Path to data config file Returns: - tuple: A tuple of datapipes - - Datapipe yielding locations - - Datapipe yielding init-times + Datapipe: A Datapipe yielding init-times """ - # Set up ID location query object - ds_gsp = get_gsp_ds(config_path) - gsp_id_to_loc = GSPLocationLookup(ds_gsp.x_osgb, ds_gsp.y_osgb) - # Filter the init-times to times we have all input data for available_target_times = get_available_t0_times( start_datetime, @@ -222,25 +208,13 @@ def get_loctimes_datapipes(config_path): # the backtest will end up producing available_target_times.to_frame().to_csv(f"{output_dir}/t0_times.csv") - # Cycle the GSP locations - location_pipe = IterableWrapper([[gsp_id_to_loc(gsp_id) for gsp_id in ALL_GSP_IDS]]).repeat( - num_t0s - ) - - # Shard and then unbatch the locations so that each worker will generate all samples for all - # GSPs and for a single init-time - location_pipe = location_pipe.sharding_filter() - location_pipe = location_pipe.unbatch(unbatch_level=1) - # Create times datapipe so each worker receives 317 copies of the same datetime for its batch - t0_datapipe = IterableWrapper([[t0 for gsp_id in ALL_GSP_IDS] for t0 in available_target_times]) + t0_datapipe = IterableWrapper(available_target_times) t0_datapipe = t0_datapipe.sharding_filter() - t0_datapipe = t0_datapipe.unbatch(unbatch_level=1) - t0_datapipe = t0_datapipe.set_length(num_t0s * len(ALL_GSP_IDS)) - location_pipe = location_pipe.set_length(num_t0s * len(ALL_GSP_IDS)) + t0_datapipe = t0_datapipe.set_length(num_t0s) - return location_pipe, t0_datapipe + return t0_datapipe class ModelPipe: @@ -375,25 +349,16 @@ def get_datapipe(config_path: str) -> NumpyBatch: """ # Construct location and init-time datapipes - location_pipe, t0_datapipe = get_loctimes_datapipes(config_path) - - # Get the number of init-times - num_batches = len(t0_datapipe) // len(ALL_GSP_IDS) + t0_datapipe = get_times_datapipe(config_path) # Construct sample datapipes data_pipeline = construct_sliced_data_pipeline( config_path, - location_pipe, t0_datapipe, ) - # Batch so that each worker returns a batch of all locations for a single init-time - # Also convert to tensor for model - data_pipeline = ( - data_pipeline.batch(len(ALL_GSP_IDS)).map(stack_np_examples_into_batch).map(batch_to_tensor) - ) - - data_pipeline = data_pipeline.set_length(num_batches) + # Convert to tensor for model + data_pipeline = data_pipeline.map(batch_to_tensor).set_length(len(t0_datapipe)) return data_pipeline diff --git a/scripts/save_concurrent_batches.py b/scripts/save_concurrent_batches.py index f421887e..37833b9e 100644 --- a/scripts/save_concurrent_batches.py +++ b/scripts/save_concurrent_batches.py @@ -32,19 +32,17 @@ import hydra import numpy as np import torch -from ocf_datapipes.batch import BatchKey, batch_to_tensor, stack_np_examples_into_batch -from ocf_datapipes.training.common import ( - open_and_return_datapipes, +from ocf_datapipes.batch import BatchKey, batch_to_tensor +from ocf_datapipes.training.pvnet_all_gsp import ( + construct_sliced_data_pipeline, + construct_time_pipeline, ) -from ocf_datapipes.training.pvnet import construct_loctime_pipelines, construct_sliced_data_pipeline from omegaconf import DictConfig, OmegaConf from sqlalchemy import exc as sa_exc from torch.utils.data import DataLoader from torch.utils.data.datapipes.iter import IterableWrapper from tqdm import tqdm -from pvnet.utils import GSPLocationLookup - warnings.filterwarnings("ignore", category=sa_exc.SAWarning) logger = logging.getLogger(__name__) @@ -61,73 +59,22 @@ def __call__(self, input): torch.save(batch, f"{self.batch_dir}/{i:06}.pt") -def select_first(x): - """Select zeroth element from indexable object""" - return x[0] - - -def _get_loctimes_datapipes(config_path, start_time, end_time, n_batches): - # Set up ID location query object - ds_gsp = next( - iter( - open_and_return_datapipes( - config_path, - use_gsp=True, - use_nwp=False, - use_pv=False, - use_sat=False, - use_hrv=False, - use_topo=False, - )["gsp"] - ) - ) - gsp_id_to_loc = GSPLocationLookup(ds_gsp.x_osgb, ds_gsp.y_osgb) - - # Cycle the GSP locations - location_pipe = IterableWrapper([[gsp_id_to_loc(gsp_id) for gsp_id in range(1, 318)]]).repeat( - n_batches - ) - - # Shard and unbatch so each worker goes through GSP 1-317 for each batch - location_pipe = location_pipe.sharding_filter() - location_pipe = location_pipe.unbatch(unbatch_level=1) - - # These two datapipes come from an earlier fork and must be iterated through together - # despite the fact that we don't want these random locations here - random_location_datapipe, t0_datapipe = construct_loctime_pipelines( +def _get_datapipe(config_path, start_time, end_time, n_batches): + t0_datapipe = construct_time_pipeline( config_path, start_time, end_time, ) - # Iterate through both but select only time - t0_datapipe = t0_datapipe.zip(random_location_datapipe).map(select_first) - - # Create times datapipe so we'll get the same time over each batch t0_datapipe = t0_datapipe.header(n_batches) - t0_datapipe = IterableWrapper([[t0 for gsp_id in range(1, 318)] for t0 in t0_datapipe]) t0_datapipe = t0_datapipe.sharding_filter() - t0_datapipe = t0_datapipe.unbatch(unbatch_level=1) - return location_pipe, t0_datapipe - - -def _get_datapipe(config_path, start_time, end_time, n_batches): - # Open datasets from the config and filter to useable location-time pairs - - location_pipe, t0_datapipe = _get_loctimes_datapipes( - config_path, start_time, end_time, n_batches - ) - - data_pipeline = construct_sliced_data_pipeline( + datapipe = construct_sliced_data_pipeline( config_path, - location_pipe, t0_datapipe, - ) - - data_pipeline = data_pipeline.batch(317).map(stack_np_examples_into_batch).map(batch_to_tensor) + ).map(batch_to_tensor) - return data_pipeline + return datapipe def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, dataloader_kwargs):