diff --git a/.gitignore b/.gitignore index fa4c1b7..0c023ff 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,6 @@ dmypy.json # Pyre type checker .pyre/ + +# OS +.DS_Store \ No newline at end of file diff --git a/ocf_data_sampler/constants.py b/ocf_data_sampler/constants.py index d0c9a18..7616d96 100644 --- a/ocf_data_sampler/constants.py +++ b/ocf_data_sampler/constants.py @@ -28,6 +28,7 @@ def __getitem__(self, key): f"Values for {key} not yet available in ocf-data-sampler {list(self.keys())}" ) + # ------ UKV # Means and std computed WITH version_7 and higher, MetOffice values UKV_STD = { @@ -49,6 +50,7 @@ def __getitem__(self, key): "prmsl": 1252.71790539, "prate": 0.00021497, } + UKV_MEAN = { "cdcb": 1412.26599062, "lcc": 50.08362643, @@ -97,6 +99,7 @@ def __getitem__(self, key): "diff_duvrs": 81605.25, "diff_sr": 818950.6875, } + ECMWF_MEAN = { "dlwrf": 27187026.0, "dswrf": 11458988.0, @@ -133,3 +136,38 @@ def __getitem__(self, key): ecmwf=ECMWF_MEAN, ) +# ------ Satellite +# RSS Mean and std values from randomised 20% of 2020 imagery + +RSS_STD = { + "HRV": 0.11405209, + "IR_016": 0.21462157, + "IR_039": 0.04618041, + "IR_087": 0.06687243, + "IR_097": 0.0468558, + "IR_108": 0.17482725, + "IR_120": 0.06115861, + "IR_134": 0.04492306, + "VIS006": 0.12184761, + "VIS008": 0.13090034, + "WV_062": 0.16111417, + "WV_073": 0.12924142, +} + +RSS_MEAN = { + "HRV": 0.09298719, + "IR_016": 0.17594202, + "IR_039": 0.86167645, + "IR_087": 0.7719318, + "IR_097": 0.8014212, + "IR_108": 0.71254843, + "IR_120": 0.89058584, + "IR_134": 0.944365, + "VIS006": 0.09633306, + "VIS008": 0.11426069, + "WV_062": 0.7359355, + "WV_073": 0.62479186, +} + +RSS_STD = _to_data_array(RSS_STD) +RSS_MEAN = _to_data_array(RSS_MEAN) diff --git a/ocf_data_sampler/numpy_batch/nwp.py b/ocf_data_sampler/numpy_batch/nwp.py index 8eae117..3206e4d 100644 --- a/ocf_data_sampler/numpy_batch/nwp.py +++ b/ocf_data_sampler/numpy_batch/nwp.py @@ -1,5 +1,4 @@ """Convert NWP to NumpyBatch""" - import pandas as pd import xarray as xr diff --git a/ocf_data_sampler/numpy_batch/satellite.py b/ocf_data_sampler/numpy_batch/satellite.py index 0a0b7bb..d55ce4f 100644 --- a/ocf_data_sampler/numpy_batch/satellite.py +++ b/ocf_data_sampler/numpy_batch/satellite.py @@ -13,6 +13,7 @@ class SatelliteBatchKey: def convert_satellite_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> dict: """Convert from Xarray to NumpyBatch""" + example = { SatelliteBatchKey.satellite_actual: da.values, SatelliteBatchKey.time_utc: da.time_utc.values.astype(float), @@ -27,4 +28,4 @@ def convert_satellite_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None if t0_idx is not None: example[SatelliteBatchKey.t0_idx] = t0_idx - return example \ No newline at end of file + return example diff --git a/ocf_data_sampler/torch_datasets/process_and_combine.py b/ocf_data_sampler/torch_datasets/process_and_combine.py index cae1e5f..e509ca5 100644 --- a/ocf_data_sampler/torch_datasets/process_and_combine.py +++ b/ocf_data_sampler/torch_datasets/process_and_combine.py @@ -4,7 +4,7 @@ from typing import Tuple from ocf_data_sampler.config import Configuration -from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS +from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD from ocf_data_sampler.numpy_batch import ( convert_nwp_to_numpy_batch, convert_satellite_to_numpy_batch, @@ -25,8 +25,8 @@ def process_and_combine_datasets( location: Location, target_key: str = 'gsp' ) -> dict: - """Normalize and convert data to numpy arrays""" + """Normalise and convert data to numpy arrays""" numpy_modalities = [] if "nwp" in dataset_dict: @@ -37,19 +37,23 @@ def process_and_combine_datasets( # Standardise provider = config.input_data.nwp[nwp_key].provider da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider] + # Convert to NumpyBatch nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(da_nwp) # Combine the NWPs into NumpyBatch numpy_modalities.append({NWPBatchKey.nwp: nwp_numpy_modalities}) + if "sat" in dataset_dict: - # Satellite is already in the range [0-1] so no need to standardise + # Standardise da_sat = dataset_dict["sat"] + da_sat = (da_sat - RSS_MEAN) / RSS_STD # Convert to NumpyBatch numpy_modalities.append(convert_satellite_to_numpy_batch(da_sat)) + gsp_config = config.input_data.gsp if "gsp" in dataset_dict: @@ -93,6 +97,7 @@ def process_and_combine_datasets( return combined_sample + def process_and_combine_site_sample_dict( dataset_dict: dict, config: Configuration, @@ -119,8 +124,9 @@ def process_and_combine_site_sample_dict( data_arrays.append((f"nwp-{provider}", da_nwp)) if "sat" in dataset_dict: - # TODO add some satellite normalisation + # Standardise da_sat = dataset_dict["sat"] + da_sat = (da_sat - RSS_MEAN) / RSS_STD data_arrays.append(("satellite", da_sat)) if "site" in dataset_dict: @@ -143,6 +149,7 @@ def merge_dicts(list_of_dicts: list[dict]) -> dict: combined_dict.update(d) return combined_dict + def merge_arrays(normalised_data_arrays: list[Tuple[str, xr.DataArray]]) -> xr.Dataset: """ Combine a list of DataArrays into a single Dataset with unique naming conventions. diff --git a/tests/torch_datasets/test_process_and_combine.py b/tests/torch_datasets/test_process_and_combine.py new file mode 100644 index 0000000..1d01449 --- /dev/null +++ b/tests/torch_datasets/test_process_and_combine.py @@ -0,0 +1,165 @@ +import pytest +import tempfile + +import numpy as np +import pandas as pd +import xarray as xr +import dask.array as da + +from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration +from ocf_data_sampler.config import Configuration +from ocf_data_sampler.select.location import Location +from ocf_data_sampler.numpy_batch import NWPBatchKey, GSPBatchKey, SatelliteBatchKey +from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset + +from ocf_data_sampler.torch_datasets.process_and_combine import ( + process_and_combine_datasets, + process_and_combine_site_sample_dict, + merge_dicts, + fill_nans_in_arrays, + compute, +) + + +def test_process_and_combine_datasets(pvnet_config_filename): + + # Load in config for function and define location + config = load_yaml_configuration(pvnet_config_filename) + t0 = pd.Timestamp("2024-01-01 00:00") + location = Location(coordinate_system="osgb", x=1234, y=5678, id=1) + + nwp_data = xr.DataArray( + np.random.rand(4, 2, 2, 2), + dims=["time_utc", "channel", "y", "x"], + coords={ + "time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"), + "channel": ["t2m", "dswrf"], + "step": ("time_utc", pd.timedelta_range(start='0h', periods=4, freq='h')), + "init_time_utc": pd.Timestamp("2024-01-01 00:00") + } + ) + + sat_data = xr.DataArray( + np.random.rand(7, 1, 2, 2), + dims=["time_utc", "channel", "y", "x"], + coords={ + "time_utc": pd.date_range("2024-01-01 00:00", periods=7, freq="5min"), + "channel": ["HRV"], + "x_geostationary": (["y", "x"], np.array([[1, 2], [1, 2]])), + "y_geostationary": (["y", "x"], np.array([[1, 1], [2, 2]])) + } + ) + + # Combine as dict + dataset_dict = { + "nwp": {"ukv": nwp_data}, + "sat": sat_data + } + + # Call relevant function + result = process_and_combine_datasets(dataset_dict, config, t0, location) + + # Assert result is dict - check and validate + assert isinstance(result, dict) + assert NWPBatchKey.nwp in result + assert result[SatelliteBatchKey.satellite_actual].shape == (7, 1, 2, 2) + assert result[NWPBatchKey.nwp]["ukv"][NWPBatchKey.nwp].shape == (4, 1, 2, 2) + + +def test_merge_dicts(): + """Test merge_dicts function""" + dict1 = {"a": 1, "b": 2} + dict2 = {"c": 3, "d": 4} + dict3 = {"e": 5} + + result = merge_dicts([dict1, dict2, dict3]) + assert result == {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5} + + # Test key overwriting + dict4 = {"a": 10, "f": 6} + result = merge_dicts([dict1, dict4]) + assert result["a"] == 10 + + +def test_fill_nans_in_arrays(): + """Test the fill_nans_in_arrays function""" + array_with_nans = np.array([1.0, np.nan, 3.0, np.nan]) + nested_dict = { + "array1": array_with_nans, + "nested": { + "array2": np.array([np.nan, 2.0, np.nan, 4.0]) + }, + "string_key": "not_an_array" + } + + result = fill_nans_in_arrays(nested_dict) + + assert not np.isnan(result["array1"]).any() + assert np.array_equal(result["array1"], np.array([1.0, 0.0, 3.0, 0.0])) + assert not np.isnan(result["nested"]["array2"]).any() + assert np.array_equal(result["nested"]["array2"], np.array([0.0, 2.0, 0.0, 4.0])) + assert result["string_key"] == "not_an_array" + + +def test_compute(): + """Test compute function with dask array""" + da_dask = xr.DataArray(da.random.random((5, 5))) + + # Create a nested dictionary with dask array + nested_dict = { + "array1": da_dask, + "nested": { + "array2": da_dask + } + } + + # Ensure initial data is lazy - i.e. not yet computed + assert not isinstance(nested_dict["array1"].data, np.ndarray) + assert not isinstance(nested_dict["nested"]["array2"].data, np.ndarray) + + # Call the compute function + result = compute(nested_dict) + + # Assert that the result is an xarray DataArray and no longer lazy + assert isinstance(result["array1"], xr.DataArray) + assert isinstance(result["nested"]["array2"], xr.DataArray) + assert isinstance(result["array1"].data, np.ndarray) + assert isinstance(result["nested"]["array2"].data, np.ndarray) + + # Ensure there no NaN values in computed data + assert not np.isnan(result["array1"].data).any() + assert not np.isnan(result["nested"]["array2"].data).any() + + +def test_process_and_combine_site_sample_dict(pvnet_config_filename): + # Load config + config = load_yaml_configuration(pvnet_config_filename) + + # Specify minimal structure for testing + raw_nwp_values = np.random.rand(4, 1, 2, 2) # Single channel + site_dict = { + "nwp": { + "ukv": xr.DataArray( + raw_nwp_values, + dims=["time_utc", "channel", "y", "x"], + coords={ + "time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"), + "channel": ["dswrf"], # Single channel + }, + ) + } + } + print(f"Input site_dict: {site_dict}") + + # Call function + result = process_and_combine_site_sample_dict(site_dict, config) + + # Assert to validate output structure + assert isinstance(result, xr.Dataset), "Result should be an xarray.Dataset" + assert len(result.data_vars) > 0, "Dataset should contain data variables" + + # Validate variable via assertion and shape of such + expected_variable = "nwp-ukv" + assert expected_variable in result.data_vars, f"Expected variable '{expected_variable}' not found" + nwp_result = result[expected_variable] + assert nwp_result.shape == (4, 1, 2, 2), f"Unexpected shape for '{expected_variable}': {nwp_result.shape}"