Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow GSP data to be optional #41

Merged
merged 2 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 76 additions & 52 deletions ocf_data_sampler/torch_datasets/pvnet_uk_regional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd
import xarray as xr
from torch.utils.data import Dataset

import pkg_resources

from ocf_data_sampler.load.gsp import open_gsp
from ocf_data_sampler.load.nwp import open_nwp
Expand Down Expand Up @@ -69,11 +69,12 @@ def get_dataset_dict(config: Configuration) -> dict[xr.DataArray, dict[xr.DataAr

datasets_dict = {}

# We always assume GSP will be included
da_gsp = open_gsp(zarr_path=in_config.gsp.gsp_zarr_path)
# Load GSP data unless the path is None
if in_config.gsp.gsp_zarr_path:
da_gsp = open_gsp(zarr_path=in_config.gsp.gsp_zarr_path)

# Remove national GSP
datasets_dict["gsp"] = da_gsp.sel(gsp_id=slice(1, None))
# Remove national GSP
datasets_dict["gsp"] = da_gsp.sel(gsp_id=slice(1, None))

# Load NWP data if in config
if in_config.nwp:
Expand Down Expand Up @@ -172,19 +173,19 @@ def find_valid_t0_times(

contiguous_time_periods['sat'] = time_periods

# GSP always assumed to be in data
gsp_config = config.input_data.gsp
if "gsp" in datasets_dict:
gsp_config = config.input_data.gsp

time_periods = find_contiguous_t0_periods(
pd.DatetimeIndex(datasets_dict["gsp"]["time_utc"]),
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
history_duration=minutes(gsp_config.history_minutes),
forecast_duration=minutes(gsp_config.forecast_minutes),
)
time_periods = find_contiguous_t0_periods(
pd.DatetimeIndex(datasets_dict["gsp"]["time_utc"]),
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
history_duration=minutes(gsp_config.history_minutes),
forecast_duration=minutes(gsp_config.forecast_minutes),
)

contiguous_time_periods['gsp'] = time_periods
contiguous_time_periods['gsp'] = time_periods

# just get the values (no the keys)
# just get the values (not the keys)
contiguous_time_periods_values = list(contiguous_time_periods.values())

# Find joint overlapping contiguous time periods
Expand Down Expand Up @@ -248,8 +249,8 @@ def slice_datasets_by_space(
width_pixels=sat_config.satellite_image_size_pixels_width,
)

# GSP always assumed to be in data
sliced_datasets_dict["gsp"] = datasets_dict["gsp"].sel(gsp_id=location.id)
if "gsp" in datasets_dict:
sliced_datasets_dict["gsp"] = datasets_dict["gsp"].sel(gsp_id=location.id)

return sliced_datasets_dict

Expand Down Expand Up @@ -314,33 +315,33 @@ def slice_datasets_by_time(
sat_dropout_time,
)

# GSP always assumed to be included
gsp_config = config.input_data.gsp
if "gsp" in datasets_dict:
gsp_config = config.input_data.gsp

sliced_datasets_dict["gsp_future"] = select_time_slice(
datasets_dict["gsp"],
t0,
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
interval_start=minutes(30),
interval_end=minutes(gsp_config.forecast_minutes),
)

sliced_datasets_dict["gsp"] = select_time_slice(
datasets_dict["gsp"],
t0,
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
interval_start=-minutes(gsp_config.history_minutes),
interval_end=minutes(0),
)
sliced_datasets_dict["gsp_future"] = select_time_slice(
datasets_dict["gsp"],
t0,
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
interval_start=minutes(30),
interval_end=minutes(gsp_config.forecast_minutes),
)
sliced_datasets_dict["gsp"] = select_time_slice(
datasets_dict["gsp"],
t0,
sample_period_duration=minutes(gsp_config.time_resolution_minutes),
interval_start=-minutes(gsp_config.history_minutes),
interval_end=minutes(0),
)

# Dropout on the GSP, but not the future GSP
gsp_dropout_time = draw_dropout_time(
t0,
dropout_timedeltas=minutes(gsp_config.dropout_timedeltas_minutes),
dropout_frac=gsp_config.dropout_fraction,
)
# Dropout on the GSP, but not the future GSP
gsp_dropout_time = draw_dropout_time(
t0,
dropout_timedeltas=minutes(gsp_config.dropout_timedeltas_minutes),
dropout_frac=gsp_config.dropout_fraction,
)

sliced_datasets_dict["gsp"] = apply_dropout_time(sliced_datasets_dict["gsp"], gsp_dropout_time)
sliced_datasets_dict["gsp"] = apply_dropout_time(sliced_datasets_dict["gsp"], gsp_dropout_time)

return sliced_datasets_dict

Expand Down Expand Up @@ -385,17 +386,17 @@ def process_and_combine_datasets(
# Convert to NumpyBatch
numpy_modalities.append(convert_satellite_to_numpy_batch(da_sat))

# GSP always assumed to be in data
gsp_config = config.input_data.gsp
da_gsp = concat_xr_time_utc([dataset_dict["gsp"], dataset_dict["gsp_future"]])
da_gsp = normalize_gsp(da_gsp)

numpy_modalities.append(
convert_gsp_to_numpy_batch(
da_gsp,
t0_idx=gsp_config.history_minutes / gsp_config.time_resolution_minutes
if "gsp" in dataset_dict:
da_gsp = concat_xr_time_utc([dataset_dict["gsp"], dataset_dict["gsp_future"]])
da_gsp = normalize_gsp(da_gsp)

numpy_modalities.append(
convert_gsp_to_numpy_batch(
da_gsp,
t0_idx=gsp_config.history_minutes / gsp_config.time_resolution_minutes
)
)
)

# Make sun coords NumpyBatch
datetimes = pd.date_range(
Expand Down Expand Up @@ -440,6 +441,29 @@ def get_locations(ga_gsp: xr.DataArray) -> list[Location]:
return locations


def get_gsp_locations() -> list[Location]:
"""Get list of locations of all GSPs"""
locations = []

# Load UK GSP locations
df_gsp_loc = pd.read_csv(
pkg_resources.resource_filename(__name__, "../data/uk_gsp_locations.csv"),
index_col="gsp_id",
)

for gsp_id in np.arange(1, 318):
locations.append(
Location(
coordinate_system = "osgb",
x=df_gsp_loc.loc[gsp_id].x_osgb,
y=df_gsp_loc.loc[gsp_id].y_osgb,
id=gsp_id,
)
)
return locations



class PVNetUKRegionalDataset(Dataset):
def __init__(
self,
Expand Down Expand Up @@ -470,7 +494,7 @@ def __init__(
valid_t0_times = valid_t0_times[valid_t0_times<=pd.Timestamp(end_time)]

# Construct list of locations to sample from
locations = get_locations(datasets_dict["gsp"])
locations = get_gsp_locations()

# Construct a lookup for locations - useful for users to construct sample by GSP ID
location_lookup = {loc.id: loc for loc in locations}
Expand Down Expand Up @@ -542,4 +566,4 @@ def get_sample(self, t0: pd.Timestamp, gsp_id: int) -> NumpyBatch:
location = self.location_lookup[gsp_id]


return self._get_sample(t0, location)
return self._get_sample(t0, location)
17 changes: 16 additions & 1 deletion tests/torch_datasets/test_pvnet_uk_regional.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import tempfile

from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset
from ocf_datapipes.config.load import load_yaml_configuration
Expand Down Expand Up @@ -56,4 +57,18 @@ def test_pvnet(pvnet_config_filename):
assert sample[BatchKey.gsp_solar_elevation].shape == (7,)



def test_pvnet_no_gsp(pvnet_config_filename):

# load config
config = load_yaml_configuration(pvnet_config_filename)
# remove gsp
config.input_data.gsp.gsp_zarr_path = ''

# save temp config file
with tempfile.NamedTemporaryFile() as temp_config_file:
save_yaml_configuration(config, temp_config_file.name)
# Create dataset object
dataset = PVNetUKRegionalDataset(temp_config_file.name)

# Generate a sample
_ = dataset[0]
Loading