From 8ea85d0c6b514082084d87d71487807116360e1e Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Thu, 17 Nov 2022 13:50:04 +0000 Subject: [PATCH 1/3] add option to pick closets pv systems, not random + tests --- .../pv/ensure_n_pv_systems_per_example.py | 46 +++++++++++++-- .../xarray/{ => pv}/test_create_pv_image.py | 0 .../test_ensure_n_pv_systems_per_example.py | 57 +++++++++++++++++++ .../test_fill_night_time_nans_with_zeros.py | 0 .../{ => pv}/test_pv_power_rolling_window.py | 0 .../{ => pv}/test_pv_remove_zero_data.py | 0 6 files changed, 97 insertions(+), 6 deletions(-) rename tests/transform/xarray/{ => pv}/test_create_pv_image.py (100%) create mode 100644 tests/transform/xarray/pv/test_ensure_n_pv_systems_per_example.py rename tests/transform/xarray/{ => pv}/test_fill_night_time_nans_with_zeros.py (100%) rename tests/transform/xarray/{ => pv}/test_pv_power_rolling_window.py (100%) rename tests/transform/xarray/{ => pv}/test_pv_remove_zero_data.py (100%) diff --git a/ocf_datapipes/transform/xarray/pv/ensure_n_pv_systems_per_example.py b/ocf_datapipes/transform/xarray/pv/ensure_n_pv_systems_per_example.py index d509fe4ec..f4a58d3e6 100644 --- a/ocf_datapipes/transform/xarray/pv/ensure_n_pv_systems_per_example.py +++ b/ocf_datapipes/transform/xarray/pv/ensure_n_pv_systems_per_example.py @@ -1,6 +1,8 @@ """Ensure there is N PV systems per example""" import logging +from typing import Optional + import numpy as np import xarray as xr from torchdata.datapipes import functional_datapipe @@ -13,7 +15,14 @@ class EnsureNPVSystemsPerExampleIterDataPipe(IterDataPipe): """Ensure there is N PV systems per example""" - def __init__(self, source_datapipe: IterDataPipe, n_pv_systems_per_example: int, seed=None): + def __init__( + self, + source_datapipe: IterDataPipe, + n_pv_systems_per_example: int, + seed=None, + method: str = "random", + locations_datapipe: Optional[IterDataPipe] = None, + ): """ Ensure there is N PV systems per example @@ -21,21 +30,46 @@ def __init__(self, source_datapipe: IterDataPipe, n_pv_systems_per_example: int, source_datapipe: Datapipe of PV data n_pv_systems_per_example: Number of PV systems to have in example seed: Random seed for choosing + method: method for picking PV systems. Can be 'random' or 'closest' + locations_datapipe: location of this example. Can be None as its only needed for 'closest' """ self.source_datapipe = source_datapipe self.n_pv_systems_per_example = n_pv_systems_per_example self.rng = np.random.default_rng(seed=seed) + self.method = method + self.locations_datapipe = locations_datapipe + + assert method in ["random", "closest"] + + if method == "closest": + assert ( + locations_datapipe is not None + ), f"If you are slect closest PV systems, then a location data pipe is needed" def __iter__(self): for xr_data in self.source_datapipe: if len(xr_data.pv_system_id) > self.n_pv_systems_per_example: logger.debug(f"Reducing PV systems to {self.n_pv_systems_per_example}") # More PV systems are available than we need. Reduce by randomly sampling: - subset_of_pv_system_ids = self.rng.choice( - xr_data.pv_system_id, - size=self.n_pv_systems_per_example, - replace=False, - ) + if self.method == "random": + subset_of_pv_system_ids = self.rng.choice( + xr_data.pv_system_id, + size=self.n_pv_systems_per_example, + replace=False, + ) + elif self.method == "closest": + + location = next(self.locations_datapipe) + + # get distance + delta_x = xr_data.x_osgb - location.x + delta_y = xr_data.y_osgb - location.y + r2 = delta_x ** 2 + delta_y ** 2 + + # order and select closest + r2 = r2.sortby(r2) + subset_of_pv_system_ids = r2.pv_system_id[: self.n_pv_systems_per_example] + xr_data = xr_data.sel(pv_system_id=subset_of_pv_system_ids) elif len(xr_data.pv_system_id) < self.n_pv_systems_per_example: logger.debug("Padding out PV systems") diff --git a/tests/transform/xarray/test_create_pv_image.py b/tests/transform/xarray/pv/test_create_pv_image.py similarity index 100% rename from tests/transform/xarray/test_create_pv_image.py rename to tests/transform/xarray/pv/test_create_pv_image.py diff --git a/tests/transform/xarray/pv/test_ensure_n_pv_systems_per_example.py b/tests/transform/xarray/pv/test_ensure_n_pv_systems_per_example.py new file mode 100644 index 000000000..8c764585a --- /dev/null +++ b/tests/transform/xarray/pv/test_ensure_n_pv_systems_per_example.py @@ -0,0 +1,57 @@ +import numpy as np +import pytest + +from ocf_datapipes.transform.xarray import EnsureNPVSystemsPerExample +from ocf_datapipes.utils.consts import Location + + +def test_ensure_n_pv_systems_per_example_expand(passiv_datapipe): + + data_before = next(iter(passiv_datapipe)) + + passiv_datapipe = EnsureNPVSystemsPerExample(passiv_datapipe, n_pv_systems_per_example=12) + data_after = next(iter(passiv_datapipe)) + + assert len(data_before[0, :]) == 2 + assert len(data_after[0, :]) == 12 + + +def test_ensure_n_pv_systems_per_example_random(passiv_datapipe): + + data_before = next(iter(passiv_datapipe)) + + passiv_datapipe = EnsureNPVSystemsPerExample(passiv_datapipe, n_pv_systems_per_example=1) + data_after = next(iter(passiv_datapipe)) + + assert len(data_before[0, :]) == 2 + assert len(data_after[0, :]) == 1 + + +def test_ensure_n_pv_systems_per_example_closest_error(passiv_datapipe): + + with pytest.raises(Exception): + _ = EnsureNPVSystemsPerExample( + passiv_datapipe, n_pv_systems_per_example=1, method="closest" + ) + + +def test_ensure_n_pv_systems_per_example_closest(passiv_datapipe): + + # make fake location datapipe + location = Location(x=2.687e05, y=6.267e05) + location_datapipe = iter([location]) + + data_before = next(iter(passiv_datapipe)) + + passiv_datapipe = EnsureNPVSystemsPerExample( + passiv_datapipe, + n_pv_systems_per_example=1, + method="closest", + locations_datapipe=location_datapipe, + ) + + data_after = next(iter(passiv_datapipe)) + + assert len(data_before[0, :]) == 2 + assert len(data_after[0, :]) == 1 + assert data_after.pv_system_id[0] == 9960 diff --git a/tests/transform/xarray/test_fill_night_time_nans_with_zeros.py b/tests/transform/xarray/pv/test_fill_night_time_nans_with_zeros.py similarity index 100% rename from tests/transform/xarray/test_fill_night_time_nans_with_zeros.py rename to tests/transform/xarray/pv/test_fill_night_time_nans_with_zeros.py diff --git a/tests/transform/xarray/test_pv_power_rolling_window.py b/tests/transform/xarray/pv/test_pv_power_rolling_window.py similarity index 100% rename from tests/transform/xarray/test_pv_power_rolling_window.py rename to tests/transform/xarray/pv/test_pv_power_rolling_window.py diff --git a/tests/transform/xarray/test_pv_remove_zero_data.py b/tests/transform/xarray/pv/test_pv_remove_zero_data.py similarity index 100% rename from tests/transform/xarray/test_pv_remove_zero_data.py rename to tests/transform/xarray/pv/test_pv_remove_zero_data.py From d718cc93e73fb232605601dfbd9c022808f73cff Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Thu, 17 Nov 2022 13:54:02 +0000 Subject: [PATCH 2/3] lint --- .../transform/xarray/pv/ensure_n_pv_systems_per_example.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ocf_datapipes/transform/xarray/pv/ensure_n_pv_systems_per_example.py b/ocf_datapipes/transform/xarray/pv/ensure_n_pv_systems_per_example.py index f4a58d3e6..777245c27 100644 --- a/ocf_datapipes/transform/xarray/pv/ensure_n_pv_systems_per_example.py +++ b/ocf_datapipes/transform/xarray/pv/ensure_n_pv_systems_per_example.py @@ -31,7 +31,8 @@ def __init__( n_pv_systems_per_example: Number of PV systems to have in example seed: Random seed for choosing method: method for picking PV systems. Can be 'random' or 'closest' - locations_datapipe: location of this example. Can be None as its only needed for 'closest' + locations_datapipe: location of this example. + Can be None as its only needed for 'closest' """ self.source_datapipe = source_datapipe self.n_pv_systems_per_example = n_pv_systems_per_example @@ -44,7 +45,7 @@ def __init__( if method == "closest": assert ( locations_datapipe is not None - ), f"If you are slect closest PV systems, then a location data pipe is needed" + ), "If you are slect closest PV systems, then a location data pipe is needed" def __iter__(self): for xr_data in self.source_datapipe: From 0aa505b8fc6d2509029376b6926d00d0e0b1afda Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 17 Nov 2022 14:35:01 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../transform/xarray/pv/ensure_n_pv_systems_per_example.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ocf_datapipes/transform/xarray/pv/ensure_n_pv_systems_per_example.py b/ocf_datapipes/transform/xarray/pv/ensure_n_pv_systems_per_example.py index 777245c27..0d9a4d671 100644 --- a/ocf_datapipes/transform/xarray/pv/ensure_n_pv_systems_per_example.py +++ b/ocf_datapipes/transform/xarray/pv/ensure_n_pv_systems_per_example.py @@ -1,6 +1,5 @@ """Ensure there is N PV systems per example""" import logging - from typing import Optional import numpy as np @@ -65,7 +64,7 @@ def __iter__(self): # get distance delta_x = xr_data.x_osgb - location.x delta_y = xr_data.y_osgb - location.y - r2 = delta_x ** 2 + delta_y ** 2 + r2 = delta_x**2 + delta_y**2 # order and select closest r2 = r2.sortby(r2)