Skip to content

Commit

Permalink
Merge pull request #41 from openclimatefix/optional_gsp
Browse files Browse the repository at this point in the history
Allow GSP data to be optional
  • Loading branch information
peterdudfield authored Aug 30, 2024
2 parents e2c8558 + ce0e295 commit 92d9718
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 53 deletions.
128 changes: 76 additions & 52 deletions ocf_data_sampler/torch_datasets/pvnet_uk_regional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd
import xarray as xr
from torch.utils.data import Dataset

import pkg_resources

from ocf_data_sampler.load.gsp import open_gsp
from ocf_data_sampler.load.nwp import open_nwp
Expand Down Expand Up @@ -67,11 +67,12 @@ def get_dataset_dict(config: Configuration) -> dict[xr.DataArray, dict[xr.DataAr

datasets_dict = {}

# We always assume GSP will be included
da_gsp = open_gsp(zarr_path=in_config.gsp.gsp_zarr_path)
# Load GSP data unless the path is None
if in_config.gsp.gsp_zarr_path:
da_gsp = open_gsp(zarr_path=in_config.gsp.gsp_zarr_path)

# Remove national GSP
datasets_dict["gsp"] = da_gsp.sel(gsp_id=slice(1, None))
# Remove national GSP
datasets_dict["gsp"] = da_gsp.sel(gsp_id=slice(1, None))

# Load NWP data if in config
if in_config.nwp:
Expand Down Expand Up @@ -170,19 +171,19 @@ def find_valid_t0_times(

contiguous_time_periods['sat'] = time_periods

# GSP always assumed to be in data
gsp_config = config.input_data.gsp
if "gsp" in datasets_dict:
gsp_config = config.input_data.gsp

time_periods = find_contiguous_t0_periods(
pd.DatetimeIndex(datasets_dict["gsp"]["time_utc"]),
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
history_duration=minutes(gsp_config.history_minutes),
forecast_duration=minutes(gsp_config.forecast_minutes),
)
time_periods = find_contiguous_t0_periods(
pd.DatetimeIndex(datasets_dict["gsp"]["time_utc"]),
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
history_duration=minutes(gsp_config.history_minutes),
forecast_duration=minutes(gsp_config.forecast_minutes),
)

contiguous_time_periods['gsp'] = time_periods
contiguous_time_periods['gsp'] = time_periods

# just get the values (no the keys)
# just get the values (not the keys)
contiguous_time_periods_values = list(contiguous_time_periods.values())

# Find joint overlapping contiguous time periods
Expand Down Expand Up @@ -246,8 +247,8 @@ def slice_datasets_by_space(
width_pixels=sat_config.satellite_image_size_pixels_width,
)

# GSP always assumed to be in data
sliced_datasets_dict["gsp"] = datasets_dict["gsp"].sel(gsp_id=location.id)
if "gsp" in datasets_dict:
sliced_datasets_dict["gsp"] = datasets_dict["gsp"].sel(gsp_id=location.id)

return sliced_datasets_dict

Expand Down Expand Up @@ -312,33 +313,33 @@ def slice_datasets_by_time(
sat_dropout_time,
)

# GSP always assumed to be included
gsp_config = config.input_data.gsp
if "gsp" in datasets_dict:
gsp_config = config.input_data.gsp

sliced_datasets_dict["gsp_future"] = select_time_slice(
datasets_dict["gsp"],
t0,
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
interval_start=minutes(30),
interval_end=minutes(gsp_config.forecast_minutes),
)

sliced_datasets_dict["gsp"] = select_time_slice(
datasets_dict["gsp"],
t0,
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
interval_start=-minutes(gsp_config.history_minutes),
interval_end=minutes(0),
)
sliced_datasets_dict["gsp_future"] = select_time_slice(
datasets_dict["gsp"],
t0,
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
interval_start=minutes(30),
interval_end=minutes(gsp_config.forecast_minutes),
)
sliced_datasets_dict["gsp"] = select_time_slice(
datasets_dict["gsp"],
t0,
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
interval_start=-minutes(gsp_config.history_minutes),
interval_end=minutes(0),
)

# Dropout on the GSP, but not the future GSP
gsp_dropout_time = draw_dropout_time(
t0,
dropout_timedeltas=minutes(gsp_config.dropout_timedeltas_minutes),
dropout_frac=gsp_config.dropout_fraction,
)
# Dropout on the GSP, but not the future GSP
gsp_dropout_time = draw_dropout_time(
t0,
dropout_timedeltas=minutes(gsp_config.dropout_timedeltas_minutes),
dropout_frac=gsp_config.dropout_fraction,
)

sliced_datasets_dict["gsp"] = apply_dropout_time(sliced_datasets_dict["gsp"], gsp_dropout_time)
sliced_datasets_dict["gsp"] = apply_dropout_time(sliced_datasets_dict["gsp"], gsp_dropout_time)

return sliced_datasets_dict

Expand Down Expand Up @@ -383,17 +384,17 @@ def process_and_combine_datasets(
# Convert to NumpyBatch
numpy_modalities.append(convert_satellite_to_numpy_batch(da_sat))

# GSP always assumed to be in data
gsp_config = config.input_data.gsp
da_gsp = concat_xr_time_utc([dataset_dict["gsp"], dataset_dict["gsp_future"]])
da_gsp = normalize_gsp(da_gsp)

numpy_modalities.append(
convert_gsp_to_numpy_batch(
da_gsp,
t0_idx=gsp_config.history_minutes / gsp_config.time_resolution_minutes
if "gsp" in dataset_dict:
da_gsp = concat_xr_time_utc([dataset_dict["gsp"], dataset_dict["gsp_future"]])
da_gsp = normalize_gsp(da_gsp)

numpy_modalities.append(
convert_gsp_to_numpy_batch(
da_gsp,
t0_idx=gsp_config.history_minutes / gsp_config.time_resolution_minutes
)
)
)

# Make sun coords NumpyBatch
datetimes = pd.date_range(
Expand Down Expand Up @@ -438,6 +439,29 @@ def get_locations(ga_gsp: xr.DataArray) -> list[Location]:
return locations


def get_gsp_locations() -> list[Location]:
"""Get list of locations of all GSPs"""
locations = []

# Load UK GSP locations
df_gsp_loc = pd.read_csv(
pkg_resources.resource_filename(__name__, "../data/uk_gsp_locations.csv"),
index_col="gsp_id",
)

for gsp_id in np.arange(1, 318):
locations.append(
Location(
coordinate_system = "osgb",
x=df_gsp_loc.loc[gsp_id].x_osgb,
y=df_gsp_loc.loc[gsp_id].y_osgb,
id=gsp_id,
)
)
return locations



class PVNetUKRegionalDataset(Dataset):
def __init__(
self,
Expand Down Expand Up @@ -468,7 +492,7 @@ def __init__(
valid_t0_times = valid_t0_times[valid_t0_times<=pd.Timestamp(end_time)]

# Construct list of locations to sample from
locations = get_locations(datasets_dict["gsp"])
locations = get_gsp_locations()

# Construct a lookup for locations - useful for users to construct sample by GSP ID
location_lookup = {loc.id: loc for loc in locations}
Expand Down Expand Up @@ -539,4 +563,4 @@ def get_sample(self, t0: pd.Timestamp, gsp_id: int) -> NumpyBatch:

location = self.location_lookup[gsp_id]

return self._get_sample(t0, location)
return self._get_sample(t0, location)
17 changes: 16 additions & 1 deletion tests/torch_datasets/test_pvnet_uk_regional.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import tempfile

from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset
from ocf_datapipes.config.load import load_yaml_configuration
Expand Down Expand Up @@ -56,4 +57,18 @@ def test_pvnet(pvnet_config_filename):
assert sample[BatchKey.gsp_solar_elevation].shape == (7,)



def test_pvnet_no_gsp(pvnet_config_filename):

# load config
config = load_yaml_configuration(pvnet_config_filename)
# remove gsp
config.input_data.gsp.gsp_zarr_path = ''

# save temp config file
with tempfile.NamedTemporaryFile() as temp_config_file:
save_yaml_configuration(config, temp_config_file.name)
# Create dataset object
dataset = PVNetUKRegionalDataset(temp_config_file.name)

# Generate a sample
_ = dataset[0]

0 comments on commit 92d9718

Please sign in to comment.