diff --git a/ocf_data_sampler/load/site.py b/ocf_data_sampler/load/site.py index a303fe2..81b7719 100644 --- a/ocf_data_sampler/load/site.py +++ b/ocf_data_sampler/load/site.py @@ -8,23 +8,23 @@ def open_site(sites_config: Site) -> xr.DataArray: # Load site generation xr.Dataset - data_ds = xr.open_dataset(sites_config.file_path) + site_generation_ds = xr.open_dataset(sites_config.file_path) # Load site generation data metadata_df = pd.read_csv(sites_config.metadata_file_path, index_col="site_id") - # Add coordinates - ds = data_ds.assign_coords( - latitude=(metadata_df.latitude.to_xarray()), - longitude=(metadata_df.longitude.to_xarray()), - capacity_kwp=data_ds.capacity_kwp, + # Ensure metadata aligns with the site_id dimension in data_ds + metadata_df = metadata_df.reindex(site_generation_ds.site_id.values) + + # Assign coordinates to the Dataset using the aligned metadata + site_generation_ds = site_generation_ds.assign_coords( + latitude=("site_id", metadata_df["latitude"].values), + longitude=("site_id", metadata_df["longitude"].values), + capacity_kwp=("site_id", metadata_df["capacity_kwp"].values), ) # Sanity checks - assert np.isfinite(data_ds.capacity_kwp.values).all() - assert (data_ds.capacity_kwp.values > 0).all() + assert np.isfinite(site_generation_ds.capacity_kwp.values).all() + assert (site_generation_ds.capacity_kwp.values > 0).all() assert metadata_df.index.is_unique - - return ds.generation_kw - - + return site_generation_ds.generation_kw diff --git a/ocf_data_sampler/torch_datasets/__init__.py b/ocf_data_sampler/torch_datasets/__init__.py index 8b13789..a185aa3 100644 --- a/ocf_data_sampler/torch_datasets/__init__.py +++ b/ocf_data_sampler/torch_datasets/__init__.py @@ -1 +1,2 @@ - +from .pvnet_uk_regional import PVNetUKRegionalDataset +from .site import SitesDataset diff --git a/ocf_data_sampler/torch_datasets/process_and_combine.py b/ocf_data_sampler/torch_datasets/process_and_combine.py index a732ef0..cae1e5f 100644 --- a/ocf_data_sampler/torch_datasets/process_and_combine.py +++ b/ocf_data_sampler/torch_datasets/process_and_combine.py @@ -1,6 +1,7 @@ import numpy as np import pandas as pd import xarray as xr +from typing import Tuple from ocf_data_sampler.config import Configuration from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS @@ -9,7 +10,6 @@ convert_satellite_to_numpy_batch, convert_gsp_to_numpy_batch, make_sun_position_numpy_batch, - convert_site_to_numpy_batch, ) from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey @@ -73,18 +73,6 @@ def process_and_combine_datasets( } ) - - if "site" in dataset_dict: - site_config = config.input_data.site - da_sites = dataset_dict["site"] - da_sites = da_sites / da_sites.capacity_kwp - - numpy_modalities.append( - convert_site_to_numpy_batch( - da_sites, t0_idx=-site_config.interval_start_minutes / site_config.time_resolution_minutes - ) - ) - if target_key == 'gsp': # Make sun coords NumpyBatch datetimes = pd.date_range( @@ -95,16 +83,6 @@ def process_and_combine_datasets( lon, lat = osgb_to_lon_lat(location.x, location.y) - elif target_key == 'site': - # Make sun coords NumpyBatch - datetimes = pd.date_range( - t0+minutes(site_config.interval_start_minutes), - t0+minutes(site_config.interval_end_minutes), - freq=minutes(site_config.time_resolution_minutes), - ) - - lon, lat = location.x, location.y - numpy_modalities.append( make_sun_position_numpy_batch(datetimes, lon, lat, key_prefix=target_key) ) @@ -115,6 +93,47 @@ def process_and_combine_datasets( return combined_sample +def process_and_combine_site_sample_dict( + dataset_dict: dict, + config: Configuration, +) -> xr.Dataset: + """ + Normalize and combine data into a single xr Dataset + + Args: + dataset_dict: dict containing sliced xr DataArrays + config: Configuration for the model + + Returns: + xr.Dataset: A merged Dataset with nans filled in. + + """ + + data_arrays = [] + + if "nwp" in dataset_dict: + for nwp_key, da_nwp in dataset_dict["nwp"].items(): + # Standardise + provider = config.input_data.nwp[nwp_key].provider + da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider] + data_arrays.append((f"nwp-{provider}", da_nwp)) + + if "sat" in dataset_dict: + # TODO add some satellite normalisation + da_sat = dataset_dict["sat"] + data_arrays.append(("satellite", da_sat)) + + if "site" in dataset_dict: + # site_config = config.input_data.site + da_sites = dataset_dict["site"] + da_sites = da_sites / da_sites.capacity_kwp + data_arrays.append(("sites", da_sites)) + + combined_sample_dataset = merge_arrays(data_arrays) + + # Fill any nan values + return combined_sample_dataset.fillna(0.0) + def merge_dicts(list_of_dicts: list[dict]) -> dict: """Merge a list of dictionaries into a single dictionary""" @@ -124,6 +143,59 @@ def merge_dicts(list_of_dicts: list[dict]) -> dict: combined_dict.update(d) return combined_dict +def merge_arrays(normalised_data_arrays: list[Tuple[str, xr.DataArray]]) -> xr.Dataset: + """ + Combine a list of DataArrays into a single Dataset with unique naming conventions. + + Args: + list_of_arrays: List of tuples where each tuple contains: + - A string (key name). + - An xarray.DataArray. + + Returns: + xr.Dataset: A merged Dataset with uniquely named variables, coordinates, and dimensions. + """ + datasets = [] + + for key, data_array in normalised_data_arrays: + # Ensure all attributes are strings for consistency + data_array = data_array.assign_attrs( + {attr_key: str(attr_value) for attr_key, attr_value in data_array.attrs.items()} + ) + + # Convert DataArray to Dataset with the variable name as the key + dataset = data_array.to_dataset(name=key) + + # Prepend key name to all dimension and coordinate names for uniqueness + dataset = dataset.rename( + {dim: f"{key}__{dim}" for dim in dataset.dims if dim not in dataset.coords} + ) + dataset = dataset.rename( + {coord: f"{key}__{coord}" for coord in dataset.coords} + ) + + # Handle concatenation dimension if applicable + concat_dim = ( + f"{key}__target_time_utc" if f"{key}__target_time_utc" in dataset.coords + else f"{key}__time_utc" + ) + + if f"{key}__init_time_utc" in dataset.coords: + init_coord = f"{key}__init_time_utc" + if dataset[init_coord].ndim == 0: # Check if scalar + expanded_init_times = [dataset[init_coord].values] * len(dataset[concat_dim]) + dataset = dataset.assign_coords({init_coord: (concat_dim, expanded_init_times)}) + + datasets.append(dataset) + + # Ensure all datasets are valid xarray.Dataset objects + for ds in datasets: + assert isinstance(ds, xr.Dataset), f"Object is not an xr.Dataset: {type(ds)}" + + # Merge all prepared datasets + combined_dataset = xr.merge(datasets) + + return combined_dataset def fill_nans_in_arrays(batch: dict) -> dict: """Fills all NaN values in each np.ndarray in the batch dictionary with zeros. diff --git a/ocf_data_sampler/torch_datasets/site.py b/ocf_data_sampler/torch_datasets/site.py index 3ec2fc3..96d5c8c 100644 --- a/ocf_data_sampler/torch_datasets/site.py +++ b/ocf_data_sampler/torch_datasets/site.py @@ -15,7 +15,7 @@ slice_datasets_by_time, slice_datasets_by_space ) from ocf_data_sampler.utils import minutes -from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets, compute +from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_site_sample_dict from ocf_data_sampler.torch_datasets.valid_time_periods import find_valid_time_periods xr.set_options(keep_attrs=True) @@ -152,10 +152,9 @@ def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict: """ sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config) sample_dict = slice_datasets_by_time(sample_dict, t0, self.config) - sample_dict = compute(sample_dict) - - sample = process_and_combine_datasets(sample_dict, self.config, t0, location, target_key='site') + sample = process_and_combine_site_sample_dict(sample_dict, self.config) + sample = sample.compute() return sample def get_location_from_site_id(self, site_id): diff --git a/tests/select/test_select_time_slice.py b/tests/select/test_select_time_slice.py index b4c8d9c..e69277e 100644 --- a/tests/select/test_select_time_slice.py +++ b/tests/select/test_select_time_slice.py @@ -197,7 +197,7 @@ def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str): t0 = pd.Timestamp(f"2024-01-02 {t0_str}") interval_start = pd.Timedelta(-6, "h") interval_end = pd.Timedelta(3, "h") - freq = pd.Timedelta("1H") + freq = pd.Timedelta("1h") dropout_timedelta = pd.Timedelta("-2h") t0_delayed = (t0 + dropout_timedelta).floor(NWP_FREQ) diff --git a/tests/torch_datasets/test_pvnet_uk_regional.py b/tests/torch_datasets/test_pvnet_uk_regional.py index ee8ea8a..3265bfb 100644 --- a/tests/torch_datasets/test_pvnet_uk_regional.py +++ b/tests/torch_datasets/test_pvnet_uk_regional.py @@ -1,7 +1,7 @@ import pytest import tempfile -from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset +from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration from ocf_data_sampler.numpy_batch import NWPBatchKey, GSPBatchKey, SatelliteBatchKey diff --git a/tests/torch_datasets/test_site.py b/tests/torch_datasets/test_site.py index a7423ac..71f4d98 100644 --- a/tests/torch_datasets/test_site.py +++ b/tests/torch_datasets/test_site.py @@ -1,11 +1,12 @@ import pandas as pd import pytest -from ocf_data_sampler.torch_datasets.site import SitesDataset +from ocf_data_sampler.torch_datasets import SitesDataset from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey from ocf_data_sampler.numpy_batch.site import SiteBatchKey from ocf_data_sampler.numpy_batch.satellite import SatelliteBatchKey +from xarray import Dataset @pytest.fixture() @@ -34,31 +35,26 @@ def test_site(site_config_filename): # Generate a sample sample = dataset[0] - assert isinstance(sample, dict) + assert isinstance(sample, Dataset) - for key in [ - NWPBatchKey.nwp, - SatelliteBatchKey.satellite_actual, - SiteBatchKey.generation, - SiteBatchKey.site_solar_azimuth, - SiteBatchKey.site_solar_elevation, - ]: - assert key in sample + # Expected dimensions and data variables + expected_dims = {'satellite__x_geostationary', 'sites__time_utc', 'nwp-ukv__target_time_utc', + 'nwp-ukv__x_osgb', 'satellite__channel', 'satellite__y_geostationary', + 'satellite__time_utc', 'nwp-ukv__channel', 'nwp-ukv__y_osgb'} + expected_data_vars = {"nwp-ukv", "satellite", "sites"} - for nwp_source in ["ukv"]: - assert nwp_source in sample[NWPBatchKey.nwp] + # Check dimensions + assert set(sample.dims) == expected_dims, f"Missing or extra dimensions: {set(sample.dims) ^ expected_dims}" + # Check data variables + assert set(sample.data_vars) == expected_data_vars, f"Missing or extra data variables: {set(sample.data_vars) ^ expected_data_vars}" # check the shape of the data is correct # 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels - assert sample[SatelliteBatchKey.satellite_actual].shape == (7, 1, 2, 2) + assert sample["satellite"].values.shape == (7, 1, 2, 2) # 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels - assert sample[NWPBatchKey.nwp]["ukv"][NWPBatchKey.nwp].shape == (4, 1, 2, 2) - # 3 hours of 30 minute data (inclusive) - assert sample[SiteBatchKey.generation].shape == (4,) - # Solar angles have same shape as GSP data - assert sample[SiteBatchKey.site_solar_azimuth].shape == (4,) - assert sample[SiteBatchKey.site_solar_elevation].shape == (4,) - + assert sample["nwp-ukv"].values.shape == (4, 1, 2, 2) + # 1.5 hours of 30 minute data (inclusive) + assert sample["sites"].values.shape == (4,) def test_site_time_filter_start(site_config_filename):