From eb13a9ad32a30367f812bb3fa057fe58fefbffb9 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Thu, 25 Jul 2024 20:05:19 +0100 Subject: [PATCH 1/7] remove satellite datapipe, remove metnet --- ocf_datapipes/load/__init__.py | 3 +- ocf_datapipes/load/satellite.py | 20 - ocf_datapipes/training/common.py | 28 +- .../training/example/gsp_pv_nwp_satellite.py | 10 +- .../training/metnet/metnet_gsp_national.py | 188 ---------- .../training/metnet/metnet_national.md | 35 -- .../training/metnet/metnet_national.py | 341 ----------------- .../training/metnet/metnet_national_class.py | 330 ---------------- .../training/metnet/metnet_preprocessor.py | 351 ------------------ .../training/metnet/metnet_pv_national.py | 346 ----------------- .../training/metnet/metnet_pv_site.py | 213 ----------- tests/conftest.py | 9 +- tests/load/test_load_satellite.py | 17 +- .../metnet/test_metnet_gsp_national.py | 30 -- tests/training/metnet/test_metnet_national.py | 18 - .../metnet/test_metnet_preprocessor.py | 161 -------- .../metnet/test_metnet_pv_national.py | 19 - tests/training/metnet/test_metnet_pv_site.py | 17 - tests/transform/numpy_batch/conftest.py | 7 +- 19 files changed, 45 insertions(+), 2098 deletions(-) delete mode 100644 ocf_datapipes/training/metnet/metnet_gsp_national.py delete mode 100644 ocf_datapipes/training/metnet/metnet_national.md delete mode 100644 ocf_datapipes/training/metnet/metnet_national.py delete mode 100644 ocf_datapipes/training/metnet/metnet_national_class.py delete mode 100644 ocf_datapipes/training/metnet/metnet_preprocessor.py delete mode 100644 ocf_datapipes/training/metnet/metnet_pv_national.py delete mode 100644 ocf_datapipes/training/metnet/metnet_pv_site.py delete mode 100644 tests/training/metnet/test_metnet_gsp_national.py delete mode 100644 tests/training/metnet/test_metnet_national.py delete mode 100644 tests/training/metnet/test_metnet_preprocessor.py delete mode 100644 tests/training/metnet/test_metnet_pv_national.py delete mode 100644 tests/training/metnet/test_metnet_pv_site.py diff --git a/ocf_datapipes/load/__init__.py b/ocf_datapipes/load/__init__.py index 3061528b8..5c2eef9e6 100644 --- a/ocf_datapipes/load/__init__.py +++ b/ocf_datapipes/load/__init__.py @@ -13,7 +13,8 @@ from .configuration import OpenConfigurationIterDataPipe as OpenConfiguration from .nwp.nwp import OpenNWPIterDataPipe as OpenNWP -from .satellite import OpenSatelliteIterDataPipe as OpenSatellite +# from .satellite import OpenSatelliteIterDataPipe as OpenSatellite +from .satellite import open_sat_data try: import rioxarray # Rioxarray is sometimes a pain to install, so only load this if its installed diff --git a/ocf_datapipes/load/satellite.py b/ocf_datapipes/load/satellite.py index a61688e9c..fb2818d76 100644 --- a/ocf_datapipes/load/satellite.py +++ b/ocf_datapipes/load/satellite.py @@ -164,23 +164,3 @@ def open_sat_data(zarr_path: Union[Path, str, list[Path], list[str]]) -> xr.Data return data_array - -@functional_datapipe("open_satellite") -class OpenSatelliteIterDataPipe(IterDataPipe): - """Open Satellite Zarr""" - - def __init__(self, zarr_path: Union[Path, str]): - """ - Opens the satellite Zarr - - Args: - zarr_path: path to the zarr file - """ - self.zarr_path = zarr_path - super().__init__() - - def __iter__(self) -> xr.DataArray: - """Open the Zarr file""" - data: xr.DataArray = open_sat_data(zarr_path=self.zarr_path) - while True: - yield data diff --git a/ocf_datapipes/training/common.py b/ocf_datapipes/training/common.py index d5f0d3d95..c3d93cdc5 100644 --- a/ocf_datapipes/training/common.py +++ b/ocf_datapipes/training/common.py @@ -19,7 +19,7 @@ OpenNWP, OpenPVFromNetCDF, OpenPVFromPVSitesDB, - OpenSatellite, + open_sat_data, OpenWindFromNetCDF, ) from ocf_datapipes.utils.utils import flatten_nwp_source_dict @@ -38,6 +38,17 @@ logger = logging.getLogger(__name__) +@functional_datapipe("fake_iter") +class FakeIter(IterDataPipe): + def __init__(self, data_xr): + self.data_xr = data_xr + + def __iter__(self) -> xr.DataArray: + while True: + yield self.data_xr + + + def is_config_and_path_valid( use_flag: bool, config, @@ -171,8 +182,11 @@ def open_and_return_datapipes( if use_sat: logger.debug("Opening Satellite Data") - sat_datapipe = ( - OpenSatellite(configuration.input_data.satellite.satellite_zarr_path) + + sat_xr = open_sat_data(configuration.input_data.satellite.satellite_zarr_path) + sat_pipe = FakeIter(sat_xr) + + sat_datapipe = (sat_pipe .filter_channels(configuration.input_data.satellite.satellite_channels) .add_t0_idx_and_sample_period_duration( sample_period_duration=minutes( @@ -186,9 +200,11 @@ def open_and_return_datapipes( if use_hrv: logger.debug("Opening HRV Satellite Data") - sat_hrv_datapipe = OpenSatellite( - configuration.input_data.hrvsatellite.hrvsatellite_zarr_path - ).add_t0_idx_and_sample_period_duration( + + sat_xr = open_sat_data(configuration.input_data.satellite.satellite_zarr_path) + sat_pipe = FakeIter(sat_xr) + + sat_hrv_datapipe = sat_pipe.add_t0_idx_and_sample_period_duration( sample_period_duration=minutes( configuration.input_data.hrvsatellite.time_resolution_minutes ), diff --git a/ocf_datapipes/training/example/gsp_pv_nwp_satellite.py b/ocf_datapipes/training/example/gsp_pv_nwp_satellite.py index 9ac221e23..3b4a004b4 100644 --- a/ocf_datapipes/training/example/gsp_pv_nwp_satellite.py +++ b/ocf_datapipes/training/example/gsp_pv_nwp_satellite.py @@ -12,8 +12,8 @@ from ocf_datapipes.batch import MergeNumpyModalities, MergeNWPNumpyModalities from ocf_datapipes.config.load import load_yaml_configuration from ocf_datapipes.config.model import Configuration -from ocf_datapipes.load import OpenGSP, OpenNWP, OpenPVFromNetCDF, OpenSatellite -from ocf_datapipes.training.common import normalize_gsp, normalize_pv +from ocf_datapipes.load import OpenGSP, OpenNWP, OpenPVFromNetCDF, open_sat_data +from ocf_datapipes.training.common import normalize_gsp, normalize_pv, FakeIter from ocf_datapipes.utils.consts import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD logger = logging.getLogger(__name__) @@ -53,9 +53,9 @@ def gsp_pv_nwp_satellite_data_pipeline(configuration: Union[Path, str]) -> IterD ) # Load and noralize satellite data - satellite_datapipe = OpenSatellite( - zarr_path=configuration.input_data.satellite.satellite_zarr_path - ).normalize(mean=RSS_MEAN, std=RSS_STD) + sat_xr = open_sat_data(zarr_path=configuration.input_data.satellite.satellite_zarr_path) + satellite_datapipe = FakeIter(sat_xr) + satellite_datapipe = satellite_datapipe.normalize(mean=RSS_MEAN, std=RSS_STD) # Load and normalize NWP data - There may be multiple NWP sources nwp_datapipe_dict = {} diff --git a/ocf_datapipes/training/metnet/metnet_gsp_national.py b/ocf_datapipes/training/metnet/metnet_gsp_national.py deleted file mode 100644 index 50cdb6216..000000000 --- a/ocf_datapipes/training/metnet/metnet_gsp_national.py +++ /dev/null @@ -1,188 +0,0 @@ -"""Create the training/validation datapipe for training the national MetNet/-2 Model""" - -import datetime -import logging -from pathlib import Path -from typing import Union - -import xarray -from torch.utils.data.datapipes.datapipe import IterDataPipe - -from ocf_datapipes.convert import ConvertGSPToNumpy -from ocf_datapipes.select import FilterGSPIDs, PickLocations -from ocf_datapipes.training.common import ( - add_selected_time_slices_from_datapipes, - get_and_return_overlapping_time_periods_and_t0, - open_and_return_datapipes, -) -from ocf_datapipes.training.metnet.metnet_preprocessor import ( - PreProcessMetNetIterDataPipe as PreProcessMetNet, -) -from ocf_datapipes.utils.consts import RSS_MEAN, RSS_STD, UKV_MEAN, UKV_STD - -xarray.set_options(keep_attrs=True) -logger = logging.getLogger("metnet_datapipe") -logger.setLevel(logging.DEBUG) - - -def normalize_gsp(x): # So it can be pickled - """ - Normalize the GSP data - - Args: - x: Input DataArray - - Returns: - Normalized DataArray - """ - return x / x.nominal_capacity_mwp - - -def normalize_pv(x): # So it can be pickled - """ - Normalize the GSP data - - Args: - x: Input DataArray - - Returns: - Normalized DataArray - """ - return x / x.observed_capacity_wp - - -def _select_non_nan_times(x): - return x.fillna(0.0) - - -def metnet_national_datapipe( - configuration_filename: Union[Path, str], - use_sun: bool = True, - use_nwp: bool = True, - use_sat: bool = True, - use_hrv: bool = True, - use_pv: bool = False, - use_gsp: bool = True, - use_topo: bool = True, - output_size: int = 256, - gsp_in_image: bool = False, - start_time: datetime.datetime = datetime.datetime(2014, 1, 1), - end_time: datetime.datetime = datetime.datetime(2023, 1, 1), -) -> IterDataPipe: - """ - Make GSP national data pipe - - Currently only has GSP and NWP's in them - - Args: - configuration_filename: the configruation filename for the pipe - use_sun: Whether to add sun features or not - use_pv: Whether to use PV input or not - use_hrv: Whether to use HRV Satellite or not - use_sat: Whether to use non-HRV Satellite or not - use_nwp: Whether to use NWP or not - use_topo: Whether to use topographic map or not - use_gsp: Whether to use GSP history - start_time: Start time to select on - end_time: End time to select from - output_size: Size, in pixels, of the output image - gsp_in_image: Add GSP history as channels in MetNet image - - Returns: datapipe - """ - - # load datasets - used_datapipes = open_and_return_datapipes( - configuration_filename=configuration_filename, - use_nwp=use_nwp, - use_topo=use_topo, - use_sat=use_sat, - use_hrv=use_hrv, - use_gsp=use_gsp, - use_pv=use_pv, - ) - # Load GSP national data - used_datapipes["gsp"] = used_datapipes["gsp"].filter_times(start_time, end_time) - - # Now get overlapping time periods - used_datapipes = get_and_return_overlapping_time_periods_and_t0(used_datapipes) - - # And now get time slices - used_datapipes = add_selected_time_slices_from_datapipes(used_datapipes) - - # Now do the extra processing - gsp_history = used_datapipes["gsp"].normalize(normalize_fn=normalize_gsp) - gsp_datapipe = used_datapipes["gsp_future"].normalize(normalize_fn=normalize_gsp) - # Split into GSP for target, only national, and one for history - gsp_datapipe = FilterGSPIDs(gsp_datapipe, gsps_to_keep=[0]) - - if "nwp" in used_datapipes.keys(): - # take nwp time slices - logger.debug("Take NWP time slices") - nwp_datapipe = used_datapipes["nwp"].normalize(mean=UKV_MEAN, std=UKV_STD) - - if "sat" in used_datapipes.keys(): - logger.debug("Take Satellite time slices") - # take sat time slices - sat_datapipe = used_datapipes["sat"].normalize(mean=RSS_MEAN, std=RSS_STD) - - if "hrv" in used_datapipes.keys(): - logger.debug("Take HRV Satellite time slices") - sat_hrv_datapipe = used_datapipes["hrv"].normalize( - mean=RSS_MEAN.sel(channel="HRV"), std=RSS_STD.sel(channel="HRV") - ) - - if "topo" in used_datapipes.keys(): - topo_datapipe = used_datapipes["topo"].map(_select_non_nan_times) - - # Now combine in the MetNet format - modalities = [] - if gsp_in_image and "hrv" in used_datapipes.keys(): - sat_hrv_datapipe, sat_gsp_datapipe = sat_hrv_datapipe.fork(2) - gsp_history = gsp_history.filter_gsp_ids(gsps_to_keep=[0]).create_gsp_image( - image_datapipe=sat_gsp_datapipe - ) - elif gsp_in_image and "sat" in used_datapipes.keys(): - sat_datapipe, sat_gsp_datapipe = sat_datapipe.fork(2) - gsp_history = gsp_history.filter_gsp_ids(gsps_to_keep=[0]).create_gsp_image( - image_datapipe=sat_gsp_datapipe - ) - elif gsp_in_image and "nwp" in used_datapipes.keys(): - nwp_datapipe, nwp_gsp_datapipe = nwp_datapipe.fork(2) - gsp_history = gsp_history.filter_gsp_ids(gsps_to_keep=[0]).create_gsp_image( - image_datapipe=nwp_gsp_datapipe, image_dim="osgb" - ) - if "nwp" in used_datapipes.keys(): - modalities.append(nwp_datapipe) - if "hrv" in used_datapipes.keys(): - modalities.append(sat_hrv_datapipe) - if "sat" in used_datapipes.keys(): - modalities.append(sat_datapipe) - if "topo" in used_datapipes.keys(): - modalities.append(topo_datapipe) - if gsp_in_image: - modalities.append(gsp_history) - - gsp_datapipe, gsp_loc_datapipe = gsp_datapipe.fork(2, buffer_size=5) - - location_datapipe = PickLocations(gsp_loc_datapipe) - - metnet_datapipe = PreProcessMetNet( - modalities, - location_datapipe=location_datapipe, - center_width=500_000, - center_height=1_000_000, - context_height=10_000_000, - context_width=10_000_000, - output_width_pixels=output_size, - output_height_pixels=output_size, - add_sun_features=use_sun, - ) - gsp_datapipe = ConvertGSPToNumpy(gsp_datapipe) - - if not gsp_in_image: - gsp_history = gsp_history.map(_select_non_nan_times) - gsp_history = ConvertGSPToNumpy(gsp_history, return_id=True) - return metnet_datapipe.zip_ocf(gsp_history, gsp_datapipe) # Makes (Inputs, Label) tuples - else: - return metnet_datapipe.zip(gsp_datapipe) diff --git a/ocf_datapipes/training/metnet/metnet_national.md b/ocf_datapipes/training/metnet/metnet_national.md deleted file mode 100644 index eaf210986..000000000 --- a/ocf_datapipes/training/metnet/metnet_national.md +++ /dev/null @@ -1,35 +0,0 @@ -# MetNet National Pipeline - -metnet_national.py is a training pipeline for loading NWP,PV,Satellite,and Topographic data and transforming it as in -the MetNet paper. - -The location is chosen using the center of the National GSP shape. Only the modalities wanted are loaded. -Then a time is chosen, and PV and NWP examples are made. - -```mermaid -graph TD - A[Load GSP] -->|Select Train/Test Times| B(Drop Regional GSP) --> A1 - C[Load NWP] --> CA[Filter] --> A1 - D[Load Satellite] --> DA[Filter] --> A1 - E[Load PV] --> EA[Filter] --> A1 - F[Load Topo] - A1[Select Joint Time Periods] - B1[Select T0 Time] - A1 --> B1 - B1 --> C1 - A1 --> C1 - B1 --> CAA - CA --> CAA[Convert to Target Time] - DA --> C1 - EA --> C1 - C1[Select Time Slice] - AA[Get Location] - B --> AA - A11[PreProcess MetNet] - C1 --> A11 - CAA --> A11 - F --> A11 - AA --> A11 - A111[Return Example] - A11 --> A111 -``` diff --git a/ocf_datapipes/training/metnet/metnet_national.py b/ocf_datapipes/training/metnet/metnet_national.py deleted file mode 100644 index 60c3c867d..000000000 --- a/ocf_datapipes/training/metnet/metnet_national.py +++ /dev/null @@ -1,341 +0,0 @@ -"""Create the training/validation datapipe for training the national MetNet/-2 Model""" - -import datetime -import logging -from datetime import timedelta -from pathlib import Path -from typing import Union - -import xarray -from torch.utils.data.datapipes.datapipe import IterDataPipe - -from ocf_datapipes.config.model import Configuration -from ocf_datapipes.convert import ConvertGSPToNumpy -from ocf_datapipes.load import ( - OpenConfiguration, - OpenGSP, - OpenNWP, - OpenPVFromNetCDF, - OpenSatellite, - OpenTopography, -) -from ocf_datapipes.select import FilterGSPIDs, PickLocations -from ocf_datapipes.training.metnet.metnet_preprocessor import ( - PreProcessMetNetIterDataPipe as PreProcessMetNet, -) -from ocf_datapipes.utils.consts import RSS_MEAN, RSS_STD, UKV_MEAN, UKV_STD - -xarray.set_options(keep_attrs=True) -logger = logging.getLogger("metnet_datapipe") -logger.setLevel(logging.DEBUG) - - -def normalize_gsp(x): # So it can be pickled - """ - Normalize the GSP data - - Args: - x: Input DataArray - - Returns: - Normalized DataArray - """ - return x / x.nominal_capacity_mwp - - -def _select_non_nan_times(x): - return x.fillna(0.0) - - -def metnet_national_datapipe( - configuration_filename: Union[Path, str], - use_sun: bool = True, - use_nwp: bool = True, - use_sat: bool = True, - use_hrv: bool = True, - use_pv: bool = True, - use_topo: bool = True, - mode: str = "train", - max_num_pv_systems: int = -1, - start_time: datetime.datetime = datetime.datetime(2014, 1, 1), - end_time: datetime.datetime = datetime.datetime(2023, 1, 1), -) -> IterDataPipe: - """ - Make GSP national data pipe - - Currently only has GSP and NWP's in them - - Args: - configuration_filename: the configruation filename for the pipe - use_sun: Whether to add sun features or not - use_pv: Whether to use PV input or not - use_hrv: Whether to use HRV Satellite or not - use_sat: Whether to use non-HRV Satellite or not - use_nwp: Whether to use NWP or not - use_topo: Whether to use topographic map or not - mode: Either 'train', where random times are selected, - or 'test' or 'val' where times are sequential - max_num_pv_systems: max number of PV systems to include, <= 0 if no sampling - start_time: Start time to select on - end_time: End time to select from - - Returns: datapipe - """ - - # load configuration - config_datapipe = OpenConfiguration(configuration_filename) - configuration: Configuration = next(iter(config_datapipe)) - - # Check which modalities to use - if use_nwp: - use_nwp = True if configuration.input_data.nwp.nwp_zarr_path != "" else False - if use_pv: - use_pv = True if configuration.input_data.pv.pv_files_groups[0].pv_filename != "" else False - if use_sat: - use_sat = True if configuration.input_data.satellite.satellite_zarr_path != "" else False - if use_hrv: - use_hrv = ( - True if configuration.input_data.hrvsatellite.hrvsatellite_zarr_path != "" else False - ) - if use_topo: - use_topo = ( - True if configuration.input_data.topographic.topographic_filename != "" else False - ) - print( - f"NWP: {use_nwp} Sat: {use_sat}, HRV: {use_hrv} " - f"PV: {use_pv} Sun: {use_sun} Topo: {use_topo}" - ) - # Load GSP national data - logger.debug("Opening GSP Data") - gsp_datapipe = OpenGSP( - gsp_pv_power_zarr_path=configuration.input_data.gsp.gsp_zarr_path - ).filter_times(start_time, end_time) - - gsp_datapipe = FilterGSPIDs(gsp_datapipe, gsps_to_keep=[0]) - - logger.debug("Add t0 idx and normalize") - - gsp_datapipe, gsp_time_periods_datapipe, gsp_t0_datapipe = ( - gsp_datapipe.normalize(normalize_fn=normalize_gsp) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=30), - history_duration=timedelta(minutes=configuration.input_data.gsp.history_minutes), - ) - .fork(3) - ) - # get time periods - # get contiguous time periods - logger.debug("Getting contiguous time periods") - gsp_time_periods_datapipe = gsp_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(minutes=30), - history_duration=timedelta(minutes=configuration.input_data.gsp.history_minutes), - forecast_duration=timedelta(minutes=configuration.input_data.gsp.forecast_minutes), - ) - - secondary_datapipes = [] - - # Load NWP data - if use_nwp: - logger.debug("Opening NWP Data") - nwp_datapipe, nwp_time_periods_datapipe = ( - OpenNWP(configuration.input_data.nwp.nwp_zarr_path) - .filter_channels(configuration.input_data.nwp.nwp_channels) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(hours=1), - history_duration=timedelta(minutes=configuration.input_data.nwp.history_minutes), - ) - .fork(2) - ) - - nwp_time_periods_datapipe = nwp_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(hours=3), # Init times are 3 hours apart - history_duration=timedelta(minutes=configuration.input_data.nwp.history_minutes), - forecast_duration=timedelta(minutes=configuration.input_data.nwp.forecast_minutes), - time_dim="init_time_utc", - ) - secondary_datapipes.append(nwp_time_periods_datapipe) - - if use_sat: - logger.debug("Opening Satellite Data") - sat_datapipe, sat_time_periods_datapipe = ( - OpenSatellite(configuration.input_data.satellite.satellite_zarr_path) - .filter_channels(configuration.input_data.satellite.satellite_channels) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta( - minutes=configuration.input_data.satellite.history_minutes - ), - ) - .fork(2) - ) - - sat_time_periods_datapipe = sat_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=configuration.input_data.satellite.history_minutes), - forecast_duration=timedelta(minutes=1), - ) - secondary_datapipes.append(sat_time_periods_datapipe) - - if use_hrv: - logger.debug("Opening HRV Satellite Data") - sat_hrv_datapipe, sat_hrv_time_periods_datapipe = ( - OpenSatellite(configuration.input_data.hrvsatellite.hrvsatellite_zarr_path) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta( - minutes=configuration.input_data.hrvsatellite.history_minutes - ), - ) - .fork(2) - ) - - sat_hrv_time_periods_datapipe = ( - sat_hrv_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta( - minutes=configuration.input_data.hrvsatellite.history_minutes - ), - forecast_duration=timedelta(minutes=1), - ) - ) - secondary_datapipes.append(sat_hrv_time_periods_datapipe) - - if use_pv: - logger.debug("Opening PV") - pv_datapipe, pv_time_periods_datapipe = ( - OpenPVFromNetCDF(pv=configuration.input_data.pv) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=configuration.input_data.pv.history_minutes), - ) - .fork(2) - ) - - pv_time_periods_datapipe = pv_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=configuration.input_data.pv.history_minutes), - forecast_duration=timedelta(minutes=1), - ) - secondary_datapipes.append(pv_time_periods_datapipe) - - # find joint overlapping timer periods - logger.debug("Getting joint time periods") - overlapping_datapipe = gsp_time_periods_datapipe.filter_to_overlapping_time_periods( - secondary_datapipes=secondary_datapipes, - ) - - # select time periods - gsp_t0_datapipe = gsp_t0_datapipe.filter_time_periods(time_periods=overlapping_datapipe) - - # select t0 periods - logger.debug("Select t0 joint") - num_t0_datapipes = ( - 1 + len(secondary_datapipes) if mode == "train" else 2 + len(secondary_datapipes) - ) - t0_datapipes = gsp_t0_datapipe.pick_t0_times().fork(num_t0_datapipes) - - # take pv time slices - logger.debug("Take GSP time slices") - gsp_datapipe = gsp_datapipe.select_time_slice( - t0_datapipe=t0_datapipes[0], - history_duration=timedelta(minutes=0), - forecast_duration=timedelta(minutes=configuration.input_data.gsp.forecast_minutes), - sample_period_duration=timedelta(minutes=30), - ) - - if use_nwp: - # take nwp time slices - logger.debug("Take NWP time slices") - nwp_datapipe = nwp_datapipe.select_time_slice_nwp( - t0_datapipe=t0_datapipes[1], - sample_period_duration=timedelta(hours=1), - history_duration=timedelta(minutes=configuration.input_data.nwp.history_minutes), - forecast_duration=timedelta(minutes=configuration.input_data.nwp.forecast_minutes), - ).normalize(mean=UKV_MEAN, std=UKV_STD) - - if use_sat: - logger.debug("Take Satellite time slices") - # take sat time slices - sat_datapipe = sat_datapipe.select_time_slice( - t0_datapipe=t0_datapipes[sum([use_nwp, use_sat])], - history_duration=timedelta(minutes=configuration.input_data.satellite.history_minutes), - forecast_duration=timedelta(minutes=0), - sample_period_duration=timedelta(minutes=5), - ).normalize(mean=RSS_MEAN, std=RSS_STD) - - if use_hrv: - logger.debug("Take HRV Satellite time slices") - sat_hrv_datapipe = sat_hrv_datapipe.select_time_slice( - t0_datapipe=t0_datapipes[sum([use_nwp, use_sat, use_hrv])], - history_duration=timedelta( - minutes=configuration.input_data.hrvsatellite.history_minutes - ), - forecast_duration=timedelta(minutes=0), - sample_period_duration=timedelta(minutes=5), - ).normalize(mean=RSS_MEAN.sel(channel="HRV"), std=RSS_STD.sel(channel="HRV")) - - if use_pv: - logger.debug("Take PV Time Slices") - # take pv time slices - pv_datapipe = pv_datapipe.select_time_slice( - t0_datapipe=t0_datapipes[sum([use_nwp, use_sat, use_hrv, use_pv])], - history_duration=timedelta(minutes=configuration.input_data.pv.history_minutes), - forecast_duration=timedelta(minutes=0), - sample_period_duration=timedelta(minutes=5), - ) - - if use_hrv: - image_datapipe = OpenSatellite( - configuration.input_data.hrvsatellite.hrvsatellite_zarr_path - ) - elif use_sat: - image_datapipe = OpenSatellite(configuration.input_data.satellite.satellite_zarr_path) - elif use_nwp: - image_datapipe = OpenNWP(configuration.input_data.nwp.nwp_zarr_path) - - pv_datapipe = pv_datapipe.create_pv_image( - image_datapipe, - normalize=True, - max_num_pv_systems=max_num_pv_systems, - ) - - if use_topo: - topo_datapipe = OpenTopography( - configuration.input_data.topographic.topographic_filename - ).map(_select_non_nan_times) - - # Now combine in the MetNet format - modalities = [] - if use_nwp: - modalities.append(nwp_datapipe) - if use_hrv: - modalities.append(sat_hrv_datapipe) - if use_sat: - modalities.append(sat_datapipe) - if use_pv: - modalities.append(pv_datapipe) - if use_topo: - modalities.append(topo_datapipe) - - gsp_datapipe, gsp_loc_datapipe = gsp_datapipe.fork(2) - - location_datapipe = PickLocations(gsp_loc_datapipe) - - combined_datapipe = PreProcessMetNet( - modalities, - location_datapipe=location_datapipe, - center_width=500_000, - center_height=1_000_000, - context_height=10_000_000, - context_width=10_000_000, - output_width_pixels=256, - output_height_pixels=256, - add_sun_features=use_sun, - ) - - gsp_datapipe = ConvertGSPToNumpy(gsp_datapipe) - if mode == "train": - return combined_datapipe.zip_ocf(gsp_datapipe) # Makes (Inputs, Label) tuples - else: - start_time_datapipe = t0_datapipes[len(t0_datapipes) - 1] # The one extra one - return combined_datapipe.zip_ocf(gsp_datapipe, start_time_datapipe) diff --git a/ocf_datapipes/training/metnet/metnet_national_class.py b/ocf_datapipes/training/metnet/metnet_national_class.py deleted file mode 100644 index 69bcbd4e8..000000000 --- a/ocf_datapipes/training/metnet/metnet_national_class.py +++ /dev/null @@ -1,330 +0,0 @@ -"""Create the training/validation datapipe for training the national MetNet/-2 Model""" - -import datetime -import logging -from datetime import timedelta -from pathlib import Path -from typing import Union - -import xarray -from torch.utils.data.datapipes.datapipe import IterDataPipe - -from ocf_datapipes.config.model import Configuration -from ocf_datapipes.convert import ConvertGSPToNumpy -from ocf_datapipes.load import ( - OpenConfiguration, - OpenGSP, - OpenNWP, - OpenPVFromNetCDF, - OpenSatellite, - OpenTopography, -) -from ocf_datapipes.select import FilterGSPIDs, PickLocations -from ocf_datapipes.training.metnet.metnet_preprocessor import ( - PreProcessMetNetIterDataPipe as PreProcessMetNet, -) -from ocf_datapipes.utils.consts import RSS_MEAN, RSS_STD, UKV_MEAN, UKV_STD - -xarray.set_options(keep_attrs=True) -logger = logging.getLogger("metnet_datapipe") -logger.setLevel(logging.DEBUG) - - -def _select_non_nan_times(x): - return x.fillna(0.0) - - -def metnet_national_datapipe( - configuration_filename: Union[Path, str], - use_sun: bool = True, - use_nwp: bool = True, - use_sat: bool = True, - use_hrv: bool = True, - use_pv: bool = True, - use_topo: bool = True, - mode: str = "train", - max_num_pv_systems: int = -1, - start_time: datetime.datetime = datetime.datetime(2014, 1, 1), - end_time: datetime.datetime = datetime.datetime(2023, 1, 1), -) -> IterDataPipe: - """ - Make GSP national data pipe - - Currently only has GSP and NWP's in them - - Args: - configuration_filename: the configruation filename for the pipe - use_sun: Whether to add sun features or not - use_pv: Whether to use PV input or not - use_hrv: Whether to use HRV Satellite or not - use_sat: Whether to use non-HRV Satellite or not - use_nwp: Whether to use NWP or not - use_topo: Whether to use topographic map or not - mode: Either 'train', where random times are selected, - or 'test' or 'val' where times are sequential - max_num_pv_systems: max number of PV systems to include, <= 0 if no sampling - start_time: Start time to select on - end_time: End time to select from - - Returns: datapipe - """ - - # load configuration - config_datapipe = OpenConfiguration(configuration_filename) - configuration: Configuration = next(iter(config_datapipe)) - - # Check which modalities to use - if use_nwp: - use_nwp = True if configuration.input_data.nwp.nwp_zarr_path != "" else False - if use_pv: - use_pv = True if configuration.input_data.pv.pv_files_groups[0].pv_filename != "" else False - if use_sat: - use_sat = True if configuration.input_data.satellite.satellite_zarr_path != "" else False - if use_hrv: - use_hrv = ( - True if configuration.input_data.hrvsatellite.hrvsatellite_zarr_path != "" else False - ) - if use_topo: - use_topo = ( - True if configuration.input_data.topographic.topographic_filename != "" else False - ) - print( - f"NWP: {use_nwp} Sat: {use_sat}, HRV: {use_hrv} " - f"PV: {use_pv} Sun: {use_sun} Topo: {use_topo}" - ) - # Load GSP national data - logger.debug("Opening GSP Data") - gsp_datapipe = OpenGSP( - gsp_pv_power_zarr_path=configuration.input_data.gsp.gsp_zarr_path - ).filter_times(start_time, end_time) - - gsp_datapipe = FilterGSPIDs(gsp_datapipe, gsps_to_keep=[0]) - - logger.debug("Add t0 idx and normalize") - - ( - gsp_datapipe, - gsp_time_periods_datapipe, - gsp_t0_datapipe, - ) = gsp_datapipe.add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=30), - history_duration=timedelta(minutes=configuration.input_data.gsp.history_minutes), - ).fork( - 3 - ) - # get time periods - # get contiguous time periods - logger.debug("Getting contiguous time periods") - gsp_time_periods_datapipe = gsp_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(minutes=30), - history_duration=timedelta(minutes=configuration.input_data.gsp.history_minutes), - forecast_duration=timedelta(minutes=configuration.input_data.gsp.forecast_minutes), - ) - - secondary_datapipes = [] - - # Load NWP data - if use_nwp: - logger.debug("Opening NWP Data") - nwp_datapipe, nwp_time_periods_datapipe = ( - OpenNWP(configuration.input_data.nwp.nwp_zarr_path) - .filter_channels(configuration.input_data.nwp.nwp_channels) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(hours=1), - history_duration=timedelta(minutes=configuration.input_data.nwp.history_minutes), - ) - .fork(2) - ) - - nwp_time_periods_datapipe = nwp_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(hours=3), # Init times are 3 hours apart - history_duration=timedelta(minutes=configuration.input_data.nwp.history_minutes), - forecast_duration=timedelta(minutes=configuration.input_data.nwp.forecast_minutes), - time_dim="init_time_utc", - ) - secondary_datapipes.append(nwp_time_periods_datapipe) - - if use_sat: - logger.debug("Opening Satellite Data") - sat_datapipe, sat_time_periods_datapipe = ( - OpenSatellite(configuration.input_data.satellite.satellite_zarr_path) - .filter_channels(configuration.input_data.satellite.satellite_channels) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta( - minutes=configuration.input_data.satellite.history_minutes - ), - ) - .fork(2) - ) - - sat_time_periods_datapipe = sat_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=configuration.input_data.satellite.history_minutes), - forecast_duration=timedelta(minutes=1), - ) - secondary_datapipes.append(sat_time_periods_datapipe) - - if use_hrv: - logger.debug("Opening HRV Satellite Data") - sat_hrv_datapipe, sat_hrv_time_periods_datapipe = ( - OpenSatellite(configuration.input_data.hrvsatellite.hrvsatellite_zarr_path) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta( - minutes=configuration.input_data.hrvsatellite.history_minutes - ), - ) - .fork(2) - ) - - sat_hrv_time_periods_datapipe = ( - sat_hrv_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta( - minutes=configuration.input_data.hrvsatellite.history_minutes - ), - forecast_duration=timedelta(minutes=1), - ) - ) - secondary_datapipes.append(sat_hrv_time_periods_datapipe) - - if use_pv: - logger.debug("Opening PV") - pv_datapipe, pv_time_periods_datapipe = ( - OpenPVFromNetCDF(pv=configuration.input_data.pv) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=configuration.input_data.pv.history_minutes), - ) - .fork(2) - ) - - pv_time_periods_datapipe = pv_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=configuration.input_data.pv.history_minutes), - forecast_duration=timedelta(minutes=1), - ) - secondary_datapipes.append(pv_time_periods_datapipe) - - # find joint overlapping timer periods - logger.debug("Getting joint time periods") - overlapping_datapipe = gsp_time_periods_datapipe.filter_to_overlapping_time_periods( - secondary_datapipes=secondary_datapipes, - ) - - # select time periods - gsp_t0_datapipe = gsp_t0_datapipe.filter_time_periods(time_periods=overlapping_datapipe) - - # select t0 periods - logger.debug("Select t0 joint") - num_t0_datapipes = ( - 1 + len(secondary_datapipes) if mode == "train" else 2 + len(secondary_datapipes) - ) - t0_datapipes = gsp_t0_datapipe.pick_t0_times().fork(num_t0_datapipes) - - # take pv time slices - logger.debug("Take GSP time slices") - gsp_datapipe = gsp_datapipe.select_time_slice( - t0_datapipe=t0_datapipes[0], - history_duration=timedelta(minutes=0), - forecast_duration=timedelta(minutes=configuration.input_data.gsp.forecast_minutes), - sample_period_duration=timedelta(minutes=30), - ) - - if use_nwp: - # take nwp time slices - logger.debug("Take NWP time slices") - nwp_datapipe = nwp_datapipe.select_time_slice_nwp( - t0_datapipe=t0_datapipes[1], - sample_period_duration=timedelta(hours=1), - history_duration=timedelta(minutes=configuration.input_data.nwp.history_minutes), - forecast_duration=timedelta(minutes=configuration.input_data.nwp.forecast_minutes), - ).normalize(mean=UKV_MEAN, std=UKV_STD) - - if use_sat: - logger.debug("Take Satellite time slices") - # take sat time slices - sat_datapipe = sat_datapipe.select_time_slice( - t0_datapipe=t0_datapipes[sum([use_nwp, use_sat])], - history_duration=timedelta(minutes=configuration.input_data.satellite.history_minutes), - forecast_duration=timedelta(minutes=0), - sample_period_duration=timedelta(minutes=5), - ).normalize(mean=RSS_MEAN, std=RSS_STD) - - if use_hrv: - logger.debug("Take HRV Satellite time slices") - sat_hrv_datapipe = sat_hrv_datapipe.select_time_slice( - t0_datapipe=t0_datapipes[sum([use_nwp, use_sat, use_hrv])], - history_duration=timedelta( - minutes=configuration.input_data.hrvsatellite.history_minutes - ), - forecast_duration=timedelta(minutes=0), - sample_period_duration=timedelta(minutes=5), - ).normalize(mean=RSS_MEAN.sel(channel="HRV"), std=RSS_STD.sel(channel="HRV")) - - if use_pv: - logger.debug("Take PV Time Slices") - # take pv time slices - pv_datapipe = pv_datapipe.select_time_slice( - t0_datapipe=t0_datapipes[sum([use_nwp, use_sat, use_hrv, use_pv])], - history_duration=timedelta(minutes=configuration.input_data.pv.history_minutes), - forecast_duration=timedelta(minutes=0), - sample_period_duration=timedelta(minutes=5), - ) - - if use_hrv: - image_datapipe = OpenSatellite( - configuration.input_data.hrvsatellite.hrvsatellite_zarr_path - ) - elif use_sat: - image_datapipe = OpenSatellite(configuration.input_data.satellite.satellite_zarr_path) - elif use_nwp: - image_datapipe = OpenNWP(configuration.input_data.nwp.nwp_zarr_path) - - pv_datapipe = pv_datapipe.create_pv_image( - image_datapipe, - normalize=True, - max_num_pv_systems=max_num_pv_systems, - ) - - if use_topo: - topo_datapipe = OpenTopography( - configuration.input_data.topographic.topographic_filename - ).map(_select_non_nan_times) - - # Now combine in the MetNet format - modalities = [] - if use_nwp: - modalities.append(nwp_datapipe) - if use_hrv: - modalities.append(sat_hrv_datapipe) - if use_sat: - modalities.append(sat_datapipe) - if use_pv: - modalities.append(pv_datapipe) - if use_topo: - modalities.append(topo_datapipe) - - gsp_datapipe, gsp_loc_datapipe = gsp_datapipe.fork(2) - - location_datapipe = PickLocations(gsp_loc_datapipe) - - combined_datapipe = PreProcessMetNet( - modalities, - location_datapipe=location_datapipe, - center_width=500_000, - center_height=1_000_000, - context_height=10_000_000, - context_width=10_000_000, - output_width_pixels=256, - output_height_pixels=256, - add_sun_features=use_sun, - ) - - gsp_datapipe = ConvertGSPToNumpy(gsp_datapipe) - if mode == "train": - return combined_datapipe.zip_ocf(gsp_datapipe) # Makes (Inputs, Label) tuples - else: - start_time_datapipe = t0_datapipes[len(t0_datapipes) - 1] # The one extra one - return combined_datapipe.zip_ocf(gsp_datapipe, start_time_datapipe) diff --git a/ocf_datapipes/training/metnet/metnet_preprocessor.py b/ocf_datapipes/training/metnet/metnet_preprocessor.py deleted file mode 100644 index cb474cc7f..000000000 --- a/ocf_datapipes/training/metnet/metnet_preprocessor.py +++ /dev/null @@ -1,351 +0,0 @@ -"""Preprocessing for MetNet-type inputs""" - -import itertools -from typing import List - -import numpy as np -import pvlib -import xarray as xr -from torch.utils.data import IterDataPipe, functional_datapipe - -from ocf_datapipes.select.select_spatial_slice import convert_coords_to_match_xarray -from ocf_datapipes.utils import Zipper -from ocf_datapipes.utils.consts import ( - AZIMUTH_MEAN, - AZIMUTH_STD, - ELEVATION_MEAN, - ELEVATION_STD, -) -from ocf_datapipes.utils.geospatial import ( - geostationary_area_coords_to_lonlat, - move_lon_lat_by_meters, - osgb_to_lon_lat, - spatial_coord_type, -) -from ocf_datapipes.utils.parallel import run_with_threadpool -from ocf_datapipes.utils.utils import trigonometric_datetime_transformation - - -@functional_datapipe("preprocess_metnet") -class PreProcessMetNetIterDataPipe(IterDataPipe): - """Preprocess set of Xarray datasets similar to MetNet-1""" - - def __init__( - self, - source_datapipes: List[IterDataPipe], - location_datapipe: IterDataPipe, - context_width: int, - context_height: int, - center_width: int, - center_height: int, - output_height_pixels: int, - output_width_pixels: int, - add_sun_features: bool = False, - only_sun: bool = False, - ): - """ - - Processes set of Xarray datasets similar to MetNet - - In terms of taking all available source datapipes: - 1. selecting the same context area of interest - 2. Creating a center crop of the center_height, center_width - 3. Downsampling the context area of interest to the same shape as the center crop - 4. Stacking those context images on the center crop. - 5. Add Month, Day, Hour channels for each input time - 6. Add Sun position as well? - - This would be designed originally for NWP+Satellite+Topographic data sources. - To add the PV power for lots of sites, the PV power would - need to be able to be on a grid for the context/center - crops and then for the downsample - - This also appends Lat/Lon coordinates to the stack, - and returns a new Numpy array with the stacked data - - Args: - source_datapipes: Datapipes that emit xarray datasets - with latitude/longitude coordinates included - location_datapipe: Datapipe emitting location coordinate for center of example - context_width: Width of the context area - context_height: Height of the context area - center_width: Center width of the area of interest - center_height: Center height of the area of interest - output_height_pixels: Output height in pixels - output_width_pixels: Output width in pixels - add_sun_features: Whether to calculate and - add Sun elevation and azimuth for each center pixel - only_sun: Whether to only output sun features - Assumes only one input to give the coordinates - """ - self.source_datapipes = source_datapipes - self.location_datapipe = location_datapipe - self.context_width = context_width - self.context_height = context_height - self.center_width = center_width - self.center_height = center_height - self.output_height_pixels = output_height_pixels - self.output_width_pixels = output_width_pixels - self.add_sun_features = add_sun_features - self.only_sun = only_sun - - def __iter__(self) -> np.ndarray: - for xr_datas, location in Zipper(Zipper(*self.source_datapipes), self.location_datapipe): - # TODO Use the Lat/Long coordinates of the center array for the lat/lon stuff - # Do the resampling and cropping in parallel - xr_datas = run_with_threadpool( - zip( - _bicycle(xr_datas), - itertools.repeat(location), - itertools.chain.from_iterable( - zip( - itertools.repeat(self.center_width), - itertools.repeat(self.context_width), - ) - ), - itertools.chain.from_iterable( - zip( - itertools.repeat(self.center_height), - itertools.repeat(self.context_height), - ) - ), - itertools.repeat(self.output_height_pixels), - itertools.repeat(self.output_width_pixels), - ), - _crop_and_resample_wrapper, - max_workers=8, - scheduled_tasks=int(len(xr_datas) * 2), # One for center, one for context - ) - xr_datas = list(xr_datas) - # Output is then list of center, context, center, context, etc. - # So we need to split the list into two lists of the same length, - # one with centers, one with contexts - centers = xr_datas[::2] - contexts = xr_datas[1::2] - # Now do the first one for the sun and other features - xr_center = centers[0] - _extra_time_dim = ( - "target_time_utc" if "target_time_utc" in xr_center.dims else "time_utc" - ) - # Add in time features for each timestep - time_image = _create_time_image( - xr_center, - time_dim=_extra_time_dim, - output_height_pixels=self.output_height_pixels, - output_width_pixels=self.output_width_pixels, - ) - contexts.append(time_image) - # Need to add sun features - if self.add_sun_features: - sun_image = _create_sun_image( - image_xr=xr_center, - x_dim="x_osgb" if "x_osgb" in xr_center.dims else "x_geostationary", - y_dim="y_osgb" if "y_osgb" in xr_center.dims else "y_geostationary", - time_dim=_extra_time_dim, - normalize=True, - ) - if self.only_sun: - contexts = [time_image, sun_image] - else: - contexts.append(sun_image) - for xr_index in range(len(centers)): - xr_center = centers[xr_index] - xr_context = contexts[xr_index] - xr_center = xr_center.to_numpy() - xr_context = xr_context.to_numpy() - if len(xr_center.shape) == 2: # Need to add channel dimension - xr_center = np.expand_dims(xr_center, axis=0) - xr_context = np.expand_dims(xr_context, axis=0) - if len(xr_center.shape) == 3: # Need to add channel dimension - xr_center = np.expand_dims(xr_center, axis=1) - xr_context = np.expand_dims(xr_context, axis=1) - centers[xr_index] = xr_center - contexts[xr_index] = xr_context - # Pad out time dimension to be the same, using the largest one - # All should have 4 dimensions at this point - max_time_len = max( - np.max([c.shape[0] for c in centers]), np.max([c.shape[0] for c in contexts]) - ) - for i in range(len(centers)): - centers[i] = np.pad( - centers[i], - pad_width=( - (0, max_time_len - centers[i].shape[0]), - (0, 0), - (0, 0), - (0, 0), - ), - mode="constant", - constant_values=0.0, - ) - for i in range(len(contexts)): - contexts[i] = np.pad( - contexts[i], - pad_width=( - (0, max_time_len - contexts[i].shape[0]), - (0, 0), - (0, 0), - (0, 0), - ), - mode="constant", - constant_values=0.0, - ) - stacked_data = np.concatenate([*centers, *contexts], axis=1) - yield stacked_data - - -def _crop_and_resample_wrapper(args): - return _crop_and_resample(*args) - - -def _bicycle(xr_datas): - for xr_data in xr_datas: - yield xr_data - yield xr_data - - -def _crop_and_resample( - xr_data: xr.Dataset, - location, - context_width, - context_height, - output_height_pixels, - output_width_pixels, -): - xr_context: xr.Dataset = _get_spatial_crop( - xr_data, - location=location, - roi_width_meters=context_width, - roi_height_meters=context_height, - ) - - # Resamples to the same number of pixels for both center and contexts - xr_context = _resample_to_pixel_size(xr_context, output_height_pixels, output_width_pixels) - return xr_context - - -def _get_spatial_crop(xr_data, location, roi_height_meters: int, roi_width_meters: int): - xr_coords, xr_x_dim, xr_y_dim = spatial_coord_type(xr_data) - - # Compute the index for left and right: - half_height = roi_height_meters // 2 - half_width = roi_width_meters // 2 - - # Find the bounding box values for the location in either lon-lat or OSGB coord systems - if location.coordinate_system == "lon_lat": - right, top = move_lon_lat_by_meters( - location.x, - location.y, - half_width, - half_height, - ) - left, bottom = move_lon_lat_by_meters( - location.x, - location.y, - -half_width, - -half_height, - ) - - elif location.coordinate_system == "osgb": - left = location.x - half_width - right = location.x + half_width - bottom = location.y - half_height - top = location.y + half_height - - else: - raise ValueError(f"Location coord system not recognized: {location.coordinate_system}") - - (left, right), (bottom, top) = convert_coords_to_match_xarray( - x=np.array([left, right], dtype=np.float32), - y=np.array([bottom, top], dtype=np.float32), - from_coords=location.coordinate_system, - xr_data=xr_data, - ) - - # Select a patch from the xarray data - x_mask = (left <= xr_data[xr_x_dim]) & (xr_data[xr_x_dim] <= right) - y_mask = (bottom <= xr_data[xr_y_dim]) & (xr_data[xr_y_dim] <= top) - selected = xr_data.isel({xr_x_dim: x_mask, xr_y_dim: y_mask}) - - return selected - - -def _resample_to_pixel_size(xr_data, height_pixels, width_pixels) -> np.ndarray: - if "x_geostationary" in xr_data.dims: - x_coords = xr_data["x_geostationary"].values - y_coords = xr_data["y_geostationary"].values - elif "x_osgb" in xr_data.dims: - x_coords = xr_data["x_osgb"].values - y_coords = xr_data["y_osgb"].values - else: - x_coords = xr_data["x"].values - y_coords = xr_data["y"].values - # Resample down to the number of pixels wanted - x_coords = np.linspace(x_coords[0], x_coords[-1], num=width_pixels) - y_coords = np.linspace(y_coords[0], y_coords[-1], num=height_pixels) - if "x_geostationary" in xr_data.dims: - xr_data = xr_data.interp( - x_geostationary=x_coords, y_geostationary=y_coords, method="linear" - ) - elif "x_osgb" in xr_data.dims: - xr_data = xr_data.interp(x_osgb=x_coords, y_osgb=y_coords, method="linear") - else: - xr_data = xr_data.interp(x=x_coords, y=y_coords, method="linear") - # Extract just the data now - return xr_data - - -def _create_time_image(xr_data, time_dim: str, output_height_pixels: int, output_width_pixels: int): - # Create trig decomposition of datetime values, tiled over output height and width - datetimes = xr_data[time_dim].values - trig_decomposition = trigonometric_datetime_transformation(datetimes) - tiled_data = np.expand_dims(trig_decomposition, (2, 3)) - tiled_data = np.tile(tiled_data, (1, 1, output_height_pixels, output_width_pixels)) - return tiled_data - - -def _create_sun_image(image_xr, x_dim, y_dim, time_dim, normalize): - # Create empty image to use for the PV Systems, assumes image has x and y coordinates - sun_image = np.zeros( - ( - 2, # Azimuth and elevation - len(image_xr[y_dim]), - len(image_xr[x_dim]), - len(image_xr[time_dim]), - ), - dtype=np.float32, - ) - if "geostationary" in x_dim: - lons, lats = geostationary_area_coords_to_lonlat( - x=image_xr[x_dim].values, y=image_xr[y_dim].values, xr_data=image_xr - ) - else: - lons, lats = osgb_to_lon_lat(x=image_xr.x_osgb.values, y=image_xr.y_osgb.values) - time_utc = image_xr[time_dim].values - - # Loop round each example to get the Sun's elevation and azimuth: - # Go through each time on its own, lat lons still in order of image - # TODO Make this faster - # dt = pd.DatetimeIndex(dt) # pvlib expects a `pd.DatetimeIndex`. - for example_idx, (lat, lon) in enumerate(zip(lats, lons)): - solpos = pvlib.solarposition.get_solarposition( - time=time_utc, - latitude=lat, - longitude=lon, - # Which `method` to use? - # pyephem seemed to be a good mix between speed and ease but causes segfaults! - # nrel_numba doesn't work when using multiple worker processes. - # nrel_c is probably fastest but requires C code to be manually compiled: - # https://midcdmz.nrel.gov/spa/ - ) - sun_image[0][:][example_idx] = solpos["azimuth"] - sun_image[1][example_idx][:] = solpos["elevation"] - - # Flip back to normal ordering - sun_image = np.transpose(sun_image, [3, 0, 1, 2]) - - # Normalize. - if normalize: - sun_image[:, 0] = (sun_image[:, 0] - AZIMUTH_MEAN) / AZIMUTH_STD - sun_image[:, 1] = (sun_image[:, 1] - ELEVATION_MEAN) / ELEVATION_STD - return sun_image diff --git a/ocf_datapipes/training/metnet/metnet_pv_national.py b/ocf_datapipes/training/metnet/metnet_pv_national.py deleted file mode 100644 index ae088e2cb..000000000 --- a/ocf_datapipes/training/metnet/metnet_pv_national.py +++ /dev/null @@ -1,346 +0,0 @@ -"""Create the training/validation datapipe for training the national MetNet/-2 Model""" - -import datetime -import logging -from datetime import timedelta -from pathlib import Path -from typing import Union - -import xarray -from torch.utils.data.datapipes.datapipe import IterDataPipe - -from ocf_datapipes.config.model import Configuration -from ocf_datapipes.convert import ConvertGSPToNumpy -from ocf_datapipes.load import ( - OpenConfiguration, - OpenGSP, - OpenNWP, - OpenPVFromNetCDF, - OpenSatellite, - OpenTopography, -) -from ocf_datapipes.select import FilterGSPIDs, PickLocations -from ocf_datapipes.training.metnet.metnet_preprocessor import ( - PreProcessMetNetIterDataPipe as PreProcessMetNet, -) -from ocf_datapipes.utils.consts import RSS_MEAN, RSS_STD, UKV_MEAN, UKV_STD - -xarray.set_options(keep_attrs=True) -logger = logging.getLogger("metnet_datapipe") -logger.setLevel(logging.DEBUG) - - -def normalize_gsp(x): # So it can be pickled - """ - Normalize the GSP data - - Args: - x: Input DataArray - - Returns: - Normalized DataArray - """ - return x / x.nominal_capacity_mwp - - -def normalize_pv(x): # So it can be pickled - """ - Normalize the GSP data - - Args: - x: Input DataArray - - Returns: - Normalized DataArray - """ - return x / x.observed_capacity_wp - - -def _select_non_nan_times(x): - return x.fillna(0.0) - - -def metnet_national_datapipe( - configuration_filename: Union[Path, str], - use_sun: bool = True, - use_nwp: bool = True, - use_sat: bool = True, - use_hrv: bool = True, - use_pv: bool = True, - use_topo: bool = True, - mode: str = "train", - max_num_pv_systems: int = -1, - start_time: datetime.datetime = datetime.datetime(2014, 1, 1), - end_time: datetime.datetime = datetime.datetime(2023, 1, 1), -) -> IterDataPipe: - """ - Make GSP national data pipe - - Currently only has GSP and NWP's in them - - Args: - configuration_filename: the configruation filename for the pipe - use_sun: Whether to add sun features or not - use_pv: Whether to use PV input or not - use_hrv: Whether to use HRV Satellite or not - use_sat: Whether to use non-HRV Satellite or not - use_nwp: Whether to use NWP or not - use_topo: Whether to use topographic map or not - mode: Either 'train', where random times are selected, - or 'test' or 'val' where times are sequential - max_num_pv_systems: max number of PV systems to include, <= 0 if no sampling - start_time: Start time to select on - end_time: End time to select from - - Returns: datapipe - """ - - # load configuration - config_datapipe = OpenConfiguration(configuration_filename) - configuration: Configuration = next(iter(config_datapipe)) - - # Check which modalities to use - if use_nwp: - use_nwp = True if configuration.input_data.nwp.nwp_zarr_path != "" else False - if use_pv: - use_pv = True if configuration.input_data.pv.pv_files_groups[0].pv_filename != "" else False - if use_sat: - use_sat = True if configuration.input_data.satellite.satellite_zarr_path != "" else False - if use_hrv: - use_hrv = ( - True if configuration.input_data.hrvsatellite.hrvsatellite_zarr_path != "" else False - ) - if use_topo: - use_topo = ( - True if configuration.input_data.topographic.topographic_filename != "" else False - ) - print( - f"NWP: {use_nwp} Sat: {use_sat}, HRV: {use_hrv}" - f" PV: {use_pv} Sun: {use_sun} Topo: {use_topo}" - ) - # Load GSP national data - logger.debug("Opening GSP Data") - gsp_datapipe = OpenGSP( - gsp_pv_power_zarr_path=configuration.input_data.gsp.gsp_zarr_path - ).filter_times(start_time, end_time) - - gsp_datapipe = FilterGSPIDs(gsp_datapipe, gsps_to_keep=[0]) - - logger.debug("Add t0 idx and normalize") - - gsp_datapipe, gsp_time_periods_datapipe, gsp_t0_datapipe = ( - gsp_datapipe.normalize(normalize_fn=normalize_gsp) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=30), - history_duration=timedelta(minutes=configuration.input_data.gsp.history_minutes), - ) - .fork(3) - ) - # get time periods - # get contiguous time periods - logger.debug("Getting contiguous time periods") - gsp_time_periods_datapipe = gsp_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(minutes=30), - history_duration=timedelta(minutes=configuration.input_data.gsp.history_minutes), - forecast_duration=timedelta(minutes=configuration.input_data.gsp.forecast_minutes), - ) - - secondary_datapipes = [] - - # Load NWP data - if use_nwp: - logger.debug("Opening NWP Data") - nwp_datapipe, nwp_time_periods_datapipe = ( - OpenNWP(configuration.input_data.nwp.nwp_zarr_path) - .filter_channels(configuration.input_data.nwp.nwp_channels) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(hours=1), - history_duration=timedelta(minutes=configuration.input_data.nwp.history_minutes), - ) - .fork(2) - ) - - nwp_time_periods_datapipe = nwp_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(hours=3), # Init times are 3 hours apart - history_duration=timedelta(minutes=configuration.input_data.nwp.history_minutes), - forecast_duration=timedelta(minutes=configuration.input_data.nwp.forecast_minutes), - time_dim="init_time_utc", - ) - secondary_datapipes.append(nwp_time_periods_datapipe) - - if use_sat: - logger.debug("Opening Satellite Data") - sat_datapipe, sat_time_periods_datapipe = ( - OpenSatellite(configuration.input_data.satellite.satellite_zarr_path) - .filter_channels(configuration.input_data.satellite.satellite_channels) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta( - minutes=configuration.input_data.satellite.history_minutes - ), - ) - .fork(2) - ) - - sat_time_periods_datapipe = sat_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=configuration.input_data.satellite.history_minutes), - forecast_duration=timedelta(minutes=1), - ) - secondary_datapipes.append(sat_time_periods_datapipe) - - if use_hrv: - logger.debug("Opening HRV Satellite Data") - sat_hrv_datapipe, sat_hrv_time_periods_datapipe = ( - OpenSatellite(configuration.input_data.hrvsatellite.hrvsatellite_zarr_path) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta( - minutes=configuration.input_data.hrvsatellite.history_minutes - ), - ) - .fork(2) - ) - - sat_hrv_time_periods_datapipe = ( - sat_hrv_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta( - minutes=configuration.input_data.hrvsatellite.history_minutes - ), - forecast_duration=timedelta(minutes=1), - ) - ) - secondary_datapipes.append(sat_hrv_time_periods_datapipe) - - if use_pv: - logger.debug("Opening PV") - pv_datapipe, pv_time_periods_datapipe = ( - OpenPVFromNetCDF(pv=configuration.input_data.pv) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=configuration.input_data.pv.history_minutes), - ) - .fork(2) - ) - - pv_time_periods_datapipe = pv_time_periods_datapipe.find_contiguous_t0_time_periods( - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=configuration.input_data.pv.history_minutes), - forecast_duration=timedelta(minutes=1), - ) - secondary_datapipes.append(pv_time_periods_datapipe) - - # find joint overlapping timer periods - logger.debug("Getting joint time periods") - overlapping_datapipe = gsp_time_periods_datapipe.filter_to_overlapping_time_periods( - secondary_datapipes=secondary_datapipes, - ) - - # select time periods - gsp_t0_datapipe = gsp_t0_datapipe.filter_time_periods(time_periods=overlapping_datapipe) - - # select t0 periods - logger.debug("Select t0 joint") - num_t0_datapipes = ( - 1 + len(secondary_datapipes) if mode == "train" else 2 + len(secondary_datapipes) - ) - t0_datapipes = gsp_t0_datapipe.pick_t0_times( - return_all_times=False # if mode == "train" else True - ).fork(num_t0_datapipes) - - # take pv time slices - logger.debug("Take GSP time slices") - gsp_datapipe = gsp_datapipe.select_time_slice( - t0_datapipe=t0_datapipes[0], - history_duration=timedelta(minutes=0), - forecast_duration=timedelta(minutes=configuration.input_data.gsp.forecast_minutes), - sample_period_duration=timedelta(minutes=30), - ) - - if use_nwp: - # take nwp time slices - logger.debug("Take NWP time slices") - nwp_datapipe = nwp_datapipe.select_time_slice_nwp( - t0_datapipe=t0_datapipes[1], - sample_period_duration=timedelta(hours=1), - history_duration=timedelta(minutes=configuration.input_data.nwp.history_minutes), - forecast_duration=timedelta(minutes=configuration.input_data.nwp.forecast_minutes), - ).normalize(mean=UKV_MEAN, std=UKV_STD) - - if use_sat: - logger.debug("Take Satellite time slices") - # take sat time slices - sat_datapipe = sat_datapipe.select_time_slice( - t0_datapipe=t0_datapipes[sum([use_nwp, use_sat])], - history_duration=timedelta(minutes=configuration.input_data.satellite.history_minutes), - forecast_duration=timedelta(minutes=0), - sample_period_duration=timedelta(minutes=5), - ).normalize(mean=RSS_MEAN, std=RSS_STD) - - if use_hrv: - logger.debug("Take HRV Satellite time slices") - sat_hrv_datapipe = sat_hrv_datapipe.select_time_slice( - t0_datapipe=t0_datapipes[sum([use_nwp, use_sat, use_hrv])], - history_duration=timedelta( - minutes=configuration.input_data.hrvsatellite.history_minutes - ), - forecast_duration=timedelta(minutes=0), - sample_period_duration=timedelta(minutes=5), - ).normalize(mean=RSS_MEAN.sel(channel="HRV"), std=RSS_STD.sel(channel="HRV")) - - if use_pv: - logger.debug("Take PV Time Slices") - # take pv time slices - pv_datapipe = pv_datapipe.normalize(normalize_fn=normalize_pv) - pv_datapipe = pv_datapipe.select_time_slice( - t0_datapipe=t0_datapipes[sum([use_nwp, use_sat, use_hrv, use_pv])], - history_duration=timedelta(minutes=configuration.input_data.pv.history_minutes), - forecast_duration=timedelta(minutes=0), - sample_period_duration=timedelta(minutes=5), - ) - - if use_topo: - topo_datapipe = OpenTopography( - configuration.input_data.topographic.topographic_filename - ).map(_select_non_nan_times) - - # Now combine in the MetNet format - modalities = [] - if use_nwp: - modalities.append(nwp_datapipe) - if use_hrv: - modalities.append(sat_hrv_datapipe) - if use_sat: - modalities.append(sat_datapipe) - if use_topo: - modalities.append(topo_datapipe) - - gsp_datapipe, gsp_loc_datapipe = gsp_datapipe.fork(2) - - location_datapipe = PickLocations(gsp_loc_datapipe) - - metnet_datapipe = PreProcessMetNet( - modalities, - location_datapipe=location_datapipe, - center_width=500_000, - center_height=1_000_000, - context_height=10_000_000, - context_width=10_000_000, - output_width_pixels=256, - output_height_pixels=256, - add_sun_features=use_sun, - ) - - pv_datapipe = ( - pv_datapipe.ensure_n_pv_systems_per_example(n_pv_systems_per_example=max_num_pv_systems) - .map(_select_non_nan_times) - .convert_pv_to_numpy(return_pv_system_row=True) - ) - combined_datapipe = metnet_datapipe.zip_ocf(pv_datapipe) - gsp_datapipe = ConvertGSPToNumpy(gsp_datapipe) - if mode == "train": - return combined_datapipe.zip_ocf(gsp_datapipe) # Makes (Inputs, Label) tuples - else: - start_time_datapipe = t0_datapipes[len(t0_datapipes) - 1] # The one extra one - return combined_datapipe.zip_ocf(gsp_datapipe, start_time_datapipe) diff --git a/ocf_datapipes/training/metnet/metnet_pv_site.py b/ocf_datapipes/training/metnet/metnet_pv_site.py deleted file mode 100644 index 7878a5aa2..000000000 --- a/ocf_datapipes/training/metnet/metnet_pv_site.py +++ /dev/null @@ -1,213 +0,0 @@ -"""Create the training/validation datapipe for training the national MetNet/-2 Model""" - -import datetime -import logging -from pathlib import Path -from typing import Union - -import xarray -from torch.utils.data.datapipes.datapipe import IterDataPipe - -from ocf_datapipes.convert import ConvertPVToNumpy -from ocf_datapipes.select import PickLocations -from ocf_datapipes.training.common import ( - add_selected_time_slices_from_datapipes, - get_and_return_overlapping_time_periods_and_t0, - open_and_return_datapipes, -) -from ocf_datapipes.training.metnet.metnet_preprocessor import ( - PreProcessMetNetIterDataPipe as PreProcessMetNet, -) -from ocf_datapipes.utils.consts import RSS_MEAN, RSS_STD, UKV_MEAN, UKV_STD -from ocf_datapipes.utils.future import ThreadPoolMapperIterDataPipe as ThreadPoolMapper - -xarray.set_options(keep_attrs=True) -logger = logging.getLogger("metnet_datapipe") -logger.setLevel(logging.DEBUG) - - -def normalize_pv(x): # So it can be pickled - """ - Normalize the PV data - - Args: - x: Input DataArray - - Returns: - Normalized DataArray - """ - return x / x.observed_capacity_wp - - -def _select_non_nan_times(x): - return x.fillna(0.0) - - -def _load_xarray_values(x): - return x.load() - - -def metnet_site_datapipe( - configuration_filename: Union[Path, str], - use_sun: bool = True, - use_nwp: bool = True, - use_sat: bool = True, - use_hrv: bool = True, - use_pv: bool = True, - use_topo: bool = True, - output_size: int = 256, - pv_in_image: bool = False, - start_time: datetime.datetime = datetime.datetime(2014, 1, 1), - end_time: datetime.datetime = datetime.datetime(2023, 1, 1), - center_size_meters: int = 64_000, - context_size_meters: int = 512_000, - batch_size: int = 1, -) -> IterDataPipe: - """ - Make PV data pipe - - Args: - configuration_filename: the configruation filename for the pipe - use_sun: Whether to add sun features or not - use_pv: Whether to use PV input or not - use_hrv: Whether to use HRV Satellite or not - use_sat: Whether to use non-HRV Satellite or not - use_nwp: Whether to use NWP or not - use_topo: Whether to use topographic map or not - start_time: Start time to select on - end_time: End time to select from - output_size: Size, in pixels, of the output image - pv_in_image: Add PV history as channels in MetNet image - center_size_meters: Center size for MeNet cutouts, in meters - context_size_meters: Context area size in meters - batch_size: Batch size for the datapipe - - Returns: datapipe - """ - - # load datasets - used_datapipes = open_and_return_datapipes( - configuration_filename=configuration_filename, - use_nwp=use_nwp, - use_topo=use_topo, - use_sat=use_sat, - use_hrv=use_hrv, - use_gsp=False, - use_pv=use_pv, - ) - # Load PV data - used_datapipes["pv"] = ( - used_datapipes["pv"].filter_times(start_time, end_time).pv_interpolate_infill() - ) - - # Now get overlapping time periods - used_datapipes = get_and_return_overlapping_time_periods_and_t0(used_datapipes, key_for_t0="pv") - - # And now get time slices - used_datapipes = add_selected_time_slices_from_datapipes(used_datapipes) - - # Now do the extra processing - pv_history = used_datapipes["pv"].normalize(normalize_fn=normalize_pv) - pv_datapipe = used_datapipes["pv_future"].normalize(normalize_fn=normalize_pv) - # Split into PV for target, and one for history - pv_datapipe, pv_loc_datapipe = pv_datapipe.fork(2) - pv_loc_datapipe, pv_id_datapipe = PickLocations(pv_loc_datapipe).fork(2) - pv_history = pv_history.select_id(pv_id_datapipe, data_source_name="pv") - - if "nwp" in used_datapipes.keys(): - # take nwp time slices - logger.debug("Take NWP time slices") - nwp_datapipe = used_datapipes["nwp"].normalize(mean=UKV_MEAN, std=UKV_STD) - pv_loc_datapipe, pv_nwp_image_loc_datapipe = pv_loc_datapipe.fork(2) - # context_size is the largest it would need - nwp_datapipe = nwp_datapipe.select_spatial_slice_meters( - pv_nwp_image_loc_datapipe, - roi_height_meters=context_size_meters, - roi_width_meters=context_size_meters, - dim_name=None, - ) - # Multithread the data - nwp_datapipe = ThreadPoolMapper( - nwp_datapipe, _load_xarray_values, max_workers=8, scheduled_tasks=batch_size - ) - - if "sat" in used_datapipes.keys(): - logger.debug("Take Satellite time slices") - # take sat time slices - sat_datapipe = used_datapipes["sat"].normalize(mean=RSS_MEAN, std=RSS_STD) - pv_loc_datapipe, pv_sat_image_loc_datapipe = pv_loc_datapipe.fork(2) - sat_datapipe = sat_datapipe.select_spatial_slice_meters( - pv_sat_image_loc_datapipe, - roi_height_meters=context_size_meters, - roi_width_meters=context_size_meters, - dim_name=None, - ) - sat_datapipe = ThreadPoolMapper( - sat_datapipe, _load_xarray_values, max_workers=8, scheduled_tasks=batch_size - ) - - if "hrv" in used_datapipes.keys(): - logger.debug("Take HRV Satellite time slices") - sat_hrv_datapipe = used_datapipes["hrv"].normalize(mean=RSS_MEAN, std=RSS_STD) - pv_loc_datapipe, pv_hrv_image_loc_datapipe = pv_loc_datapipe.fork(2) - sat_hrv_datapipe = sat_hrv_datapipe.select_spatial_slice_meters( - pv_hrv_image_loc_datapipe, - roi_height_meters=context_size_meters, - roi_width_meters=context_size_meters, - dim_name=None, - ) - sat_hrv_datapipe = ThreadPoolMapper( - sat_hrv_datapipe, _load_xarray_values, max_workers=8, scheduled_tasks=batch_size - ) - - if "topo" in used_datapipes.keys(): - topo_datapipe = used_datapipes["topo"].map(_select_non_nan_times) - - # Now combine in the MetNet format - modalities = [] - - if pv_in_image and "hrv" in used_datapipes.keys(): - sat_hrv_datapipe, sat_pv_datapipe = sat_hrv_datapipe.fork(2) - pv_history = pv_history.create_pv_history_image(image_datapipe=sat_pv_datapipe) - elif pv_in_image and "sat" in used_datapipes.keys(): - sat_datapipe, sat_pv_datapipe = sat_datapipe.fork(2) - pv_history = pv_history.create_pv_history_image(image_datapipe=sat_pv_datapipe) - elif pv_in_image and "nwp" in used_datapipes.keys(): - nwp_datapipe, nwp_pv_datapipe = nwp_datapipe.fork(2) - pv_history = pv_history.create_pv_history_image( - image_datapipe=nwp_pv_datapipe, image_dim="osgb" - ) - - if "nwp" in used_datapipes.keys(): - modalities.append(nwp_datapipe) - if "hrv" in used_datapipes.keys(): - modalities.append(sat_hrv_datapipe) - if "sat" in used_datapipes.keys(): - modalities.append(sat_datapipe) - if "topo" in used_datapipes.keys(): - modalities.append(topo_datapipe) - if pv_in_image: - modalities.append(pv_history) - - metnet_datapipe = PreProcessMetNet( - modalities, - location_datapipe=pv_loc_datapipe, - center_width=center_size_meters, - center_height=center_size_meters, # 64km - context_height=context_size_meters, - context_width=context_size_meters, # 512km - output_width_pixels=output_size, - output_height_pixels=output_size, - add_sun_features=use_sun, - ) - - pv_datapipe = ConvertPVToNumpy(pv_datapipe) - - if not pv_in_image: - pv_history = pv_history.map(_select_non_nan_times) - pv_history = ConvertPVToNumpy(pv_history, return_pv_id=True) - return metnet_datapipe.batch(batch_size).zip_ocf( - pv_history.batch(batch_size), pv_datapipe.batch(batch_size) - ) - else: - return metnet_datapipe.batch(batch_size).zip_ocf(pv_datapipe.batch(batch_size)) diff --git a/tests/conftest.py b/tests/conftest.py index 94d03972f..f33631f0e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,9 +29,10 @@ OpenGSP, OpenNWP, OpenPVFromNetCDF, - OpenSatellite, OpenTopography, + open_sat_data ) +from ocf_datapipes.training.common import FakeIter xr.set_options(keep_attrs=True) @@ -49,7 +50,7 @@ def top_test_directory(): @pytest.fixture() def sat_hrv_datapipe(): filename = _top_test_directory + "/data/hrv_sat_data.zarr" - return OpenSatellite(zarr_path=filename) + return FakeIter(open_sat_data(zarr_path=filename)) @pytest.fixture() @@ -57,7 +58,7 @@ def sat_datapipe(): filename = f"{_top_test_directory}/data/sat_data.zarr" # The saved data is scaled from 0-1024. Now we use data scaled from 0-1 # Rescale here for subsequent tests - return OpenSatellite(zarr_path=filename).map(lambda da: da / 1024) + return FakeIter(open_sat_data(zarr_path=filename)).map(lambda da: da / 1024) @pytest.fixture() @@ -65,7 +66,7 @@ def sat_15_datapipe(): filename = f"{_top_test_directory}/data/sat_data_15.zarr" # The saved data is scaled from 0-1024. Now we use data scaled from 0-1 # Rescale here for subsequent tests - return OpenSatellite(zarr_path=filename).map(_sat_rescale) + return FakeIter(open_sat_data(zarr_path=filename)).map(_sat_rescale) @pytest.fixture() diff --git a/tests/load/test_load_satellite.py b/tests/load/test_load_satellite.py index 52c1b5056..d9e204f75 100644 --- a/tests/load/test_load_satellite.py +++ b/tests/load/test_load_satellite.py @@ -6,23 +6,20 @@ """ -from ocf_datapipes.load import OpenSatellite +from ocf_datapipes.load import open_sat_data from freezegun import freeze_time def test_open_satellite(): - sat_datapipe = OpenSatellite(zarr_path="tests/data/hrv_sat_data.zarr") - metadata = next(iter(sat_datapipe)) - assert metadata is not None + sat_xr = open_sat_data(zarr_path="tests/data/hrv_sat_data.zarr") + assert sat_xr is not None def test_open_hrvsatellite(): - sat_datapipe = OpenSatellite(zarr_path="tests/data/sat_data.zarr") - metadata = next(iter(sat_datapipe)) - assert metadata is not None + sat_xr = open_sat_data(zarr_path="tests/data/sat_data.zarr") + assert sat_xr is not None def test_open_satellite_15(): - sat_datapipe = OpenSatellite(zarr_path="tests/data/sat_data_15.zarr") - metadata = next(iter(sat_datapipe)) - assert metadata is not None + sat_xr = open_sat_data(zarr_path="tests/data/sat_data_15.zarr") + assert sat_xr is not None diff --git a/tests/training/metnet/test_metnet_gsp_national.py b/tests/training/metnet/test_metnet_gsp_national.py deleted file mode 100644 index 62e2ef8eb..000000000 --- a/tests/training/metnet/test_metnet_gsp_national.py +++ /dev/null @@ -1,30 +0,0 @@ -import os - -import numpy as np -import pytest -from torch.utils.data import DataLoader - -import ocf_datapipes -from ocf_datapipes.training.metnet.metnet_gsp_national import metnet_national_datapipe - - -@pytest.mark.skip("Failing at the moment") -def test_metnet_gsp_national_datapipe(): - filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml") - datapipe = metnet_national_datapipe(filename, use_pv=False) - dataloader = DataLoader(datapipe) - for i, batch in enumerate(dataloader): - _ = batch - if i + 1 % 50000 == 0: - break - - -@pytest.mark.skip("Failing at the moment") -def test_metnet_gsp_national_image_datapipe(): - filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml") - datapipe = metnet_national_datapipe(filename, use_pv=False, gsp_in_image=True, output_size=128) - dataloader = iter(datapipe) - batch = next(dataloader) - x, y = batch - assert np.isfinite(x).all() - assert np.isfinite(y).all() diff --git a/tests/training/metnet/test_metnet_national.py b/tests/training/metnet/test_metnet_national.py deleted file mode 100644 index 29329b4c6..000000000 --- a/tests/training/metnet/test_metnet_national.py +++ /dev/null @@ -1,18 +0,0 @@ -import os - -import numpy as np -import pytest - -import ocf_datapipes -from ocf_datapipes.training.metnet.metnet_national import metnet_national_datapipe - - -@pytest.mark.skip("Failing at the moment") -def test_metnet_national_datapipe(): - filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml") - - datapipe = metnet_national_datapipe(filename, max_num_pv_systems=1).set_length(2) - - batch = next(iter(datapipe)) - assert np.isfinite(batch[0]).all() - assert np.isfinite(batch[1]).all() diff --git a/tests/training/metnet/test_metnet_preprocessor.py b/tests/training/metnet/test_metnet_preprocessor.py deleted file mode 100644 index e8f9957c2..000000000 --- a/tests/training/metnet/test_metnet_preprocessor.py +++ /dev/null @@ -1,161 +0,0 @@ -from ocf_datapipes.select import FilterGSPIDs, PickLocations -from ocf_datapipes.transform.xarray import CreatePVImage -from ocf_datapipes.training.metnet.metnet_preprocessor import ( - PreProcessMetNetIterDataPipe as PreProcessMetNet, -) - - -def test_metnet_preprocess_no_sun(sat_datapipe, gsp_datapipe): - gsp_datapipe = FilterGSPIDs(gsp_datapipe, gsps_to_keep=[0]) - gsp_datapipe = PickLocations(gsp_datapipe) - datapipe = PreProcessMetNet( - [sat_datapipe], - location_datapipe=gsp_datapipe, - center_width=100_000, - center_height=100_000, - context_height=1_000_000, - context_width=1_000_000, - output_width_pixels=100, - output_height_pixels=100, - add_sun_features=False, - ) - data = next(iter(datapipe)) - print(data.shape) - - -def test_metnet_preprocess(sat_datapipe, gsp_datapipe): - gsp_datapipe = FilterGSPIDs(gsp_datapipe, gsps_to_keep=[0]) - gsp_datapipe = PickLocations(gsp_datapipe) - datapipe = PreProcessMetNet( - [sat_datapipe], - location_datapipe=gsp_datapipe, - center_width=100_000, - center_height=100_000, - context_height=1_000_000, - context_width=1_000_000, - output_width_pixels=100, - output_height_pixels=100, - add_sun_features=True, - ) - data = next(iter(datapipe)) - print(data.shape) - - -def test_metnet_preprocess_both_sat(sat_datapipe, sat_hrv_datapipe, gsp_datapipe): - gsp_datapipe = FilterGSPIDs(gsp_datapipe, gsps_to_keep=[0]) - gsp_datapipe = PickLocations(gsp_datapipe) - datapipe = PreProcessMetNet( - [sat_datapipe, sat_hrv_datapipe], - location_datapipe=gsp_datapipe, - center_width=100_000, - center_height=100_000, - context_height=1_000_000, - context_width=1_000_000, - output_width_pixels=100, - output_height_pixels=100, - add_sun_features=False, - ) - data = next(iter(datapipe)) - print(data.shape) - - -def test_metnet_preprocess_both_sat_other_order(sat_datapipe, sat_hrv_datapipe, gsp_datapipe): - gsp_datapipe = FilterGSPIDs(gsp_datapipe, gsps_to_keep=[0]) - gsp_datapipe = PickLocations(gsp_datapipe) - datapipe = PreProcessMetNet( - [sat_hrv_datapipe, sat_datapipe], - location_datapipe=gsp_datapipe, - center_width=100_000, - center_height=100_000, - context_height=1_000_000, - context_width=1_000_000, - output_width_pixels=100, - output_height_pixels=100, - add_sun_features=True, - ) - data = next(iter(datapipe)) - print(data.shape) - - -def test_metnet_preprocess_both_sat_pv( - sat_datapipe, sat_hrv_datapipe, gsp_datapipe, passiv_datapipe -): - gsp_datapipe = FilterGSPIDs(gsp_datapipe, gsps_to_keep=[0]) - gsp_datapipe = PickLocations(gsp_datapipe) - sat_datapipe, image_datapipe = sat_datapipe.fork(2) - passiv_datapipe = CreatePVImage(passiv_datapipe, image_datapipe=image_datapipe, normalize=True) - datapipe = PreProcessMetNet( - [sat_hrv_datapipe, sat_datapipe, passiv_datapipe], - location_datapipe=gsp_datapipe, - center_width=100_000, - center_height=100_000, - context_height=1_000_000, - context_width=1_000_000, - output_width_pixels=100, - output_height_pixels=100, - add_sun_features=True, - ) - data = next(iter(datapipe)) - assert data.shape == (289, 14, 100, 100) - - -def test_metnet_preprocess_sat_hrv_pv_nwp( - sat_datapipe, sat_hrv_datapipe, gsp_datapipe, passiv_datapipe, nwp_datapipe -): - gsp_datapipe = FilterGSPIDs(gsp_datapipe, gsps_to_keep=[0]) - gsp_datapipe = PickLocations(gsp_datapipe) - sat_datapipe, image_datapipe = sat_datapipe.fork(2) - passiv_datapipe = CreatePVImage(passiv_datapipe, image_datapipe=image_datapipe, normalize=True) - datapipe = PreProcessMetNet( - [sat_hrv_datapipe, sat_datapipe, passiv_datapipe], - location_datapipe=gsp_datapipe, - center_width=100_000, - center_height=100_000, - context_height=1_000_000, - context_width=1_000_000, - output_width_pixels=100, - output_height_pixels=100, - add_sun_features=True, - ) - data = next(iter(datapipe)) - assert data.shape == (289, 14, 100, 100) - - -def test_metnet_preprocess_sat_topo(sat_datapipe, gsp_datapipe, topo_datapipe): - gsp_datapipe = FilterGSPIDs(gsp_datapipe, gsps_to_keep=[0]) - gsp_datapipe = PickLocations(gsp_datapipe) - datapipe = PreProcessMetNet( - [sat_datapipe, topo_datapipe], - location_datapipe=gsp_datapipe, - center_width=100_000, - center_height=100_000, - context_height=1_000_000, - context_width=1_000_000, - output_width_pixels=100, - output_height_pixels=100, - add_sun_features=True, - ) - data = next(iter(datapipe)) - assert data.shape == (25, 12, 100, 100) - - -def test_metnet_preprocess_sat_hrv_pv_nwp_topo( - sat_datapipe, sat_hrv_datapipe, gsp_datapipe, passiv_datapipe, nwp_datapipe, topo_datapipe -): - gsp_datapipe = FilterGSPIDs(gsp_datapipe, gsps_to_keep=[0]) - gsp_datapipe = PickLocations(gsp_datapipe) - sat_datapipe, image_datapipe = sat_datapipe.fork(2) - passiv_datapipe = CreatePVImage(passiv_datapipe, image_datapipe=image_datapipe, normalize=True) - datapipe = PreProcessMetNet( - [sat_hrv_datapipe, sat_datapipe, passiv_datapipe, topo_datapipe], - location_datapipe=gsp_datapipe, - center_width=100_000, - center_height=100_000, - context_height=1_000_000, - context_width=1_000_000, - output_width_pixels=100, - output_height_pixels=100, - add_sun_features=True, - ) - data = next(iter(datapipe)) - assert data.shape == (289, 16, 100, 100) diff --git a/tests/training/metnet/test_metnet_pv_national.py b/tests/training/metnet/test_metnet_pv_national.py deleted file mode 100644 index c4392ce16..000000000 --- a/tests/training/metnet/test_metnet_pv_national.py +++ /dev/null @@ -1,19 +0,0 @@ -import os - -import numpy as np -import pytest - -import ocf_datapipes -from ocf_datapipes.training.metnet.metnet_pv_national import metnet_national_datapipe - - -@pytest.mark.skip("Failing at the moment") -def test_metnet_pv_national_datapipe(): - filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml") - datapipe = metnet_national_datapipe(filename, use_nwp=False, max_num_pv_systems=1) - - batch = next(iter(datapipe)) - assert np.isfinite(batch[0]).all() - assert np.isfinite(batch[1]).all() - assert np.isfinite(batch[2]).all() - assert np.isfinite(batch[3]).all() diff --git a/tests/training/metnet/test_metnet_pv_site.py b/tests/training/metnet/test_metnet_pv_site.py deleted file mode 100644 index 1d2f3ef6b..000000000 --- a/tests/training/metnet/test_metnet_pv_site.py +++ /dev/null @@ -1,17 +0,0 @@ -import os - -import numpy as np -import pytest - -import ocf_datapipes -from ocf_datapipes.training.metnet.metnet_pv_site import metnet_site_datapipe - - -@pytest.mark.skip("Failing at the moment") -def test_metnet_site_datapipe(): - filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml") - datapipe = metnet_site_datapipe(filename, use_nwp=False, pv_in_image=True) - - batch = next(iter(datapipe)) - assert np.isfinite(batch[0]).all() - assert np.isfinite(batch[1]).all() diff --git a/tests/transform/numpy_batch/conftest.py b/tests/transform/numpy_batch/conftest.py index d137ff7cc..dff4276f7 100644 --- a/tests/transform/numpy_batch/conftest.py +++ b/tests/transform/numpy_batch/conftest.py @@ -10,7 +10,8 @@ ConvertPVToNumpyBatch, ConvertSatelliteToNumpyBatch, ) -from ocf_datapipes.load import OpenGSP, OpenNWP, OpenSatellite +from ocf_datapipes.load import OpenGSP, OpenNWP, open_sat_data +from ocf_datapipes.training.common import FakeIter from ocf_datapipes.batch import MergeNumpyModalities from ocf_datapipes.transform.xarray import ( @@ -21,7 +22,7 @@ @pytest.fixture() def sat_hrv_np_datapipe(): filename = Path(ocf_datapipes.__file__).parent.parent / "tests" / "data" / "hrv_sat_data.zarr" - dp = OpenSatellite(zarr_path=filename) + dp = FakeIter(open_sat_data(zarr_path=filename)) dp = AddT0IdxAndSamplePeriodDuration( dp, sample_period_duration=timedelta(minutes=5), @@ -34,7 +35,7 @@ def sat_hrv_np_datapipe(): @pytest.fixture() def sat_np_datapipe(): filename = Path(ocf_datapipes.__file__).parent.parent / "tests" / "data" / "sat_data.zarr" - dp = OpenSatellite(zarr_path=filename) + dp = FakeIter(open_sat_data(zarr_path=filename)) dp = AddT0IdxAndSamplePeriodDuration( dp, sample_period_duration=timedelta(minutes=5), From a4b8efa31cb05ffd34fed8a77e0cbaa2e1b11025 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Jul 2024 19:07:02 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- ocf_datapipes/load/__init__.py | 1 + ocf_datapipes/load/satellite.py | 2 -- ocf_datapipes/training/common.py | 18 ++++++++---------- .../training/example/gsp_pv_nwp_satellite.py | 2 +- tests/conftest.py | 8 +------- 5 files changed, 11 insertions(+), 20 deletions(-) diff --git a/ocf_datapipes/load/__init__.py b/ocf_datapipes/load/__init__.py index 5c2eef9e6..6bea10b3b 100644 --- a/ocf_datapipes/load/__init__.py +++ b/ocf_datapipes/load/__init__.py @@ -13,6 +13,7 @@ from .configuration import OpenConfigurationIterDataPipe as OpenConfiguration from .nwp.nwp import OpenNWPIterDataPipe as OpenNWP + # from .satellite import OpenSatelliteIterDataPipe as OpenSatellite from .satellite import open_sat_data diff --git a/ocf_datapipes/load/satellite.py b/ocf_datapipes/load/satellite.py index fb2818d76..da7982ff1 100644 --- a/ocf_datapipes/load/satellite.py +++ b/ocf_datapipes/load/satellite.py @@ -9,7 +9,6 @@ import pandas as pd import xarray as xr from ocf_blosc2 import Blosc2 # noqa: F401 -from torch.utils.data import IterDataPipe, functional_datapipe _log = logging.getLogger(__name__) @@ -163,4 +162,3 @@ def open_sat_data(zarr_path: Union[Path, str, list[Path], list[str]]) -> xr.Data _log.info("Opened satellite data") return data_array - diff --git a/ocf_datapipes/training/common.py b/ocf_datapipes/training/common.py index c3d93cdc5..828a26530 100644 --- a/ocf_datapipes/training/common.py +++ b/ocf_datapipes/training/common.py @@ -19,8 +19,8 @@ OpenNWP, OpenPVFromNetCDF, OpenPVFromPVSitesDB, - open_sat_data, OpenWindFromNetCDF, + open_sat_data, ) from ocf_datapipes.utils.utils import flatten_nwp_source_dict @@ -48,7 +48,6 @@ def __iter__(self) -> xr.DataArray: yield self.data_xr - def is_config_and_path_valid( use_flag: bool, config, @@ -186,14 +185,13 @@ def open_and_return_datapipes( sat_xr = open_sat_data(configuration.input_data.satellite.satellite_zarr_path) sat_pipe = FakeIter(sat_xr) - sat_datapipe = (sat_pipe - .filter_channels(configuration.input_data.satellite.satellite_channels) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=minutes( - configuration.input_data.satellite.time_resolution_minutes - ), - history_duration=minutes(configuration.input_data.satellite.history_minutes), - ) + sat_datapipe = sat_pipe.filter_channels( + configuration.input_data.satellite.satellite_channels + ).add_t0_idx_and_sample_period_duration( + sample_period_duration=minutes( + configuration.input_data.satellite.time_resolution_minutes + ), + history_duration=minutes(configuration.input_data.satellite.history_minutes), ) used_datapipes["sat"] = sat_datapipe diff --git a/ocf_datapipes/training/example/gsp_pv_nwp_satellite.py b/ocf_datapipes/training/example/gsp_pv_nwp_satellite.py index 3b4a004b4..23d1b8109 100644 --- a/ocf_datapipes/training/example/gsp_pv_nwp_satellite.py +++ b/ocf_datapipes/training/example/gsp_pv_nwp_satellite.py @@ -13,7 +13,7 @@ from ocf_datapipes.config.load import load_yaml_configuration from ocf_datapipes.config.model import Configuration from ocf_datapipes.load import OpenGSP, OpenNWP, OpenPVFromNetCDF, open_sat_data -from ocf_datapipes.training.common import normalize_gsp, normalize_pv, FakeIter +from ocf_datapipes.training.common import FakeIter, normalize_gsp, normalize_pv from ocf_datapipes.utils.consts import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD logger = logging.getLogger(__name__) diff --git a/tests/conftest.py b/tests/conftest.py index f33631f0e..63b924cd7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,13 +25,7 @@ from ocf_datapipes.config.load import load_yaml_configuration from ocf_datapipes.config.model import PV, PVFiles from ocf_datapipes.config.save import save_yaml_configuration -from ocf_datapipes.load import ( - OpenGSP, - OpenNWP, - OpenPVFromNetCDF, - OpenTopography, - open_sat_data -) +from ocf_datapipes.load import OpenGSP, OpenNWP, OpenPVFromNetCDF, OpenTopography, open_sat_data from ocf_datapipes.training.common import FakeIter xr.set_options(keep_attrs=True) From 12ef3ba92c8bfc4489b1ab58d8eb0ab2f2808a1d Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Thu, 25 Jul 2024 20:08:46 +0100 Subject: [PATCH 3/7] tidy --- ocf_datapipes/load/__init__.py | 1 - ocf_datapipes/training/common.py | 6 ++++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/ocf_datapipes/load/__init__.py b/ocf_datapipes/load/__init__.py index 5c2eef9e6..3684d8379 100644 --- a/ocf_datapipes/load/__init__.py +++ b/ocf_datapipes/load/__init__.py @@ -13,7 +13,6 @@ from .configuration import OpenConfigurationIterDataPipe as OpenConfiguration from .nwp.nwp import OpenNWPIterDataPipe as OpenNWP -# from .satellite import OpenSatelliteIterDataPipe as OpenSatellite from .satellite import open_sat_data try: diff --git a/ocf_datapipes/training/common.py b/ocf_datapipes/training/common.py index c3d93cdc5..7b08a7ec2 100644 --- a/ocf_datapipes/training/common.py +++ b/ocf_datapipes/training/common.py @@ -40,6 +40,12 @@ @functional_datapipe("fake_iter") class FakeIter(IterDataPipe): + """ This makes a fake iter datapipe + + We are using this just to move away from datapipes. + This can be done by removing certain function to return xarray data, + and for the moment using this FakeIter, to make it back into a datapipe + """ def __init__(self, data_xr): self.data_xr = data_xr From 6f8f39950d90749b6f389ddec295cc3d4dd591f0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Jul 2024 19:11:35 +0000 Subject: [PATCH 4/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- ocf_datapipes/training/common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ocf_datapipes/training/common.py b/ocf_datapipes/training/common.py index fbeb6ee8c..36a81ac62 100644 --- a/ocf_datapipes/training/common.py +++ b/ocf_datapipes/training/common.py @@ -40,12 +40,13 @@ @functional_datapipe("fake_iter") class FakeIter(IterDataPipe): - """ This makes a fake iter datapipe + """This makes a fake iter datapipe We are using this just to move away from datapipes. This can be done by removing certain function to return xarray data, and for the moment using this FakeIter, to make it back into a datapipe """ + def __init__(self, data_xr): self.data_xr = data_xr From 56d472d62683932e008e7e343a140d7346406f4b Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Thu, 25 Jul 2024 20:12:55 +0100 Subject: [PATCH 5/7] remove metnet training --- tests/end2end/test_metnet_training.py | 179 -------------------------- 1 file changed, 179 deletions(-) delete mode 100644 tests/end2end/test_metnet_training.py diff --git a/tests/end2end/test_metnet_training.py b/tests/end2end/test_metnet_training.py deleted file mode 100644 index 86752bd8b..000000000 --- a/tests/end2end/test_metnet_training.py +++ /dev/null @@ -1,179 +0,0 @@ -import numpy as np -import torch -import xarray -from torch.utils.data.datapipes._decorator import functional_datapipe -from torch.utils.data.datapipes.iter import IterableWrapper - - -xarray.set_options(keep_attrs=True) - -from datetime import timedelta - -from ocf_datapipes.select import ( - FilterGSPIDs, - PickLocations, - SelectSpatialSliceMeters, - SelectTimeSliceNWP, -) - -from ocf_datapipes.transform.xarray import ( - AddT0IdxAndSamplePeriodDuration, - CreatePVImage, - Downsample, - Normalize, - ReprojectTopography, -) -from ocf_datapipes.training.metnet.metnet_preprocessor import ( - PreProcessMetNetIterDataPipe as PreProcessMetNet, -) - -from ocf_datapipes.utils.consts import UKV_MEAN, UKV_STD, RSS_MEAN, RSS_STD - -import pytest - - -def last_time(ds, time_dim="time_utc"): - return ds[time_dim].values[-1] - - -# N.B First change which broke this test was changing the NWP data in the test directory to include -# more forecast steps -@pytest.mark.skip(reason="Not maintained for the moment") -def test_metnet_production( - sat_hrv_datapipe, sat_datapipe, passiv_datapipe, topo_datapipe, gsp_datapipe, nwp_datapipe -): - #################################### - # - # Equivalent to PP's loading and filtering methods - # - ##################################### - # Normalize GSP and PV on whole dataset here - pv_datapipe = passiv_datapipe - gsp_datapipe, gsp_loc_datapipe = FilterGSPIDs(gsp_datapipe, gsps_to_keep=[0]).fork(2) - gsp_datapipe = Normalize(gsp_datapipe, normalize_fn=lambda x: x / x.installedcapacity_mwp) - topo_datapipe = ReprojectTopography(topo_datapipe) - sat_hrv_datapipe = AddT0IdxAndSamplePeriodDuration( - sat_hrv_datapipe, - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=60), - ) - sat_datapipe = AddT0IdxAndSamplePeriodDuration( - sat_datapipe, - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=60), - ) - pv_datapipe = AddT0IdxAndSamplePeriodDuration( - pv_datapipe, - sample_period_duration=timedelta(minutes=5), - history_duration=timedelta(minutes=60), - ) - gsp_datapipe, gsp_t0_datapipe = AddT0IdxAndSamplePeriodDuration( - gsp_datapipe, - sample_period_duration=timedelta(minutes=30), - history_duration=timedelta(hours=2), - ).fork(2) - nwp_datapipe = AddT0IdxAndSamplePeriodDuration( - nwp_datapipe, sample_period_duration=timedelta(hours=1), history_duration=timedelta(hours=2) - ) - - #################################### - # - # Equivalent to PP's xr_batch_processors and normal loading/selecting - # - ##################################### - - ( - location_datapipe1, - location_datapipe2, - location_datapipe3, - location_datapipe4, - location_datapipe5, - ) = PickLocations(gsp_loc_datapipe, return_all_locations=True).fork( - 5 - ) # Its in order then - pv_datapipe, pv_t0_datapipe = SelectSpatialSliceMeters( - pv_datapipe, - location_datapipe=location_datapipe1, - roi_width_meters=100_000, - roi_height_meters=100_000, - ).fork( - 2 - ) # Has to be large as test PV systems aren't in first 20 GSPs it seems - nwp_datapipe, nwp_t0_datapipe = Downsample(nwp_datapipe, y_coarsen=16, x_coarsen=16).fork(2) - nwp_t0_datapipe = nwp_t0_datapipe.map(lambda x: last_time(x, "init_time_utc")) - nwp_datapipe = SelectTimeSliceNWP( - nwp_datapipe, - t0_datapipe=nwp_t0_datapipe, - sample_period_duration=timedelta(hours=1), - history_duration=timedelta(hours=2), - forecast_duration=timedelta(hours=3), - ) - gsp_t0_datapipe = gsp_t0_datapipe.map(last_time) - gsp_datapipe = SelectLiveTimeSlice( - gsp_datapipe, - t0_datapipe=gsp_t0_datapipe, - history_duration=timedelta(hours=2), - ) - sat_t0_datapipe = sat_datapipe.map(last_time) - sat_datapipe, image_datapipe = SelectLiveTimeSlice( - sat_datapipe, - t0_datapipe=sat_t0_datapipe, - history_duration=timedelta(hours=1), - ).fork(2) - sat_hrv_t0_datapipe = sat_hrv_datapipe.map(last_time) - sat_hrv_datapipe = SelectLiveTimeSlice( - sat_hrv_datapipe, - t0_datapipe=sat_hrv_t0_datapipe, - history_duration=timedelta(hours=1), - ) - passiv_t0_datapipe = pv_t0_datapipe.map(last_time) - sat_hrv_t0_datapipe - pv_datapipe = SelectLiveTimeSlice( - pv_datapipe, - t0_datapipe=passiv_t0_datapipe, - history_duration=timedelta(hours=1), - ) - gsp_datapipe = SelectSpatialSliceMeters( - gsp_datapipe, - location_datapipe=location_datapipe4, - dim_name="gsp_id", - roi_width_meters=10, - roi_height_meters=10, - ) - - pv_datapipe = CreatePVImage(pv_datapipe, image_datapipe) - - sat_hrv_datapipe = Normalize( - sat_hrv_datapipe, mean=RSS_MEAN.sel(channel="HRV") / 4, std=RSS_STD.sel(channel="HRV") / 4 - ).map( - lambda x: x.resample(time_utc="5min").interpolate("linear") - ) # Interplate to 5 minutes incase its 15 minutes - sat_datapipe = Normalize(sat_datapipe, mean=RSS_MEAN, std=RSS_STD).map( - lambda x: x.resample(time_utc="5min").interpolate("linear") - ) # Interplate to 5 minutes incase its 15 minutes - nwp_datapipe = Normalize(nwp_datapipe, mean=UKV_MEAN, std=UKV_STD) - topo_datapipe = Normalize(topo_datapipe, calculate_mean_std_from_example=True) - - # Now combine in the MetNet format - combined_datapipe = PreProcessMetNet( - [ - nwp_datapipe, - sat_hrv_datapipe, - sat_datapipe, - pv_datapipe, - ], - location_datapipe=location_datapipe5, - center_width=500_000, - center_height=1_000_000, - context_height=10_000_000, - context_width=10_000_000, - output_width_pixels=512, - output_height_pixels=512, - add_sun_features=True, - ) - - batch = next(iter(combined_datapipe)) - assert ~np.isnan(batch).any() - print(batch.shape) - batch = next(iter(gsp_datapipe)) - print(batch.shape) From de65edd4c16022098edf23b8264c568e42663b03 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Thu, 25 Jul 2024 20:17:15 +0100 Subject: [PATCH 6/7] lint --- ocf_datapipes/training/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ocf_datapipes/training/common.py b/ocf_datapipes/training/common.py index 36a81ac62..b5e27ca8e 100644 --- a/ocf_datapipes/training/common.py +++ b/ocf_datapipes/training/common.py @@ -47,10 +47,10 @@ class FakeIter(IterDataPipe): and for the moment using this FakeIter, to make it back into a datapipe """ - def __init__(self, data_xr): + def __init__(self, data_xr): # noqa self.data_xr = data_xr - def __iter__(self) -> xr.DataArray: + def __iter__(self) -> xr.DataArray: # noqa while True: yield self.data_xr From f90dcec27d14db5f4c2d4198ec586029ffc7e167 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Jul 2024 19:18:33 +0000 Subject: [PATCH 7/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- ocf_datapipes/training/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ocf_datapipes/training/common.py b/ocf_datapipes/training/common.py index b5e27ca8e..4a9e27d25 100644 --- a/ocf_datapipes/training/common.py +++ b/ocf_datapipes/training/common.py @@ -47,10 +47,10 @@ class FakeIter(IterDataPipe): and for the moment using this FakeIter, to make it back into a datapipe """ - def __init__(self, data_xr): # noqa + def __init__(self, data_xr): # noqa self.data_xr = data_xr - def __iter__(self) -> xr.DataArray: # noqa + def __iter__(self) -> xr.DataArray: # noqa while True: yield self.data_xr