From 2a4acfca9cfea53c44eb4c99c065086f52e8637c Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Fri, 25 Oct 2024 11:15:44 +0100 Subject: [PATCH 01/24] first try at site workflow --- ocf_data_sampler/config/model.py | 39 +++ ocf_data_sampler/load/load_dataset.py | 54 +++ ocf_data_sampler/load/sites.py | 23 ++ ocf_data_sampler/numpy_batch/__init__.py | 1 + ocf_data_sampler/numpy_batch/site.py | 31 ++ ocf_data_sampler/select/geospatial.py | 44 ++- .../select/select_spatial_slice.py | 10 +- .../select/space_slice_for_dataset.py | 52 +++ .../select/time_slice_for_dataset.py | 120 +++++++ ocf_data_sampler/torch_datasets/Readme.md | 42 +++ .../torch_datasets/process_and_combine.py | 151 +++++++++ .../torch_datasets/pvnet_uk_regional.py | 319 +----------------- ocf_data_sampler/torch_datasets/site.py | 274 +++++++++++++++ .../torch_datasets/xarray_compute.py | 8 + tests/conftest.py | 63 ++++ tests/load/test_load_sites.py | 14 + tests/torch_datasets/test_site.py | 60 ++++ 17 files changed, 989 insertions(+), 316 deletions(-) create mode 100644 ocf_data_sampler/load/load_dataset.py create mode 100644 ocf_data_sampler/load/sites.py create mode 100644 ocf_data_sampler/numpy_batch/site.py create mode 100644 ocf_data_sampler/select/space_slice_for_dataset.py create mode 100644 ocf_data_sampler/select/time_slice_for_dataset.py create mode 100644 ocf_data_sampler/torch_datasets/process_and_combine.py create mode 100644 ocf_data_sampler/torch_datasets/site.py create mode 100644 ocf_data_sampler/torch_datasets/xarray_compute.py create mode 100644 tests/load/test_load_sites.py create mode 100644 tests/torch_datasets/test_site.py diff --git a/ocf_data_sampler/config/model.py b/ocf_data_sampler/config/model.py index dfe9d31..3ba6f81 100644 --- a/ocf_data_sampler/config/model.py +++ b/ocf_data_sampler/config/model.py @@ -102,6 +102,44 @@ class TimeResolutionMixin(Base): ) +class Sites(DataSourceMixin, TimeResolutionMixin, DropoutMixin): + """Site configuration model""" + + filename: str = Field( + ..., + description="The NetCDF files holding the power timeseries.", + ) + metadata_filename: str = Field( + ..., + description="The CSV files describing wind system.", + ) + + # site_ids: List[int] = Field( + # None, + # description="List of the ML IDs of the Wind systems you'd like to filter to.", + # ) + + @field_validator("forecast_minutes") + def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int: + """Check forecast length requested will give stable number of timesteps""" + if v % info.data["time_resolution_minutes"] != 0: + message = "Forecast duration must be divisible by time resolution" + logger.error(message) + raise Exception(message) + return v + + @field_validator("history_minutes") + def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int: + """Check history length requested will give stable number of timesteps""" + if v % info.data["time_resolution_minutes"] != 0: + message = "History duration must be divisible by time resolution" + logger.error(message) + raise Exception(message) + return v + + # TODO validate the netcdf for sites + # TODO validate the csv for metadata + class Satellite(DataSourceMixin, TimeResolutionMixin, DropoutMixin): """Satellite configuration model""" @@ -240,6 +278,7 @@ class InputData(Base): satellite: Optional[Satellite] = None nwp: Optional[MultiNWP] = None gsp: Optional[GSP] = None + site: Optional[Sites] = None class Configuration(Base): diff --git a/ocf_data_sampler/load/load_dataset.py b/ocf_data_sampler/load/load_dataset.py new file mode 100644 index 0000000..972a6e0 --- /dev/null +++ b/ocf_data_sampler/load/load_dataset.py @@ -0,0 +1,54 @@ +import xarray as xr + +from ocf_data_sampler.config import Configuration +from ocf_data_sampler.load.gsp import open_gsp +from ocf_data_sampler.load.nwp import open_nwp +from ocf_data_sampler.load.satellite import open_sat_data +from ocf_data_sampler.load.sites import open_sites + + +def get_dataset_dict(config: Configuration) -> dict[xr.DataArray, dict[xr.DataArray]]: + """Construct dictionary of all of the input data sources + + Args: + config: Configuration file + """ + + in_config = config.input_data + + datasets_dict = {} + + # Load GSP data unless the path is None + if in_config.gsp and in_config.gsp.gsp_zarr_path: + da_gsp = open_gsp(zarr_path=in_config.gsp.gsp_zarr_path).compute() + + # Remove national GSP + datasets_dict["gsp"] = da_gsp.sel(gsp_id=slice(1, None)) + + # Load NWP data if in config + if in_config.nwp: + + datasets_dict["nwp"] = {} + for nwp_source, nwp_config in in_config.nwp.items(): + + da_nwp = open_nwp(nwp_config.nwp_zarr_path, provider=nwp_config.nwp_provider) + + da_nwp = da_nwp.sel(channel=list(nwp_config.nwp_channels)) + + datasets_dict["nwp"][nwp_source] = da_nwp + + # Load satellite data if in config + if in_config.satellite: + sat_config = config.input_data.satellite + + da_sat = open_sat_data(sat_config.satellite_zarr_path) + + da_sat = da_sat.sel(channel=list(sat_config.satellite_channels)) + + datasets_dict["sat"] = da_sat + + if in_config.site: + da_sites = open_sites(in_config.site) + datasets_dict["site"] = da_sites + + return datasets_dict diff --git a/ocf_data_sampler/load/sites.py b/ocf_data_sampler/load/sites.py new file mode 100644 index 0000000..3487b60 --- /dev/null +++ b/ocf_data_sampler/load/sites.py @@ -0,0 +1,23 @@ +import pandas as pd +import xarray as xr + +from ocf_data_sampler.config.model import Sites + + +def open_sites(sites_config: Sites) -> xr.DataArray: + + # Load site generation xr.Dataset + data_ds = xr.open_dataset(sites_config.filename) + + # Load site generation data + metadata_df = pd.read_csv(sites_config.metadata_filename) + metadata_df.set_index("system_id", inplace=True, drop=True) + + # Add coordinates + ds = data_ds.assign_coords( + latitude=(metadata_df.latitude.to_xarray()), + longitude=(metadata_df.longitude.to_xarray()), + capacity_kwp=data_ds.capacity_kwp, + ) + + return ds.generation_kw diff --git a/ocf_data_sampler/numpy_batch/__init__.py b/ocf_data_sampler/numpy_batch/__init__.py index 8845068..25535f7 100644 --- a/ocf_data_sampler/numpy_batch/__init__.py +++ b/ocf_data_sampler/numpy_batch/__init__.py @@ -4,4 +4,5 @@ from .nwp import convert_nwp_to_numpy_batch from .satellite import convert_satellite_to_numpy_batch from .sun_position import make_sun_position_numpy_batch +from .site import convert_site_to_numpy_batch diff --git a/ocf_data_sampler/numpy_batch/site.py b/ocf_data_sampler/numpy_batch/site.py new file mode 100644 index 0000000..779e354 --- /dev/null +++ b/ocf_data_sampler/numpy_batch/site.py @@ -0,0 +1,31 @@ +"""Convert site to Numpy Batch""" + +import xarray as xr + + +class SiteBatchKey: + + site = "site" + site_capacity_kwp = "site_capacity_kwp" + site_time_utc = "site_time_utc" + site_t0_idx = "site_t0_idx" + site_solar_azimuth = "site_solar_azimuth" + site_solar_elevation = "site_solar_elevation" + site_id = "site_id" + site_latitude = "site_latitude" + site_longitude = "site_longitude" + + +def convert_site_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> dict: + """Convert from Xarray to NumpyBatch""" + + example = { + SiteBatchKey.site: da.values, + SiteBatchKey.site_capacity_kwp: da.isel(time_utc=0)["capacity_kwp"].values, + SiteBatchKey.site_time_utc: da["time_utc"].values.astype(float), + } + + if t0_idx is not None: + example[SiteBatchKey.site_t0_idx] = t0_idx + + return example diff --git a/ocf_data_sampler/select/geospatial.py b/ocf_data_sampler/select/geospatial.py index 8137f16..31ed9ca 100644 --- a/ocf_data_sampler/select/geospatial.py +++ b/ocf_data_sampler/select/geospatial.py @@ -55,6 +55,23 @@ def lon_lat_to_osgb( return _lon_lat_to_osgb(xx=x, yy=y) +def lat_lon_to_geostationary_area_coords( + longitude: Union[Number, np.ndarray], + latitude: Union[Number, np.ndarray], + xr_data: xr.DataArray, +) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]: + """Loads geostationary area and transformation from lat-lon to geostationary coords + + Args: + longitude: longitude + latitude: latitude + xr_data: xarray object with geostationary area + + Returns: + Geostationary coords: x, y + """ + return coordinates_to_geostationary_area_coords(longitude, latitude, xr_data, WGS84) + def osgb_to_geostationary_area_coords( x: Union[Number, np.ndarray], y: Union[Number, np.ndarray], @@ -70,6 +87,31 @@ def osgb_to_geostationary_area_coords( Returns: Geostationary coords: x, y """ + + return coordinates_to_geostationary_area_coords(x, y, xr_data, OSGB36) + + + +def coordinates_to_geostationary_area_coords( + x: Union[Number, np.ndarray], + y: Union[Number, np.ndarray], + xr_data: xr.DataArray, + crs_from: int +) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]: + """Loads geostationary area and transformation from respective coordiates to geostationary coords + + Args: + x: osgb east-west, or latitude + y: osgb north-south, or longitude + xr_data: xarray object with geostationary area + crs_from: the cordiates system of x,y + + Returns: + Geostationary coords: x, y + """ + + assert crs_from in [OSGB36, WGS84], f"Unrecognized coordinate system: {crs_from}" + # Only load these if using geostationary projection import pyresample @@ -80,7 +122,7 @@ def osgb_to_geostationary_area_coords( ) geostationary_crs = geostationary_area_definition.crs osgb_to_geostationary = pyproj.Transformer.from_crs( - crs_from=OSGB36, crs_to=geostationary_crs, always_xy=True + crs_from=crs_from, crs_to=geostationary_crs, always_xy=True ).transform return osgb_to_geostationary(xx=x, yy=y) diff --git a/ocf_data_sampler/select/select_spatial_slice.py b/ocf_data_sampler/select/select_spatial_slice.py index 9dad8ea..0330dfe 100644 --- a/ocf_data_sampler/select/select_spatial_slice.py +++ b/ocf_data_sampler/select/select_spatial_slice.py @@ -8,6 +8,7 @@ from ocf_data_sampler.select.location import Location from ocf_data_sampler.select.geospatial import ( lon_lat_to_osgb, + lat_lon_to_geostationary_area_coords, osgb_to_geostationary_area_coords, osgb_to_lon_lat, spatial_coord_type, @@ -101,7 +102,7 @@ def _get_idx_of_pixel_closest_to_poi( def _get_idx_of_pixel_closest_to_poi_geostationary( da: xr.DataArray, - center_osgb: Location, + center: Location, ) -> Location: """ Return x and y index location of pixel at center of region of interest. @@ -116,7 +117,12 @@ def _get_idx_of_pixel_closest_to_poi_geostationary( _, x_dim, y_dim = spatial_coord_type(da) - x, y = osgb_to_geostationary_area_coords(x=center_osgb.x, y=center_osgb.y, xr_data=da) + if center.coordinate_system == 'osgb': + x, y = osgb_to_geostationary_area_coords(x=center.x, y=center.y, xr_data=da) + elif center.coordinate_system == 'lon_lat': + x, y = lat_lon_to_geostationary_area_coords(longitude=center.x, latitude=center.y, xr_data=da) + else: + x,y = center.x, center.y center_geostationary = Location(x=x, y=y, coordinate_system="geostationary") # Check that the requested point lies within the data diff --git a/ocf_data_sampler/select/space_slice_for_dataset.py b/ocf_data_sampler/select/space_slice_for_dataset.py new file mode 100644 index 0000000..e6d7b80 --- /dev/null +++ b/ocf_data_sampler/select/space_slice_for_dataset.py @@ -0,0 +1,52 @@ +from ocf_data_sampler.config import Configuration +from ocf_data_sampler.select.location import Location +from ocf_data_sampler.select.select_spatial_slice import select_spatial_slice_pixels + + +def slice_datasets_by_space( + datasets_dict: dict, + location: Location, + config: Configuration, +) -> dict: + """Slice a dictionaries of input data sources around a given location + + Args: + datasets_dict: Dictionary of the input data sources + location: The location to sample around + config: Configuration object. + """ + + assert set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp", "site"}) + + sliced_datasets_dict = {} + + if "nwp" in datasets_dict: + + sliced_datasets_dict["nwp"] = {} + + for nwp_key, nwp_config in config.input_data.nwp.items(): + + sliced_datasets_dict["nwp"][nwp_key] = select_spatial_slice_pixels( + datasets_dict["nwp"][nwp_key], + location, + height_pixels=nwp_config.nwp_image_size_pixels_height, + width_pixels=nwp_config.nwp_image_size_pixels_width, + ) + + if "sat" in datasets_dict: + sat_config = config.input_data.satellite + + sliced_datasets_dict["sat"] = select_spatial_slice_pixels( + datasets_dict["sat"], + location, + height_pixels=sat_config.satellite_image_size_pixels_height, + width_pixels=sat_config.satellite_image_size_pixels_width, + ) + + if "gsp" in datasets_dict: + sliced_datasets_dict["gsp"] = datasets_dict["gsp"].sel(gsp_id=location.id) + + if "site" in datasets_dict: + sliced_datasets_dict["site"] = datasets_dict["site"].sel(system_id=location.id) + + return sliced_datasets_dict diff --git a/ocf_data_sampler/select/time_slice_for_dataset.py b/ocf_data_sampler/select/time_slice_for_dataset.py new file mode 100644 index 0000000..769e495 --- /dev/null +++ b/ocf_data_sampler/select/time_slice_for_dataset.py @@ -0,0 +1,120 @@ +import pandas as pd + +from ocf_data_sampler.config import Configuration +from ocf_data_sampler.select.dropout import draw_dropout_time, apply_dropout_time +from ocf_data_sampler.select.select_time_slice import select_time_slice_nwp, select_time_slice + + +def minutes(minutes: list[float]): + """Timedelta minutes + + Args: + m: minutes + """ + return pd.to_timedelta(minutes, unit="m") + + +def slice_datasets_by_time( + datasets_dict: dict, + t0: pd.Timestamp, + config: Configuration, +) -> dict: + """Slice a dictionaries of input data sources around a given t0 time + + Args: + datasets_dict: Dictionary of the input data sources + t0: The init-time + config: Configuration object. + """ + + sliced_datasets_dict = {} + + if "nwp" in datasets_dict: + + sliced_datasets_dict["nwp"] = {} + + for nwp_key, da_nwp in datasets_dict["nwp"].items(): + + nwp_config = config.input_data.nwp[nwp_key] + + sliced_datasets_dict["nwp"][nwp_key] = select_time_slice_nwp( + da_nwp, + t0, + sample_period_duration=minutes(nwp_config.time_resolution_minutes), + history_duration=minutes(nwp_config.history_minutes), + forecast_duration=minutes(nwp_config.forecast_minutes), + dropout_timedeltas=minutes(nwp_config.dropout_timedeltas_minutes), + dropout_frac=nwp_config.dropout_fraction, + accum_channels=nwp_config.nwp_accum_channels, + ) + + if "sat" in datasets_dict: + + sat_config = config.input_data.satellite + + sliced_datasets_dict["sat"] = select_time_slice( + datasets_dict["sat"], + t0, + sample_period_duration=minutes(sat_config.time_resolution_minutes), + interval_start=minutes(-sat_config.history_minutes), + interval_end=minutes(-sat_config.live_delay_minutes), + max_steps_gap=2, + ) + + # Randomly sample dropout + sat_dropout_time = draw_dropout_time( + t0, + dropout_timedeltas=minutes(sat_config.dropout_timedeltas_minutes), + dropout_frac=sat_config.dropout_fraction, + ) + + # Apply the dropout + sliced_datasets_dict["sat"] = apply_dropout_time( + sliced_datasets_dict["sat"], + sat_dropout_time, + ) + + if "gsp" in datasets_dict: + gsp_config = config.input_data.gsp + + sliced_datasets_dict["gsp_future"] = select_time_slice( + datasets_dict["gsp"], + t0, + sample_period_duration=minutes(gsp_config.time_resolution_minutes), + interval_start=minutes(30), + interval_end=minutes(gsp_config.forecast_minutes), + ) + + sliced_datasets_dict["gsp"] = select_time_slice( + datasets_dict["gsp"], + t0, + sample_period_duration=minutes(gsp_config.time_resolution_minutes), + interval_start=-minutes(gsp_config.history_minutes), + interval_end=minutes(0), + ) + + # Dropout on the GSP, but not the future GSP + gsp_dropout_time = draw_dropout_time( + t0, + dropout_timedeltas=minutes(gsp_config.dropout_timedeltas_minutes), + dropout_frac=gsp_config.dropout_fraction, + ) + + sliced_datasets_dict["gsp"] = apply_dropout_time( + sliced_datasets_dict["gsp"], gsp_dropout_time + ) + + if "site" in datasets_dict: + site_config = config.input_data.site + + sliced_datasets_dict["site"] = select_time_slice( + datasets_dict["site"], + t0, + sample_period_duration=minutes(site_config.time_resolution_minutes), + interval_start=-minutes(site_config.history_minutes), + interval_end=minutes(site_config.forecast_minutes), + ) + + # TODO add dropout? + + return sliced_datasets_dict diff --git a/ocf_data_sampler/torch_datasets/Readme.md b/ocf_data_sampler/torch_datasets/Readme.md index c557d6e..606de43 100644 --- a/ocf_data_sampler/torch_datasets/Readme.md +++ b/ocf_data_sampler/torch_datasets/Readme.md @@ -47,3 +47,45 @@ graph TD E0 --> F[Sample] ``` +## Site +The Site torch dataset gets sample for each site. +This works for mulitple sites with different valid time periods of data + +```mermaid +graph TD + A1([Load Site]) + A2([Load NWP]) + A3([Load Satellite]) + D0[All Site Locations] + A1 --> D1 + A2 --> D1 + A3 --> D1 + A1 --> D0 + A1 --> D3 + A2 --> D3 + A3 --> D3 + D1[T0 and Site Ids \n for each Site] --> D2 + D2[T0 and Site Ids Pairs] + D3[Data] +``` + +### Get a Sample + +```mermaid +graph TD + A0([Index]) + A1([T0 and Site Ids Pairs]) + A2([All Site Locations]) + A0 --> B0 + A1 --> B0 + A2 --> B0 + B0[T0 and Location] + B1([Data]) + B0 --> D0 + B1 --> D0 + D0[Filter by Location \n Site, Satellite and NWP] --> D1 + D1[Filter by Time \n Site, Satellite and NWP] --> D2 + D2[Load into Memory] --> E0 + E0[Add Site Sun Features] + E0 --> F[Sample] +``` \ No newline at end of file diff --git a/ocf_data_sampler/torch_datasets/process_and_combine.py b/ocf_data_sampler/torch_datasets/process_and_combine.py new file mode 100644 index 0000000..73694be --- /dev/null +++ b/ocf_data_sampler/torch_datasets/process_and_combine.py @@ -0,0 +1,151 @@ +import numpy as np +import pandas as pd +import xarray as xr + +from ocf_data_sampler.config import Configuration +from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS +from ocf_data_sampler.numpy_batch import ( + convert_nwp_to_numpy_batch, + convert_satellite_to_numpy_batch, + convert_gsp_to_numpy_batch, + make_sun_position_numpy_batch, + convert_site_to_numpy_batch, +) +from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey +from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey +from ocf_data_sampler.select.geospatial import osgb_to_lon_lat +from ocf_data_sampler.select.location import Location + + +def minutes(minutes: list[float]): + """Timedelta minutes + + Args: + m: minutes + """ + return pd.to_timedelta(minutes, unit="m") + + +def process_and_combine_datasets( + dataset_dict: dict, + config: Configuration, + t0: pd.Timestamp, + location: Location, + sun_position_key: str = 'gsp' +) -> dict: + """Normalize and convert data to numpy arrays""" + + numpy_modalities = [] + + if "nwp" in dataset_dict: + + nwp_numpy_modalities = dict() + + for nwp_key, da_nwp in dataset_dict["nwp"].items(): + # Standardise + provider = config.input_data.nwp[nwp_key].nwp_provider + da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider] + # Convert to NumpyBatch + nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(da_nwp) + + # Combine the NWPs into NumpyBatch + numpy_modalities.append({NWPBatchKey.nwp: nwp_numpy_modalities}) + + if "sat" in dataset_dict: + # Satellite is already in the range [0-1] so no need to standardise + da_sat = dataset_dict["sat"] + + # Convert to NumpyBatch + numpy_modalities.append(convert_satellite_to_numpy_batch(da_sat)) + + gsp_config = config.input_data.gsp + + if "gsp" in dataset_dict: + da_gsp = xr.concat([dataset_dict["gsp"], dataset_dict["gsp_future"]], dim="time_utc") + da_gsp = da_gsp / da_gsp.effective_capacity_mwp + + numpy_modalities.append( + convert_gsp_to_numpy_batch( + da_gsp, t0_idx=gsp_config.history_minutes / gsp_config.time_resolution_minutes + ) + ) + + # Add coordinate data + # TODO: Do we need all of these? + numpy_modalities.append( + { + GSPBatchKey.gsp_id: location.id, + GSPBatchKey.gsp_x_osgb: location.x, + GSPBatchKey.gsp_y_osgb: location.y, + } + ) + + + if "site" in dataset_dict: + site_config = config.input_data.site + da_sites = dataset_dict["site"] + da_sites = da_sites / da_sites.capacity_kwp + + numpy_modalities.append( + convert_site_to_numpy_batch( + da_sites, t0_idx=site_config.history_minutes / site_config.time_resolution_minutes + ) + ) + + if sun_position_key == 'gsp': + # Make sun coords NumpyBatch + datetimes = pd.date_range( + t0 - minutes(gsp_config.history_minutes), + t0 + minutes(gsp_config.forecast_minutes), + freq=minutes(gsp_config.time_resolution_minutes), + ) + + lon, lat = osgb_to_lon_lat(location.x, location.y) + key_prefix = "gsp" + + elif sun_position_key == 'site': + # Make sun coords NumpyBatch + datetimes = pd.date_range( + t0 - minutes(site_config.history_minutes), + t0 + minutes(site_config.forecast_minutes), + freq=minutes(site_config.time_resolution_minutes), + ) + + lon, lat = location.x, location.y + key_prefix = "site" + + numpy_modalities.append( + make_sun_position_numpy_batch(datetimes, lon, lat, key_prefix=key_prefix) + ) + + # Combine all the modalities and fill NaNs + combined_sample = merge_dicts(numpy_modalities) + combined_sample = fill_nans_in_arrays(combined_sample) + + return combined_sample + + +def merge_dicts(list_of_dicts: list[dict]) -> dict: + """Merge a list of dictionaries into a single dictionary""" + # TODO: This doesn't account for duplicate keys, which will be overwritten + combined_dict = {} + for d in list_of_dicts: + combined_dict.update(d) + return combined_dict + + +def fill_nans_in_arrays(batch: dict) -> dict: + """Fills all NaN values in each np.ndarray in the batch dictionary with zeros. + + Operation is performed in-place on the batch. + """ + for k, v in batch.items(): + if isinstance(v, np.ndarray) and np.issubdtype(v.dtype, np.number): + if np.isnan(v).any(): + batch[k] = np.nan_to_num(v, copy=False, nan=0.0) + + # Recursion is included to reach NWP arrays in subdict + elif isinstance(v, dict): + fill_nans_in_arrays(v) + + return batch diff --git a/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py b/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py index d62f84c..2da1741 100644 --- a/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +++ b/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py @@ -6,96 +6,33 @@ from torch.utils.data import Dataset import pkg_resources -from ocf_data_sampler.load.gsp import open_gsp -from ocf_data_sampler.load.nwp import open_nwp -from ocf_data_sampler.load.satellite import open_sat_data - from ocf_data_sampler.select.find_contiguous_time_periods import ( find_contiguous_t0_periods, find_contiguous_t0_periods_nwp, intersection_of_multiple_dataframes_of_periods, ) from ocf_data_sampler.select.fill_time_periods import fill_time_periods -from ocf_data_sampler.select.select_time_slice import select_time_slice, select_time_slice_nwp -from ocf_data_sampler.select.dropout import draw_dropout_time, apply_dropout_time -from ocf_data_sampler.select.select_spatial_slice import select_spatial_slice_pixels - -from ocf_data_sampler.numpy_batch import ( - convert_gsp_to_numpy_batch, - convert_nwp_to_numpy_batch, - convert_satellite_to_numpy_batch, - make_sun_position_numpy_batch, -) - from ocf_data_sampler.config import Configuration, load_yaml_configuration -from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey -from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey from ocf_data_sampler.select.location import Location -from ocf_data_sampler.select.geospatial import osgb_to_lon_lat - -from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS - - +from ocf_data_sampler.load.load_dataset import get_dataset_dict +from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets +from ocf_data_sampler.select.space_slice_for_dataset import slice_datasets_by_space +from ocf_data_sampler.select.time_slice_for_dataset import slice_datasets_by_time +from ocf_data_sampler.torch_datasets.xarray_compute import compute xr.set_options(keep_attrs=True) - - def minutes(minutes: list[float]): """Timedelta minutes - + Args: m: minutes """ return pd.to_timedelta(minutes, unit="m") -def get_dataset_dict(config: Configuration) -> dict[xr.DataArray, dict[xr.DataArray]]: - """Construct dictionary of all of the input data sources - - Args: - config: Configuration file - """ - - in_config = config.input_data - - datasets_dict = {} - - # Load GSP data unless the path is None - if in_config.gsp.gsp_zarr_path: - da_gsp = open_gsp(zarr_path=in_config.gsp.gsp_zarr_path).compute() - - # Remove national GSP - datasets_dict["gsp"] = da_gsp.sel(gsp_id=slice(1, None)) - - # Load NWP data if in config - if in_config.nwp: - - datasets_dict["nwp"] = {} - for nwp_source, nwp_config in in_config.nwp.items(): - - da_nwp = open_nwp(nwp_config.nwp_zarr_path, provider=nwp_config.nwp_provider) - - da_nwp = da_nwp.sel(channel=list(nwp_config.nwp_channels)) - - datasets_dict["nwp"][nwp_source] = da_nwp - - # Load satellite data if in config - if in_config.satellite: - sat_config = config.input_data.satellite - - da_sat = open_sat_data(sat_config.satellite_zarr_path) - - da_sat = da_sat.sel(channel=list(sat_config.satellite_channels)) - - datasets_dict["sat"] = da_sat - - return datasets_dict - - - def find_valid_t0_times( datasets_dict: dict, config: Configuration, @@ -203,250 +140,6 @@ def find_valid_t0_times( return valid_t0_times -def slice_datasets_by_space( - datasets_dict: dict, - location: Location, - config: Configuration, -) -> dict: - """Slice a dictionaries of input data sources around a given location - - Args: - datasets_dict: Dictionary of the input data sources - location: The location to sample around - config: Configuration object. - """ - - assert set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp"}) - - sliced_datasets_dict = {} - - if "nwp" in datasets_dict: - - sliced_datasets_dict["nwp"] = {} - - for nwp_key, nwp_config in config.input_data.nwp.items(): - - sliced_datasets_dict["nwp"][nwp_key] = select_spatial_slice_pixels( - datasets_dict["nwp"][nwp_key], - location, - height_pixels=nwp_config.nwp_image_size_pixels_height, - width_pixels=nwp_config.nwp_image_size_pixels_width, - ) - - if "sat" in datasets_dict: - sat_config = config.input_data.satellite - - sliced_datasets_dict["sat"] = select_spatial_slice_pixels( - datasets_dict["sat"], - location, - height_pixels=sat_config.satellite_image_size_pixels_height, - width_pixels=sat_config.satellite_image_size_pixels_width, - ) - - if "gsp" in datasets_dict: - sliced_datasets_dict["gsp"] = datasets_dict["gsp"].sel(gsp_id=location.id) - - return sliced_datasets_dict - - -def slice_datasets_by_time( - datasets_dict: dict, - t0: pd.Timedelta, - config: Configuration, -) -> dict: - """Slice a dictionaries of input data sources around a given t0 time - - Args: - datasets_dict: Dictionary of the input data sources - t0: The init-time - config: Configuration object. - """ - - sliced_datasets_dict = {} - - if "nwp" in datasets_dict: - - sliced_datasets_dict["nwp"] = {} - - for nwp_key, da_nwp in datasets_dict["nwp"].items(): - - nwp_config = config.input_data.nwp[nwp_key] - - sliced_datasets_dict["nwp"][nwp_key] = select_time_slice_nwp( - da_nwp, - t0, - sample_period_duration=minutes(nwp_config.time_resolution_minutes), - history_duration=minutes(nwp_config.history_minutes), - forecast_duration=minutes(nwp_config.forecast_minutes), - dropout_timedeltas=minutes(nwp_config.dropout_timedeltas_minutes), - dropout_frac=nwp_config.dropout_fraction, - accum_channels=nwp_config.nwp_accum_channels, - ) - - if "sat" in datasets_dict: - - sat_config = config.input_data.satellite - - sliced_datasets_dict["sat"] = select_time_slice( - datasets_dict["sat"], - t0, - sample_period_duration=minutes(sat_config.time_resolution_minutes), - interval_start=minutes(-sat_config.history_minutes), - interval_end=minutes(-sat_config.live_delay_minutes), - max_steps_gap=2, - ) - - # Randomly sample dropout - sat_dropout_time = draw_dropout_time( - t0, - dropout_timedeltas=minutes(sat_config.dropout_timedeltas_minutes), - dropout_frac=sat_config.dropout_fraction, - ) - - # Apply the dropout - sliced_datasets_dict["sat"] = apply_dropout_time( - sliced_datasets_dict["sat"], - sat_dropout_time, - ) - - if "gsp" in datasets_dict: - gsp_config = config.input_data.gsp - - sliced_datasets_dict["gsp_future"] = select_time_slice( - datasets_dict["gsp"], - t0, - sample_period_duration=minutes(gsp_config.time_resolution_minutes), - interval_start=minutes(30), - interval_end=minutes(gsp_config.forecast_minutes), - ) - - sliced_datasets_dict["gsp"] = select_time_slice( - datasets_dict["gsp"], - t0, - sample_period_duration=minutes(gsp_config.time_resolution_minutes), - interval_start=-minutes(gsp_config.history_minutes), - interval_end=minutes(0), - ) - - # Dropout on the GSP, but not the future GSP - gsp_dropout_time = draw_dropout_time( - t0, - dropout_timedeltas=minutes(gsp_config.dropout_timedeltas_minutes), - dropout_frac=gsp_config.dropout_fraction, - ) - - sliced_datasets_dict["gsp"] = apply_dropout_time(sliced_datasets_dict["gsp"], gsp_dropout_time) - - return sliced_datasets_dict - - -def fill_nans_in_arrays(batch: dict) -> dict: - """Fills all NaN values in each np.ndarray in the batch dictionary with zeros. - - Operation is performed in-place on the batch. - """ - for k, v in batch.items(): - if isinstance(v, np.ndarray) and np.issubdtype(v.dtype, np.number): - if np.isnan(v).any(): - batch[k] = np.nan_to_num(v, copy=False, nan=0.0) - - # Recursion is included to reach NWP arrays in subdict - elif isinstance(v, dict): - fill_nans_in_arrays(v) - - return batch - - - -def merge_dicts(list_of_dicts: list[dict]) -> dict: - """Merge a list of dictionaries into a single dictionary""" - # TODO: This doesn't account for duplicate keys, which will be overwritten - combined_dict = {} - for d in list_of_dicts: - combined_dict.update(d) - return combined_dict - - -def process_and_combine_datasets( - dataset_dict: dict, - config: Configuration, - t0: pd.Timedelta, - location: Location, - ) -> dict: - """Normalize and convert data to numpy arrays""" - - numpy_modalities = [] - - if "nwp" in dataset_dict: - - nwp_numpy_modalities = dict() - - for nwp_key, da_nwp in dataset_dict["nwp"].items(): - # Standardise - provider = config.input_data.nwp[nwp_key].nwp_provider - da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider] - # Convert to NumpyBatch - nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(da_nwp) - - # Combine the NWPs into NumpyBatch - numpy_modalities.append({NWPBatchKey.nwp: nwp_numpy_modalities}) - - if "sat" in dataset_dict: - # Satellite is already in the range [0-1] so no need to standardise - da_sat = dataset_dict["sat"] - - # Convert to NumpyBatch - numpy_modalities.append(convert_satellite_to_numpy_batch(da_sat)) - - gsp_config = config.input_data.gsp - - if "gsp" in dataset_dict: - da_gsp = xr.concat([dataset_dict["gsp"], dataset_dict["gsp_future"]], dim="time_utc") - da_gsp = da_gsp / da_gsp.effective_capacity_mwp - - numpy_modalities.append( - convert_gsp_to_numpy_batch( - da_gsp, - t0_idx=gsp_config.history_minutes / gsp_config.time_resolution_minutes - ) - ) - - # Make sun coords NumpyBatch - datetimes = pd.date_range( - t0-minutes(gsp_config.history_minutes), - t0+minutes(gsp_config.forecast_minutes), - freq=minutes(gsp_config.time_resolution_minutes), - ) - - lon, lat = osgb_to_lon_lat(location.x, location.y) - - numpy_modalities.append(make_sun_position_numpy_batch(datetimes, lon, lat)) - - # Add coordinate data - # TODO: Do we need all of these? - numpy_modalities.append({ - GSPBatchKey.gsp_id: location.id, - GSPBatchKey.gsp_x_osgb: location.x, - GSPBatchKey.gsp_y_osgb: location.y, - }) - - # Combine all the modalities and fill NaNs - combined_sample = merge_dicts(numpy_modalities) - combined_sample = fill_nans_in_arrays(combined_sample) - - return combined_sample - - -def compute(xarray_dict: dict) -> dict: - """Eagerly load a nested dictionary of xarray DataArrays""" - for k, v in xarray_dict.items(): - if isinstance(v, dict): - xarray_dict[k] = compute(v) - else: - xarray_dict[k] = v.compute(scheduler="single-threaded") - return xarray_dict - - def get_gsp_locations(gsp_ids: list[int] | None = None) -> list[Location]: """Get list of locations of all GSPs""" diff --git a/ocf_data_sampler/torch_datasets/site.py b/ocf_data_sampler/torch_datasets/site.py new file mode 100644 index 0000000..de4ba11 --- /dev/null +++ b/ocf_data_sampler/torch_datasets/site.py @@ -0,0 +1,274 @@ +"""Torch dataset for sites""" +import logging + +import numpy as np +import pandas as pd +import xarray as xr +from torch.utils.data import Dataset + +from ocf_data_sampler.select.find_contiguous_time_periods import ( + find_contiguous_t0_periods, find_contiguous_t0_periods_nwp, + intersection_of_multiple_dataframes_of_periods, +) +from ocf_data_sampler.select.fill_time_periods import fill_time_periods + +from ocf_data_sampler.config import Configuration, load_yaml_configuration + +from ocf_data_sampler.select.location import Location + +from ocf_data_sampler.load.load_dataset import get_dataset_dict +from ocf_data_sampler.select.time_slice_for_dataset import slice_datasets_by_time +from ocf_data_sampler.select.space_slice_for_dataset import slice_datasets_by_space +from ocf_data_sampler.torch_datasets.xarray_compute import compute +from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets + + + +xr.set_options(keep_attrs=True) + + +def find_valid_t0_and_system_ids( + datasets_dict: dict, + config: Configuration, +) -> pd.DataFrame: + """Find the t0 times where all of the requested input data is available + + The idea is to + 1. Get valid time periods for nwp + 2. Get valid time periods for satellite + 3. Get valid time period for nwp and satellite + 4. For each site location, find valid periods for that location + + Args: + datasets_dict: A dictionary of input datasets + config: Configuration file + """ + + assert set(datasets_dict.keys()).issubset({"nwp", "sat", "site", "gsp"}) + + contiguous_time_periods: dict[str: pd.DataFrame] = {} # Used to store contiguous time periods from each data source + + # TODO refactor as this code is duplicated + if "nwp" in datasets_dict: + for nwp_key, nwp_config in config.input_data.nwp.items(): + + da = datasets_dict["nwp"][nwp_key] + + if nwp_config.dropout_timedeltas_minutes is None: + max_dropout = pd.to_timedelta(0, unit="m") + else: + max_dropout = pd.to_timedelta(np.max(np.abs(nwp_config.dropout_timedeltas_minutes)), unit="m") + + if nwp_config.max_staleness_minutes is None: + max_staleness = None + else: + max_staleness = pd.to_timedelta(nwp_config.max_staleness_minutes, unit="m") + + # The last step of the forecast is lost if we have to diff channels + if len(nwp_config.nwp_accum_channels) > 0: + end_buffer = pd.to_timedelta(nwp_config.time_resolution_minutes, unit="m") + else: + end_buffer = pd.to_timedelta(0, unit="m") + + # This is the max staleness we can use considering the max step of the input data + max_possible_staleness = ( + pd.Timedelta(da["step"].max().item()) + - pd.to_timedelta(nwp_config.forecast_minutes, unit='m') + - end_buffer + ) + + # Default to use max possible staleness unless specified in config + if max_staleness is None: + max_staleness = max_possible_staleness + else: + # Make sure the max acceptable staleness isn't longer than the max possible + assert max_staleness <= max_possible_staleness + + time_periods = find_contiguous_t0_periods_nwp( + datetimes=pd.DatetimeIndex(da["init_time_utc"]), + history_duration=pd.to_timedelta(nwp_config.history_minutes, unit="m"), + max_staleness=max_staleness, + max_dropout=max_dropout, + ) + + contiguous_time_periods[f'nwp_{nwp_key}'] = time_periods + + if "sat" in datasets_dict: + sat_config = config.input_data.satellite + + time_periods = find_contiguous_t0_periods( + pd.DatetimeIndex(datasets_dict["sat"]["time_utc"]), + sample_period_duration=pd.to_timedelta(sat_config.time_resolution_minutes, unit="m"), + history_duration=pd.to_timedelta(sat_config.history_minutes, unit="m"), + forecast_duration=pd.to_timedelta(sat_config.forecast_minutes, unit="m"), + ) + + contiguous_time_periods['sat'] = time_periods + + # just get the values (not the keys) + contiguous_time_periods_values = list(contiguous_time_periods.values()) + + # Find joint overlapping contiguous time periods + if len(contiguous_time_periods_values) > 1: + valid_time_periods = intersection_of_multiple_dataframes_of_periods( + contiguous_time_periods_values + ) + else: + valid_time_periods = contiguous_time_periods_values[0] + + # check there are some valid time periods + if len(valid_time_periods) == 0: + raise ValueError(f"No valid time periods found, {contiguous_time_periods=}") + + # 4. Now lets loop over each location in system id and find the valid periods + # Should we have a different option if there are not nans + sites = datasets_dict["site"] + system_ids = sites.system_id.values + site_config = config.input_data.site + valid_t0_and_system_ids = [] + for system_id in system_ids: + site = sites.sel(system_id=system_id) + + # drop any nan values + # not sure this is right? + site = site.dropna(dim='time_utc') + + # Get the valid time periods for this location + time_periods = find_contiguous_t0_periods( + pd.DatetimeIndex(site["time_utc"]), + sample_period_duration=pd.to_timedelta(site_config.time_resolution_minutes, unit="m"), + history_duration=pd.to_timedelta(site_config.history_minutes, unit="m"), + forecast_duration=pd.to_timedelta(site_config.forecast_minutes, unit="m"), + ) + valid_time_periods_per_site = intersection_of_multiple_dataframes_of_periods( + [valid_time_periods, time_periods] + ) + + # Fill out the contiguous time periods to get the t0 times + valid_t0_times_per_site = fill_time_periods( + valid_time_periods_per_site, + freq=pd.to_timedelta(site_config.time_resolution_minutes,unit='m') + ) + + valid_t0_per_site = pd.DataFrame(index=valid_t0_times_per_site) + valid_t0_per_site['system_id'] = system_id + valid_t0_and_system_ids.append(valid_t0_per_site) + + valid_t0_and_system_ids = pd.concat(valid_t0_and_system_ids) + valid_t0_and_system_ids.index.name = 't0' + valid_t0_and_system_ids.reset_index(inplace=True) + + print(valid_t0_and_system_ids) + + return valid_t0_and_system_ids + + +def get_locations(site_xr:xr.Dataset): + """Get list of locations of all sites""" + + locations = [] + for system_id in site_xr.system_id.values: + site = site_xr.sel(system_id=system_id) + location = Location( + id=system_id, + x=site.longitude.values, + y=site.latitude.values, + coordinate_system="lon_lat" + ) + locations.append(location) + + return locations + + +class SitesDataset(Dataset): + def __init__( + self, + config_filename: str, + start_time: str | None = None, + end_time: str | None = None, + gsp_ids: list[int] | None = None, + ): + """A torch Dataset for creating PVNet UK GSP samples + + Args: + config_filename: Path to the configuration file + start_time: Limit the init-times to be after this + end_time: Limit the init-times to be before this + gsp_ids: List of GSP IDs to create samples for. Defaults to all + """ + + config = load_yaml_configuration(config_filename) + + datasets_dict = get_dataset_dict(config) + + # get all locations + self.locations = get_locations(datasets_dict['site']) + + # Get t0 times where all input data is available + valid_t0_and_system_ids = find_valid_t0_and_system_ids(datasets_dict, config) + + # Filter t0 times to given range + + # Assign coords and indices to self + self.valid_t0_and_system_ids = valid_t0_and_system_ids + + # Assign config and input data to self + self.datasets_dict = datasets_dict + self.config = config + + def __len__(self): + return len(self.valid_t0_and_system_ids) + + def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict: + """Generate the PVNet sample for given coordinates + + Args: + t0: init-time for sample + location: location for sample + """ + sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config) + sample_dict = slice_datasets_by_time(sample_dict, t0, self.config) + sample_dict = compute(sample_dict) + + sample = process_and_combine_datasets(sample_dict, self.config, t0, location, sun_position_key='site') + + return sample + + def get_location_from_system_id(self, system_id): + """Get location from system id""" + + locations = [loc for loc in self.locations if loc.id == system_id] + if len(locations) == 0: + raise ValueError(f"Location not found for system_id {system_id}") + + if len(locations) > 1: + logging.warning(f"Multiple locations found for system_id {system_id}, but will take the first") + + return locations[0] + + def __getitem__(self, idx): + + # Get the coordinates of the sample + # TOD change to system ids + t0_and_system_id = self.valid_t0_and_system_ids.iloc[idx] + t0, system_id = t0_and_system_id + + # get location from system_id + location = self.get_location_from_system_id(system_id) + + # Generate the sample + return self._get_sample(t0, location) + + def get_sample(self, t0: pd.Timestamp, location: Location) -> dict: + """Generate a sample for the given coordinates. + + Useful for users to generate samples by t0 and location + + Args: + t0: init-time for sample + location: location object + """ + # Check the user has asked for a sample which we have the data for + # TODO + + return self._get_sample(t0, location) \ No newline at end of file diff --git a/ocf_data_sampler/torch_datasets/xarray_compute.py b/ocf_data_sampler/torch_datasets/xarray_compute.py new file mode 100644 index 0000000..2efaa53 --- /dev/null +++ b/ocf_data_sampler/torch_datasets/xarray_compute.py @@ -0,0 +1,8 @@ +def compute(xarray_dict: dict) -> dict: + """Eagerly load a nested dictionary of xarray DataArrays""" + for k, v in xarray_dict.items(): + if isinstance(v, dict): + xarray_dict[k] = compute(v) + else: + xarray_dict[k] = v.compute(scheduler="single-threaded") + return xarray_dict diff --git a/tests/conftest.py b/tests/conftest.py index b92b801..af62ba5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,8 @@ import xarray as xr import tempfile +from ocf_data_sampler.config.model import Sites + _top_test_directory = os.path.dirname(os.path.realpath(__file__)) @pytest.fixture() @@ -197,6 +199,67 @@ def ds_uk_gsp(): }) +@pytest.fixture(scope="session") +def data_sites() -> Sites: + """ + Make fake data for sites + Returns: filename for netcdf file, and csv metadata + """ + times = pd.date_range("2023-01-01 00:00", "2023-01-02 00:00", freq="30min") + system_ids = list(range(0,10)) + capacity_kwp_1d = np.array([0.1,1.1,4,6,8,9,15,2,3,4]) + # these are quite specific for the fake satellite data + longitude = np.arange(-4, -3, 0.1) + latitude = np.arange(51, 52, 0.1) + + generation = np.random.uniform(0, 200, size=(len(times), len(system_ids))).astype(np.float32) + + # repeat capacity in new dims len(times) times + capacity_kwp = (np.tile(capacity_kwp_1d, len(times))).reshape(len(times),10) + + coords = ( + ("time_utc", times), + ("system_id", system_ids), + ) + + + da_cap = xr.DataArray( + capacity_kwp, + coords=coords, + ) + + da_gen = xr.DataArray( + generation, + coords=coords, + ) + + # metadata + meta_df = pd.DataFrame(columns=[], data = []) + meta_df['system_id'] = system_ids + meta_df['capacity_kwp'] = capacity_kwp_1d + meta_df['longitude'] = longitude + meta_df['latitude'] = latitude + + generation = xr.Dataset({ + "capacity_kwp": da_cap, + "generation_kw":da_gen + }) + + with tempfile.TemporaryDirectory() as tmpdir: + filename = tmpdir + "/sites.netcdf" + filename_csv = tmpdir + "/sites_metadata.csv" + generation.to_netcdf(filename) + meta_df.to_csv(filename_csv) + + site = Sites(filename=filename, + metadata_filename=filename_csv, + time_resolution_minutes=30, + forecast_minutes=60, + history_minutes=30) + + yield site + + @pytest.fixture(scope="session") def uk_gsp_zarr_path(ds_uk_gsp): diff --git a/tests/load/test_load_sites.py b/tests/load/test_load_sites.py new file mode 100644 index 0000000..6cd46e1 --- /dev/null +++ b/tests/load/test_load_sites.py @@ -0,0 +1,14 @@ +from ocf_data_sampler.load.sites import open_sites +import xarray as xr + + +def test_open_site(data_sites): + da = open_sites(data_sites) + + assert isinstance(da, xr.DataArray) + assert da.dims == ("time_utc", "system_id") + + assert "capacity_kwp" in da.coords + assert "latitude" in da.coords + assert "longitude" in da.coords + assert da.shape == (49, 10) diff --git a/tests/torch_datasets/test_site.py b/tests/torch_datasets/test_site.py new file mode 100644 index 0000000..53fc13f --- /dev/null +++ b/tests/torch_datasets/test_site.py @@ -0,0 +1,60 @@ +import pytest +import tempfile + +from ocf_data_sampler.torch_datasets.site import SitesDataset +from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration +from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey +from ocf_data_sampler.numpy_batch.site import SiteBatchKey +from ocf_data_sampler.numpy_batch.satellite import SatelliteBatchKey + + +@pytest.fixture() +def site_config_filename(tmp_path, config_filename, nwp_ukv_zarr_path, sat_zarr_path, data_sites): + + # adjust config to point to the zarr file + config = load_yaml_configuration(config_filename) + config.input_data.nwp["ukv"].nwp_zarr_path = nwp_ukv_zarr_path + config.input_data.satellite.satellite_zarr_path = sat_zarr_path + config.input_data.site = data_sites + config.input_data.gsp = None + + filename = f"{tmp_path}/configuration.yaml" + save_yaml_configuration(config, filename) + return filename + + +def test_site(site_config_filename): + + # Create dataset object + dataset = SitesDataset(site_config_filename) + + assert len(dataset) == 10 * 41 + # TODO check 41 + + # Generate a sample + sample = dataset[0] + + assert isinstance(sample, dict) + + for key in [ + NWPBatchKey.nwp, + SatelliteBatchKey.satellite_actual, + SiteBatchKey.site, + SiteBatchKey.site_solar_azimuth, + SiteBatchKey.site_solar_elevation, + ]: + assert key in sample + + for nwp_source in ["ukv"]: + assert nwp_source in sample[NWPBatchKey.nwp] + + # check the shape of the data is correct + # 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels + assert sample[SatelliteBatchKey.satellite_actual].shape == (7, 1, 2, 2) + # 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels + assert sample[NWPBatchKey.nwp]["ukv"][NWPBatchKey.nwp].shape == (4, 1, 2, 2) + # 3 hours of 30 minute data (inclusive) + assert sample[SiteBatchKey.site].shape == (4,) + # Solar angles have same shape as GSP data + assert sample[SiteBatchKey.site_solar_azimuth].shape == (4,) + assert sample[SiteBatchKey.site_solar_elevation].shape == (4,) From 941cf032beef5ba7d2e491737161e409528bd624 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Fri, 25 Oct 2024 15:24:54 +0100 Subject: [PATCH 02/24] remove side_ids, rename sites --> site, system_id --> site_id --- ocf_data_sampler/config/model.py | 5 -- ocf_data_sampler/load/load_dataset.py | 4 +- ocf_data_sampler/load/{sites.py => site.py} | 4 +- .../select/space_slice_for_dataset.py | 2 +- ocf_data_sampler/torch_datasets/site.py | 51 +++++++++---------- tests/conftest.py | 8 +-- tests/load/test_load_sites.py | 6 +-- 7 files changed, 36 insertions(+), 44 deletions(-) rename ocf_data_sampler/load/{sites.py => site.py} (81%) diff --git a/ocf_data_sampler/config/model.py b/ocf_data_sampler/config/model.py index 3ba6f81..aba0076 100644 --- a/ocf_data_sampler/config/model.py +++ b/ocf_data_sampler/config/model.py @@ -114,11 +114,6 @@ class Sites(DataSourceMixin, TimeResolutionMixin, DropoutMixin): description="The CSV files describing wind system.", ) - # site_ids: List[int] = Field( - # None, - # description="List of the ML IDs of the Wind systems you'd like to filter to.", - # ) - @field_validator("forecast_minutes") def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int: """Check forecast length requested will give stable number of timesteps""" diff --git a/ocf_data_sampler/load/load_dataset.py b/ocf_data_sampler/load/load_dataset.py index 972a6e0..65addf3 100644 --- a/ocf_data_sampler/load/load_dataset.py +++ b/ocf_data_sampler/load/load_dataset.py @@ -4,7 +4,7 @@ from ocf_data_sampler.load.gsp import open_gsp from ocf_data_sampler.load.nwp import open_nwp from ocf_data_sampler.load.satellite import open_sat_data -from ocf_data_sampler.load.sites import open_sites +from ocf_data_sampler.load.site import open_site def get_dataset_dict(config: Configuration) -> dict[xr.DataArray, dict[xr.DataArray]]: @@ -48,7 +48,7 @@ def get_dataset_dict(config: Configuration) -> dict[xr.DataArray, dict[xr.DataAr datasets_dict["sat"] = da_sat if in_config.site: - da_sites = open_sites(in_config.site) + da_sites = open_site(in_config.site) datasets_dict["site"] = da_sites return datasets_dict diff --git a/ocf_data_sampler/load/sites.py b/ocf_data_sampler/load/site.py similarity index 81% rename from ocf_data_sampler/load/sites.py rename to ocf_data_sampler/load/site.py index 3487b60..3eb47cb 100644 --- a/ocf_data_sampler/load/sites.py +++ b/ocf_data_sampler/load/site.py @@ -4,14 +4,14 @@ from ocf_data_sampler.config.model import Sites -def open_sites(sites_config: Sites) -> xr.DataArray: +def open_site(sites_config: Sites) -> xr.DataArray: # Load site generation xr.Dataset data_ds = xr.open_dataset(sites_config.filename) # Load site generation data metadata_df = pd.read_csv(sites_config.metadata_filename) - metadata_df.set_index("system_id", inplace=True, drop=True) + metadata_df.set_index("site_id", inplace=True, drop=True) # Add coordinates ds = data_ds.assign_coords( diff --git a/ocf_data_sampler/select/space_slice_for_dataset.py b/ocf_data_sampler/select/space_slice_for_dataset.py index e6d7b80..556aac5 100644 --- a/ocf_data_sampler/select/space_slice_for_dataset.py +++ b/ocf_data_sampler/select/space_slice_for_dataset.py @@ -47,6 +47,6 @@ def slice_datasets_by_space( sliced_datasets_dict["gsp"] = datasets_dict["gsp"].sel(gsp_id=location.id) if "site" in datasets_dict: - sliced_datasets_dict["site"] = datasets_dict["site"].sel(system_id=location.id) + sliced_datasets_dict["site"] = datasets_dict["site"].sel(site_id=location.id) return sliced_datasets_dict diff --git a/ocf_data_sampler/torch_datasets/site.py b/ocf_data_sampler/torch_datasets/site.py index de4ba11..120d99b 100644 --- a/ocf_data_sampler/torch_datasets/site.py +++ b/ocf_data_sampler/torch_datasets/site.py @@ -27,7 +27,7 @@ xr.set_options(keep_attrs=True) -def find_valid_t0_and_system_ids( +def find_valid_t0_and_site_ids( datasets_dict: dict, config: Configuration, ) -> pd.DataFrame: @@ -123,11 +123,11 @@ def find_valid_t0_and_system_ids( # 4. Now lets loop over each location in system id and find the valid periods # Should we have a different option if there are not nans sites = datasets_dict["site"] - system_ids = sites.system_id.values + site_ids = sites.site_id.values site_config = config.input_data.site - valid_t0_and_system_ids = [] - for system_id in system_ids: - site = sites.sel(system_id=system_id) + valid_t0_and_site_ids = [] + for site_id in site_ids: + site = sites.sel(site_id=site_id) # drop any nan values # not sure this is right? @@ -151,26 +151,24 @@ def find_valid_t0_and_system_ids( ) valid_t0_per_site = pd.DataFrame(index=valid_t0_times_per_site) - valid_t0_per_site['system_id'] = system_id - valid_t0_and_system_ids.append(valid_t0_per_site) + valid_t0_per_site['site_id'] = site_id + valid_t0_and_site_ids.append(valid_t0_per_site) - valid_t0_and_system_ids = pd.concat(valid_t0_and_system_ids) - valid_t0_and_system_ids.index.name = 't0' - valid_t0_and_system_ids.reset_index(inplace=True) + valid_t0_and_site_ids = pd.concat(valid_t0_and_site_ids) + valid_t0_and_site_ids.index.name = 't0' + valid_t0_and_site_ids.reset_index(inplace=True) - print(valid_t0_and_system_ids) - - return valid_t0_and_system_ids + return valid_t0_and_site_ids def get_locations(site_xr:xr.Dataset): """Get list of locations of all sites""" locations = [] - for system_id in site_xr.system_id.values: - site = site_xr.sel(system_id=system_id) + for site_id in site_xr.site_id.values: + site = site_xr.sel(site_id=site_id) location = Location( - id=system_id, + id=site_id, x=site.longitude.values, y=site.latitude.values, coordinate_system="lon_lat" @@ -205,19 +203,19 @@ def __init__( self.locations = get_locations(datasets_dict['site']) # Get t0 times where all input data is available - valid_t0_and_system_ids = find_valid_t0_and_system_ids(datasets_dict, config) + valid_t0_and_site_ids = find_valid_t0_and_site_ids(datasets_dict, config) # Filter t0 times to given range # Assign coords and indices to self - self.valid_t0_and_system_ids = valid_t0_and_system_ids + self.valid_t0_and_site_ids = valid_t0_and_site_ids # Assign config and input data to self self.datasets_dict = datasets_dict self.config = config def __len__(self): - return len(self.valid_t0_and_system_ids) + return len(self.valid_t0_and_site_ids) def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict: """Generate the PVNet sample for given coordinates @@ -234,15 +232,15 @@ def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict: return sample - def get_location_from_system_id(self, system_id): + def get_location_from_site_id(self, site_id): """Get location from system id""" - locations = [loc for loc in self.locations if loc.id == system_id] + locations = [loc for loc in self.locations if loc.id == site_id] if len(locations) == 0: - raise ValueError(f"Location not found for system_id {system_id}") + raise ValueError(f"Location not found for site_id {site_id}") if len(locations) > 1: - logging.warning(f"Multiple locations found for system_id {system_id}, but will take the first") + logging.warning(f"Multiple locations found for site_id {site_id}, but will take the first") return locations[0] @@ -250,11 +248,10 @@ def __getitem__(self, idx): # Get the coordinates of the sample # TOD change to system ids - t0_and_system_id = self.valid_t0_and_system_ids.iloc[idx] - t0, system_id = t0_and_system_id + t0, site_id = self.valid_t0_and_site_ids.iloc[idx] - # get location from system_id - location = self.get_location_from_system_id(system_id) + # get location from site id + location = self.get_location_from_site_id(site_id) # Generate the sample return self._get_sample(t0, location) diff --git a/tests/conftest.py b/tests/conftest.py index af62ba5..6f574bf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -206,20 +206,20 @@ def data_sites() -> Sites: Returns: filename for netcdf file, and csv metadata """ times = pd.date_range("2023-01-01 00:00", "2023-01-02 00:00", freq="30min") - system_ids = list(range(0,10)) + site_ids = list(range(0,10)) capacity_kwp_1d = np.array([0.1,1.1,4,6,8,9,15,2,3,4]) # these are quite specific for the fake satellite data longitude = np.arange(-4, -3, 0.1) latitude = np.arange(51, 52, 0.1) - generation = np.random.uniform(0, 200, size=(len(times), len(system_ids))).astype(np.float32) + generation = np.random.uniform(0, 200, size=(len(times), len(site_ids))).astype(np.float32) # repeat capacity in new dims len(times) times capacity_kwp = (np.tile(capacity_kwp_1d, len(times))).reshape(len(times),10) coords = ( ("time_utc", times), - ("system_id", system_ids), + ("site_id", site_ids), ) @@ -235,7 +235,7 @@ def data_sites() -> Sites: # metadata meta_df = pd.DataFrame(columns=[], data = []) - meta_df['system_id'] = system_ids + meta_df['site_id'] = site_ids meta_df['capacity_kwp'] = capacity_kwp_1d meta_df['longitude'] = longitude meta_df['latitude'] = latitude diff --git a/tests/load/test_load_sites.py b/tests/load/test_load_sites.py index 6cd46e1..79b92a7 100644 --- a/tests/load/test_load_sites.py +++ b/tests/load/test_load_sites.py @@ -1,12 +1,12 @@ -from ocf_data_sampler.load.sites import open_sites +from ocf_data_sampler.load.site import open_site import xarray as xr def test_open_site(data_sites): - da = open_sites(data_sites) + da = open_site(data_sites) assert isinstance(da, xr.DataArray) - assert da.dims == ("time_utc", "system_id") + assert da.dims == ("time_utc", "site_id") assert "capacity_kwp" in da.coords assert "latitude" in da.coords From acdad11bbb468e674d79727ba3daa725e2f6510d Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Fri, 25 Oct 2024 15:47:58 +0100 Subject: [PATCH 03/24] add site dropout --- ocf_data_sampler/select/time_slice_for_dataset.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/ocf_data_sampler/select/time_slice_for_dataset.py b/ocf_data_sampler/select/time_slice_for_dataset.py index 769e495..110038d 100644 --- a/ocf_data_sampler/select/time_slice_for_dataset.py +++ b/ocf_data_sampler/select/time_slice_for_dataset.py @@ -115,6 +115,17 @@ def slice_datasets_by_time( interval_end=minutes(site_config.forecast_minutes), ) - # TODO add dropout? + # Randomly sample dropout + site_dropout_time = draw_dropout_time( + t0, + dropout_timedeltas=minutes(site_config.dropout_timedeltas_minutes), + dropout_frac=site_config.dropout_fraction, + ) + + # Apply the dropout + sliced_datasets_dict["site"] = apply_dropout_time( + sliced_datasets_dict["site"], + site_dropout_time, + ) return sliced_datasets_dict From 658f0fd8dc0366260dca9540b6b05cd744edde90 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Fri, 25 Oct 2024 16:09:00 +0100 Subject: [PATCH 04/24] refactor minutes transfer --- ocf_data_sampler/select/dropout.py | 3 +- .../select/space_slice_for_dataset.py | 1 + .../select/time_slice_for_dataset.py | 11 ++----- ocf_data_sampler/time_functions.py | 11 +++++++ .../torch_datasets/process_and_combine.py | 10 +------ .../torch_datasets/pvnet_uk_regional.py | 8 +---- ocf_data_sampler/torch_datasets/site.py | 30 +++++++++---------- 7 files changed, 33 insertions(+), 41 deletions(-) create mode 100644 ocf_data_sampler/time_functions.py diff --git a/ocf_data_sampler/select/dropout.py b/ocf_data_sampler/select/dropout.py index 2405546..e4795cb 100644 --- a/ocf_data_sampler/select/dropout.py +++ b/ocf_data_sampler/select/dropout.py @@ -1,3 +1,4 @@ +""" Functions for simulating dropout in time series data """ import numpy as np import pandas as pd import xarray as xr @@ -5,7 +6,7 @@ def draw_dropout_time( t0: pd.Timestamp, - dropout_timedeltas: list[pd.Timedelta] | None, + dropout_timedeltas: list[pd.Timedelta] | pd.Timedelta | None, dropout_frac: float = 0, ): diff --git a/ocf_data_sampler/select/space_slice_for_dataset.py b/ocf_data_sampler/select/space_slice_for_dataset.py index 556aac5..b94a16f 100644 --- a/ocf_data_sampler/select/space_slice_for_dataset.py +++ b/ocf_data_sampler/select/space_slice_for_dataset.py @@ -1,3 +1,4 @@ +""" Functions for selecting data around a given location """ from ocf_data_sampler.config import Configuration from ocf_data_sampler.select.location import Location from ocf_data_sampler.select.select_spatial_slice import select_spatial_slice_pixels diff --git a/ocf_data_sampler/select/time_slice_for_dataset.py b/ocf_data_sampler/select/time_slice_for_dataset.py index 110038d..2597157 100644 --- a/ocf_data_sampler/select/time_slice_for_dataset.py +++ b/ocf_data_sampler/select/time_slice_for_dataset.py @@ -1,17 +1,10 @@ +""" Slice datasets by time""" import pandas as pd from ocf_data_sampler.config import Configuration from ocf_data_sampler.select.dropout import draw_dropout_time, apply_dropout_time from ocf_data_sampler.select.select_time_slice import select_time_slice_nwp, select_time_slice - - -def minutes(minutes: list[float]): - """Timedelta minutes - - Args: - m: minutes - """ - return pd.to_timedelta(minutes, unit="m") +from ocf_data_sampler.time_functions import minutes def slice_datasets_by_time( diff --git a/ocf_data_sampler/time_functions.py b/ocf_data_sampler/time_functions.py new file mode 100644 index 0000000..bbb09fd --- /dev/null +++ b/ocf_data_sampler/time_functions.py @@ -0,0 +1,11 @@ +import pandas as pd + + +def minutes(minutes: int | list[float]) -> pd.Timedelta | pd.TimedeltaIndex: + """Timedelta minutes + + Args: + minutes: the number of minutes, single value or list + """ + minutes_delta = pd.to_timedelta(minutes, unit="m") + return minutes_delta diff --git a/ocf_data_sampler/torch_datasets/process_and_combine.py b/ocf_data_sampler/torch_datasets/process_and_combine.py index 73694be..3b3926e 100644 --- a/ocf_data_sampler/torch_datasets/process_and_combine.py +++ b/ocf_data_sampler/torch_datasets/process_and_combine.py @@ -15,15 +15,7 @@ from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey from ocf_data_sampler.select.geospatial import osgb_to_lon_lat from ocf_data_sampler.select.location import Location - - -def minutes(minutes: list[float]): - """Timedelta minutes - - Args: - m: minutes - """ - return pd.to_timedelta(minutes, unit="m") +from ocf_data_sampler.time_functions import minutes def process_and_combine_datasets( diff --git a/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py b/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py index 2da1741..d464c69 100644 --- a/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +++ b/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py @@ -21,16 +21,10 @@ from ocf_data_sampler.select.space_slice_for_dataset import slice_datasets_by_space from ocf_data_sampler.select.time_slice_for_dataset import slice_datasets_by_time from ocf_data_sampler.torch_datasets.xarray_compute import compute +from ocf_data_sampler.time_functions import minutes xr.set_options(keep_attrs=True) -def minutes(minutes: list[float]): - """Timedelta minutes - - Args: - m: minutes - """ - return pd.to_timedelta(minutes, unit="m") def find_valid_t0_times( diff --git a/ocf_data_sampler/torch_datasets/site.py b/ocf_data_sampler/torch_datasets/site.py index 120d99b..0de6a90 100644 --- a/ocf_data_sampler/torch_datasets/site.py +++ b/ocf_data_sampler/torch_datasets/site.py @@ -21,7 +21,7 @@ from ocf_data_sampler.select.space_slice_for_dataset import slice_datasets_by_space from ocf_data_sampler.torch_datasets.xarray_compute import compute from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets - +from ocf_data_sampler.time_functions import minutes xr.set_options(keep_attrs=True) @@ -55,25 +55,25 @@ def find_valid_t0_and_site_ids( da = datasets_dict["nwp"][nwp_key] if nwp_config.dropout_timedeltas_minutes is None: - max_dropout = pd.to_timedelta(0, unit="m") + max_dropout = minutes(0) else: - max_dropout = pd.to_timedelta(np.max(np.abs(nwp_config.dropout_timedeltas_minutes)), unit="m") + max_dropout = minutes(np.max(np.abs(nwp_config.dropout_timedeltas_minutes))) if nwp_config.max_staleness_minutes is None: max_staleness = None else: - max_staleness = pd.to_timedelta(nwp_config.max_staleness_minutes, unit="m") + max_staleness = minutes(nwp_config.max_staleness_minutes) # The last step of the forecast is lost if we have to diff channels if len(nwp_config.nwp_accum_channels) > 0: - end_buffer = pd.to_timedelta(nwp_config.time_resolution_minutes, unit="m") + end_buffer = pd.to_timedelta(nwp_config.time_resolution_minutes) else: - end_buffer = pd.to_timedelta(0, unit="m") + end_buffer =minutes(0) # This is the max staleness we can use considering the max step of the input data max_possible_staleness = ( pd.Timedelta(da["step"].max().item()) - - pd.to_timedelta(nwp_config.forecast_minutes, unit='m') + - minutes(nwp_config.forecast_minutes) - end_buffer ) @@ -86,7 +86,7 @@ def find_valid_t0_and_site_ids( time_periods = find_contiguous_t0_periods_nwp( datetimes=pd.DatetimeIndex(da["init_time_utc"]), - history_duration=pd.to_timedelta(nwp_config.history_minutes, unit="m"), + history_duration=minutes(nwp_config.history_minutes), max_staleness=max_staleness, max_dropout=max_dropout, ) @@ -98,9 +98,9 @@ def find_valid_t0_and_site_ids( time_periods = find_contiguous_t0_periods( pd.DatetimeIndex(datasets_dict["sat"]["time_utc"]), - sample_period_duration=pd.to_timedelta(sat_config.time_resolution_minutes, unit="m"), - history_duration=pd.to_timedelta(sat_config.history_minutes, unit="m"), - forecast_duration=pd.to_timedelta(sat_config.forecast_minutes, unit="m"), + sample_period_duration=minutes(sat_config.time_resolution_minutes), + history_duration=minutes(sat_config.history_minutes), + forecast_duration=minutes(sat_config.forecast_minutes), ) contiguous_time_periods['sat'] = time_periods @@ -136,9 +136,9 @@ def find_valid_t0_and_site_ids( # Get the valid time periods for this location time_periods = find_contiguous_t0_periods( pd.DatetimeIndex(site["time_utc"]), - sample_period_duration=pd.to_timedelta(site_config.time_resolution_minutes, unit="m"), - history_duration=pd.to_timedelta(site_config.history_minutes, unit="m"), - forecast_duration=pd.to_timedelta(site_config.forecast_minutes, unit="m"), + sample_period_duration=minutes(site_config.time_resolution_minutes), + history_duration=minutes(site_config.history_minutes), + forecast_duration=minutes(site_config.forecast_minutes), ) valid_time_periods_per_site = intersection_of_multiple_dataframes_of_periods( [valid_time_periods, time_periods] @@ -147,7 +147,7 @@ def find_valid_t0_and_site_ids( # Fill out the contiguous time periods to get the t0 times valid_t0_times_per_site = fill_time_periods( valid_time_periods_per_site, - freq=pd.to_timedelta(site_config.time_resolution_minutes,unit='m') + freq=minutes(site_config.time_resolution_minutes) ) valid_t0_per_site = pd.DataFrame(index=valid_t0_times_per_site) From 9ad097533fd934960687a05ab15aa55abc5d515b Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Fri, 25 Oct 2024 16:22:19 +0100 Subject: [PATCH 05/24] Add comment --- ocf_data_sampler/load/load_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ocf_data_sampler/load/load_dataset.py b/ocf_data_sampler/load/load_dataset.py index 65addf3..ad08e1c 100644 --- a/ocf_data_sampler/load/load_dataset.py +++ b/ocf_data_sampler/load/load_dataset.py @@ -1,3 +1,4 @@ +""" Loads all data sources """ import xarray as xr from ocf_data_sampler.config import Configuration From 0964a4802a75458db8c542174325cca6477499ae Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Mon, 28 Oct 2024 17:49:52 +0000 Subject: [PATCH 06/24] add Legacy support --- ocf_data_sampler/load/nwp/providers/ecmwf.py | 7 ++- ocf_data_sampler/load/site.py | 53 ++++++++++++++++++++ 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/ocf_data_sampler/load/nwp/providers/ecmwf.py b/ocf_data_sampler/load/nwp/providers/ecmwf.py index 9ba0ee4..1130f96 100755 --- a/ocf_data_sampler/load/nwp/providers/ecmwf.py +++ b/ocf_data_sampler/load/nwp/providers/ecmwf.py @@ -9,7 +9,6 @@ ) - def open_ifs(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArray: """ Opens the ECMWF IFS NWP data @@ -27,10 +26,14 @@ def open_ifs(zarr_path: Path | str | list[Path] | list[str]) -> xr.DataArray: ds = ds.rename( { "init_time": "init_time_utc", - "variable": "channel", } ) + # LEGACY SUPPORT + # rename variable to channel if it exists + if "variable" in ds: + ds = ds.rename({"variable": "channel"}) + # Check the timestamps are unique and increasing check_time_unique_increasing(ds.init_time_utc) diff --git a/ocf_data_sampler/load/site.py b/ocf_data_sampler/load/site.py index 3eb47cb..fddfa98 100644 --- a/ocf_data_sampler/load/site.py +++ b/ocf_data_sampler/load/site.py @@ -11,6 +11,10 @@ def open_site(sites_config: Sites) -> xr.DataArray: # Load site generation data metadata_df = pd.read_csv(sites_config.metadata_filename) + + # LEGACY SUPPORT + data_ds = legacy_format(data_ds, metadata_df) + metadata_df.set_index("site_id", inplace=True, drop=True) # Add coordinates @@ -21,3 +25,52 @@ def open_site(sites_config: Sites) -> xr.DataArray: ) return ds.generation_kw + + +def legacy_format(data_ds, metadata_df): + """This formats old legacy data to the new format. + + 1. This renames the columns in the metadata + 2. Re-formats the site data from data variables named by the site_id to + a data array with a site_id dimension + """ + + if "system_id" in metadata_df.columns: + metadata_df["site_id"] = metadata_df["system_id"] + + if "capacity_megawatts" in metadata_df.columns: + metadata_df["capacity_kwp"] = metadata_df["capacity_megawatts"] * 1000 + + # only site data has the site_id as data variables. + # We want to join them all together and create another variable canned site_id + if "0" in data_ds: + gen_df = data_ds.to_dataframe() + gen_da = xr.DataArray( + data=gen_df.values, + coords=( + ("time_utc", gen_df.index.values), + ("site_id", metadata_df["site_id"]), + ), + name="generation_kw", + ) + + capacity_df = gen_df + for col in capacity_df.columns: + capacity_df[col] = metadata_df[metadata_df["site_id"].astype(str) == col][ + "capacity_kwp" + ].iloc[0] + capacity_da = xr.DataArray( + data=capacity_df.values, + coords=( + ("time_utc", gen_df.index.values), + ("site_id", metadata_df["site_id"]), + ), + name="capacity_kwp", + ) + data_ds = xr.Dataset( + { + "generation_kw": gen_da, + "capacity_kwp": capacity_da, + } + ) + return data_ds From e4882d5df4fb082bff39918501d4dcb21b88b51f Mon Sep 17 00:00:00 2001 From: Peter Dudfield <34686298+peterdudfield@users.noreply.github.com> Date: Thu, 31 Oct 2024 13:56:53 +0000 Subject: [PATCH 07/24] Update ocf_data_sampler/config/model.py Co-authored-by: Alexandra Udaltsova <43303448+AUdaltsova@users.noreply.github.com> --- ocf_data_sampler/config/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ocf_data_sampler/config/model.py b/ocf_data_sampler/config/model.py index aba0076..f4d2015 100644 --- a/ocf_data_sampler/config/model.py +++ b/ocf_data_sampler/config/model.py @@ -111,7 +111,7 @@ class Sites(DataSourceMixin, TimeResolutionMixin, DropoutMixin): ) metadata_filename: str = Field( ..., - description="The CSV files describing wind system.", + description="The CSV files describing power system", ) @field_validator("forecast_minutes") From 021cf4133ec998c17e57068eddfec2245d0e43d7 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Thu, 31 Oct 2024 14:00:40 +0000 Subject: [PATCH 08/24] reanme Sites->Site --- ocf_data_sampler/config/model.py | 4 ++-- ocf_data_sampler/load/site.py | 4 ++-- tests/conftest.py | 17 ++++++++--------- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/ocf_data_sampler/config/model.py b/ocf_data_sampler/config/model.py index aba0076..5ac403b 100644 --- a/ocf_data_sampler/config/model.py +++ b/ocf_data_sampler/config/model.py @@ -102,7 +102,7 @@ class TimeResolutionMixin(Base): ) -class Sites(DataSourceMixin, TimeResolutionMixin, DropoutMixin): +class Site(DataSourceMixin, TimeResolutionMixin, DropoutMixin): """Site configuration model""" filename: str = Field( @@ -273,7 +273,7 @@ class InputData(Base): satellite: Optional[Satellite] = None nwp: Optional[MultiNWP] = None gsp: Optional[GSP] = None - site: Optional[Sites] = None + site: Optional[Site] = None class Configuration(Base): diff --git a/ocf_data_sampler/load/site.py b/ocf_data_sampler/load/site.py index fddfa98..3c5b7e9 100644 --- a/ocf_data_sampler/load/site.py +++ b/ocf_data_sampler/load/site.py @@ -1,10 +1,10 @@ import pandas as pd import xarray as xr -from ocf_data_sampler.config.model import Sites +from ocf_data_sampler.config.model import Site -def open_site(sites_config: Sites) -> xr.DataArray: +def open_site(sites_config: Site) -> xr.DataArray: # Load site generation xr.Dataset data_ds = xr.open_dataset(sites_config.filename) diff --git a/tests/conftest.py b/tests/conftest.py index 6f574bf..9d6649a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,7 @@ import xarray as xr import tempfile -from ocf_data_sampler.config.model import Sites +from ocf_data_sampler.config.model import Site _top_test_directory = os.path.dirname(os.path.realpath(__file__)) @@ -200,7 +200,7 @@ def ds_uk_gsp(): @pytest.fixture(scope="session") -def data_sites() -> Sites: +def data_sites() -> Site: """ Make fake data for sites Returns: filename for netcdf file, and csv metadata @@ -222,7 +222,6 @@ def data_sites() -> Sites: ("site_id", site_ids), ) - da_cap = xr.DataArray( capacity_kwp, coords=coords, @@ -242,7 +241,7 @@ def data_sites() -> Sites: generation = xr.Dataset({ "capacity_kwp": da_cap, - "generation_kw":da_gen + "generation_kw": da_gen, }) with tempfile.TemporaryDirectory() as tmpdir: @@ -251,11 +250,11 @@ def data_sites() -> Sites: generation.to_netcdf(filename) meta_df.to_csv(filename_csv) - site = Sites(filename=filename, - metadata_filename=filename_csv, - time_resolution_minutes=30, - forecast_minutes=60, - history_minutes=30) + site = Site(filename=filename, + metadata_filename=filename_csv, + time_resolution_minutes=30, + forecast_minutes=60, + history_minutes=30) yield site From 3a85983311d7fdb85783686a1599565b6851eae1 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Thu, 31 Oct 2024 14:02:44 +0000 Subject: [PATCH 09/24] remove legacy format to script --- ocf_data_sampler/load/site.py | 50 ----------------------------------- scripts/refactor_site.py | 50 +++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 50 deletions(-) create mode 100644 scripts/refactor_site.py diff --git a/ocf_data_sampler/load/site.py b/ocf_data_sampler/load/site.py index 3c5b7e9..1c5a0af 100644 --- a/ocf_data_sampler/load/site.py +++ b/ocf_data_sampler/load/site.py @@ -12,9 +12,6 @@ def open_site(sites_config: Site) -> xr.DataArray: # Load site generation data metadata_df = pd.read_csv(sites_config.metadata_filename) - # LEGACY SUPPORT - data_ds = legacy_format(data_ds, metadata_df) - metadata_df.set_index("site_id", inplace=True, drop=True) # Add coordinates @@ -27,50 +24,3 @@ def open_site(sites_config: Site) -> xr.DataArray: return ds.generation_kw -def legacy_format(data_ds, metadata_df): - """This formats old legacy data to the new format. - - 1. This renames the columns in the metadata - 2. Re-formats the site data from data variables named by the site_id to - a data array with a site_id dimension - """ - - if "system_id" in metadata_df.columns: - metadata_df["site_id"] = metadata_df["system_id"] - - if "capacity_megawatts" in metadata_df.columns: - metadata_df["capacity_kwp"] = metadata_df["capacity_megawatts"] * 1000 - - # only site data has the site_id as data variables. - # We want to join them all together and create another variable canned site_id - if "0" in data_ds: - gen_df = data_ds.to_dataframe() - gen_da = xr.DataArray( - data=gen_df.values, - coords=( - ("time_utc", gen_df.index.values), - ("site_id", metadata_df["site_id"]), - ), - name="generation_kw", - ) - - capacity_df = gen_df - for col in capacity_df.columns: - capacity_df[col] = metadata_df[metadata_df["site_id"].astype(str) == col][ - "capacity_kwp" - ].iloc[0] - capacity_da = xr.DataArray( - data=capacity_df.values, - coords=( - ("time_utc", gen_df.index.values), - ("site_id", metadata_df["site_id"]), - ), - name="capacity_kwp", - ) - data_ds = xr.Dataset( - { - "generation_kw": gen_da, - "capacity_kwp": capacity_da, - } - ) - return data_ds diff --git a/scripts/refactor_site.py b/scripts/refactor_site.py new file mode 100644 index 0000000..fc0f7fd --- /dev/null +++ b/scripts/refactor_site.py @@ -0,0 +1,50 @@ +""" Helper functions for refactoring sitethe site data """ + + +def legacy_format(data_ds, metadata_df): + """This formats old legacy data to the new format. + + 1. This renames the columns in the metadata + 2. Re-formats the site data from data variables named by the site_id to + a data array with a site_id dimension. Also adds capacity_kwp to the dataset as a time series for each site_id + """ + + if "system_id" in metadata_df.columns: + metadata_df["site_id"] = metadata_df["system_id"] + + if "capacity_megawatts" in metadata_df.columns: + metadata_df["capacity_kwp"] = metadata_df["capacity_megawatts"] * 1000 + + # only site data has the site_id as data variables. + # We want to join them all together and create another variable canned site_id + if "0" in data_ds: + gen_df = data_ds.to_dataframe() + gen_da = xr.DataArray( + data=gen_df.values, + coords=( + ("time_utc", gen_df.index.values), + ("site_id", metadata_df["site_id"]), + ), + name="generation_kw", + ) + + capacity_df = gen_df + for col in capacity_df.columns: + capacity_df[col] = metadata_df[metadata_df["site_id"].astype(str) == col][ + "capacity_kwp" + ].iloc[0] + capacity_da = xr.DataArray( + data=capacity_df.values, + coords=( + ("time_utc", gen_df.index.values), + ("site_id", metadata_df["site_id"]), + ), + name="capacity_kwp", + ) + data_ds = xr.Dataset( + { + "generation_kw": gen_da, + "capacity_kwp": capacity_da, + } + ) + return data_ds \ No newline at end of file From ae80011184c3dbef53d5f416ea552721622e922b Mon Sep 17 00:00:00 2001 From: Peter Dudfield <34686298+peterdudfield@users.noreply.github.com> Date: Mon, 4 Nov 2024 16:17:40 +0000 Subject: [PATCH 10/24] Update scripts/refactor_site.py Co-authored-by: Alexandra Udaltsova <43303448+AUdaltsova@users.noreply.github.com> --- scripts/refactor_site.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/refactor_site.py b/scripts/refactor_site.py index fc0f7fd..caa42da 100644 --- a/scripts/refactor_site.py +++ b/scripts/refactor_site.py @@ -16,7 +16,7 @@ def legacy_format(data_ds, metadata_df): metadata_df["capacity_kwp"] = metadata_df["capacity_megawatts"] * 1000 # only site data has the site_id as data variables. - # We want to join them all together and create another variable canned site_id + # We want to join them all together and create another coordinate called site_id if "0" in data_ds: gen_df = data_ds.to_dataframe() gen_da = xr.DataArray( From 972bb0fbeb29a2bce6cf3732b6bbbcaeb650a0b7 Mon Sep 17 00:00:00 2001 From: Peter Dudfield <34686298+peterdudfield@users.noreply.github.com> Date: Mon, 4 Nov 2024 16:18:38 +0000 Subject: [PATCH 11/24] Update ocf_data_sampler/torch_datasets/site.py Co-authored-by: Alexandra Udaltsova <43303448+AUdaltsova@users.noreply.github.com> --- ocf_data_sampler/torch_datasets/site.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ocf_data_sampler/torch_datasets/site.py b/ocf_data_sampler/torch_datasets/site.py index 0de6a90..5483a22 100644 --- a/ocf_data_sampler/torch_datasets/site.py +++ b/ocf_data_sampler/torch_datasets/site.py @@ -44,7 +44,7 @@ def find_valid_t0_and_site_ids( config: Configuration file """ - assert set(datasets_dict.keys()).issubset({"nwp", "sat", "site", "gsp"}) + assert set(datasets_dict.keys()).issubset({"nwp", "sat", "site"}) contiguous_time_periods: dict[str: pd.DataFrame] = {} # Used to store contiguous time periods from each data source From f930ff53854fdec0d308b06862fcb9418ddbc66c Mon Sep 17 00:00:00 2001 From: Peter Dudfield <34686298+peterdudfield@users.noreply.github.com> Date: Mon, 4 Nov 2024 16:18:54 +0000 Subject: [PATCH 12/24] Update ocf_data_sampler/torch_datasets/site.py Co-authored-by: Alexandra Udaltsova <43303448+AUdaltsova@users.noreply.github.com> --- ocf_data_sampler/torch_datasets/site.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ocf_data_sampler/torch_datasets/site.py b/ocf_data_sampler/torch_datasets/site.py index 5483a22..0cce14b 100644 --- a/ocf_data_sampler/torch_datasets/site.py +++ b/ocf_data_sampler/torch_datasets/site.py @@ -66,7 +66,7 @@ def find_valid_t0_and_site_ids( # The last step of the forecast is lost if we have to diff channels if len(nwp_config.nwp_accum_channels) > 0: - end_buffer = pd.to_timedelta(nwp_config.time_resolution_minutes) + end_buffer = minutes(nwp_config.time_resolution_minutes) else: end_buffer =minutes(0) From 7d00ab5eef8656159f3aaf4f6230bf2a215b4fbb Mon Sep 17 00:00:00 2001 From: Peter Dudfield <34686298+peterdudfield@users.noreply.github.com> Date: Mon, 4 Nov 2024 16:19:10 +0000 Subject: [PATCH 13/24] Update ocf_data_sampler/select/space_slice_for_dataset.py Co-authored-by: Alexandra Udaltsova <43303448+AUdaltsova@users.noreply.github.com> --- ocf_data_sampler/select/space_slice_for_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ocf_data_sampler/select/space_slice_for_dataset.py b/ocf_data_sampler/select/space_slice_for_dataset.py index b94a16f..d5d6c43 100644 --- a/ocf_data_sampler/select/space_slice_for_dataset.py +++ b/ocf_data_sampler/select/space_slice_for_dataset.py @@ -9,7 +9,7 @@ def slice_datasets_by_space( location: Location, config: Configuration, ) -> dict: - """Slice a dictionaries of input data sources around a given location + """Slice the dictionary of input data sources around a given location Args: datasets_dict: Dictionary of the input data sources From 7dc24328114e748ffba070235de865ed049a2e8e Mon Sep 17 00:00:00 2001 From: Peter Dudfield <34686298+peterdudfield@users.noreply.github.com> Date: Mon, 4 Nov 2024 16:19:34 +0000 Subject: [PATCH 14/24] Update ocf_data_sampler/torch_datasets/process_and_combine.py Co-authored-by: Alexandra Udaltsova <43303448+AUdaltsova@users.noreply.github.com> --- ocf_data_sampler/torch_datasets/process_and_combine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ocf_data_sampler/torch_datasets/process_and_combine.py b/ocf_data_sampler/torch_datasets/process_and_combine.py index 5b593dc..22c24e2 100644 --- a/ocf_data_sampler/torch_datasets/process_and_combine.py +++ b/ocf_data_sampler/torch_datasets/process_and_combine.py @@ -58,7 +58,7 @@ def process_and_combine_datasets( numpy_modalities.append( convert_gsp_to_numpy_batch( - da_gsp, t0_idx=gsp_config.history_minutes / gsp_config.time_resolution_minutes + da_gsp, t0_idx=gsp_config.history_minutes // gsp_config.time_resolution_minutes ) ) From e75f7ade7edc4c48a6f142b76910322870f72dd6 Mon Sep 17 00:00:00 2001 From: Peter Dudfield <34686298+peterdudfield@users.noreply.github.com> Date: Mon, 4 Nov 2024 16:20:37 +0000 Subject: [PATCH 15/24] Update ocf_data_sampler/load/load_dataset.py Co-authored-by: Alexandra Udaltsova <43303448+AUdaltsova@users.noreply.github.com> --- ocf_data_sampler/load/load_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ocf_data_sampler/load/load_dataset.py b/ocf_data_sampler/load/load_dataset.py index ad08e1c..53ad2d6 100644 --- a/ocf_data_sampler/load/load_dataset.py +++ b/ocf_data_sampler/load/load_dataset.py @@ -8,7 +8,7 @@ from ocf_data_sampler.load.site import open_site -def get_dataset_dict(config: Configuration) -> dict[xr.DataArray, dict[xr.DataArray]]: +def get_dataset_dict(config: Configuration) -> dict[str, dict[xr.DataArray]]: """Construct dictionary of all of the input data sources Args: From e7a7a22ee0a3c22801ccaa4f86ccfc8975f2d365 Mon Sep 17 00:00:00 2001 From: Peter Dudfield <34686298+peterdudfield@users.noreply.github.com> Date: Mon, 4 Nov 2024 16:22:41 +0000 Subject: [PATCH 16/24] Update scripts/refactor_site.py Co-authored-by: Sukhil Patel <42407101+Sukh-P@users.noreply.github.com> --- scripts/refactor_site.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/refactor_site.py b/scripts/refactor_site.py index caa42da..b3bf5ea 100644 --- a/scripts/refactor_site.py +++ b/scripts/refactor_site.py @@ -1,4 +1,4 @@ -""" Helper functions for refactoring sitethe site data """ +""" Helper functions for refactoring legacy site data """ def legacy_format(data_ds, metadata_df): From 3a666c4b140223a6ecb48105aa26b05c1bb437b7 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Mon, 4 Nov 2024 16:44:12 +0000 Subject: [PATCH 17/24] Some of PR comments --- ocf_data_sampler/config/model.py | 4 +-- ocf_data_sampler/load/site.py | 6 ++-- ocf_data_sampler/numpy_batch/site.py | 6 ++-- ocf_data_sampler/select/geospatial.py | 2 +- .../select/select_spatial_slice.py | 4 +-- ...ataset.py => spatial_slice_for_dataset.py} | 0 .../select/time_slice_for_dataset.py | 2 +- .../torch_datasets/process_and_combine.py | 10 ++++++ .../torch_datasets/pvnet_uk_regional.py | 5 ++- ocf_data_sampler/torch_datasets/site.py | 33 +++++++------------ .../torch_datasets/xarray_compute.py | 8 ----- tests/conftest.py | 4 +-- tests/torch_datasets/test_site.py | 4 +-- 13 files changed, 38 insertions(+), 50 deletions(-) rename ocf_data_sampler/select/{space_slice_for_dataset.py => spatial_slice_for_dataset.py} (100%) delete mode 100644 ocf_data_sampler/torch_datasets/xarray_compute.py diff --git a/ocf_data_sampler/config/model.py b/ocf_data_sampler/config/model.py index 2d7482b..05f8472 100644 --- a/ocf_data_sampler/config/model.py +++ b/ocf_data_sampler/config/model.py @@ -105,11 +105,11 @@ class TimeResolutionMixin(Base): class Site(DataSourceMixin, TimeResolutionMixin, DropoutMixin): """Site configuration model""" - filename: str = Field( + file_path: str = Field( ..., description="The NetCDF files holding the power timeseries.", ) - metadata_filename: str = Field( + metadata_file_path: str = Field( ..., description="The CSV files describing power system", ) diff --git a/ocf_data_sampler/load/site.py b/ocf_data_sampler/load/site.py index 1c5a0af..9df0a0c 100644 --- a/ocf_data_sampler/load/site.py +++ b/ocf_data_sampler/load/site.py @@ -7,12 +7,10 @@ def open_site(sites_config: Site) -> xr.DataArray: # Load site generation xr.Dataset - data_ds = xr.open_dataset(sites_config.filename) + data_ds = xr.open_dataset(sites_config.file_path) # Load site generation data - metadata_df = pd.read_csv(sites_config.metadata_filename) - - metadata_df.set_index("site_id", inplace=True, drop=True) + metadata_df = pd.read_csv(sites_config.metadata_file_path, index_col="site_id") # Add coordinates ds = data_ds.assign_coords( diff --git a/ocf_data_sampler/numpy_batch/site.py b/ocf_data_sampler/numpy_batch/site.py index 779e354..157039c 100644 --- a/ocf_data_sampler/numpy_batch/site.py +++ b/ocf_data_sampler/numpy_batch/site.py @@ -5,22 +5,20 @@ class SiteBatchKey: - site = "site" + generation = "site" site_capacity_kwp = "site_capacity_kwp" site_time_utc = "site_time_utc" site_t0_idx = "site_t0_idx" site_solar_azimuth = "site_solar_azimuth" site_solar_elevation = "site_solar_elevation" site_id = "site_id" - site_latitude = "site_latitude" - site_longitude = "site_longitude" def convert_site_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> dict: """Convert from Xarray to NumpyBatch""" example = { - SiteBatchKey.site: da.values, + SiteBatchKey.generation: da.values, SiteBatchKey.site_capacity_kwp: da.isel(time_utc=0)["capacity_kwp"].values, SiteBatchKey.site_time_utc: da["time_utc"].values.astype(float), } diff --git a/ocf_data_sampler/select/geospatial.py b/ocf_data_sampler/select/geospatial.py index 31ed9ca..b8a4411 100644 --- a/ocf_data_sampler/select/geospatial.py +++ b/ocf_data_sampler/select/geospatial.py @@ -55,7 +55,7 @@ def lon_lat_to_osgb( return _lon_lat_to_osgb(xx=x, yy=y) -def lat_lon_to_geostationary_area_coords( +def lon_lat_to_geostationary_area_coords( longitude: Union[Number, np.ndarray], latitude: Union[Number, np.ndarray], xr_data: xr.DataArray, diff --git a/ocf_data_sampler/select/select_spatial_slice.py b/ocf_data_sampler/select/select_spatial_slice.py index 0330dfe..03d93c6 100644 --- a/ocf_data_sampler/select/select_spatial_slice.py +++ b/ocf_data_sampler/select/select_spatial_slice.py @@ -8,7 +8,7 @@ from ocf_data_sampler.select.location import Location from ocf_data_sampler.select.geospatial import ( lon_lat_to_osgb, - lat_lon_to_geostationary_area_coords, + lon_lat_to_geostationary_area_coords, osgb_to_geostationary_area_coords, osgb_to_lon_lat, spatial_coord_type, @@ -120,7 +120,7 @@ def _get_idx_of_pixel_closest_to_poi_geostationary( if center.coordinate_system == 'osgb': x, y = osgb_to_geostationary_area_coords(x=center.x, y=center.y, xr_data=da) elif center.coordinate_system == 'lon_lat': - x, y = lat_lon_to_geostationary_area_coords(longitude=center.x, latitude=center.y, xr_data=da) + x, y = lon_lat_to_geostationary_area_coords(longitude=center.x, latitude=center.y, xr_data=da) else: x,y = center.x, center.y center_geostationary = Location(x=x, y=y, coordinate_system="geostationary") diff --git a/ocf_data_sampler/select/space_slice_for_dataset.py b/ocf_data_sampler/select/spatial_slice_for_dataset.py similarity index 100% rename from ocf_data_sampler/select/space_slice_for_dataset.py rename to ocf_data_sampler/select/spatial_slice_for_dataset.py diff --git a/ocf_data_sampler/select/time_slice_for_dataset.py b/ocf_data_sampler/select/time_slice_for_dataset.py index 2597157..2edf55a 100644 --- a/ocf_data_sampler/select/time_slice_for_dataset.py +++ b/ocf_data_sampler/select/time_slice_for_dataset.py @@ -12,7 +12,7 @@ def slice_datasets_by_time( t0: pd.Timestamp, config: Configuration, ) -> dict: - """Slice a dictionaries of input data sources around a given t0 time + """Slice the dictionary of input data sources around a given t0 time Args: datasets_dict: Dictionary of the input data sources diff --git a/ocf_data_sampler/torch_datasets/process_and_combine.py b/ocf_data_sampler/torch_datasets/process_and_combine.py index 22c24e2..2828bf7 100644 --- a/ocf_data_sampler/torch_datasets/process_and_combine.py +++ b/ocf_data_sampler/torch_datasets/process_and_combine.py @@ -141,3 +141,13 @@ def fill_nans_in_arrays(batch: dict) -> dict: fill_nans_in_arrays(v) return batch + + +def compute(xarray_dict: dict) -> dict: + """Eagerly load a nested dictionary of xarray DataArrays""" + for k, v in xarray_dict.items(): + if isinstance(v, dict): + xarray_dict[k] = compute(v) + else: + xarray_dict[k] = v.compute(scheduler="single-threaded") + return xarray_dict diff --git a/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py b/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py index d464c69..e0b0dbe 100644 --- a/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +++ b/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py @@ -17,10 +17,9 @@ from ocf_data_sampler.select.location import Location from ocf_data_sampler.load.load_dataset import get_dataset_dict -from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets -from ocf_data_sampler.select.space_slice_for_dataset import slice_datasets_by_space +from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets, compute +from ocf_data_sampler.select.spatial_slice_for_dataset import slice_datasets_by_space from ocf_data_sampler.select.time_slice_for_dataset import slice_datasets_by_time -from ocf_data_sampler.torch_datasets.xarray_compute import compute from ocf_data_sampler.time_functions import minutes xr.set_options(keep_attrs=True) diff --git a/ocf_data_sampler/torch_datasets/site.py b/ocf_data_sampler/torch_datasets/site.py index 0cce14b..babd28f 100644 --- a/ocf_data_sampler/torch_datasets/site.py +++ b/ocf_data_sampler/torch_datasets/site.py @@ -18,9 +18,8 @@ from ocf_data_sampler.load.load_dataset import get_dataset_dict from ocf_data_sampler.select.time_slice_for_dataset import slice_datasets_by_time -from ocf_data_sampler.select.space_slice_for_dataset import slice_datasets_by_space -from ocf_data_sampler.torch_datasets.xarray_compute import compute -from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets +from ocf_data_sampler.select.spatial_slice_for_dataset import slice_datasets_by_space +from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets, compute from ocf_data_sampler.time_functions import minutes @@ -184,15 +183,13 @@ def __init__( config_filename: str, start_time: str | None = None, end_time: str | None = None, - gsp_ids: list[int] | None = None, ): - """A torch Dataset for creating PVNet UK GSP samples + """A torch Dataset for creating PVNet Site samples Args: config_filename: Path to the configuration file start_time: Limit the init-times to be after this end_time: Limit the init-times to be before this - gsp_ids: List of GSP IDs to create samples for. Defaults to all """ config = load_yaml_configuration(config_filename) @@ -205,6 +202,15 @@ def __init__( # Get t0 times where all input data is available valid_t0_and_site_ids = find_valid_t0_and_site_ids(datasets_dict, config) + # Filter t0 times to given range + if start_time is not None: + valid_t0_and_site_ids \ + = valid_t0_and_site_ids[valid_t0_and_site_ids['t0'] >= pd.Timestamp(start_time)] + + if end_time is not None: + valid_t0_and_site_ids \ + = valid_t0_and_site_ids[valid_t0_and_site_ids['t0'] <= pd.Timestamp(end_time)] + # Filter t0 times to given range # Assign coords and indices to self @@ -247,7 +253,6 @@ def get_location_from_site_id(self, site_id): def __getitem__(self, idx): # Get the coordinates of the sample - # TOD change to system ids t0, site_id = self.valid_t0_and_site_ids.iloc[idx] # get location from site id @@ -255,17 +260,3 @@ def __getitem__(self, idx): # Generate the sample return self._get_sample(t0, location) - - def get_sample(self, t0: pd.Timestamp, location: Location) -> dict: - """Generate a sample for the given coordinates. - - Useful for users to generate samples by t0 and location - - Args: - t0: init-time for sample - location: location object - """ - # Check the user has asked for a sample which we have the data for - # TODO - - return self._get_sample(t0, location) \ No newline at end of file diff --git a/ocf_data_sampler/torch_datasets/xarray_compute.py b/ocf_data_sampler/torch_datasets/xarray_compute.py deleted file mode 100644 index 2efaa53..0000000 --- a/ocf_data_sampler/torch_datasets/xarray_compute.py +++ /dev/null @@ -1,8 +0,0 @@ -def compute(xarray_dict: dict) -> dict: - """Eagerly load a nested dictionary of xarray DataArrays""" - for k, v in xarray_dict.items(): - if isinstance(v, dict): - xarray_dict[k] = compute(v) - else: - xarray_dict[k] = v.compute(scheduler="single-threaded") - return xarray_dict diff --git a/tests/conftest.py b/tests/conftest.py index 9d6649a..2964d99 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -250,8 +250,8 @@ def data_sites() -> Site: generation.to_netcdf(filename) meta_df.to_csv(filename_csv) - site = Site(filename=filename, - metadata_filename=filename_csv, + site = Site(file_path=filename, + metadata_file_path=filename_csv, time_resolution_minutes=30, forecast_minutes=60, history_minutes=30) diff --git a/tests/torch_datasets/test_site.py b/tests/torch_datasets/test_site.py index 53fc13f..2115867 100644 --- a/tests/torch_datasets/test_site.py +++ b/tests/torch_datasets/test_site.py @@ -39,7 +39,7 @@ def test_site(site_config_filename): for key in [ NWPBatchKey.nwp, SatelliteBatchKey.satellite_actual, - SiteBatchKey.site, + SiteBatchKey.generation, SiteBatchKey.site_solar_azimuth, SiteBatchKey.site_solar_elevation, ]: @@ -54,7 +54,7 @@ def test_site(site_config_filename): # 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels assert sample[NWPBatchKey.nwp]["ukv"][NWPBatchKey.nwp].shape == (4, 1, 2, 2) # 3 hours of 30 minute data (inclusive) - assert sample[SiteBatchKey.site].shape == (4,) + assert sample[SiteBatchKey.generation].shape == (4,) # Solar angles have same shape as GSP data assert sample[SiteBatchKey.site_solar_azimuth].shape == (4,) assert sample[SiteBatchKey.site_solar_elevation].shape == (4,) From b59aee29f206db359312ac2c6d3fb68fb98afdd6 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Mon, 4 Nov 2024 17:24:45 +0000 Subject: [PATCH 18/24] refactor --- .../torch_datasets/pvnet_uk_regional.py | 95 +-------------- ocf_data_sampler/torch_datasets/site.py | 89 ++------------- .../torch_datasets/valid_time_periods.py | 108 ++++++++++++++++++ 3 files changed, 119 insertions(+), 173 deletions(-) create mode 100644 ocf_data_sampler/torch_datasets/valid_time_periods.py diff --git a/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py b/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py index e0b0dbe..d51b947 100644 --- a/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +++ b/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py @@ -6,10 +6,6 @@ from torch.utils.data import Dataset import pkg_resources -from ocf_data_sampler.select.find_contiguous_time_periods import ( - find_contiguous_t0_periods, find_contiguous_t0_periods_nwp, - intersection_of_multiple_dataframes_of_periods, -) from ocf_data_sampler.select.fill_time_periods import fill_time_periods from ocf_data_sampler.config import Configuration, load_yaml_configuration @@ -21,11 +17,11 @@ from ocf_data_sampler.select.spatial_slice_for_dataset import slice_datasets_by_space from ocf_data_sampler.select.time_slice_for_dataset import slice_datasets_by_time from ocf_data_sampler.time_functions import minutes +from ocf_data_sampler.torch_datasets.valid_time_periods import find_valid_time_periods xr.set_options(keep_attrs=True) - def find_valid_t0_times( datasets_dict: dict, config: Configuration, @@ -33,96 +29,11 @@ def find_valid_t0_times( """Find the t0 times where all of the requested input data is available Args: - datasets_dict: A dictionary of input datasets + datasets_dict: A dictionary of input datasets config: Configuration file """ - assert set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp"}) - - contiguous_time_periods: dict[str: pd.DataFrame] = {} # Used to store contiguous time periods from each data source - - if "nwp" in datasets_dict: - for nwp_key, nwp_config in config.input_data.nwp.items(): - - da = datasets_dict["nwp"][nwp_key] - - if nwp_config.dropout_timedeltas_minutes is None: - max_dropout = minutes(0) - else: - max_dropout = minutes(np.max(np.abs(nwp_config.dropout_timedeltas_minutes))) - - if nwp_config.max_staleness_minutes is None: - max_staleness = None - else: - max_staleness = minutes(nwp_config.max_staleness_minutes) - - # The last step of the forecast is lost if we have to diff channels - if len(nwp_config.nwp_accum_channels) > 0: - end_buffer = minutes(nwp_config.time_resolution_minutes) - else: - end_buffer = minutes(0) - - # This is the max staleness we can use considering the max step of the input data - max_possible_staleness = ( - pd.Timedelta(da["step"].max().item()) - - minutes(nwp_config.forecast_minutes) - - end_buffer - ) - - # Default to use max possible staleness unless specified in config - if max_staleness is None: - max_staleness = max_possible_staleness - else: - # Make sure the max acceptable staleness isn't longer than the max possible - assert max_staleness <= max_possible_staleness - - time_periods = find_contiguous_t0_periods_nwp( - datetimes=pd.DatetimeIndex(da["init_time_utc"]), - history_duration=minutes(nwp_config.history_minutes), - max_staleness=max_staleness, - max_dropout=max_dropout, - ) - - contiguous_time_periods[f'nwp_{nwp_key}'] = time_periods - - if "sat" in datasets_dict: - sat_config = config.input_data.satellite - - time_periods = find_contiguous_t0_periods( - pd.DatetimeIndex(datasets_dict["sat"]["time_utc"]), - sample_period_duration=minutes(sat_config.time_resolution_minutes), - history_duration=minutes(sat_config.history_minutes), - forecast_duration=minutes(sat_config.forecast_minutes), - ) - - contiguous_time_periods['sat'] = time_periods - - if "gsp" in datasets_dict: - gsp_config = config.input_data.gsp - - time_periods = find_contiguous_t0_periods( - pd.DatetimeIndex(datasets_dict["gsp"]["time_utc"]), - sample_period_duration=minutes(gsp_config.time_resolution_minutes), - history_duration=minutes(gsp_config.history_minutes), - forecast_duration=minutes(gsp_config.forecast_minutes), - ) - - contiguous_time_periods['gsp'] = time_periods - - # just get the values (not the keys) - contiguous_time_periods_values = list(contiguous_time_periods.values()) - - # Find joint overlapping contiguous time periods - if len(contiguous_time_periods_values) > 1: - valid_time_periods = intersection_of_multiple_dataframes_of_periods( - contiguous_time_periods_values - ) - else: - valid_time_periods = contiguous_time_periods_values[0] - - # check there are some valid time periods - if len(valid_time_periods) == 0: - raise ValueError(f"No valid time periods found, {contiguous_time_periods=}") + valid_time_periods = find_valid_time_periods(datasets_dict, config) # Fill out the contiguous time periods to get the t0 times valid_t0_times = fill_time_periods( diff --git a/ocf_data_sampler/torch_datasets/site.py b/ocf_data_sampler/torch_datasets/site.py index babd28f..209fd9b 100644 --- a/ocf_data_sampler/torch_datasets/site.py +++ b/ocf_data_sampler/torch_datasets/site.py @@ -21,6 +21,7 @@ from ocf_data_sampler.select.spatial_slice_for_dataset import slice_datasets_by_space from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets, compute from ocf_data_sampler.time_functions import minutes +from ocf_data_sampler.torch_datasets.valid_time_periods import find_valid_time_periods xr.set_options(keep_attrs=True) @@ -33,93 +34,19 @@ def find_valid_t0_and_site_ids( """Find the t0 times where all of the requested input data is available The idea is to - 1. Get valid time periods for nwp - 2. Get valid time periods for satellite - 3. Get valid time period for nwp and satellite - 4. For each site location, find valid periods for that location + 1. Get valid time period for nwp and satellite + 2. For each site location, find valid periods for that location Args: datasets_dict: A dictionary of input datasets config: Configuration file """ - assert set(datasets_dict.keys()).issubset({"nwp", "sat", "site"}) + # 1. Get valid time period for nwp and satellite + datasets_nwp_and_sat_dict = {"nwp": datasets_dict["nwp"], "sat": datasets_dict["sat"]} + valid_time_periods = find_valid_time_periods(datasets_nwp_and_sat_dict, config) - contiguous_time_periods: dict[str: pd.DataFrame] = {} # Used to store contiguous time periods from each data source - - # TODO refactor as this code is duplicated - if "nwp" in datasets_dict: - for nwp_key, nwp_config in config.input_data.nwp.items(): - - da = datasets_dict["nwp"][nwp_key] - - if nwp_config.dropout_timedeltas_minutes is None: - max_dropout = minutes(0) - else: - max_dropout = minutes(np.max(np.abs(nwp_config.dropout_timedeltas_minutes))) - - if nwp_config.max_staleness_minutes is None: - max_staleness = None - else: - max_staleness = minutes(nwp_config.max_staleness_minutes) - - # The last step of the forecast is lost if we have to diff channels - if len(nwp_config.nwp_accum_channels) > 0: - end_buffer = minutes(nwp_config.time_resolution_minutes) - else: - end_buffer =minutes(0) - - # This is the max staleness we can use considering the max step of the input data - max_possible_staleness = ( - pd.Timedelta(da["step"].max().item()) - - minutes(nwp_config.forecast_minutes) - - end_buffer - ) - - # Default to use max possible staleness unless specified in config - if max_staleness is None: - max_staleness = max_possible_staleness - else: - # Make sure the max acceptable staleness isn't longer than the max possible - assert max_staleness <= max_possible_staleness - - time_periods = find_contiguous_t0_periods_nwp( - datetimes=pd.DatetimeIndex(da["init_time_utc"]), - history_duration=minutes(nwp_config.history_minutes), - max_staleness=max_staleness, - max_dropout=max_dropout, - ) - - contiguous_time_periods[f'nwp_{nwp_key}'] = time_periods - - if "sat" in datasets_dict: - sat_config = config.input_data.satellite - - time_periods = find_contiguous_t0_periods( - pd.DatetimeIndex(datasets_dict["sat"]["time_utc"]), - sample_period_duration=minutes(sat_config.time_resolution_minutes), - history_duration=minutes(sat_config.history_minutes), - forecast_duration=minutes(sat_config.forecast_minutes), - ) - - contiguous_time_periods['sat'] = time_periods - - # just get the values (not the keys) - contiguous_time_periods_values = list(contiguous_time_periods.values()) - - # Find joint overlapping contiguous time periods - if len(contiguous_time_periods_values) > 1: - valid_time_periods = intersection_of_multiple_dataframes_of_periods( - contiguous_time_periods_values - ) - else: - valid_time_periods = contiguous_time_periods_values[0] - - # check there are some valid time periods - if len(valid_time_periods) == 0: - raise ValueError(f"No valid time periods found, {contiguous_time_periods=}") - - # 4. Now lets loop over each location in system id and find the valid periods + # 2. Now lets loop over each location in system id and find the valid periods # Should we have a different option if there are not nans sites = datasets_dict["site"] site_ids = sites.site_id.values @@ -160,7 +87,7 @@ def find_valid_t0_and_site_ids( return valid_t0_and_site_ids -def get_locations(site_xr:xr.Dataset): +def get_locations(site_xr: xr.Dataset): """Get list of locations of all sites""" locations = [] diff --git a/ocf_data_sampler/torch_datasets/valid_time_periods.py b/ocf_data_sampler/torch_datasets/valid_time_periods.py new file mode 100644 index 0000000..9ee93c1 --- /dev/null +++ b/ocf_data_sampler/torch_datasets/valid_time_periods.py @@ -0,0 +1,108 @@ +import numpy as np +import pandas as pd + +from ocf_data_sampler.config import Configuration +from ocf_data_sampler.select.find_contiguous_time_periods import find_contiguous_t0_periods_nwp, \ + find_contiguous_t0_periods, intersection_of_multiple_dataframes_of_periods +from ocf_data_sampler.time_functions import minutes + + +def find_valid_time_periods( + datasets_dict: dict, + config: Configuration, +): + """Find the t0 times where all of the requested input data is available + + Args: + datasets_dict: A dictionary of input datasets + config: Configuration file + """ + + assert set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp"}) + + contiguous_time_periods: dict[str: pd.DataFrame] = {} # Used to store contiguous time periods from each data source + + if "nwp" in datasets_dict: + for nwp_key, nwp_config in config.input_data.nwp.items(): + + da = datasets_dict["nwp"][nwp_key] + + if nwp_config.dropout_timedeltas_minutes is None: + max_dropout = minutes(0) + else: + max_dropout = minutes(np.max(np.abs(nwp_config.dropout_timedeltas_minutes))) + + if nwp_config.max_staleness_minutes is None: + max_staleness = None + else: + max_staleness = minutes(nwp_config.max_staleness_minutes) + + # The last step of the forecast is lost if we have to diff channels + if len(nwp_config.nwp_accum_channels) > 0: + end_buffer = minutes(nwp_config.time_resolution_minutes) + else: + end_buffer = minutes(0) + + # This is the max staleness we can use considering the max step of the input data + max_possible_staleness = ( + pd.Timedelta(da["step"].max().item()) + - minutes(nwp_config.forecast_minutes) + - end_buffer + ) + + # Default to use max possible staleness unless specified in config + if max_staleness is None: + max_staleness = max_possible_staleness + else: + # Make sure the max acceptable staleness isn't longer than the max possible + assert max_staleness <= max_possible_staleness + + time_periods = find_contiguous_t0_periods_nwp( + datetimes=pd.DatetimeIndex(da["init_time_utc"]), + history_duration=minutes(nwp_config.history_minutes), + max_staleness=max_staleness, + max_dropout=max_dropout, + ) + + contiguous_time_periods[f'nwp_{nwp_key}'] = time_periods + + if "sat" in datasets_dict: + sat_config = config.input_data.satellite + + time_periods = find_contiguous_t0_periods( + pd.DatetimeIndex(datasets_dict["sat"]["time_utc"]), + sample_period_duration=minutes(sat_config.time_resolution_minutes), + history_duration=minutes(sat_config.history_minutes), + forecast_duration=minutes(sat_config.forecast_minutes), + ) + + contiguous_time_periods['sat'] = time_periods + + if "gsp" in datasets_dict: + gsp_config = config.input_data.gsp + + time_periods = find_contiguous_t0_periods( + pd.DatetimeIndex(datasets_dict["gsp"]["time_utc"]), + sample_period_duration=minutes(gsp_config.time_resolution_minutes), + history_duration=minutes(gsp_config.history_minutes), + forecast_duration=minutes(gsp_config.forecast_minutes), + ) + + contiguous_time_periods['gsp'] = time_periods + + # just get the values (not the keys) + contiguous_time_periods_values = list(contiguous_time_periods.values()) + + # Find joint overlapping contiguous time periods + if len(contiguous_time_periods_values) > 1: + valid_time_periods = intersection_of_multiple_dataframes_of_periods( + contiguous_time_periods_values + ) + else: + valid_time_periods = contiguous_time_periods_values[0] + + # check there are some valid time periods + if len(valid_time_periods) == 0: + raise ValueError(f"No valid time periods found, {contiguous_time_periods=}") + + return valid_time_periods From 1d9317cba628dfe6586580213db4d491cbf84044 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Mon, 4 Nov 2024 17:46:38 +0000 Subject: [PATCH 19/24] add sanity checks --- ocf_data_sampler/load/site.py | 6 ++++++ ocf_data_sampler/torch_datasets/pvnet_uk_regional.py | 1 - 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/ocf_data_sampler/load/site.py b/ocf_data_sampler/load/site.py index 9df0a0c..a303fe2 100644 --- a/ocf_data_sampler/load/site.py +++ b/ocf_data_sampler/load/site.py @@ -1,5 +1,6 @@ import pandas as pd import xarray as xr +import numpy as np from ocf_data_sampler.config.model import Site @@ -19,6 +20,11 @@ def open_site(sites_config: Site) -> xr.DataArray: capacity_kwp=data_ds.capacity_kwp, ) + # Sanity checks + assert np.isfinite(data_ds.capacity_kwp.values).all() + assert (data_ds.capacity_kwp.values > 0).all() + assert metadata_df.index.is_unique + return ds.generation_kw diff --git a/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py b/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py index d51b947..d6ff362 100644 --- a/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +++ b/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py @@ -70,7 +70,6 @@ def get_gsp_locations(gsp_ids: list[int] | None = None) -> list[Location]: return locations - class PVNetUKRegionalDataset(Dataset): def __init__( self, From f9d49d331dea6d45bd51f6fe75a43fe7f6aa7e8b Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Mon, 4 Nov 2024 17:50:17 +0000 Subject: [PATCH 20/24] tidy imports --- ocf_data_sampler/select/__init__.py | 9 +++++++- .../torch_datasets/pvnet_uk_regional.py | 12 +++------- ocf_data_sampler/torch_datasets/site.py | 22 +++++++------------ 3 files changed, 19 insertions(+), 24 deletions(-) diff --git a/ocf_data_sampler/select/__init__.py b/ocf_data_sampler/select/__init__.py index 8b13789..2e6e5c6 100644 --- a/ocf_data_sampler/select/__init__.py +++ b/ocf_data_sampler/select/__init__.py @@ -1 +1,8 @@ - +from .fill_time_periods import fill_time_periods +from .find_contiguous_time_periods import ( + find_contiguous_t0_periods, + intersection_of_multiple_dataframes_of_periods, +) +from .location import Location +from .spatial_slice_for_dataset import slice_datasets_by_space +from .time_slice_for_dataset import slice_datasets_by_time diff --git a/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py b/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py index d6ff362..e9c1c42 100644 --- a/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +++ b/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py @@ -2,21 +2,15 @@ import numpy as np import pandas as pd +import pkg_resources import xarray as xr from torch.utils.data import Dataset -import pkg_resources - -from ocf_data_sampler.select.fill_time_periods import fill_time_periods from ocf_data_sampler.config import Configuration, load_yaml_configuration - -from ocf_data_sampler.select.location import Location - from ocf_data_sampler.load.load_dataset import get_dataset_dict -from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets, compute -from ocf_data_sampler.select.spatial_slice_for_dataset import slice_datasets_by_space -from ocf_data_sampler.select.time_slice_for_dataset import slice_datasets_by_time +from ocf_data_sampler.select import fill_time_periods, Location, slice_datasets_by_space, slice_datasets_by_time from ocf_data_sampler.time_functions import minutes +from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets, compute from ocf_data_sampler.torch_datasets.valid_time_periods import find_valid_time_periods xr.set_options(keep_attrs=True) diff --git a/ocf_data_sampler/torch_datasets/site.py b/ocf_data_sampler/torch_datasets/site.py index 209fd9b..e595b44 100644 --- a/ocf_data_sampler/torch_datasets/site.py +++ b/ocf_data_sampler/torch_datasets/site.py @@ -1,29 +1,23 @@ """Torch dataset for sites""" import logging -import numpy as np import pandas as pd import xarray as xr from torch.utils.data import Dataset -from ocf_data_sampler.select.find_contiguous_time_periods import ( - find_contiguous_t0_periods, find_contiguous_t0_periods_nwp, - intersection_of_multiple_dataframes_of_periods, -) -from ocf_data_sampler.select.fill_time_periods import fill_time_periods - from ocf_data_sampler.config import Configuration, load_yaml_configuration - -from ocf_data_sampler.select.location import Location - from ocf_data_sampler.load.load_dataset import get_dataset_dict -from ocf_data_sampler.select.time_slice_for_dataset import slice_datasets_by_time -from ocf_data_sampler.select.spatial_slice_for_dataset import slice_datasets_by_space -from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets, compute +from ocf_data_sampler.select import ( + Location, + fill_time_periods, + find_contiguous_t0_periods, + intersection_of_multiple_dataframes_of_periods, + slice_datasets_by_time, slice_datasets_by_space +) from ocf_data_sampler.time_functions import minutes +from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets, compute from ocf_data_sampler.torch_datasets.valid_time_periods import find_valid_time_periods - xr.set_options(keep_attrs=True) From 8660732248bc8c100ac1bb7605e40a38aa2e44eb Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Mon, 4 Nov 2024 17:52:45 +0000 Subject: [PATCH 21/24] add tests --- tests/torch_datasets/test_site.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/torch_datasets/test_site.py b/tests/torch_datasets/test_site.py index 2115867..f0307e6 100644 --- a/tests/torch_datasets/test_site.py +++ b/tests/torch_datasets/test_site.py @@ -58,3 +58,19 @@ def test_site(site_config_filename): # Solar angles have same shape as GSP data assert sample[SiteBatchKey.site_solar_azimuth].shape == (4,) assert sample[SiteBatchKey.site_solar_elevation].shape == (4,) + + +def test_site_time_filter_start(site_config_filename): + + # Create dataset object + dataset = SitesDataset(site_config_filename, start_time="2024-01-01") + + assert len(dataset) == 0 + + +def test_site_time_filter_end(site_config_filename): + + # Create dataset object + dataset = SitesDataset(site_config_filename, end_time="2000-01-01") + + assert len(dataset) == 0 From 61f697b73e8edaab98a8272eebbdccd36c9945e2 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Mon, 4 Nov 2024 17:57:28 +0000 Subject: [PATCH 22/24] add get_sample back in + test --- ocf_data_sampler/torch_datasets/site.py | 14 ++++++++++++++ tests/torch_datasets/test_site.py | 11 ++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/ocf_data_sampler/torch_datasets/site.py b/ocf_data_sampler/torch_datasets/site.py index e595b44..885a861 100644 --- a/ocf_data_sampler/torch_datasets/site.py +++ b/ocf_data_sampler/torch_datasets/site.py @@ -181,3 +181,17 @@ def __getitem__(self, idx): # Generate the sample return self._get_sample(t0, location) + + def get_sample(self, t0: pd.Timestamp, site_id: int) -> dict: + """Generate a sample for a given site id and t0. + + Useful for users to generate samples by t0 and site id + + Args: + t0: init-time for sample + site_id: location object + """ + + location = self.get_location_from_site_id(site_id) + + return self._get_sample(t0, location) diff --git a/tests/torch_datasets/test_site.py b/tests/torch_datasets/test_site.py index f0307e6..ab27763 100644 --- a/tests/torch_datasets/test_site.py +++ b/tests/torch_datasets/test_site.py @@ -1,5 +1,5 @@ +import pandas as pd import pytest -import tempfile from ocf_data_sampler.torch_datasets.site import SitesDataset from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration @@ -74,3 +74,12 @@ def test_site_time_filter_end(site_config_filename): dataset = SitesDataset(site_config_filename, end_time="2000-01-01") assert len(dataset) == 0 + + +def test_site_get_sample(site_config_filename): + + # Create dataset object + dataset = SitesDataset(site_config_filename) + + assert len(dataset) == 410 + sample = dataset.get_sample(t0=pd.Timestamp("2023-01-01 12:00"), site_id=1) From 7c03a8110cf42682e37ef4939148d45e4cb64799 Mon Sep 17 00:00:00 2001 From: Peter Dudfield <34686298+peterdudfield@users.noreply.github.com> Date: Tue, 5 Nov 2024 12:02:30 +0000 Subject: [PATCH 23/24] Update ocf_data_sampler/torch_datasets/site.py Co-authored-by: Alexandra Udaltsova <43303448+AUdaltsova@users.noreply.github.com> --- ocf_data_sampler/torch_datasets/site.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ocf_data_sampler/torch_datasets/site.py b/ocf_data_sampler/torch_datasets/site.py index 885a861..937a4db 100644 --- a/ocf_data_sampler/torch_datasets/site.py +++ b/ocf_data_sampler/torch_datasets/site.py @@ -189,7 +189,7 @@ def get_sample(self, t0: pd.Timestamp, site_id: int) -> dict: Args: t0: init-time for sample - site_id: location object + site_id: site id as int """ location = self.get_location_from_site_id(site_id) From 158fee1c59a0daaec4af1becec46e86ad2c55fee Mon Sep 17 00:00:00 2001 From: Peter Dudfield <34686298+peterdudfield@users.noreply.github.com> Date: Tue, 5 Nov 2024 12:02:36 +0000 Subject: [PATCH 24/24] Update ocf_data_sampler/torch_datasets/site.py Co-authored-by: Alexandra Udaltsova <43303448+AUdaltsova@users.noreply.github.com> --- ocf_data_sampler/torch_datasets/site.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ocf_data_sampler/torch_datasets/site.py b/ocf_data_sampler/torch_datasets/site.py index 937a4db..b92f821 100644 --- a/ocf_data_sampler/torch_datasets/site.py +++ b/ocf_data_sampler/torch_datasets/site.py @@ -132,7 +132,6 @@ def __init__( valid_t0_and_site_ids \ = valid_t0_and_site_ids[valid_t0_and_site_ids['t0'] <= pd.Timestamp(end_time)] - # Filter t0 times to given range # Assign coords and indices to self self.valid_t0_and_site_ids = valid_t0_and_site_ids