Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/openclimatefix/PVNet into main
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Oct 7, 2024
2 parents 5881da4 + dfce50a commit 16c9b5e
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[bumpversion]
commit = True
tag = True
current_version = 3.0.57
current_version = 3.0.58
message = Bump version: {current_version} → {new_version} [skip ci]

[bumpversion:file:pvnet/__init__.py]
Expand Down
2 changes: 1 addition & 1 deletion pvnet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""PVNet"""
__version__ = "3.0.57"
__version__ = "3.0.58"
15 changes: 3 additions & 12 deletions scripts/backtest_uk_gsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,11 @@
NumpyBatch,
batch_to_tensor,
copy_batch_to_device,
stack_np_examples_into_batch,
)
from ocf_datapipes.config.load import load_yaml_configuration
from ocf_datapipes.load import OpenGSP
from ocf_datapipes.training.pvnet_all_gsp import (
create_t0_datapipe, construct_sliced_data_pipeline
)
from ocf_datapipes.training.common import _get_datapipes_dict
from ocf_datapipes.training.pvnet_all_gsp import construct_sliced_data_pipeline, create_t0_datapipe
from ocf_datapipes.utils.consts import ELEVATION_MEAN, ELEVATION_STD
from omegaconf import DictConfig

Expand All @@ -57,7 +54,6 @@
from tqdm import tqdm

from pvnet.load_model import get_model_from_checkpoints
from pvnet.utils import GSPLocationLookup

# ------------------------------------------------------------------
# USER CONFIGURED VARIABLES
Expand Down Expand Up @@ -143,7 +139,7 @@ def get_available_t0_times(start_datetime, end_datetime, config_path):
# Pop out the config file
config = datapipes_dict.pop("config")

# We are going to abuse the `create_datapipes()` function to find the init-times in
# We are going to abuse the `create_t0_datapipe()` function to find the init-times in
# potential_init_times which we have input data for. To do this, we will feed in some fake GSP
# data which has the potential_init_times as timestamps. This is a bit hacky but works for now

Expand Down Expand Up @@ -171,7 +167,7 @@ def get_available_t0_times(start_datetime, end_datetime, config_path):
# Overwrite the GSP data which is already in the datapipes dict
datapipes_dict["gsp"] = IterableWrapper([ds_fake_gsp])

# Use create_t0_and_loc_datapipes to get datapipe of init-times
# Use create_t0_datapipe to get datapipe of init-times
t0_datapipe = create_t0_datapipe(
datapipes_dict,
configuration=config,
Expand Down Expand Up @@ -199,10 +195,6 @@ def get_times_datapipe(config_path):
Datapipe: A Datapipe yielding init-times
"""

# Set up ID location query object
ds_gsp = get_gsp_ds(config_path)
gsp_id_to_loc = GSPLocationLookup(ds_gsp.x_osgb, ds_gsp.y_osgb)

# Filter the init-times to times we have all input data for
available_target_times = get_available_t0_times(
start_datetime,
Expand Down Expand Up @@ -368,7 +360,6 @@ def get_datapipe(config_path: str) -> NumpyBatch:
# Convert to tensor for model
data_pipeline = data_pipeline.map(batch_to_tensor).set_length(len(t0_datapipe))


return data_pipeline


Expand Down
14 changes: 5 additions & 9 deletions scripts/save_concurrent_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,17 @@
import hydra
import numpy as np
import torch
from ocf_datapipes.batch import BatchKey, batch_to_tensor, stack_np_examples_into_batch
from ocf_datapipes.training.common import (
open_and_return_datapipes,
)
from ocf_datapipes.batch import BatchKey, batch_to_tensor
from ocf_datapipes.training.pvnet_all_gsp import (
construct_time_pipeline, construct_sliced_data_pipeline
construct_sliced_data_pipeline,
construct_time_pipeline,
)
from omegaconf import DictConfig, OmegaConf
from sqlalchemy import exc as sa_exc
from torch.utils.data import DataLoader
from torch.utils.data.datapipes.iter import IterableWrapper
from tqdm import tqdm


warnings.filterwarnings("ignore", category=sa_exc.SAWarning)

logger = logging.getLogger(__name__)
Expand All @@ -63,7 +60,6 @@ def __call__(self, input):


def _get_datapipe(config_path, start_time, end_time, n_batches):

t0_datapipe = construct_time_pipeline(
config_path,
start_time,
Expand All @@ -72,11 +68,11 @@ def _get_datapipe(config_path, start_time, end_time, n_batches):

t0_datapipe = t0_datapipe.header(n_batches)
t0_datapipe = t0_datapipe.sharding_filter()

datapipe = construct_sliced_data_pipeline(
config_path,
t0_datapipe,
)
).map(batch_to_tensor)

return datapipe

Expand Down

0 comments on commit 16c9b5e

Please sign in to comment.