-
-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Site torch dataset update #82
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
|
||
from .pvnet_uk_regional import PVNetUKRegionalDataset | ||
from .site import SitesDataset |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import numpy as np | ||
import pandas as pd | ||
import xarray as xr | ||
from typing import Tuple | ||
|
||
from ocf_data_sampler.config import Configuration | ||
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS | ||
|
@@ -9,7 +10,6 @@ | |
convert_satellite_to_numpy_batch, | ||
convert_gsp_to_numpy_batch, | ||
make_sun_position_numpy_batch, | ||
convert_site_to_numpy_batch, | ||
) | ||
from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey | ||
from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey | ||
|
@@ -73,18 +73,6 @@ def process_and_combine_datasets( | |
} | ||
) | ||
|
||
|
||
if "site" in dataset_dict: | ||
site_config = config.input_data.site | ||
da_sites = dataset_dict["site"] | ||
da_sites = da_sites / da_sites.capacity_kwp | ||
|
||
numpy_modalities.append( | ||
convert_site_to_numpy_batch( | ||
da_sites, t0_idx=-site_config.interval_start_minutes / site_config.time_resolution_minutes | ||
) | ||
) | ||
|
||
if target_key == 'gsp': | ||
# Make sun coords NumpyBatch | ||
datetimes = pd.date_range( | ||
|
@@ -95,16 +83,6 @@ def process_and_combine_datasets( | |
|
||
lon, lat = osgb_to_lon_lat(location.x, location.y) | ||
|
||
elif target_key == 'site': | ||
# Make sun coords NumpyBatch | ||
datetimes = pd.date_range( | ||
t0+minutes(site_config.interval_start_minutes), | ||
t0+minutes(site_config.interval_end_minutes), | ||
freq=minutes(site_config.time_resolution_minutes), | ||
) | ||
|
||
lon, lat = location.x, location.y | ||
|
||
numpy_modalities.append( | ||
make_sun_position_numpy_batch(datetimes, lon, lat, key_prefix=target_key) | ||
) | ||
|
@@ -115,6 +93,47 @@ def process_and_combine_datasets( | |
|
||
return combined_sample | ||
|
||
def process_and_combine_site_sample_dict( | ||
dataset_dict: dict, | ||
config: Configuration, | ||
) -> xr.Dataset: | ||
""" | ||
Normalize and combine data into a single xr Dataset | ||
|
||
Args: | ||
dataset_dict: dict containing sliced xr DataArrays | ||
config: Configuration for the model | ||
|
||
Returns: | ||
xr.Dataset: A merged Dataset with nans filled in. | ||
|
||
""" | ||
|
||
data_arrays = [] | ||
|
||
if "nwp" in dataset_dict: | ||
for nwp_key, da_nwp in dataset_dict["nwp"].items(): | ||
# Standardise | ||
provider = config.input_data.nwp[nwp_key].provider | ||
da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider] | ||
data_arrays.append((f"nwp-{provider}", da_nwp)) | ||
|
||
if "sat" in dataset_dict: | ||
# Satellite is already in the range [0-1] so no need to standardise | ||
da_sat = dataset_dict["sat"] | ||
data_arrays.append(("satellite", da_sat)) | ||
|
||
if "site" in dataset_dict: | ||
# site_config = config.input_data.site | ||
da_sites = dataset_dict["site"] | ||
da_sites = da_sites / da_sites.capacity_kwp | ||
data_arrays.append(("sites", da_sites)) | ||
|
||
combined_sample_dataset = merge_arrays(data_arrays) | ||
|
||
# Fill any nan values | ||
return combined_sample_dataset.fillna(0.0) | ||
|
||
|
||
def merge_dicts(list_of_dicts: list[dict]) -> dict: | ||
"""Merge a list of dictionaries into a single dictionary""" | ||
|
@@ -124,6 +143,59 @@ def merge_dicts(list_of_dicts: list[dict]) -> dict: | |
combined_dict.update(d) | ||
return combined_dict | ||
|
||
def merge_arrays(normalised_data_arrays: list[Tuple[str, xr.DataArray]]) -> xr.Dataset: | ||
""" | ||
Combine a list of DataArrays into a single Dataset with unique naming conventions. | ||
|
||
Args: | ||
list_of_arrays: List of tuples where each tuple contains: | ||
- A string (key name). | ||
- An xarray.DataArray. | ||
|
||
Returns: | ||
xr.Dataset: A merged Dataset with uniquely named variables, coordinates, and dimensions. | ||
""" | ||
datasets = [] | ||
|
||
for key, data_array in normalised_data_arrays: | ||
# Ensure all attributes are strings for consistency | ||
data_array = data_array.assign_attrs( | ||
{attr_key: str(attr_value) for attr_key, attr_value in data_array.attrs.items()} | ||
) | ||
|
||
# Convert DataArray to Dataset with the variable name as the key | ||
dataset = data_array.to_dataset(name=key) | ||
|
||
# Prepend key name to all dimension and coordinate names for uniqueness | ||
dataset = dataset.rename( | ||
{dim: f"{key}__{dim}" for dim in dataset.dims if dim not in dataset.coords} | ||
) | ||
dataset = dataset.rename( | ||
{coord: f"{key}__{coord}" for coord in dataset.coords} | ||
) | ||
|
||
# Handle concatenation dimension if applicable | ||
concat_dim = ( | ||
f"{key}__target_time_utc" if f"{key}__target_time_utc" in dataset.coords | ||
else f"{key}__time_utc" | ||
) | ||
|
||
if f"{key}__init_time_utc" in dataset.coords: | ||
init_coord = f"{key}__init_time_utc" | ||
if dataset[init_coord].ndim == 0: # Check if scalar | ||
expanded_init_times = [dataset[init_coord].values] * len(dataset[concat_dim]) | ||
dataset = dataset.assign_coords({init_coord: (concat_dim, expanded_init_times)}) | ||
|
||
datasets.append(dataset) | ||
|
||
# Ensure all datasets are valid xarray.Dataset objects | ||
for ds in datasets: | ||
assert isinstance(ds, xr.Dataset), f"Object is not an xr.Dataset: {type(ds)}" | ||
|
||
# Merge all prepared datasets | ||
combined_dataset = xr.merge(datasets) | ||
|
||
return combined_dataset | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Out of curiosity bc I think I've missed a conversation somewhere: am I right in thinking that eventually gsp version will also output netcdfs and we will merge it into this function and move convert_to_numpy_batch into the training pipeline? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep exactly I think that is the current idea, the convert_to_numpy_batch logic can still be in this repo and then we can call that in PVNet in the training pipeline |
||
|
||
def fill_nans_in_arrays(batch: dict) -> dict: | ||
"""Fills all NaN values in each np.ndarray in the batch dictionary with zeros. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,7 @@ | |
slice_datasets_by_time, slice_datasets_by_space | ||
) | ||
from ocf_data_sampler.utils import minutes | ||
from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets, compute | ||
from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_site_sample_dict | ||
from ocf_data_sampler.torch_datasets.valid_time_periods import find_valid_time_periods | ||
|
||
xr.set_options(keep_attrs=True) | ||
|
@@ -152,10 +152,9 @@ def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict: | |
""" | ||
sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config) | ||
sample_dict = slice_datasets_by_time(sample_dict, t0, self.config) | ||
sample_dict = compute(sample_dict) | ||
|
||
sample = process_and_combine_datasets(sample_dict, self.config, t0, location, target_key='site') | ||
|
||
sample = process_and_combine_site_sample_dict(sample_dict, self.config) | ||
sample = sample.compute() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice! |
||
return sample | ||
|
||
def get_location_from_site_id(self, site_id): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -197,7 +197,7 @@ def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str): | |
t0 = pd.Timestamp(f"2024-01-02 {t0_str}") | ||
interval_start = pd.Timedelta(-6, "h") | ||
interval_end = pd.Timedelta(3, "h") | ||
freq = pd.Timedelta("1H") | ||
freq = pd.Timedelta("1h") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To stop a syntax deprecation warning from popping up |
||
dropout_timedelta = pd.Timedelta("-2h") | ||
|
||
t0_delayed = (t0 + dropout_timedelta).floor(NWP_FREQ) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a TODO here for normalization satellite data
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#87