diff --git a/scripts/backtest_uk_gsp.py b/scripts/backtest_uk_gsp.py index 2d696ffe..f492453f 100644 --- a/scripts/backtest_uk_gsp.py +++ b/scripts/backtest_uk_gsp.py @@ -38,14 +38,11 @@ NumpyBatch, batch_to_tensor, copy_batch_to_device, - stack_np_examples_into_batch, ) from ocf_datapipes.config.load import load_yaml_configuration from ocf_datapipes.load import OpenGSP -from ocf_datapipes.training.pvnet_all_gsp import ( - create_t0_datapipe, construct_sliced_data_pipeline -) from ocf_datapipes.training.common import _get_datapipes_dict +from ocf_datapipes.training.pvnet_all_gsp import construct_sliced_data_pipeline, create_t0_datapipe from ocf_datapipes.utils.consts import ELEVATION_MEAN, ELEVATION_STD from omegaconf import DictConfig @@ -201,7 +198,7 @@ def get_times_datapipe(config_path): # Set up ID location query object ds_gsp = get_gsp_ds(config_path) - gsp_id_to_loc = GSPLocationLookup(ds_gsp.x_osgb, ds_gsp.y_osgb) + GSPLocationLookup(ds_gsp.x_osgb, ds_gsp.y_osgb) # Filter the init-times to times we have all input data for available_target_times = get_available_t0_times( @@ -368,7 +365,6 @@ def get_datapipe(config_path: str) -> NumpyBatch: # Convert to tensor for model data_pipeline = data_pipeline.map(batch_to_tensor).set_length(len(t0_datapipe)) - return data_pipeline diff --git a/scripts/save_concurrent_batches.py b/scripts/save_concurrent_batches.py index 3fc9df20..ed32dfba 100644 --- a/scripts/save_concurrent_batches.py +++ b/scripts/save_concurrent_batches.py @@ -32,12 +32,10 @@ import hydra import numpy as np import torch -from ocf_datapipes.batch import BatchKey, batch_to_tensor, stack_np_examples_into_batch -from ocf_datapipes.training.common import ( - open_and_return_datapipes, -) +from ocf_datapipes.batch import BatchKey from ocf_datapipes.training.pvnet_all_gsp import ( - construct_time_pipeline, construct_sliced_data_pipeline + construct_sliced_data_pipeline, + construct_time_pipeline, ) from omegaconf import DictConfig, OmegaConf from sqlalchemy import exc as sa_exc @@ -45,7 +43,6 @@ from torch.utils.data.datapipes.iter import IterableWrapper from tqdm import tqdm - warnings.filterwarnings("ignore", category=sa_exc.SAWarning) logger = logging.getLogger(__name__) @@ -63,7 +60,6 @@ def __call__(self, input): def _get_datapipe(config_path, start_time, end_time, n_batches): - t0_datapipe = construct_time_pipeline( config_path, start_time, @@ -72,7 +68,7 @@ def _get_datapipe(config_path, start_time, end_time, n_batches): t0_datapipe = t0_datapipe.header(n_batches) t0_datapipe = t0_datapipe.sharding_filter() - + datapipe = construct_sliced_data_pipeline( config_path, t0_datapipe,