diff --git a/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py b/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py index af150bd..735aa81 100644 --- a/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +++ b/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py @@ -4,7 +4,7 @@ import pandas as pd import xarray as xr from torch.utils.data import Dataset - +import pkg_resources from ocf_data_sampler.load.gsp import open_gsp from ocf_data_sampler.load.nwp import open_nwp @@ -69,11 +69,12 @@ def get_dataset_dict(config: Configuration) -> dict[xr.DataArray, dict[xr.DataAr datasets_dict = {} - # We always assume GSP will be included - da_gsp = open_gsp(zarr_path=in_config.gsp.gsp_zarr_path) + # Load GSP data unless the path is None + if in_config.gsp.gsp_zarr_path: + da_gsp = open_gsp(zarr_path=in_config.gsp.gsp_zarr_path) - # Remove national GSP - datasets_dict["gsp"] = da_gsp.sel(gsp_id=slice(1, None)) + # Remove national GSP + datasets_dict["gsp"] = da_gsp.sel(gsp_id=slice(1, None)) # Load NWP data if in config if in_config.nwp: @@ -172,19 +173,19 @@ def find_valid_t0_times( contiguous_time_periods['sat'] = time_periods - # GSP always assumed to be in data - gsp_config = config.input_data.gsp + if "gsp" in datasets_dict: + gsp_config = config.input_data.gsp - time_periods = find_contiguous_t0_periods( - pd.DatetimeIndex(datasets_dict["gsp"]["time_utc"]), - sample_period_duration=minutes(gsp_config.time_resolution_minutes), - history_duration=minutes(gsp_config.history_minutes), - forecast_duration=minutes(gsp_config.forecast_minutes), - ) + time_periods = find_contiguous_t0_periods( + pd.DatetimeIndex(datasets_dict["gsp"]["time_utc"]), + sample_period_duration=minutes(gsp_config.time_resolution_minutes), + history_duration=minutes(gsp_config.history_minutes), + forecast_duration=minutes(gsp_config.forecast_minutes), + ) - contiguous_time_periods['gsp'] = time_periods + contiguous_time_periods['gsp'] = time_periods - # just get the values (no the keys) + # just get the values (not the keys) contiguous_time_periods_values = list(contiguous_time_periods.values()) # Find joint overlapping contiguous time periods @@ -248,8 +249,8 @@ def slice_datasets_by_space( width_pixels=sat_config.satellite_image_size_pixels_width, ) - # GSP always assumed to be in data - sliced_datasets_dict["gsp"] = datasets_dict["gsp"].sel(gsp_id=location.id) + if "gsp" in datasets_dict: + sliced_datasets_dict["gsp"] = datasets_dict["gsp"].sel(gsp_id=location.id) return sliced_datasets_dict @@ -314,33 +315,33 @@ def slice_datasets_by_time( sat_dropout_time, ) - # GSP always assumed to be included - gsp_config = config.input_data.gsp + if "gsp" in datasets_dict: + gsp_config = config.input_data.gsp - sliced_datasets_dict["gsp_future"] = select_time_slice( - datasets_dict["gsp"], - t0, - sample_period_duration=minutes(gsp_config.time_resolution_minutes), - interval_start=minutes(30), - interval_end=minutes(gsp_config.forecast_minutes), - ) - - sliced_datasets_dict["gsp"] = select_time_slice( - datasets_dict["gsp"], - t0, - sample_period_duration=minutes(gsp_config.time_resolution_minutes), - interval_start=-minutes(gsp_config.history_minutes), - interval_end=minutes(0), - ) + sliced_datasets_dict["gsp_future"] = select_time_slice( + datasets_dict["gsp"], + t0, + sample_period_duration=minutes(gsp_config.time_resolution_minutes), + interval_start=minutes(30), + interval_end=minutes(gsp_config.forecast_minutes), + ) + + sliced_datasets_dict["gsp"] = select_time_slice( + datasets_dict["gsp"], + t0, + sample_period_duration=minutes(gsp_config.time_resolution_minutes), + interval_start=-minutes(gsp_config.history_minutes), + interval_end=minutes(0), + ) - # Dropout on the GSP, but not the future GSP - gsp_dropout_time = draw_dropout_time( - t0, - dropout_timedeltas=minutes(gsp_config.dropout_timedeltas_minutes), - dropout_frac=gsp_config.dropout_fraction, - ) + # Dropout on the GSP, but not the future GSP + gsp_dropout_time = draw_dropout_time( + t0, + dropout_timedeltas=minutes(gsp_config.dropout_timedeltas_minutes), + dropout_frac=gsp_config.dropout_fraction, + ) - sliced_datasets_dict["gsp"] = apply_dropout_time(sliced_datasets_dict["gsp"], gsp_dropout_time) + sliced_datasets_dict["gsp"] = apply_dropout_time(sliced_datasets_dict["gsp"], gsp_dropout_time) return sliced_datasets_dict @@ -385,17 +386,17 @@ def process_and_combine_datasets( # Convert to NumpyBatch numpy_modalities.append(convert_satellite_to_numpy_batch(da_sat)) - # GSP always assumed to be in data gsp_config = config.input_data.gsp - da_gsp = concat_xr_time_utc([dataset_dict["gsp"], dataset_dict["gsp_future"]]) - da_gsp = normalize_gsp(da_gsp) - - numpy_modalities.append( - convert_gsp_to_numpy_batch( - da_gsp, - t0_idx=gsp_config.history_minutes / gsp_config.time_resolution_minutes + if "gsp" in dataset_dict: + da_gsp = concat_xr_time_utc([dataset_dict["gsp"], dataset_dict["gsp_future"]]) + da_gsp = normalize_gsp(da_gsp) + + numpy_modalities.append( + convert_gsp_to_numpy_batch( + da_gsp, + t0_idx=gsp_config.history_minutes / gsp_config.time_resolution_minutes + ) ) - ) # Make sun coords NumpyBatch datetimes = pd.date_range( @@ -440,6 +441,29 @@ def get_locations(ga_gsp: xr.DataArray) -> list[Location]: return locations +def get_gsp_locations() -> list[Location]: + """Get list of locations of all GSPs""" + locations = [] + + # Load UK GSP locations + df_gsp_loc = pd.read_csv( + pkg_resources.resource_filename(__name__, "../data/uk_gsp_locations.csv"), + index_col="gsp_id", + ) + + for gsp_id in np.arange(1, 318): + locations.append( + Location( + coordinate_system = "osgb", + x=df_gsp_loc.loc[gsp_id].x_osgb, + y=df_gsp_loc.loc[gsp_id].y_osgb, + id=gsp_id, + ) + ) + return locations + + + class PVNetUKRegionalDataset(Dataset): def __init__( self, @@ -470,7 +494,7 @@ def __init__( valid_t0_times = valid_t0_times[valid_t0_times<=pd.Timestamp(end_time)] # Construct list of locations to sample from - locations = get_locations(datasets_dict["gsp"]) + locations = get_gsp_locations() # Construct a lookup for locations - useful for users to construct sample by GSP ID location_lookup = {loc.id: loc for loc in locations} @@ -542,4 +566,4 @@ def get_sample(self, t0: pd.Timestamp, gsp_id: int) -> NumpyBatch: location = self.location_lookup[gsp_id] - return self._get_sample(t0, location) + return self._get_sample(t0, location) \ No newline at end of file diff --git a/tests/torch_datasets/test_pvnet_uk_regional.py b/tests/torch_datasets/test_pvnet_uk_regional.py index ff5c2fe..1e4ac88 100644 --- a/tests/torch_datasets/test_pvnet_uk_regional.py +++ b/tests/torch_datasets/test_pvnet_uk_regional.py @@ -1,4 +1,5 @@ import pytest +import tempfile from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset from ocf_datapipes.config.load import load_yaml_configuration @@ -56,4 +57,18 @@ def test_pvnet(pvnet_config_filename): assert sample[BatchKey.gsp_solar_elevation].shape == (7,) - +def test_pvnet_no_gsp(pvnet_config_filename): + + # load config + config = load_yaml_configuration(pvnet_config_filename) + # remove gsp + config.input_data.gsp.gsp_zarr_path = '' + + # save temp config file + with tempfile.NamedTemporaryFile() as temp_config_file: + save_yaml_configuration(config, temp_config_file.name) + # Create dataset object + dataset = PVNetUKRegionalDataset(temp_config_file.name) + + # Generate a sample + _ = dataset[0]