Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use concurrent batch pipeline for ~30x speed up #236

Merged
merged 10 commits into from
Oct 7, 2024
73 changes: 19 additions & 54 deletions scripts/backtest_uk_gsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,11 @@
NumpyBatch,
batch_to_tensor,
copy_batch_to_device,
stack_np_examples_into_batch,
)
from ocf_datapipes.config.load import load_yaml_configuration
from ocf_datapipes.load import OpenGSP
from ocf_datapipes.training.common import create_t0_and_loc_datapipes
from ocf_datapipes.training.pvnet import (
_get_datapipes_dict,
construct_sliced_data_pipeline,
)
from ocf_datapipes.training.common import _get_datapipes_dict
from ocf_datapipes.training.pvnet_all_gsp import construct_sliced_data_pipeline, create_t0_datapipe
from ocf_datapipes.utils.consts import ELEVATION_MEAN, ELEVATION_STD
from omegaconf import DictConfig

Expand All @@ -58,20 +54,19 @@
from tqdm import tqdm

from pvnet.load_model import get_model_from_checkpoints
from pvnet.utils import GSPLocationLookup

# ------------------------------------------------------------------
# USER CONFIGURED VARIABLES
output_dir = "/mnt/disks/backtest/test_backtest"
output_dir = "/mnt/disks/extra_batches/test_backtest"

# Local directory to load the PVNet checkpoint from. By default this should pull the best performing
# checkpoint on the val set
model_chckpoint_dir = "/home/jamesfulton/repos/PVNet/checkpoints/kqaknmuc"
model_chckpoint_dir = "/home/jamesfulton/repos/PVNet/checkpoints/q911tei5"

# Local directory to load the summation model checkpoint from. By default this should pull the best
# performing checkpoint on the val set. If set to None a simple sum is used instead
summation_chckpoint_dir = (
"/home/jamesfulton/repos/PVNet_summation/checkpoints/pvnet_summation/nw673nw2"
"/home/jamesfulton/repos/PVNet_summation/checkpoints/pvnet_summation/73oa4w9t"
)

# Forecasts will be made for all available init times between these
Expand Down Expand Up @@ -144,7 +139,7 @@ def get_available_t0_times(start_datetime, end_datetime, config_path):
# Pop out the config file
config = datapipes_dict.pop("config")

# We are going to abuse the `create_t0_and_loc_datapipes()` function to find the init-times in
# We are going to abuse the `create_t0_datapipe()` function to find the init-times in
# potential_init_times which we have input data for. To do this, we will feed in some fake GSP
# data which has the potential_init_times as timestamps. This is a bit hacky but works for now

Expand Down Expand Up @@ -172,18 +167,15 @@ def get_available_t0_times(start_datetime, end_datetime, config_path):
# Overwrite the GSP data which is already in the datapipes dict
datapipes_dict["gsp"] = IterableWrapper([ds_fake_gsp])

# Use create_t0_and_loc_datapipes to get datapipe of init-times
location_pipe, t0_datapipe = create_t0_and_loc_datapipes(
# Use create_t0_datapipe to get datapipe of init-times
t0_datapipe = create_t0_datapipe(
datapipes_dict,
configuration=config,
key_for_t0="gsp",
shuffle=False,
)

# Create a full list of available init-times. Note that we need to loop over the t0s AND
# locations to avoid the torch datapipes buffer overflow but we don't actually use the location
available_init_times = [t0 for _, t0 in zip(location_pipe, t0_datapipe)]
available_init_times = pd.to_datetime(available_init_times)
# Create a full list of available init-times
available_init_times = pd.to_datetime([t0 for t0 in t0_datapipe])

logger.info(
f"{len(available_init_times)} out of {len(potential_init_times)} "
Expand All @@ -193,22 +185,16 @@ def get_available_t0_times(start_datetime, end_datetime, config_path):
return available_init_times


def get_loctimes_datapipes(config_path):
"""Create location and init-time datapipes
def get_times_datapipe(config_path):
"""Create init-time datapipe

Args:
config_path: Path to data config file

Returns:
tuple: A tuple of datapipes
- Datapipe yielding locations
- Datapipe yielding init-times
Datapipe: A Datapipe yielding init-times
"""

# Set up ID location query object
ds_gsp = get_gsp_ds(config_path)
gsp_id_to_loc = GSPLocationLookup(ds_gsp.x_osgb, ds_gsp.y_osgb)

# Filter the init-times to times we have all input data for
available_target_times = get_available_t0_times(
start_datetime,
Expand All @@ -222,25 +208,13 @@ def get_loctimes_datapipes(config_path):
# the backtest will end up producing
available_target_times.to_frame().to_csv(f"{output_dir}/t0_times.csv")

# Cycle the GSP locations
location_pipe = IterableWrapper([[gsp_id_to_loc(gsp_id) for gsp_id in ALL_GSP_IDS]]).repeat(
num_t0s
)

# Shard and then unbatch the locations so that each worker will generate all samples for all
# GSPs and for a single init-time
location_pipe = location_pipe.sharding_filter()
location_pipe = location_pipe.unbatch(unbatch_level=1)

# Create times datapipe so each worker receives 317 copies of the same datetime for its batch
t0_datapipe = IterableWrapper([[t0 for gsp_id in ALL_GSP_IDS] for t0 in available_target_times])
t0_datapipe = IterableWrapper(available_target_times)
t0_datapipe = t0_datapipe.sharding_filter()
t0_datapipe = t0_datapipe.unbatch(unbatch_level=1)

t0_datapipe = t0_datapipe.set_length(num_t0s * len(ALL_GSP_IDS))
location_pipe = location_pipe.set_length(num_t0s * len(ALL_GSP_IDS))
t0_datapipe = t0_datapipe.set_length(num_t0s)

return location_pipe, t0_datapipe
return t0_datapipe


class ModelPipe:
Expand Down Expand Up @@ -375,25 +349,16 @@ def get_datapipe(config_path: str) -> NumpyBatch:
"""

# Construct location and init-time datapipes
location_pipe, t0_datapipe = get_loctimes_datapipes(config_path)

# Get the number of init-times
num_batches = len(t0_datapipe) // len(ALL_GSP_IDS)
t0_datapipe = get_times_datapipe(config_path)

# Construct sample datapipes
data_pipeline = construct_sliced_data_pipeline(
config_path,
location_pipe,
t0_datapipe,
)

# Batch so that each worker returns a batch of all locations for a single init-time
# Also convert to tensor for model
data_pipeline = (
data_pipeline.batch(len(ALL_GSP_IDS)).map(stack_np_examples_into_batch).map(batch_to_tensor)
)

data_pipeline = data_pipeline.set_length(num_batches)
# Convert to tensor for model
data_pipeline = data_pipeline.map(batch_to_tensor).set_length(len(t0_datapipe))

return data_pipeline

Expand Down
69 changes: 8 additions & 61 deletions scripts/save_concurrent_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,17 @@
import hydra
import numpy as np
import torch
from ocf_datapipes.batch import BatchKey, batch_to_tensor, stack_np_examples_into_batch
from ocf_datapipes.training.common import (
open_and_return_datapipes,
from ocf_datapipes.batch import BatchKey
from ocf_datapipes.training.pvnet_all_gsp import (
construct_sliced_data_pipeline,
construct_time_pipeline,
)
from ocf_datapipes.training.pvnet import construct_loctime_pipelines, construct_sliced_data_pipeline
from omegaconf import DictConfig, OmegaConf
from sqlalchemy import exc as sa_exc
from torch.utils.data import DataLoader
from torch.utils.data.datapipes.iter import IterableWrapper
from tqdm import tqdm

from pvnet.utils import GSPLocationLookup

warnings.filterwarnings("ignore", category=sa_exc.SAWarning)

logger = logging.getLogger(__name__)
Expand All @@ -61,73 +59,22 @@ def __call__(self, input):
torch.save(batch, f"{self.batch_dir}/{i:06}.pt")


def select_first(x):
"""Select zeroth element from indexable object"""
return x[0]


def _get_loctimes_datapipes(config_path, start_time, end_time, n_batches):
# Set up ID location query object
ds_gsp = next(
iter(
open_and_return_datapipes(
config_path,
use_gsp=True,
use_nwp=False,
use_pv=False,
use_sat=False,
use_hrv=False,
use_topo=False,
)["gsp"]
)
)
gsp_id_to_loc = GSPLocationLookup(ds_gsp.x_osgb, ds_gsp.y_osgb)

# Cycle the GSP locations
location_pipe = IterableWrapper([[gsp_id_to_loc(gsp_id) for gsp_id in range(1, 318)]]).repeat(
n_batches
)

# Shard and unbatch so each worker goes through GSP 1-317 for each batch
location_pipe = location_pipe.sharding_filter()
location_pipe = location_pipe.unbatch(unbatch_level=1)

# These two datapipes come from an earlier fork and must be iterated through together
# despite the fact that we don't want these random locations here
random_location_datapipe, t0_datapipe = construct_loctime_pipelines(
def _get_datapipe(config_path, start_time, end_time, n_batches):
t0_datapipe = construct_time_pipeline(
config_path,
start_time,
end_time,
)

# Iterate through both but select only time
t0_datapipe = t0_datapipe.zip(random_location_datapipe).map(select_first)

# Create times datapipe so we'll get the same time over each batch
t0_datapipe = t0_datapipe.header(n_batches)
t0_datapipe = IterableWrapper([[t0 for gsp_id in range(1, 318)] for t0 in t0_datapipe])
t0_datapipe = t0_datapipe.sharding_filter()
t0_datapipe = t0_datapipe.unbatch(unbatch_level=1)

return location_pipe, t0_datapipe


def _get_datapipe(config_path, start_time, end_time, n_batches):
# Open datasets from the config and filter to useable location-time pairs

location_pipe, t0_datapipe = _get_loctimes_datapipes(
config_path, start_time, end_time, n_batches
)

data_pipeline = construct_sliced_data_pipeline(
datapipe = construct_sliced_data_pipeline(
config_path,
location_pipe,
t0_datapipe,
)

data_pipeline = data_pipeline.batch(317).map(stack_np_examples_into_batch).map(batch_to_tensor)

return data_pipeline
return datapipe


def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, dataloader_kwargs):
Expand Down
Loading