Skip to content

Commit

Permalink
Fix dual import
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Nov 15, 2023
1 parent 566681c commit b84fdec
Showing 1 changed file with 91 additions and 51 deletions.
142 changes: 91 additions & 51 deletions ocf_datapipes/training/windnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import xarray as xr
from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
from torchdata.datapipes.iter import IterDataPipe, IterableWrapper

from ocf_datapipes.batch import MergeNumpyModalities
from ocf_datapipes.config.model import Configuration
Expand All @@ -23,6 +23,9 @@
BatchKey,
NumpyBatch,
)
from ocf_datapipes.load import (
OpenConfiguration,
)
from ocf_datapipes.utils.utils import combine_to_single_dataset, uncombine_from_single_dataset

xr.set_options(keep_attrs=True)
Expand Down Expand Up @@ -94,47 +97,6 @@ def gsp_drop_national(x: Union[xr.DataArray, xr.Dataset]):
return x.where(x.gsp_id != 0, drop=True)


@functional_datapipe("pvnet_select_pv_by_ml_id")
class PVNetSelectPVbyMLIDIterDataPipe(IterDataPipe):
"""Select specific set of PV systems by ML ID."""

def __init__(self, source_datapipe: IterDataPipe, ml_ids: np.array):
"""Select specific set of PV systems by ML ID.
Args:
source_datapipe: Datapipe emitting PV xarray data
ml_ids: List-like of ML IDs to select
Returns:
Filtered data source
"""
self.source_datapipe = source_datapipe
self.ml_ids = ml_ids

def __iter__(self):
for x in self.source_datapipe:
# Check for missing IDs
ml_ids_not_in_data = ~np.isin(self.ml_ids, x.ml_id)
if ml_ids_not_in_data.any():
missing_ml_ids = np.array(self.ml_ids)[ml_ids_not_in_data]
logger.warning(
f"The following ML IDs were mising in the PV site-level input data: "
f"{missing_ml_ids}. The values for these IDs will be set to NaN."
)

x_filtered = (
# Many ML-IDs are null, so filter first
x.where(~x.ml_id.isnull(), drop=True)
# Swap dimensions so we can select by ml_id coordinate
.swap_dims({"pv_system_id": "ml_id"})
# Select IDs - missing IDs are given NaN values
.reindex(ml_id=self.ml_ids)
# Swap back dimensions
.swap_dims({"ml_id": "pv_system_id"})
)
yield x_filtered


def fill_nans_in_pv(x: Union[xr.DataArray, xr.Dataset]):
"""Fill NaNs in PV data with the value -1
Expand Down Expand Up @@ -326,6 +288,35 @@ def __iter__(self):
yield {k: v for k, v in zip(self.keys, data)}


@functional_datapipe("load_dict_datasets")
class LoadDictDatasetIterDataPipe(IterDataPipe):
""" """

filenames: List[str]
keys: List[str]
configuration: Configuration

def __init__(self, filenames: List[str], keys: List[str], configuration: Configuration):
"""Init"""
super().__init__()
self.keys = keys
self.filenames = filenames
self.configuration

def __iter__(self):
"""Iter"""
# Iterate through each filename, loading it, uncombining it, and then yielding it
while True:
for filename in self.filenames:
dataset = xr.open_dataset(filename)
datasets = uncombine_from_single_dataset(dataset)
# Yield a dictionary of the data, using the keys in self.keys
dataset_dict = {}
for k in self.keys:
dataset_dict[k] = datasets[k]
yield dataset_dict


def _get_datapipes_dict(
config_filename: str,
block_sat: bool,
Expand Down Expand Up @@ -356,11 +347,6 @@ def _get_datapipes_dict(
if "pv" in datapipes_dict:
datapipes_dict["pv"] = OpenPVFromPVSitesDB(config.input_data.pv.history_minutes)

if "pv" in datapipes_dict and config.input_data.pv.pv_ml_ids != []:
datapipes_dict["pv"] = datapipes_dict["pv"].pvnet_select_pv_by_ml_id(
config.input_data.pv.pv_ml_ids
)

return datapipes_dict


Expand Down Expand Up @@ -709,12 +695,12 @@ def construct_sliced_data_pipeline(


def convert_to_numpy_batch(
datapipes_dict: dict,
datapipes_dict: dict[str, Union[IterDataPipe, Configuration]],
block_sat: bool = False,
block_nwp: bool = False,
check_satellite_no_zeros: bool = False,
):
configuration = datapipes_dict["config"]
configuration: Configuration = datapipes_dict["config"]
# Spatially slice, normalize, and convert data to numpy arrays
numpy_modalities = []
# Unpack for convenience
Expand Down Expand Up @@ -817,6 +803,54 @@ def windnet_datapipe(
)


def split_dataset_dict_dp(element):
"""
Split the dictionary of datapipes into individual datapipes
"""
return {k: IterableWrapper([v]) for k, v in element.items() if k != "config"}


def windnet_netcdf_datapipe(
config_filename: str,
keys: List[str],
filenames: List[str],
block_sat: bool = False,
block_nwp: bool = False,
) -> IterDataPipe:
"""
Load the saved Datapipes from windnet, and transform to numpy batch
Args:
config_filename: Path to config file.
keys: List of keys to extract from the single NetCDF files
block_sat: Whether to load zeroes for satellite data.
block_nwp: Whether to load zeroes for NWP data.
Returns:
Datapipe that transforms the NetCDF files to numpy batch
"""
logger.info("Constructing windnet file pipeline")
config_datapipe = OpenConfiguration(config_filename)
configuration: Configuration = next(iter(config_datapipe))
# Load files
datapipe_dict_dp: IterDataPipe = LoadDictDatasetIterDataPipe(
filenames=filenames,
keys=keys,
configuration=configuration,
)
# Split the dataset_dict_dp into dictionary of individual datapipes
datapipe_dict: dict[str:IterDataPipe] = datapipe_dict_dp.map(split_dataset_dict_dp)

# Convert to numpy batch
datapipe = convert_to_numpy_batch(
datapipe_dict,
block_sat,
block_nwp,
)

return datapipe


def check_nans_in_satellite_data(batch: NumpyBatch) -> NumpyBatch:
"""
Check if there are any Nans values in the satellite data.
Expand Down Expand Up @@ -861,5 +895,11 @@ def check_nans_in_satellite_data(batch: NumpyBatch) -> NumpyBatch:
)
datasets = next(iter(dp))
dataset = combine_to_single_dataset(datasets)
multiple_datasets = uncombine_from_single_dataset(dataset)
print(multiple_datasets)
dataset.to_zarr("test.nc", mode="w", compute=True)
dp = windnet_netcdf_datapipe(
config_filename=configuration_filename,
filenames=["test.zarr"],
keys=["gsp", "nwp", "sat", "pv"],
)
datasets = next(iter(dp))
print(datasets)

0 comments on commit b84fdec

Please sign in to comment.