Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 7, 2024
1 parent 5904d2e commit 8089876
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 14 deletions.
8 changes: 2 additions & 6 deletions scripts/backtest_uk_gsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,11 @@
NumpyBatch,
batch_to_tensor,
copy_batch_to_device,
stack_np_examples_into_batch,
)
from ocf_datapipes.config.load import load_yaml_configuration
from ocf_datapipes.load import OpenGSP
from ocf_datapipes.training.pvnet_all_gsp import (
create_t0_datapipe, construct_sliced_data_pipeline
)
from ocf_datapipes.training.common import _get_datapipes_dict
from ocf_datapipes.training.pvnet_all_gsp import construct_sliced_data_pipeline, create_t0_datapipe
from ocf_datapipes.utils.consts import ELEVATION_MEAN, ELEVATION_STD
from omegaconf import DictConfig

Expand Down Expand Up @@ -201,7 +198,7 @@ def get_times_datapipe(config_path):

# Set up ID location query object
ds_gsp = get_gsp_ds(config_path)
gsp_id_to_loc = GSPLocationLookup(ds_gsp.x_osgb, ds_gsp.y_osgb)
GSPLocationLookup(ds_gsp.x_osgb, ds_gsp.y_osgb)

# Filter the init-times to times we have all input data for
available_target_times = get_available_t0_times(
Expand Down Expand Up @@ -368,7 +365,6 @@ def get_datapipe(config_path: str) -> NumpyBatch:
# Convert to tensor for model
data_pipeline = data_pipeline.map(batch_to_tensor).set_length(len(t0_datapipe))


return data_pipeline


Expand Down
12 changes: 4 additions & 8 deletions scripts/save_concurrent_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,17 @@
import hydra
import numpy as np
import torch
from ocf_datapipes.batch import BatchKey, batch_to_tensor, stack_np_examples_into_batch
from ocf_datapipes.training.common import (
open_and_return_datapipes,
)
from ocf_datapipes.batch import BatchKey
from ocf_datapipes.training.pvnet_all_gsp import (
construct_time_pipeline, construct_sliced_data_pipeline
construct_sliced_data_pipeline,
construct_time_pipeline,
)
from omegaconf import DictConfig, OmegaConf
from sqlalchemy import exc as sa_exc
from torch.utils.data import DataLoader
from torch.utils.data.datapipes.iter import IterableWrapper
from tqdm import tqdm


warnings.filterwarnings("ignore", category=sa_exc.SAWarning)

logger = logging.getLogger(__name__)
Expand All @@ -63,7 +60,6 @@ def __call__(self, input):


def _get_datapipe(config_path, start_time, end_time, n_batches):

t0_datapipe = construct_time_pipeline(
config_path,
start_time,
Expand All @@ -72,7 +68,7 @@ def _get_datapipe(config_path, start_time, end_time, n_batches):

t0_datapipe = t0_datapipe.header(n_batches)
t0_datapipe = t0_datapipe.sharding_filter()

datapipe = construct_sliced_data_pipeline(
config_path,
t0_datapipe,
Expand Down

0 comments on commit 8089876

Please sign in to comment.