Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Site torch dataset update #82

Merged
merged 5 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions ocf_data_sampler/load/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,23 @@
def open_site(sites_config: Site) -> xr.DataArray:

# Load site generation xr.Dataset
data_ds = xr.open_dataset(sites_config.file_path)
site_generation_ds = xr.open_dataset(sites_config.file_path)

# Load site generation data
metadata_df = pd.read_csv(sites_config.metadata_file_path, index_col="site_id")

# Add coordinates
ds = data_ds.assign_coords(
latitude=(metadata_df.latitude.to_xarray()),
longitude=(metadata_df.longitude.to_xarray()),
capacity_kwp=data_ds.capacity_kwp,
# Ensure metadata aligns with the site_id dimension in data_ds
metadata_df = metadata_df.reindex(site_generation_ds.site_id.values)

# Assign coordinates to the Dataset using the aligned metadata
site_generation_ds = site_generation_ds.assign_coords(
latitude=("site_id", metadata_df["latitude"].values),
longitude=("site_id", metadata_df["longitude"].values),
capacity_kwp=("site_id", metadata_df["capacity_kwp"].values),
)

# Sanity checks
assert np.isfinite(data_ds.capacity_kwp.values).all()
assert (data_ds.capacity_kwp.values > 0).all()
assert np.isfinite(site_generation_ds.capacity_kwp.values).all()
assert (site_generation_ds.capacity_kwp.values > 0).all()
assert metadata_df.index.is_unique

return ds.generation_kw


return site_generation_ds.generation_kw
3 changes: 2 additions & 1 deletion ocf_data_sampler/torch_datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@

from .pvnet_uk_regional import PVNetUKRegionalDataset
from .site import SitesDataset
118 changes: 95 additions & 23 deletions ocf_data_sampler/torch_datasets/process_and_combine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pandas as pd
import xarray as xr
from typing import Tuple

from ocf_data_sampler.config import Configuration
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
Expand All @@ -9,7 +10,6 @@
convert_satellite_to_numpy_batch,
convert_gsp_to_numpy_batch,
make_sun_position_numpy_batch,
convert_site_to_numpy_batch,
)
from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey
from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
Expand Down Expand Up @@ -73,18 +73,6 @@ def process_and_combine_datasets(
}
)


if "site" in dataset_dict:
site_config = config.input_data.site
da_sites = dataset_dict["site"]
da_sites = da_sites / da_sites.capacity_kwp

numpy_modalities.append(
convert_site_to_numpy_batch(
da_sites, t0_idx=-site_config.interval_start_minutes / site_config.time_resolution_minutes
)
)

if target_key == 'gsp':
# Make sun coords NumpyBatch
datetimes = pd.date_range(
Expand All @@ -95,16 +83,6 @@ def process_and_combine_datasets(

lon, lat = osgb_to_lon_lat(location.x, location.y)

elif target_key == 'site':
# Make sun coords NumpyBatch
datetimes = pd.date_range(
t0+minutes(site_config.interval_start_minutes),
t0+minutes(site_config.interval_end_minutes),
freq=minutes(site_config.time_resolution_minutes),
)

lon, lat = location.x, location.y

numpy_modalities.append(
make_sun_position_numpy_batch(datetimes, lon, lat, key_prefix=target_key)
)
Expand All @@ -115,6 +93,47 @@ def process_and_combine_datasets(

return combined_sample

def process_and_combine_site_sample_dict(
dataset_dict: dict,
config: Configuration,
) -> xr.Dataset:
"""
Normalize and combine data into a single xr Dataset

Args:
dataset_dict: dict containing sliced xr DataArrays
config: Configuration for the model

Returns:
xr.Dataset: A merged Dataset with nans filled in.

"""

data_arrays = []

if "nwp" in dataset_dict:
for nwp_key, da_nwp in dataset_dict["nwp"].items():
# Standardise
provider = config.input_data.nwp[nwp_key].provider
da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]
data_arrays.append((f"nwp-{provider}", da_nwp))

if "sat" in dataset_dict:
# TODO add some satellite normalisation
da_sat = dataset_dict["sat"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a TODO here for normalization satellite data

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#87

data_arrays.append(("satellite", da_sat))

if "site" in dataset_dict:
# site_config = config.input_data.site
da_sites = dataset_dict["site"]
da_sites = da_sites / da_sites.capacity_kwp
data_arrays.append(("sites", da_sites))

combined_sample_dataset = merge_arrays(data_arrays)

# Fill any nan values
return combined_sample_dataset.fillna(0.0)


def merge_dicts(list_of_dicts: list[dict]) -> dict:
"""Merge a list of dictionaries into a single dictionary"""
Expand All @@ -124,6 +143,59 @@ def merge_dicts(list_of_dicts: list[dict]) -> dict:
combined_dict.update(d)
return combined_dict

def merge_arrays(normalised_data_arrays: list[Tuple[str, xr.DataArray]]) -> xr.Dataset:
"""
Combine a list of DataArrays into a single Dataset with unique naming conventions.

Args:
list_of_arrays: List of tuples where each tuple contains:
- A string (key name).
- An xarray.DataArray.

Returns:
xr.Dataset: A merged Dataset with uniquely named variables, coordinates, and dimensions.
"""
datasets = []

for key, data_array in normalised_data_arrays:
# Ensure all attributes are strings for consistency
data_array = data_array.assign_attrs(
{attr_key: str(attr_value) for attr_key, attr_value in data_array.attrs.items()}
)

# Convert DataArray to Dataset with the variable name as the key
dataset = data_array.to_dataset(name=key)

# Prepend key name to all dimension and coordinate names for uniqueness
dataset = dataset.rename(
{dim: f"{key}__{dim}" for dim in dataset.dims if dim not in dataset.coords}
)
dataset = dataset.rename(
{coord: f"{key}__{coord}" for coord in dataset.coords}
)

# Handle concatenation dimension if applicable
concat_dim = (
f"{key}__target_time_utc" if f"{key}__target_time_utc" in dataset.coords
else f"{key}__time_utc"
)

if f"{key}__init_time_utc" in dataset.coords:
init_coord = f"{key}__init_time_utc"
if dataset[init_coord].ndim == 0: # Check if scalar
expanded_init_times = [dataset[init_coord].values] * len(dataset[concat_dim])
dataset = dataset.assign_coords({init_coord: (concat_dim, expanded_init_times)})

datasets.append(dataset)

# Ensure all datasets are valid xarray.Dataset objects
for ds in datasets:
assert isinstance(ds, xr.Dataset), f"Object is not an xr.Dataset: {type(ds)}"

# Merge all prepared datasets
combined_dataset = xr.merge(datasets)

return combined_dataset
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity bc I think I've missed a conversation somewhere: am I right in thinking that eventually gsp version will also output netcdfs and we will merge it into this function and move convert_to_numpy_batch into the training pipeline?

Copy link
Member Author

@Sukh-P Sukh-P Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep exactly I think that is the current idea, the convert_to_numpy_batch logic can still be in this repo and then we can call that in PVNet in the training pipeline


def fill_nans_in_arrays(batch: dict) -> dict:
"""Fills all NaN values in each np.ndarray in the batch dictionary with zeros.
Expand Down
7 changes: 3 additions & 4 deletions ocf_data_sampler/torch_datasets/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
slice_datasets_by_time, slice_datasets_by_space
)
from ocf_data_sampler.utils import minutes
from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets, compute
from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_site_sample_dict
from ocf_data_sampler.torch_datasets.valid_time_periods import find_valid_time_periods

xr.set_options(keep_attrs=True)
Expand Down Expand Up @@ -152,10 +152,9 @@ def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict:
"""
sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config)
sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
sample_dict = compute(sample_dict)

sample = process_and_combine_datasets(sample_dict, self.config, t0, location, target_key='site')

sample = process_and_combine_site_sample_dict(sample_dict, self.config)
sample = sample.compute()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

return sample

def get_location_from_site_id(self, site_id):
Expand Down
2 changes: 1 addition & 1 deletion tests/select/test_select_time_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str):
t0 = pd.Timestamp(f"2024-01-02 {t0_str}")
interval_start = pd.Timedelta(-6, "h")
interval_end = pd.Timedelta(3, "h")
freq = pd.Timedelta("1H")
freq = pd.Timedelta("1h")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To stop a syntax deprecation warning from popping up

dropout_timedelta = pd.Timedelta("-2h")

t0_delayed = (t0 + dropout_timedelta).floor(NWP_FREQ)
Expand Down
2 changes: 1 addition & 1 deletion tests/torch_datasets/test_pvnet_uk_regional.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import tempfile

from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset
from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset
from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
from ocf_data_sampler.numpy_batch import NWPBatchKey, GSPBatchKey, SatelliteBatchKey

Expand Down
36 changes: 16 additions & 20 deletions tests/torch_datasets/test_site.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import pandas as pd
import pytest

from ocf_data_sampler.torch_datasets.site import SitesDataset
from ocf_data_sampler.torch_datasets import SitesDataset
from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
from ocf_data_sampler.numpy_batch.site import SiteBatchKey
from ocf_data_sampler.numpy_batch.satellite import SatelliteBatchKey
from xarray import Dataset


@pytest.fixture()
Expand Down Expand Up @@ -34,31 +35,26 @@ def test_site(site_config_filename):
# Generate a sample
sample = dataset[0]

assert isinstance(sample, dict)
assert isinstance(sample, Dataset)

for key in [
NWPBatchKey.nwp,
SatelliteBatchKey.satellite_actual,
SiteBatchKey.generation,
SiteBatchKey.site_solar_azimuth,
SiteBatchKey.site_solar_elevation,
]:
assert key in sample
# Expected dimensions and data variables
expected_dims = {'satellite__x_geostationary', 'sites__time_utc', 'nwp-ukv__target_time_utc',
'nwp-ukv__x_osgb', 'satellite__channel', 'satellite__y_geostationary',
'satellite__time_utc', 'nwp-ukv__channel', 'nwp-ukv__y_osgb'}
expected_data_vars = {"nwp-ukv", "satellite", "sites"}

for nwp_source in ["ukv"]:
assert nwp_source in sample[NWPBatchKey.nwp]
# Check dimensions
assert set(sample.dims) == expected_dims, f"Missing or extra dimensions: {set(sample.dims) ^ expected_dims}"
# Check data variables
assert set(sample.data_vars) == expected_data_vars, f"Missing or extra data variables: {set(sample.data_vars) ^ expected_data_vars}"

# check the shape of the data is correct
# 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
assert sample[SatelliteBatchKey.satellite_actual].shape == (7, 1, 2, 2)
assert sample["satellite"].values.shape == (7, 1, 2, 2)
# 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels
assert sample[NWPBatchKey.nwp]["ukv"][NWPBatchKey.nwp].shape == (4, 1, 2, 2)
# 3 hours of 30 minute data (inclusive)
assert sample[SiteBatchKey.generation].shape == (4,)
# Solar angles have same shape as GSP data
assert sample[SiteBatchKey.site_solar_azimuth].shape == (4,)
assert sample[SiteBatchKey.site_solar_elevation].shape == (4,)

assert sample["nwp-ukv"].values.shape == (4, 1, 2, 2)
# 1.5 hours of 30 minute data (inclusive)
assert sample["sites"].values.shape == (4,)

def test_site_time_filter_start(site_config_filename):

Expand Down
Loading