diff --git a/ocf_datapipes/load/__init__.py b/ocf_datapipes/load/__init__.py index 3061528b8..3684d8379 100644 --- a/ocf_datapipes/load/__init__.py +++ b/ocf_datapipes/load/__init__.py @@ -13,7 +13,7 @@ from .configuration import OpenConfigurationIterDataPipe as OpenConfiguration from .nwp.nwp import OpenNWPIterDataPipe as OpenNWP -from .satellite import OpenSatelliteIterDataPipe as OpenSatellite +from .satellite import open_sat_data try: import rioxarray # Rioxarray is sometimes a pain to install, so only load this if its installed diff --git a/ocf_datapipes/load/satellite.py b/ocf_datapipes/load/satellite.py index a61688e9c..da7982ff1 100644 --- a/ocf_datapipes/load/satellite.py +++ b/ocf_datapipes/load/satellite.py @@ -9,7 +9,6 @@ import pandas as pd import xarray as xr from ocf_blosc2 import Blosc2 # noqa: F401 -from torch.utils.data import IterDataPipe, functional_datapipe _log = logging.getLogger(__name__) @@ -163,24 +162,3 @@ def open_sat_data(zarr_path: Union[Path, str, list[Path], list[str]]) -> xr.Data _log.info("Opened satellite data") return data_array - - -@functional_datapipe("open_satellite") -class OpenSatelliteIterDataPipe(IterDataPipe): - """Open Satellite Zarr""" - - def __init__(self, zarr_path: Union[Path, str]): - """ - Opens the satellite Zarr - - Args: - zarr_path: path to the zarr file - """ - self.zarr_path = zarr_path - super().__init__() - - def __iter__(self) -> xr.DataArray: - """Open the Zarr file""" - data: xr.DataArray = open_sat_data(zarr_path=self.zarr_path) - while True: - yield data diff --git a/ocf_datapipes/training/common.py b/ocf_datapipes/training/common.py index d5f0d3d95..4a9e27d25 100644 --- a/ocf_datapipes/training/common.py +++ b/ocf_datapipes/training/common.py @@ -19,8 +19,8 @@ OpenNWP, OpenPVFromNetCDF, OpenPVFromPVSitesDB, - OpenSatellite, OpenWindFromNetCDF, + open_sat_data, ) from ocf_datapipes.utils.utils import flatten_nwp_source_dict @@ -38,6 +38,23 @@ logger = logging.getLogger(__name__) +@functional_datapipe("fake_iter") +class FakeIter(IterDataPipe): + """This makes a fake iter datapipe + + We are using this just to move away from datapipes. + This can be done by removing certain function to return xarray data, + and for the moment using this FakeIter, to make it back into a datapipe + """ + + def __init__(self, data_xr): # noqa + self.data_xr = data_xr + + def __iter__(self) -> xr.DataArray: # noqa + while True: + yield self.data_xr + + def is_config_and_path_valid( use_flag: bool, config, @@ -171,24 +188,28 @@ def open_and_return_datapipes( if use_sat: logger.debug("Opening Satellite Data") - sat_datapipe = ( - OpenSatellite(configuration.input_data.satellite.satellite_zarr_path) - .filter_channels(configuration.input_data.satellite.satellite_channels) - .add_t0_idx_and_sample_period_duration( - sample_period_duration=minutes( - configuration.input_data.satellite.time_resolution_minutes - ), - history_duration=minutes(configuration.input_data.satellite.history_minutes), - ) + + sat_xr = open_sat_data(configuration.input_data.satellite.satellite_zarr_path) + sat_pipe = FakeIter(sat_xr) + + sat_datapipe = sat_pipe.filter_channels( + configuration.input_data.satellite.satellite_channels + ).add_t0_idx_and_sample_period_duration( + sample_period_duration=minutes( + configuration.input_data.satellite.time_resolution_minutes + ), + history_duration=minutes(configuration.input_data.satellite.history_minutes), ) used_datapipes["sat"] = sat_datapipe if use_hrv: logger.debug("Opening HRV Satellite Data") - sat_hrv_datapipe = OpenSatellite( - configuration.input_data.hrvsatellite.hrvsatellite_zarr_path - ).add_t0_idx_and_sample_period_duration( + + sat_xr = open_sat_data(configuration.input_data.satellite.satellite_zarr_path) + sat_pipe = FakeIter(sat_xr) + + sat_hrv_datapipe = sat_pipe.add_t0_idx_and_sample_period_duration( sample_period_duration=minutes( configuration.input_data.hrvsatellite.time_resolution_minutes ), diff --git a/ocf_datapipes/training/example/gsp_pv_nwp_satellite.py b/ocf_datapipes/training/example/gsp_pv_nwp_satellite.py index 9ac221e23..23d1b8109 100644 --- a/ocf_datapipes/training/example/gsp_pv_nwp_satellite.py +++ b/ocf_datapipes/training/example/gsp_pv_nwp_satellite.py @@ -12,8 +12,8 @@ from ocf_datapipes.batch import MergeNumpyModalities, MergeNWPNumpyModalities from ocf_datapipes.config.load import load_yaml_configuration from ocf_datapipes.config.model import Configuration -from ocf_datapipes.load import OpenGSP, OpenNWP, OpenPVFromNetCDF, OpenSatellite -from ocf_datapipes.training.common import normalize_gsp, normalize_pv +from ocf_datapipes.load import OpenGSP, OpenNWP, OpenPVFromNetCDF, open_sat_data +from ocf_datapipes.training.common import FakeIter, normalize_gsp, normalize_pv from ocf_datapipes.utils.consts import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD logger = logging.getLogger(__name__) @@ -53,9 +53,9 @@ def gsp_pv_nwp_satellite_data_pipeline(configuration: Union[Path, str]) -> IterD ) # Load and noralize satellite data - satellite_datapipe = OpenSatellite( - zarr_path=configuration.input_data.satellite.satellite_zarr_path - ).normalize(mean=RSS_MEAN, std=RSS_STD) + sat_xr = open_sat_data(zarr_path=configuration.input_data.satellite.satellite_zarr_path) + satellite_datapipe = FakeIter(sat_xr) + satellite_datapipe = satellite_datapipe.normalize(mean=RSS_MEAN, std=RSS_STD) # Load and normalize NWP data - There may be multiple NWP sources nwp_datapipe_dict = {} diff --git a/tests/conftest.py b/tests/conftest.py index 94d03972f..63b924cd7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,13 +25,8 @@ from ocf_datapipes.config.load import load_yaml_configuration from ocf_datapipes.config.model import PV, PVFiles from ocf_datapipes.config.save import save_yaml_configuration -from ocf_datapipes.load import ( - OpenGSP, - OpenNWP, - OpenPVFromNetCDF, - OpenSatellite, - OpenTopography, -) +from ocf_datapipes.load import OpenGSP, OpenNWP, OpenPVFromNetCDF, OpenTopography, open_sat_data +from ocf_datapipes.training.common import FakeIter xr.set_options(keep_attrs=True) @@ -49,7 +44,7 @@ def top_test_directory(): @pytest.fixture() def sat_hrv_datapipe(): filename = _top_test_directory + "/data/hrv_sat_data.zarr" - return OpenSatellite(zarr_path=filename) + return FakeIter(open_sat_data(zarr_path=filename)) @pytest.fixture() @@ -57,7 +52,7 @@ def sat_datapipe(): filename = f"{_top_test_directory}/data/sat_data.zarr" # The saved data is scaled from 0-1024. Now we use data scaled from 0-1 # Rescale here for subsequent tests - return OpenSatellite(zarr_path=filename).map(lambda da: da / 1024) + return FakeIter(open_sat_data(zarr_path=filename)).map(lambda da: da / 1024) @pytest.fixture() @@ -65,7 +60,7 @@ def sat_15_datapipe(): filename = f"{_top_test_directory}/data/sat_data_15.zarr" # The saved data is scaled from 0-1024. Now we use data scaled from 0-1 # Rescale here for subsequent tests - return OpenSatellite(zarr_path=filename).map(_sat_rescale) + return FakeIter(open_sat_data(zarr_path=filename)).map(_sat_rescale) @pytest.fixture() diff --git a/tests/load/test_load_satellite.py b/tests/load/test_load_satellite.py index 52c1b5056..d9e204f75 100644 --- a/tests/load/test_load_satellite.py +++ b/tests/load/test_load_satellite.py @@ -6,23 +6,20 @@ """ -from ocf_datapipes.load import OpenSatellite +from ocf_datapipes.load import open_sat_data from freezegun import freeze_time def test_open_satellite(): - sat_datapipe = OpenSatellite(zarr_path="tests/data/hrv_sat_data.zarr") - metadata = next(iter(sat_datapipe)) - assert metadata is not None + sat_xr = open_sat_data(zarr_path="tests/data/hrv_sat_data.zarr") + assert sat_xr is not None def test_open_hrvsatellite(): - sat_datapipe = OpenSatellite(zarr_path="tests/data/sat_data.zarr") - metadata = next(iter(sat_datapipe)) - assert metadata is not None + sat_xr = open_sat_data(zarr_path="tests/data/sat_data.zarr") + assert sat_xr is not None def test_open_satellite_15(): - sat_datapipe = OpenSatellite(zarr_path="tests/data/sat_data_15.zarr") - metadata = next(iter(sat_datapipe)) - assert metadata is not None + sat_xr = open_sat_data(zarr_path="tests/data/sat_data_15.zarr") + assert sat_xr is not None diff --git a/tests/transform/numpy_batch/conftest.py b/tests/transform/numpy_batch/conftest.py index d137ff7cc..dff4276f7 100644 --- a/tests/transform/numpy_batch/conftest.py +++ b/tests/transform/numpy_batch/conftest.py @@ -10,7 +10,8 @@ ConvertPVToNumpyBatch, ConvertSatelliteToNumpyBatch, ) -from ocf_datapipes.load import OpenGSP, OpenNWP, OpenSatellite +from ocf_datapipes.load import OpenGSP, OpenNWP, open_sat_data +from ocf_datapipes.training.common import FakeIter from ocf_datapipes.batch import MergeNumpyModalities from ocf_datapipes.transform.xarray import ( @@ -21,7 +22,7 @@ @pytest.fixture() def sat_hrv_np_datapipe(): filename = Path(ocf_datapipes.__file__).parent.parent / "tests" / "data" / "hrv_sat_data.zarr" - dp = OpenSatellite(zarr_path=filename) + dp = FakeIter(open_sat_data(zarr_path=filename)) dp = AddT0IdxAndSamplePeriodDuration( dp, sample_period_duration=timedelta(minutes=5), @@ -34,7 +35,7 @@ def sat_hrv_np_datapipe(): @pytest.fixture() def sat_np_datapipe(): filename = Path(ocf_datapipes.__file__).parent.parent / "tests" / "data" / "sat_data.zarr" - dp = OpenSatellite(zarr_path=filename) + dp = FakeIter(open_sat_data(zarr_path=filename)) dp = AddT0IdxAndSamplePeriodDuration( dp, sample_period_duration=timedelta(minutes=5),