Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove satellite datapipe, remove metnet #345

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion ocf_datapipes/load/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from .configuration import OpenConfigurationIterDataPipe as OpenConfiguration
from .nwp.nwp import OpenNWPIterDataPipe as OpenNWP
from .satellite import OpenSatelliteIterDataPipe as OpenSatellite
from .satellite import open_sat_data

try:
import rioxarray # Rioxarray is sometimes a pain to install, so only load this if its installed
Expand Down
22 changes: 0 additions & 22 deletions ocf_datapipes/load/satellite.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import pandas as pd
import xarray as xr
from ocf_blosc2 import Blosc2 # noqa: F401
from torch.utils.data import IterDataPipe, functional_datapipe

_log = logging.getLogger(__name__)

Expand Down Expand Up @@ -163,24 +162,3 @@ def open_sat_data(zarr_path: Union[Path, str, list[Path], list[str]]) -> xr.Data
_log.info("Opened satellite data")

return data_array


@functional_datapipe("open_satellite")
class OpenSatelliteIterDataPipe(IterDataPipe):
"""Open Satellite Zarr"""

def __init__(self, zarr_path: Union[Path, str]):
"""
Opens the satellite Zarr

Args:
zarr_path: path to the zarr file
"""
self.zarr_path = zarr_path
super().__init__()

def __iter__(self) -> xr.DataArray:
"""Open the Zarr file"""
data: xr.DataArray = open_sat_data(zarr_path=self.zarr_path)
while True:
yield data
47 changes: 34 additions & 13 deletions ocf_datapipes/training/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
OpenNWP,
OpenPVFromNetCDF,
OpenPVFromPVSitesDB,
OpenSatellite,
OpenWindFromNetCDF,
open_sat_data,
)
from ocf_datapipes.utils.utils import flatten_nwp_source_dict

Expand All @@ -38,6 +38,23 @@
logger = logging.getLogger(__name__)


@functional_datapipe("fake_iter")
class FakeIter(IterDataPipe):
"""This makes a fake iter datapipe

We are using this just to move away from datapipes.
This can be done by removing certain function to return xarray data,
and for the moment using this FakeIter, to make it back into a datapipe
"""

def __init__(self, data_xr): # noqa
self.data_xr = data_xr

def __iter__(self) -> xr.DataArray: # noqa
while True:
yield self.data_xr


def is_config_and_path_valid(
use_flag: bool,
config,
Expand Down Expand Up @@ -171,24 +188,28 @@ def open_and_return_datapipes(

if use_sat:
logger.debug("Opening Satellite Data")
sat_datapipe = (
OpenSatellite(configuration.input_data.satellite.satellite_zarr_path)
.filter_channels(configuration.input_data.satellite.satellite_channels)
.add_t0_idx_and_sample_period_duration(
sample_period_duration=minutes(
configuration.input_data.satellite.time_resolution_minutes
),
history_duration=minutes(configuration.input_data.satellite.history_minutes),
)

sat_xr = open_sat_data(configuration.input_data.satellite.satellite_zarr_path)
sat_pipe = FakeIter(sat_xr)

sat_datapipe = sat_pipe.filter_channels(
configuration.input_data.satellite.satellite_channels
).add_t0_idx_and_sample_period_duration(
sample_period_duration=minutes(
configuration.input_data.satellite.time_resolution_minutes
),
history_duration=minutes(configuration.input_data.satellite.history_minutes),
)

used_datapipes["sat"] = sat_datapipe

if use_hrv:
logger.debug("Opening HRV Satellite Data")
sat_hrv_datapipe = OpenSatellite(
configuration.input_data.hrvsatellite.hrvsatellite_zarr_path
).add_t0_idx_and_sample_period_duration(

sat_xr = open_sat_data(configuration.input_data.satellite.satellite_zarr_path)
sat_pipe = FakeIter(sat_xr)

sat_hrv_datapipe = sat_pipe.add_t0_idx_and_sample_period_duration(
sample_period_duration=minutes(
configuration.input_data.hrvsatellite.time_resolution_minutes
),
Expand Down
10 changes: 5 additions & 5 deletions ocf_datapipes/training/example/gsp_pv_nwp_satellite.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from ocf_datapipes.batch import MergeNumpyModalities, MergeNWPNumpyModalities
from ocf_datapipes.config.load import load_yaml_configuration
from ocf_datapipes.config.model import Configuration
from ocf_datapipes.load import OpenGSP, OpenNWP, OpenPVFromNetCDF, OpenSatellite
from ocf_datapipes.training.common import normalize_gsp, normalize_pv
from ocf_datapipes.load import OpenGSP, OpenNWP, OpenPVFromNetCDF, open_sat_data
from ocf_datapipes.training.common import FakeIter, normalize_gsp, normalize_pv
from ocf_datapipes.utils.consts import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -53,9 +53,9 @@ def gsp_pv_nwp_satellite_data_pipeline(configuration: Union[Path, str]) -> IterD
)

# Load and noralize satellite data
satellite_datapipe = OpenSatellite(
zarr_path=configuration.input_data.satellite.satellite_zarr_path
).normalize(mean=RSS_MEAN, std=RSS_STD)
sat_xr = open_sat_data(zarr_path=configuration.input_data.satellite.satellite_zarr_path)
satellite_datapipe = FakeIter(sat_xr)
satellite_datapipe = satellite_datapipe.normalize(mean=RSS_MEAN, std=RSS_STD)

# Load and normalize NWP data - There may be multiple NWP sources
nwp_datapipe_dict = {}
Expand Down
15 changes: 5 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,8 @@
from ocf_datapipes.config.load import load_yaml_configuration
from ocf_datapipes.config.model import PV, PVFiles
from ocf_datapipes.config.save import save_yaml_configuration
from ocf_datapipes.load import (
OpenGSP,
OpenNWP,
OpenPVFromNetCDF,
OpenSatellite,
OpenTopography,
)
from ocf_datapipes.load import OpenGSP, OpenNWP, OpenPVFromNetCDF, OpenTopography, open_sat_data
from ocf_datapipes.training.common import FakeIter

xr.set_options(keep_attrs=True)

Expand All @@ -49,23 +44,23 @@ def top_test_directory():
@pytest.fixture()
def sat_hrv_datapipe():
filename = _top_test_directory + "/data/hrv_sat_data.zarr"
return OpenSatellite(zarr_path=filename)
return FakeIter(open_sat_data(zarr_path=filename))


@pytest.fixture()
def sat_datapipe():
filename = f"{_top_test_directory}/data/sat_data.zarr"
# The saved data is scaled from 0-1024. Now we use data scaled from 0-1
# Rescale here for subsequent tests
return OpenSatellite(zarr_path=filename).map(lambda da: da / 1024)
return FakeIter(open_sat_data(zarr_path=filename)).map(lambda da: da / 1024)


@pytest.fixture()
def sat_15_datapipe():
filename = f"{_top_test_directory}/data/sat_data_15.zarr"
# The saved data is scaled from 0-1024. Now we use data scaled from 0-1
# Rescale here for subsequent tests
return OpenSatellite(zarr_path=filename).map(_sat_rescale)
return FakeIter(open_sat_data(zarr_path=filename)).map(_sat_rescale)


@pytest.fixture()
Expand Down
17 changes: 7 additions & 10 deletions tests/load/test_load_satellite.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,20 @@

"""

from ocf_datapipes.load import OpenSatellite
from ocf_datapipes.load import open_sat_data
from freezegun import freeze_time


def test_open_satellite():
sat_datapipe = OpenSatellite(zarr_path="tests/data/hrv_sat_data.zarr")
metadata = next(iter(sat_datapipe))
assert metadata is not None
sat_xr = open_sat_data(zarr_path="tests/data/hrv_sat_data.zarr")
assert sat_xr is not None


def test_open_hrvsatellite():
sat_datapipe = OpenSatellite(zarr_path="tests/data/sat_data.zarr")
metadata = next(iter(sat_datapipe))
assert metadata is not None
sat_xr = open_sat_data(zarr_path="tests/data/sat_data.zarr")
assert sat_xr is not None


def test_open_satellite_15():
sat_datapipe = OpenSatellite(zarr_path="tests/data/sat_data_15.zarr")
metadata = next(iter(sat_datapipe))
assert metadata is not None
sat_xr = open_sat_data(zarr_path="tests/data/sat_data_15.zarr")
assert sat_xr is not None
7 changes: 4 additions & 3 deletions tests/transform/numpy_batch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
ConvertPVToNumpyBatch,
ConvertSatelliteToNumpyBatch,
)
from ocf_datapipes.load import OpenGSP, OpenNWP, OpenSatellite
from ocf_datapipes.load import OpenGSP, OpenNWP, open_sat_data
from ocf_datapipes.training.common import FakeIter
from ocf_datapipes.batch import MergeNumpyModalities

from ocf_datapipes.transform.xarray import (
Expand All @@ -21,7 +22,7 @@
@pytest.fixture()
def sat_hrv_np_datapipe():
filename = Path(ocf_datapipes.__file__).parent.parent / "tests" / "data" / "hrv_sat_data.zarr"
dp = OpenSatellite(zarr_path=filename)
dp = FakeIter(open_sat_data(zarr_path=filename))
dp = AddT0IdxAndSamplePeriodDuration(
dp,
sample_period_duration=timedelta(minutes=5),
Expand All @@ -34,7 +35,7 @@ def sat_hrv_np_datapipe():
@pytest.fixture()
def sat_np_datapipe():
filename = Path(ocf_datapipes.__file__).parent.parent / "tests" / "data" / "sat_data.zarr"
dp = OpenSatellite(zarr_path=filename)
dp = FakeIter(open_sat_data(zarr_path=filename))
dp = AddT0IdxAndSamplePeriodDuration(
dp,
sample_period_duration=timedelta(minutes=5),
Expand Down
Loading