From 6b582d6e5938f5ae493c3340ec1d45804b4c572a Mon Sep 17 00:00:00 2001 From: Felix Peretz Date: Mon, 16 Dec 2024 18:32:14 +0000 Subject: [PATCH 01/20] Satellite normalisation and related testing --- .DS_Store | Bin 0 -> 10244 bytes ocf_data_sampler/.DS_Store | Bin 0 -> 6148 bytes ocf_data_sampler/constants.py | 67 +++++++++ ocf_data_sampler/numpy_batch/nwp.py | 7 +- ocf_data_sampler/numpy_batch/satellite.py | 9 +- .../torch_datasets/process_and_combine.py | 25 +++- requirements.txt | 136 ++++++++++++++++++ tests/.DS_Store | Bin 0 -> 6148 bytes tests/torch_datasets/.DS_Store | Bin 0 -> 6148 bytes .../test_process_and_combine.py | 136 ++++++++++++++++++ 10 files changed, 372 insertions(+), 8 deletions(-) create mode 100644 .DS_Store create mode 100644 ocf_data_sampler/.DS_Store create mode 100644 requirements.txt create mode 100644 tests/.DS_Store create mode 100644 tests/torch_datasets/.DS_Store create mode 100644 tests/torch_datasets/test_process_and_combine.py diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..47564d99da74d156096f22186be649ed86012e73 GIT binary patch literal 10244 zcmeHMUvCmY5TB(~K&+CeF-;Gm~)hul$Yni}5q zGx!yJ_Jinm@kxKPcO`dA|HK#5w7X<)?w0xO-u`BG$P7dzD#KQh$RMILI?LP@T#jh` zJ@=`W2uE&13gC&Bs7Wq4Xa*(> zaC|V(S(ZaN4x|hnxX2U$GKpc?aF0AdVvJChLpctlgksJhdk`U1geisy;P~EPb;xok z$AJvsBmy{zaAgr@C_=6doFUaolmn@+ngPwgGy|NwU#C2o)Ws*7zuVNHAu@Rvxx9zo z+~~T^M%Qfs!G~-WJ^x2G-0@gmB6`%PU9u#~meqhfsse55DERLRJ_|fo4+5-79XW@x zs<@+|l0bFHCOxaJaqtWl$o?j^Z5{a4Q?4?vXmCmN$ z)^^-c%{Tp_S1|{#`DxpAyN$zpv+=^VM)ll{HP`V?+pz``V7Dx&ym@ImEq7FLhfd3r z*jn&KI+M=SbMyQA%cYe?qf}fuSTy#Fh2;l}#>%7O!9gZ{bK&0Ot@?*<&vDr$&_wtP zLUY0Mc9x#OAvS?{0z{3rv5HNOf?}tQqH@yAZMaGc#`n?dFgiP1AI8LijN_#P{}cYa z82MD96w}Y2Ml8%@;cRzs)?p^ZIyX)=pKkFva&oT1BzrPixbz;_Cd`#bcaM$?u~V0! zEjYz0t;0#GK=ssfN;uImL{7$okBEEhyw=2ZIU@MGy@lhftg^d2G{?qNZ^n8|Nr6;qIaqp zxPT0ZC)^4D`L+8eN<3LIXF0vdCDa-MA;Bx$3bk;4fgzM6<{EIs1+!zUK&NWQ;jgG{mn*i=Alg z_>BtiyX(k1o6{qO?{9rwl#_8$juFAj;}YKc6vNeqwV*XkXav8fDUTDglJ^Fyg~$?_ z?vQI1#~O>+^B}{qk6mI_jlijBDJ!cvvp4Ww0$xfv;=yT*|BSe>f|Lr(Wkh)+?xPXc zLGBejU`?-4;aJ8-y)i49HCQdoc8UyBR8hsTqUCEa;}gubq$g2vxS751QFF{SXUG$B z+n8GgHH-b~7_K^u$E}(wpbGqZ1+ZtcwL1>ARRvT5RbZB|CpqAs(>o+rxY;h;AYTANq%qLDUSDA3%`Z4aa`?iOTmR~ g#fast_!KUNe#;lY%wz2kJuv+vATns93jC-7U)XV!q5uE@ literal 0 HcmV?d00001 diff --git a/ocf_data_sampler/constants.py b/ocf_data_sampler/constants.py index d0c9a18..d4bc4e7 100644 --- a/ocf_data_sampler/constants.py +++ b/ocf_data_sampler/constants.py @@ -7,6 +7,10 @@ "ecmwf", ] +SAT_PROVIDERS = [ + "rss", +] + def _to_data_array(d): return xr.DataArray( @@ -28,6 +32,21 @@ def __getitem__(self, key): f"Values for {key} not yet available in ocf-data-sampler {list(self.keys())}" ) + +class SatStatDict(dict): + """Custom dictionary class to hold Satellite normalization stats""" + + def __getitem__(self, key): + if key not in SAT_PROVIDERS: + raise KeyError(f"{key} is not a supported Satellite provider - {SAT_PROVIDERS}") + elif key in self.keys(): + return super().__getitem__(key) + else: + raise KeyError( + f"Values for {key} not yet available in ocf-data-sampler {list(self.keys())}" + ) + + # ------ UKV # Means and std computed WITH version_7 and higher, MetOffice values UKV_STD = { @@ -49,6 +68,7 @@ def __getitem__(self, key): "prmsl": 1252.71790539, "prate": 0.00021497, } + UKV_MEAN = { "cdcb": 1412.26599062, "lcc": 50.08362643, @@ -97,6 +117,7 @@ def __getitem__(self, key): "diff_duvrs": 81605.25, "diff_sr": 818950.6875, } + ECMWF_MEAN = { "dlwrf": 27187026.0, "dswrf": 11458988.0, @@ -133,3 +154,49 @@ def __getitem__(self, key): ecmwf=ECMWF_MEAN, ) +# ------ Satellite +# RSS Mean and std values from randomised 20% of 2020 imagery + +RSS_STD = { + "HRV": 0.11405209, + "IR_016": 0.21462157, + "IR_039": 0.04618041, + "IR_087": 0.06687243, + "IR_097": 0.0468558, + "IR_108": 0.17482725, + "IR_120": 0.06115861, + "IR_134": 0.04492306, + "VIS006": 0.12184761, + "VIS008": 0.13090034, + "WV_062": 0.16111417, + "WV_073": 0.12924142, +} + +RSS_MEAN = { + "HRV": 0.09298719, + "IR_016": 0.17594202, + "IR_039": 0.86167645, + "IR_087": 0.7719318, + "IR_097": 0.8014212, + "IR_108": 0.71254843, + "IR_120": 0.89058584, + "IR_134": 0.944365, + "VIS006": 0.09633306, + "VIS008": 0.11426069, + "WV_062": 0.7359355, + "WV_073": 0.62479186, +} + +# Specified to ensure calculation stability +EPSILON = 1e-8 + +RSS_STD = _to_data_array(RSS_STD) +RSS_MEAN = _to_data_array(RSS_MEAN) + +SAT_STDS = SatStatDict( + rss=RSS_STD, +) + +SAT_MEANS = SatStatDict( + rss=RSS_MEAN, +) diff --git a/ocf_data_sampler/numpy_batch/nwp.py b/ocf_data_sampler/numpy_batch/nwp.py index 8eae117..ef4e25d 100644 --- a/ocf_data_sampler/numpy_batch/nwp.py +++ b/ocf_data_sampler/numpy_batch/nwp.py @@ -1,5 +1,4 @@ """Convert NWP to NumpyBatch""" - import pandas as pd import xarray as xr @@ -19,6 +18,12 @@ class NWPBatchKey: def convert_nwp_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> dict: """Convert from Xarray to NWP NumpyBatch""" + # Missing coordinate checking stage + required_coords = ["y_osgb", "x_osgb"] + for coord in required_coords: + if coord not in da.coords: + raise ValueError(f"Input DataArray missing '{coord}'") + example = { NWPBatchKey.nwp: da.values, NWPBatchKey.channel_names: da.channel.values, diff --git a/ocf_data_sampler/numpy_batch/satellite.py b/ocf_data_sampler/numpy_batch/satellite.py index 0a0b7bb..6696ef0 100644 --- a/ocf_data_sampler/numpy_batch/satellite.py +++ b/ocf_data_sampler/numpy_batch/satellite.py @@ -13,6 +13,13 @@ class SatelliteBatchKey: def convert_satellite_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> dict: """Convert from Xarray to NumpyBatch""" + + # Missing coordinate checking stage + required_coords = ["x_geostationary", "y_geostationary"] + for coord in required_coords: + if coord not in da.coords: + raise ValueError(f"Input DataArray missing '{coord}'") + example = { SatelliteBatchKey.satellite_actual: da.values, SatelliteBatchKey.time_utc: da.time_utc.values.astype(float), @@ -27,4 +34,4 @@ def convert_satellite_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None if t0_idx is not None: example[SatelliteBatchKey.t0_idx] = t0_idx - return example \ No newline at end of file + return example diff --git a/ocf_data_sampler/torch_datasets/process_and_combine.py b/ocf_data_sampler/torch_datasets/process_and_combine.py index a732ef0..bd405ed 100644 --- a/ocf_data_sampler/torch_datasets/process_and_combine.py +++ b/ocf_data_sampler/torch_datasets/process_and_combine.py @@ -3,7 +3,7 @@ import xarray as xr from ocf_data_sampler.config import Configuration -from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS +from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, SAT_MEANS, SAT_STDS, EPSILON from ocf_data_sampler.numpy_batch import ( convert_nwp_to_numpy_batch, convert_satellite_to_numpy_batch, @@ -13,6 +13,8 @@ ) from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey +from ocf_data_sampler.numpy_batch.satellite import SatelliteBatchKey + from ocf_data_sampler.select.geospatial import osgb_to_lon_lat from ocf_data_sampler.select.location import Location from ocf_data_sampler.utils import minutes @@ -25,8 +27,8 @@ def process_and_combine_datasets( location: Location, target_key: str = 'gsp' ) -> dict: - """Normalize and convert data to numpy arrays""" + """Normalise and convert data to numpy arrays""" numpy_modalities = [] if "nwp" in dataset_dict: @@ -37,18 +39,29 @@ def process_and_combine_datasets( # Standardise provider = config.input_data.nwp[nwp_key].provider da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider] + # Convert to NumpyBatch nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(da_nwp) # Combine the NWPs into NumpyBatch numpy_modalities.append({NWPBatchKey.nwp: nwp_numpy_modalities}) + if "sat" in dataset_dict: - # Satellite is already in the range [0-1] so no need to standardise - da_sat = dataset_dict["sat"] - # Convert to NumpyBatch - numpy_modalities.append(convert_satellite_to_numpy_batch(da_sat)) + sat_numpy_modalities = dict() + + for sat_key, da_sat in dataset_dict["sat"].items(): + # Standardise + provider = config.input_data.satellite[sat_key].provider + da_sat = (da_sat - SAT_MEANS[provider]) / (SAT_STDS[provider] + EPSILON) + + # Convert to NumpyBatch + sat_numpy_modalities[sat_key] = convert_satellite_to_numpy_batch(da_sat) + + # Combine the Sattelites into NumpyBatch + numpy_modalities.append({SatelliteBatchKey.satellite_actual: sat_numpy_modalities}) + gsp_config = config.input_data.gsp diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..a8f49ce --- /dev/null +++ b/requirements.txt @@ -0,0 +1,136 @@ +aiobotocore +aiohttp +aioitertools +aiosignal +alabaster +altair +anaconda-client +anyio +appdirs==1.4.4 +argon2-cffi +arrow +astroid +astropy +attrs +autopep8 +Babel +bcrypt +beautifulsoup4 +black +blosc2>=2.7.1,<3.0.0 +bokeh +botocore +Bottleneck +cachetools +certifi +cffi +chardet +charset-normalizer +click +cloudpickle +colorama +cycler +cytoolz +dask +datashader +debugpy +decorator +defusedxml +dill +distributed +fsspec +gensim +greenlet +h5py +holoviews +hvplot +idna +imagecodecs +imageio +imbalanced-learn +importlib-metadata +ipykernel +ipython +ipywidgets +isort +jedi +joblib +jupyter +jupyter-client +jupyter-core +jupyterlab +jupyterlab-widgets +kiwisolver +lazy-object-proxy +llvmlite +lmdb +locket +lxml +lz4 +Markdown +matplotlib==3.9.2 +mccabe +mistune +more-itertools +mpmath +msgpack +multidict +networkx +nltk +notebook +numba +numcodecs==0.13.1 +numexpr +numpy +numpydoc +openpyxl +packaging +pandas +param +patsy +pexpect +pickleshare +pillow +pkginfo +platformdirs +pluggy +prompt-toolkit +protobuf==4.25.3 +psutil +ptyprocess +PyArrow +pydantic +PyYAML +pyzmq +pathy +qtconsole +queuelib +regex +requests +scikit-image +scikit-learn +scipy +seaborn +setuptools==75.1.0 +shapely +six +sqlalchemy +statsmodels +sympy +tables +tabulate +threadpoolctl +tifffile +toolz +torch==2.5.1 +tornado +tqdm +traitlets +typing-extensions +urllib3 +watchdog +wcwidth +xarray +zarr==2.18.3 +zict +zstandard diff --git a/tests/.DS_Store b/tests/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..918e6124fc19858560b92bb6f3b327ddc001014c GIT binary patch literal 6148 zcmeHKL2uJA6n^eXmlA<`0BIMbNL;JYZU}_9gl-(T5(Ec8Az8u}k;PS$QcYE*oZ&z4 z7r63E_%EE`d$x;~6cC3Ap~&Z%=93`sMd2j*lrDq}@ui_2>L zH8)u}OzROPw`2p*J37iSat{GedzrR1~^FEjrxh(iCtYfUX7B5Z#qpjm! zTu2+qDqt13Bnt5U;Gr=325XIK>p-Qh0KhJ~m7&dl7MSB1^bOV;F#;2s3e;3#t{6ho z(eD{Q-(am#(@B`ihcKCixuFP^9q~QoPQuq{YpZ}&psB!?ejM=m|M>U!|0c;kSp}>D zmr4QAJq!*fSduwgHx|cdt%vdkg^m4cjf#TG9LK7`NAV7dGK_iL0Qv@Njp%{dKLSby LTUiDEsRBO$&|$uR literal 0 HcmV?d00001 diff --git a/tests/torch_datasets/.DS_Store b/tests/torch_datasets/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..330862ab50d7f8d728b6eb1091d509d39c3300c5 GIT binary patch literal 6148 zcmeHKu};H441I);| z>+k4UV2MYph&6ZhfC_hfUSrCAzFswLT{UgR3Zzd?iRbSww2uc>Uo21Gu*E&~hE;A^ zt)+j7CRY&`tTIgnX#s~j6+7nYsxvFW+qXPR4iyXp1HnKr5DYj2yt7rt2MI$51HnKr z@Wp_h4+V!}>DVW_&b7 literal 0 HcmV?d00001 diff --git a/tests/torch_datasets/test_process_and_combine.py b/tests/torch_datasets/test_process_and_combine.py new file mode 100644 index 0000000..5908a5d --- /dev/null +++ b/tests/torch_datasets/test_process_and_combine.py @@ -0,0 +1,136 @@ +import pytest +import numpy as np +import pandas as pd +import xarray as xr + +from unittest.mock import MagicMock + +from ocf_data_sampler.constants import SAT_MEANS, SAT_STDS, NWP_MEANS, NWP_STDS, EPSILON +from ocf_data_sampler.numpy_batch.satellite import SatelliteBatchKey +from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey +from ocf_data_sampler.torch_datasets.process_and_combine import ( + process_and_combine_datasets, + merge_dicts, + fill_nans_in_arrays, + compute, +) + + +@pytest.fixture(scope="module") +def da_sat_like(): + """ Create dummy satellite DataArray """ + data = np.random.rand(3, 4, 4).astype(np.float32) + coords = { + "time_utc": pd.date_range("2023-01-01", periods=3), + "x_geostationary": np.arange(4), + "y_geostationary": np.arange(4), + } + return xr.DataArray( + data, + coords=coords, + dims=["time_utc", "y_geostationary", "x_geostationary"], + ) + + +@pytest.fixture(scope="module") +def da_nwp_like(): + """ Create dummy NWP DataArray """ + data = np.random.rand(3, 3, 4, 4).astype(np.float32) + coords = { + "init_time_utc": pd.date_range("2023-01-01", periods=3), + "step": pd.timedelta_range("0h", periods=3, freq="1h"), + "x_osgb": np.arange(4), + "y_osgb": np.arange(4), + } + return xr.DataArray( + data, + coords=coords, + dims=["init_time_utc", "step", "y_osgb", "x_osgb"], + ) + + +def test_merge_dicts(): + dicts = [ + {"a": 1, "b": 2}, + {"c": 3, "d": 4}, + {"b": 5, "e": 6}, + ] + result = merge_dicts(dicts) + expected = {"a": 1, "b": 5, "c": 3, "d": 4, "e": 6} + assert result == expected + + +def test_fill_nans_in_arrays(): + batch = { + "array_1": np.array([1, np.nan, 3]), + "array_2": np.array([[np.nan, 2], [3, np.nan]]), + "nested_dict": { + "array_3": np.array([np.nan, 0]), + }, + "non_numeric": "keep_this", + } + + result = fill_nans_in_arrays(batch) + expected = { + "array_1": np.array([1, 0, 3]), + "array_2": np.array([[0, 2], [3, 0]]), + "nested_dict": { + "array_3": np.array([0, 0]), + }, + "non_numeric": "keep_this", + } + + np.testing.assert_array_equal(result["array_1"], expected["array_1"]) + np.testing.assert_array_equal(result["array_2"], expected["array_2"]) + np.testing.assert_array_equal(result["nested_dict"]["array_3"], expected["nested_dict"]["array_3"]) + assert result["non_numeric"] == expected["non_numeric"] + + +def test_compute(): + mock_dataarray = MagicMock() + mock_dataarray.compute = MagicMock(side_effect=lambda scheduler=None: mock_dataarray) + + xarray_dict = { + "level1": { + "level2": mock_dataarray, + }, + "another_level1": mock_dataarray, + } + + result = compute(xarray_dict) + assert result["level1"]["level2"] == mock_dataarray + assert result["another_level1"] == mock_dataarray + mock_dataarray.compute.assert_called() + + +def test_process_and_combine_datasets(da_sat_like, da_nwp_like): + # Dummy config with valid integer values + mock_config = MagicMock() + mock_config.input_data.nwp = {"ukv": MagicMock(provider="ukv")} + mock_config.input_data.satellite = {"rss": MagicMock(provider="rss")} + mock_config.input_data.gsp = MagicMock( + interval_start_minutes=-30, + interval_end_minutes=30, + time_resolution_minutes=15, + ) + + t0 = pd.Timestamp("2023-01-01 00:00:00") + mock_location = MagicMock(x=12345.6, y=65432.1, id=1) + + dataset_dict = { + "nwp": {"ukv": da_nwp_like}, + "sat": {"rss": da_sat_like}, + } + + # Run function + result = process_and_combine_datasets(dataset_dict, mock_config, t0, mock_location) + + # Assertion currently only for sattelite and NWP + assert isinstance(result, dict) + assert SatelliteBatchKey.satellite_actual in result + assert NWPBatchKey.nwp in result + + # Assert no NaNs remain + for key, value in result.items(): + if isinstance(value, np.ndarray): + assert not np.isnan(value).any() From 3c50f9e5cdb68f58129f32dab561580b95d9979b Mon Sep 17 00:00:00 2001 From: Felix <137530077+felix-e-h-p@users.noreply.github.com> Date: Tue, 17 Dec 2024 08:28:28 +0000 Subject: [PATCH 02/20] Update workflows.yaml Override anaconda-client issue via setting version to 3.11 --- .github/workflows/workflows.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/workflows.yaml b/.github/workflows/workflows.yaml index 9905fc0..8e60944 100644 --- a/.github/workflows/workflows.yaml +++ b/.github/workflows/workflows.yaml @@ -18,3 +18,5 @@ jobs: #sudo_apt_install: "libgeos++-dev libproj-dev proj-data proj-bin" # brew_install: "proj geos librttopo" os_list: '["ubuntu-latest"]' + env: + PYTHON_VERSION: "3.11" From e0e141b0ab16ec6abf5e2fb7c3fff60dee5e463f Mon Sep 17 00:00:00 2001 From: Felix <137530077+felix-e-h-p@users.noreply.github.com> Date: Fri, 20 Dec 2024 10:32:46 +0000 Subject: [PATCH 03/20] Update process_and_combine.py --- ocf_data_sampler/torch_datasets/process_and_combine.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ocf_data_sampler/torch_datasets/process_and_combine.py b/ocf_data_sampler/torch_datasets/process_and_combine.py index bd405ed..78e8de5 100644 --- a/ocf_data_sampler/torch_datasets/process_and_combine.py +++ b/ocf_data_sampler/torch_datasets/process_and_combine.py @@ -54,6 +54,9 @@ def process_and_combine_datasets( for sat_key, da_sat in dataset_dict["sat"].items(): # Standardise provider = config.input_data.satellite[sat_key].provider + + # Not entirely sure if epsilon is necessary considering mean and std values are consistently non-zero + # Purely a safety measure da_sat = (da_sat - SAT_MEANS[provider]) / (SAT_STDS[provider] + EPSILON) # Convert to NumpyBatch From b84f9cfc320d12459b6c81c77991fe2ceba6661e Mon Sep 17 00:00:00 2001 From: Felix <137530077+felix-e-h-p@users.noreply.github.com> Date: Fri, 20 Dec 2024 10:37:07 +0000 Subject: [PATCH 04/20] Update test_process_and_combine.py --- tests/torch_datasets/test_process_and_combine.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/torch_datasets/test_process_and_combine.py b/tests/torch_datasets/test_process_and_combine.py index 5908a5d..bae0b58 100644 --- a/tests/torch_datasets/test_process_and_combine.py +++ b/tests/torch_datasets/test_process_and_combine.py @@ -50,17 +50,22 @@ def da_nwp_like(): def test_merge_dicts(): + # Example dict list with overlap dicts = [ {"a": 1, "b": 2}, {"c": 3, "d": 4}, {"b": 5, "e": 6}, ] + result = merge_dicts(dicts) + + # Prioritise later values for overlap expected = {"a": 1, "b": 5, "c": 3, "d": 4, "e": 6} assert result == expected def test_fill_nans_in_arrays(): + # Arrays with NaN, nested dict and non numerical key batch = { "array_1": np.array([1, np.nan, 3]), "array_2": np.array([[np.nan, 2], [3, np.nan]]), @@ -71,6 +76,8 @@ def test_fill_nans_in_arrays(): } result = fill_nans_in_arrays(batch) + + # NaN should be filled with zeros expected = { "array_1": np.array([1, 0, 3]), "array_2": np.array([[0, 2], [3, 0]]), From 07f821359a922f8cc37517dee0b093b796daab15 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Fri, 20 Dec 2024 11:00:19 +0000 Subject: [PATCH 05/20] run on python 3.11 --- .github/workflows/workflows.yaml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/workflows.yaml b/.github/workflows/workflows.yaml index 8e60944..6d9d722 100644 --- a/.github/workflows/workflows.yaml +++ b/.github/workflows/workflows.yaml @@ -18,5 +18,4 @@ jobs: #sudo_apt_install: "libgeos++-dev libproj-dev proj-data proj-bin" # brew_install: "proj geos librttopo" os_list: '["ubuntu-latest"]' - env: - PYTHON_VERSION: "3.11" + python_version: "3.11" From 8e2d65744484eb52a6405dc5d9b7251fe12d2c47 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Fri, 20 Dec 2024 11:01:50 +0000 Subject: [PATCH 06/20] trigger tests --- .github/workflows/workflows.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/workflows.yaml b/.github/workflows/workflows.yaml index 6d9d722..fbaa78d 100644 --- a/.github/workflows/workflows.yaml +++ b/.github/workflows/workflows.yaml @@ -18,4 +18,4 @@ jobs: #sudo_apt_install: "libgeos++-dev libproj-dev proj-data proj-bin" # brew_install: "proj geos librttopo" os_list: '["ubuntu-latest"]' - python_version: "3.11" + python_version: "['3.11']" From 33c1fc318f29abf117cd57f12766ba2d46a42ced Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Fri, 20 Dec 2024 11:27:29 +0000 Subject: [PATCH 07/20] fix --- .github/workflows/workflows.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/workflows.yaml b/.github/workflows/workflows.yaml index fbaa78d..3e79031 100644 --- a/.github/workflows/workflows.yaml +++ b/.github/workflows/workflows.yaml @@ -18,4 +18,4 @@ jobs: #sudo_apt_install: "libgeos++-dev libproj-dev proj-data proj-bin" # brew_install: "proj geos librttopo" os_list: '["ubuntu-latest"]' - python_version: "['3.11']" + python-version: "['3.11']" From d318c973382305ba3c25bbf4e06aeaa64c748b34 Mon Sep 17 00:00:00 2001 From: Felix Peretz Date: Sun, 22 Dec 2024 12:21:38 +0000 Subject: [PATCH 08/20] Unecessary requirements.txt removed --- requirements.txt | 136 ----------------------------------------------- 1 file changed, 136 deletions(-) delete mode 100644 requirements.txt diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index a8f49ce..0000000 --- a/requirements.txt +++ /dev/null @@ -1,136 +0,0 @@ -aiobotocore -aiohttp -aioitertools -aiosignal -alabaster -altair -anaconda-client -anyio -appdirs==1.4.4 -argon2-cffi -arrow -astroid -astropy -attrs -autopep8 -Babel -bcrypt -beautifulsoup4 -black -blosc2>=2.7.1,<3.0.0 -bokeh -botocore -Bottleneck -cachetools -certifi -cffi -chardet -charset-normalizer -click -cloudpickle -colorama -cycler -cytoolz -dask -datashader -debugpy -decorator -defusedxml -dill -distributed -fsspec -gensim -greenlet -h5py -holoviews -hvplot -idna -imagecodecs -imageio -imbalanced-learn -importlib-metadata -ipykernel -ipython -ipywidgets -isort -jedi -joblib -jupyter -jupyter-client -jupyter-core -jupyterlab -jupyterlab-widgets -kiwisolver -lazy-object-proxy -llvmlite -lmdb -locket -lxml -lz4 -Markdown -matplotlib==3.9.2 -mccabe -mistune -more-itertools -mpmath -msgpack -multidict -networkx -nltk -notebook -numba -numcodecs==0.13.1 -numexpr -numpy -numpydoc -openpyxl -packaging -pandas -param -patsy -pexpect -pickleshare -pillow -pkginfo -platformdirs -pluggy -prompt-toolkit -protobuf==4.25.3 -psutil -ptyprocess -PyArrow -pydantic -PyYAML -pyzmq -pathy -qtconsole -queuelib -regex -requests -scikit-image -scikit-learn -scipy -seaborn -setuptools==75.1.0 -shapely -six -sqlalchemy -statsmodels -sympy -tables -tabulate -threadpoolctl -tifffile -toolz -torch==2.5.1 -tornado -tqdm -traitlets -typing-extensions -urllib3 -watchdog -wcwidth -xarray -zarr==2.18.3 -zict -zstandard From bf89df13f32a469a2f88888b86bc8f59af1bfca2 Mon Sep 17 00:00:00 2001 From: Felix Peretz Date: Sun, 22 Dec 2024 17:38:24 +0000 Subject: [PATCH 09/20] Specified changes undertaken --- .../test_process_and_combine.py | 231 +++++++++--------- 1 file changed, 120 insertions(+), 111 deletions(-) diff --git a/tests/torch_datasets/test_process_and_combine.py b/tests/torch_datasets/test_process_and_combine.py index bae0b58..32004ba 100644 --- a/tests/torch_datasets/test_process_and_combine.py +++ b/tests/torch_datasets/test_process_and_combine.py @@ -2,142 +2,151 @@ import numpy as np import pandas as pd import xarray as xr +from ocf_data_sampler.config import Configuration +from ocf_data_sampler.select.location import Location +from ocf_data_sampler.select.select_time_slice import select_time_slice_nwp -from unittest.mock import MagicMock - -from ocf_data_sampler.constants import SAT_MEANS, SAT_STDS, NWP_MEANS, NWP_STDS, EPSILON -from ocf_data_sampler.numpy_batch.satellite import SatelliteBatchKey -from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey from ocf_data_sampler.torch_datasets.process_and_combine import ( process_and_combine_datasets, merge_dicts, fill_nans_in_arrays, compute, ) +from ocf_data_sampler.numpy_batch import NWPBatchKey, SatelliteBatchKey, GSPBatchKey -@pytest.fixture(scope="module") -def da_sat_like(): - """ Create dummy satellite DataArray """ - data = np.random.rand(3, 4, 4).astype(np.float32) - coords = { - "time_utc": pd.date_range("2023-01-01", periods=3), - "x_geostationary": np.arange(4), - "y_geostationary": np.arange(4), +@pytest.fixture +def mock_configuration(): + config = Configuration() + config.input_data.nwp = { + "ukv": type("Provider", (object,), {"provider": "ukv"}), + "ecmwf": type("Provider", (object,), {"provider": "ecmwf"}), } - return xr.DataArray( - data, - coords=coords, - dims=["time_utc", "y_geostationary", "x_geostationary"], + config.input_data.gsp = type( + "GSPConfig", + (object,), + { + "interval_start_minutes": -180, + "interval_end_minutes": 180, + "time_resolution_minutes": 30, + }, + )() + config.input_data.site = type( + "SiteConfig", + (object,), + { + "interval_start_minutes": -120, + "interval_end_minutes": 120, + "time_resolution_minutes": 15, + }, + )() + return config + + +@pytest.fixture +def mock_dataset_dict(): + x_osgb = np.linspace(0, 10, 2) + y_osgb = np.linspace(0, 10, 2) + init_time_utc = pd.date_range("2023-01-01T00:00", periods=10, freq="1h") + step = pd.to_timedelta(np.arange(10), unit="h") # Create a valid step coordinate + + # Create mock NWP data with valid `step` and `init_time_utc` + nwp_data = xr.DataArray( + np.random.rand(10, 5, 2, 2), + dims=["step", "channel", "x", "y"], + coords={ + "step": step, # Properly define `step` as a coordinate + "init_time_utc": ("step", init_time_utc), # Link `init_time_utc` to `step` + "channel": ["cdcb", "lcc", "mcc", "hcc", "sde"], + "x_osgb": ("x", x_osgb), + "y_osgb": ("y", y_osgb), + }, ) + # Ensure step remains accessible even after dimension swapping + nwp_data = nwp_data.swap_dims({"step": "init_time_utc"}).reset_coords("step", drop=False) + + # Create mock satellite data + sat_data = xr.DataArray( + np.random.rand(10, 1, 2, 2), + dims=["time", "channel", "x", "y"], + coords={ + "time": pd.date_range("2023-01-01", periods=10, freq="30min"), + "channel": ["HRV"], + "x_osgb": ("x", x_osgb), + "y_osgb": ("y", y_osgb), + }, + ) + + # Create mock GSP data + gsp_data = xr.DataArray( + np.random.rand(10), + dims=["time"], + coords={"time": pd.date_range("2023-01-01", periods=10, freq="30min")}, + ) + gsp_future_data = xr.DataArray( + np.random.rand(10), + dims=["time"], + coords={"time": pd.date_range("2023-01-01T05:00", periods=10, freq="30min")}, + ) -@pytest.fixture(scope="module") -def da_nwp_like(): - """ Create dummy NWP DataArray """ - data = np.random.rand(3, 3, 4, 4).astype(np.float32) - coords = { - "init_time_utc": pd.date_range("2023-01-01", periods=3), - "step": pd.timedelta_range("0h", periods=3, freq="1h"), - "x_osgb": np.arange(4), - "y_osgb": np.arange(4), + return { + "nwp": {"ukv": nwp_data}, + "sat": sat_data, + "gsp": gsp_data, + "gsp_future": gsp_future_data, } - return xr.DataArray( - data, - coords=coords, - dims=["init_time_utc", "step", "y_osgb", "x_osgb"], + + +def test_process_and_combine_datasets(mock_configuration, mock_dataset_dict): + location = Location(x=0, y=0, id=1) + t0 = pd.Timestamp("2023-01-01 06:00") + + # Apply time slicing to the NWP data + for nwp_key, da_nwp in mock_dataset_dict["nwp"].items(): + mock_dataset_dict["nwp"][nwp_key] = select_time_slice_nwp( + da=da_nwp, + t0=t0, + interval_start=pd.Timedelta(hours=-3), + interval_end=pd.Timedelta(hours=3), + sample_period_duration=pd.Timedelta(minutes=30), + ) + + result = process_and_combine_datasets( + dataset_dict=mock_dataset_dict, + config=mock_configuration, + t0=t0, + location=location, + target_key="gsp", ) + assert isinstance(result, dict) + assert GSPBatchKey.gsp in result + assert NWPBatchKey.nwp in result + assert SatelliteBatchKey.satellite_actual in result + def test_merge_dicts(): - # Example dict list with overlap - dicts = [ - {"a": 1, "b": 2}, - {"c": 3, "d": 4}, - {"b": 5, "e": 6}, - ] - - result = merge_dicts(dicts) - - # Prioritise later values for overlap - expected = {"a": 1, "b": 5, "c": 3, "d": 4, "e": 6} - assert result == expected + dicts = [{"a": 1, "b": 2}, {"b": 3, "c": 4}] + merged = merge_dicts(dicts) + assert merged == {"a": 1, "b": 3, "c": 4} def test_fill_nans_in_arrays(): - # Arrays with NaN, nested dict and non numerical key batch = { - "array_1": np.array([1, np.nan, 3]), - "array_2": np.array([[np.nan, 2], [3, np.nan]]), - "nested_dict": { - "array_3": np.array([np.nan, 0]), - }, - "non_numeric": "keep_this", + "a": np.array([1.0, np.nan, 3.0]), + "b": {"nested": np.array([np.nan, 5.0])}, } - - result = fill_nans_in_arrays(batch) - - # NaN should be filled with zeros - expected = { - "array_1": np.array([1, 0, 3]), - "array_2": np.array([[0, 2], [3, 0]]), - "nested_dict": { - "array_3": np.array([0, 0]), - }, - "non_numeric": "keep_this", - } - - np.testing.assert_array_equal(result["array_1"], expected["array_1"]) - np.testing.assert_array_equal(result["array_2"], expected["array_2"]) - np.testing.assert_array_equal(result["nested_dict"]["array_3"], expected["nested_dict"]["array_3"]) - assert result["non_numeric"] == expected["non_numeric"] + filled_batch = fill_nans_in_arrays(batch) + assert np.array_equal(filled_batch["a"], np.array([1.0, 0.0, 3.0])) + assert np.array_equal(filled_batch["b"]["nested"], np.array([0.0, 5.0])) def test_compute(): - mock_dataarray = MagicMock() - mock_dataarray.compute = MagicMock(side_effect=lambda scheduler=None: mock_dataarray) - - xarray_dict = { - "level1": { - "level2": mock_dataarray, - }, - "another_level1": mock_dataarray, - } - - result = compute(xarray_dict) - assert result["level1"]["level2"] == mock_dataarray - assert result["another_level1"] == mock_dataarray - mock_dataarray.compute.assert_called() - - -def test_process_and_combine_datasets(da_sat_like, da_nwp_like): - # Dummy config with valid integer values - mock_config = MagicMock() - mock_config.input_data.nwp = {"ukv": MagicMock(provider="ukv")} - mock_config.input_data.satellite = {"rss": MagicMock(provider="rss")} - mock_config.input_data.gsp = MagicMock( - interval_start_minutes=-30, - interval_end_minutes=30, - time_resolution_minutes=15, + data = xr.DataArray( + np.random.rand(10), dims=["time"], coords={"time": pd.date_range("2023-01-01", periods=10)} ) + nested_dict = {"level1": {"level2": data}} + computed_dict = compute(nested_dict) - t0 = pd.Timestamp("2023-01-01 00:00:00") - mock_location = MagicMock(x=12345.6, y=65432.1, id=1) - - dataset_dict = { - "nwp": {"ukv": da_nwp_like}, - "sat": {"rss": da_sat_like}, - } - - # Run function - result = process_and_combine_datasets(dataset_dict, mock_config, t0, mock_location) - - # Assertion currently only for sattelite and NWP - assert isinstance(result, dict) - assert SatelliteBatchKey.satellite_actual in result - assert NWPBatchKey.nwp in result - - # Assert no NaNs remain - for key, value in result.items(): - if isinstance(value, np.ndarray): - assert not np.isnan(value).any() + assert computed_dict["level1"]["level2"].equals(data) From 584459c3fe422daea430b122d348c3760578b99c Mon Sep 17 00:00:00 2001 From: Felix Peretz Date: Sun, 22 Dec 2024 17:41:39 +0000 Subject: [PATCH 10/20] --- .DS_Store | Bin 10244 -> 6148 bytes .gitignore | 3 ++ ocf_data_sampler/constants.py | 39 ++++-------------- ocf_data_sampler/numpy_batch/nwp.py | 6 --- ocf_data_sampler/numpy_batch/satellite.py | 6 --- .../torch_datasets/process_and_combine.py | 23 ++++------- 6 files changed, 17 insertions(+), 60 deletions(-) diff --git a/.DS_Store b/.DS_Store index 47564d99da74d156096f22186be649ed86012e73..d5c257a446ccff182e41b1b7d4c855bc3c611125 100644 GIT binary patch delta 168 zcmZn(XfcprU|?W$DortDU=RQ@Ie-{MGjdEU6q~50D9i-nfW?ZF%8Lt!A4ToZP(p zuF2=c;Gm~)hul$Yni}5q zGx!yJ_Jinm@kxKPcO`dA|HK#5w7X<)?w0xO-u`BG$P7dzD#KQh$RMILI?LP@T#jh` zJ@=`W2uE&13gC&Bs7Wq4Xa*(> zaC|V(S(ZaN4x|hnxX2U$GKpc?aF0AdVvJChLpctlgksJhdk`U1geisy;P~EPb;xok z$AJvsBmy{zaAgr@C_=6doFUaolmn@+ngPwgGy|NwU#C2o)Ws*7zuVNHAu@Rvxx9zo z+~~T^M%Qfs!G~-WJ^x2G-0@gmB6`%PU9u#~meqhfsse55DERLRJ_|fo4+5-79XW@x zs<@+|l0bFHCOxaJaqtWl$o?j^Z5{a4Q?4?vXmCmN$ z)^^-c%{Tp_S1|{#`DxpAyN$zpv+=^VM)ll{HP`V?+pz``V7Dx&ym@ImEq7FLhfd3r z*jn&KI+M=SbMyQA%cYe?qf}fuSTy#Fh2;l}#>%7O!9gZ{bK&0Ot@?*<&vDr$&_wtP zLUY0Mc9x#OAvS?{0z{3rv5HNOf?}tQqH@yAZMaGc#`n?dFgiP1AI8LijN_#P{}cYa z82MD96w}Y2Ml8%@;cRzs)?p^ZIyX)=pKkFva&oT1BzrPixbz;_Cd`#bcaM$?u~V0! zEjYz0t;0#GK=ssfN;uImL{7$okBEEhyw=2ZIU@MGy@lhftg^d2G{?qNZ^n8|Nr6;qIaqp zxPT0ZC)^4D`L+8eN<3LIXF0vdCDa-MA;Bx$3bk;4f dict: """Convert from Xarray to NWP NumpyBatch""" - # Missing coordinate checking stage - required_coords = ["y_osgb", "x_osgb"] - for coord in required_coords: - if coord not in da.coords: - raise ValueError(f"Input DataArray missing '{coord}'") - example = { NWPBatchKey.nwp: da.values, NWPBatchKey.channel_names: da.channel.values, diff --git a/ocf_data_sampler/numpy_batch/satellite.py b/ocf_data_sampler/numpy_batch/satellite.py index 6696ef0..d55ce4f 100644 --- a/ocf_data_sampler/numpy_batch/satellite.py +++ b/ocf_data_sampler/numpy_batch/satellite.py @@ -14,12 +14,6 @@ class SatelliteBatchKey: def convert_satellite_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> dict: """Convert from Xarray to NumpyBatch""" - # Missing coordinate checking stage - required_coords = ["x_geostationary", "y_geostationary"] - for coord in required_coords: - if coord not in da.coords: - raise ValueError(f"Input DataArray missing '{coord}'") - example = { SatelliteBatchKey.satellite_actual: da.values, SatelliteBatchKey.time_utc: da.time_utc.values.astype(float), diff --git a/ocf_data_sampler/torch_datasets/process_and_combine.py b/ocf_data_sampler/torch_datasets/process_and_combine.py index 78e8de5..8b59da7 100644 --- a/ocf_data_sampler/torch_datasets/process_and_combine.py +++ b/ocf_data_sampler/torch_datasets/process_and_combine.py @@ -3,7 +3,7 @@ import xarray as xr from ocf_data_sampler.config import Configuration -from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, SAT_MEANS, SAT_STDS, EPSILON +from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, SAT_MEANS, SAT_STDS from ocf_data_sampler.numpy_batch import ( convert_nwp_to_numpy_batch, convert_satellite_to_numpy_batch, @@ -48,21 +48,12 @@ def process_and_combine_datasets( if "sat" in dataset_dict: - - sat_numpy_modalities = dict() - - for sat_key, da_sat in dataset_dict["sat"].items(): - # Standardise - provider = config.input_data.satellite[sat_key].provider - - # Not entirely sure if epsilon is necessary considering mean and std values are consistently non-zero - # Purely a safety measure - da_sat = (da_sat - SAT_MEANS[provider]) / (SAT_STDS[provider] + EPSILON) - - # Convert to NumpyBatch - sat_numpy_modalities[sat_key] = convert_satellite_to_numpy_batch(da_sat) - - # Combine the Sattelites into NumpyBatch + # Standardise + da_sat = dataset_dict["sat"] + da_sat = (da_sat - SAT_MEANS) / SAT_STDS + # Convert to NumpyBatch + sat_numpy_modalities = convert_satellite_to_numpy_batch(da_sat) + # Combine the Satellite into NumpyBatch numpy_modalities.append({SatelliteBatchKey.satellite_actual: sat_numpy_modalities}) From 7fa176759f6a03d2ff53fb945e86e65dabc9e9ef Mon Sep 17 00:00:00 2001 From: Felix <137530077+felix-e-h-p@users.noreply.github.com> Date: Sun, 22 Dec 2024 17:45:28 +0000 Subject: [PATCH 11/20] Removal of DS Store (in ignore) --- .DS_Store | Bin 6148 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 .DS_Store diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index d5c257a446ccff182e41b1b7d4c855bc3c611125..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHK%}T>S5Z<+|O({YS3Oz1(E!fst5HBIt7cim+m70*C!I&*gY7eE5v%Zi|;`2DO zyMY#iM-e*%yWi~m>}Ed5{xHV4I}iJe*^Ds@8X`xfLeO04+A+b1T+NXMi)CSI*NgQ@NgYSS_nHu{+D<$)MjAgW+h^70c5> zuPcVfqt&Wq9ULB=T~3~pmsGxKA~`UwWY1s;@1U$z^y25^56&=4Jig+{e?K!?|7^w$wlK*zTPqA=(fEHr`#gzHp5 zoyyG_+=^| z`I{*;A_jkG-)AMVYhpTX}fa3TO||P%y4Q1qAesO8^XTAL%Hk_6yV@&M{bM V#97d;(gEorpa`Lk82AMSz5v#mOC10J From fcb7555f0928048b694945b6e3f7624579a0796c Mon Sep 17 00:00:00 2001 From: Felix <137530077+felix-e-h-p@users.noreply.github.com> Date: Sun, 22 Dec 2024 17:45:45 +0000 Subject: [PATCH 12/20] Removal of DS Store (in ignore) --- tests/.DS_Store | Bin 6148 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 tests/.DS_Store diff --git a/tests/.DS_Store b/tests/.DS_Store deleted file mode 100644 index 918e6124fc19858560b92bb6f3b327ddc001014c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKL2uJA6n^eXmlA<`0BIMbNL;JYZU}_9gl-(T5(Ec8Az8u}k;PS$QcYE*oZ&z4 z7r63E_%EE`d$x;~6cC3Ap~&Z%=93`sMd2j*lrDq}@ui_2>L zH8)u}OzROPw`2p*J37iSat{GedzrR1~^FEjrxh(iCtYfUX7B5Z#qpjm! zTu2+qDqt13Bnt5U;Gr=325XIK>p-Qh0KhJ~m7&dl7MSB1^bOV;F#;2s3e;3#t{6ho z(eD{Q-(am#(@B`ihcKCixuFP^9q~QoPQuq{YpZ}&psB!?ejM=m|M>U!|0c;kSp}>D zmr4QAJq!*fSduwgHx|cdt%vdkg^m4cjf#TG9LK7`NAV7dGK_iL0Qv@Njp%{dKLSby LTUiDEsRBO$&|$uR From cbc1b4b265bfa3648a4af4e47c8d6c0af49a0135 Mon Sep 17 00:00:00 2001 From: Felix <137530077+felix-e-h-p@users.noreply.github.com> Date: Sun, 22 Dec 2024 17:46:16 +0000 Subject: [PATCH 13/20] Removal of DS Store (in ignore) --- ocf_data_sampler/.DS_Store | Bin 6148 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 ocf_data_sampler/.DS_Store diff --git a/ocf_data_sampler/.DS_Store b/ocf_data_sampler/.DS_Store deleted file mode 100644 index 0b09e2d677d065a693bcc2d23e20e997189dc058..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKK~BRk5L}0-Diom}xghxk75pGn<-nN-l(s3QQXmyg5548i7Z6Y2H#~+D%&y(a zPRkV`*sa!+I6GeN#ENSoGW}J4LDVIpHk>gzM6<{EIs1+!zUK&NWQ;jgG{mn*i=Alg z_>BtiyX(k1o6{qO?{9rwl#_8$juFAj;}YKc6vNeqwV*XkXav8fDUTDglJ^Fyg~$?_ z?vQI1#~O>+^B}{qk6mI_jlijBDJ!cvvp4Ww0$xfv;=yT*|BSe>f|Lr(Wkh)+?xPXc zLGBejU`?-4;aJ8-y)i49HCQdoc8UyBR8hsTqUCEa;}gubq$g2vxS751QFF{SXUG$B z+n8GgHH-b~7_K^u$E}(wpbGqZ1+ZtcwL1>ARRvT5RbZB|CpqAs(>o+rxY;h;AYTANq%qLDUSDA3%`Z4aa`?iOTmR~ g#fast_!KUNe#;lY%wz2kJuv+vATns93jC-7U)XV!q5uE@ From 59b3fa4b04c6177600a51adbfba9e051393bc8aa Mon Sep 17 00:00:00 2001 From: Felix <137530077+felix-e-h-p@users.noreply.github.com> Date: Sun, 22 Dec 2024 17:53:51 +0000 Subject: [PATCH 14/20] Update workflows.yaml Python version set removed --- .github/workflows/workflows.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/workflows.yaml b/.github/workflows/workflows.yaml index 3e79031..9905fc0 100644 --- a/.github/workflows/workflows.yaml +++ b/.github/workflows/workflows.yaml @@ -18,4 +18,3 @@ jobs: #sudo_apt_install: "libgeos++-dev libproj-dev proj-data proj-bin" # brew_install: "proj geos librttopo" os_list: '["ubuntu-latest"]' - python-version: "['3.11']" From b30ad79594780444b5395aab08a1afcb3e556d46 Mon Sep 17 00:00:00 2001 From: Felix Peretz Date: Sun, 22 Dec 2024 19:20:50 +0000 Subject: [PATCH 15/20] Refactor of testing functionality --- .../test_process_and_combine.py | 297 +++++++++++------- 1 file changed, 182 insertions(+), 115 deletions(-) diff --git a/tests/torch_datasets/test_process_and_combine.py b/tests/torch_datasets/test_process_and_combine.py index 32004ba..76d9ab9 100644 --- a/tests/torch_datasets/test_process_and_combine.py +++ b/tests/torch_datasets/test_process_and_combine.py @@ -2,9 +2,13 @@ import numpy as np import pandas as pd import xarray as xr + from ocf_data_sampler.config import Configuration +from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey +from ocf_data_sampler.numpy_batch.satellite import SatelliteBatchKey from ocf_data_sampler.select.location import Location -from ocf_data_sampler.select.select_time_slice import select_time_slice_nwp +from ocf_data_sampler.select.select_time_slice import select_time_slice, select_time_slice_nwp +from ocf_data_sampler.utils import minutes from ocf_data_sampler.torch_datasets.process_and_combine import ( process_and_combine_datasets, @@ -12,141 +16,204 @@ fill_nans_in_arrays, compute, ) -from ocf_data_sampler.numpy_batch import NWPBatchKey, SatelliteBatchKey, GSPBatchKey -@pytest.fixture -def mock_configuration(): - config = Configuration() - config.input_data.nwp = { - "ukv": type("Provider", (object,), {"provider": "ukv"}), - "ecmwf": type("Provider", (object,), {"provider": "ecmwf"}), - } - config.input_data.gsp = type( - "GSPConfig", - (object,), - { - "interval_start_minutes": -180, - "interval_end_minutes": 180, - "time_resolution_minutes": 30, - }, - )() - config.input_data.site = type( - "SiteConfig", - (object,), - { - "interval_start_minutes": -120, - "interval_end_minutes": 120, - "time_resolution_minutes": 15, - }, - )() - return config - +NWP_FREQ = pd.Timedelta("3h") -@pytest.fixture -def mock_dataset_dict(): - x_osgb = np.linspace(0, 10, 2) - y_osgb = np.linspace(0, 10, 2) - init_time_utc = pd.date_range("2023-01-01T00:00", periods=10, freq="1h") - step = pd.to_timedelta(np.arange(10), unit="h") # Create a valid step coordinate - - # Create mock NWP data with valid `step` and `init_time_utc` - nwp_data = xr.DataArray( - np.random.rand(10, 5, 2, 2), - dims=["step", "channel", "x", "y"], - coords={ - "step": step, # Properly define `step` as a coordinate - "init_time_utc": ("step", init_time_utc), # Link `init_time_utc` to `step` - "channel": ["cdcb", "lcc", "mcc", "hcc", "sde"], - "x_osgb": ("x", x_osgb), - "y_osgb": ("y", y_osgb), - }, - ) - # Ensure step remains accessible even after dimension swapping - nwp_data = nwp_data.swap_dims({"step": "init_time_utc"}).reset_coords("step", drop=False) - - # Create mock satellite data - sat_data = xr.DataArray( - np.random.rand(10, 1, 2, 2), - dims=["time", "channel", "x", "y"], - coords={ - "time": pd.date_range("2023-01-01", periods=10, freq="30min"), - "channel": ["HRV"], - "x_osgb": ("x", x_osgb), - "y_osgb": ("y", y_osgb), - }, - ) +@pytest.fixture(scope="module") +def da_sat_like(): + """Create dummy satellite data""" + x = np.arange(-100, 100) + y = np.arange(-100, 100) + datetimes = pd.date_range("2024-01-02 00:00", "2024-01-03 00:00", freq="5min") - # Create mock GSP data - gsp_data = xr.DataArray( - np.random.rand(10), - dims=["time"], - coords={"time": pd.date_range("2023-01-01", periods=10, freq="30min")}, + da_sat = xr.DataArray( + np.random.normal(size=(len(datetimes), len(x), len(y))), + coords=dict( + time_utc=(["time_utc"], datetimes), + x_geostationary=(["x_geostationary"], x), + y_geostationary=(["y_geostationary"], y), + ) ) - gsp_future_data = xr.DataArray( - np.random.rand(10), - dims=["time"], - coords={"time": pd.date_range("2023-01-01T05:00", periods=10, freq="30min")}, + return da_sat + + +@pytest.fixture(scope="module") +def da_nwp_like(): + """Create dummy NWP data""" + x = np.arange(-100, 100) + y = np.arange(-100, 100) + datetimes = pd.date_range("2024-01-02 00:00", "2024-01-03 00:00", freq=NWP_FREQ) + steps = pd.timedelta_range("0h", "16h", freq="1h") + channels = ["t", "dswrf"] + + da_nwp = xr.DataArray( + np.random.normal(size=(len(datetimes), len(steps), len(channels), len(x), len(y))), + coords=dict( + init_time_utc=(["init_time_utc"], datetimes), + step=(["step"], steps), + channel=(["channel"], channels), + x_osgb=(["x_osgb"], x), + y_osgb=(["y_osgb"], y), + ) ) + return da_nwp - return { - "nwp": {"ukv": nwp_data}, - "sat": sat_data, - "gsp": gsp_data, - "gsp_future": gsp_future_data, - } +@pytest.fixture +def mock_constants(monkeypatch): + """Creation of dummy constants used in normalisation process""" + mock_nwp_means = {"ukv": { + "t": 10.0, + "dswrf": 50.0 + }} + mock_nwp_stds = {"ukv": { + "t": 2.0, + "dswrf": 10.0 + }} + mock_sat_means = 100.0 + mock_sat_stds = 20.0 + + monkeypatch.setattr("ocf_data_sampler.constants.NWP_MEANS", mock_nwp_means) + monkeypatch.setattr("ocf_data_sampler.constants.NWP_STDS", mock_nwp_stds) + monkeypatch.setattr("ocf_data_sampler.constants.SAT_MEANS", mock_sat_means) + monkeypatch.setattr("ocf_data_sampler.constants.SAT_STDS", mock_sat_stds) -def test_process_and_combine_datasets(mock_configuration, mock_dataset_dict): - location = Location(x=0, y=0, id=1) - t0 = pd.Timestamp("2023-01-01 06:00") - # Apply time slicing to the NWP data - for nwp_key, da_nwp in mock_dataset_dict["nwp"].items(): - mock_dataset_dict["nwp"][nwp_key] = select_time_slice_nwp( - da=da_nwp, - t0=t0, - interval_start=pd.Timedelta(hours=-3), - interval_end=pd.Timedelta(hours=3), - sample_period_duration=pd.Timedelta(minutes=30), - ) +@pytest.fixture +def mock_config(): + """Specify dummy configuration""" + class MockConfig: + class InputData: + class NWP: + provider = "ukv" + interval_start_minutes = -360 + interval_end_minutes = 180 + time_resolution_minutes = 60 + + class GSP: + interval_start_minutes = -120 + interval_end_minutes = 120 + time_resolution_minutes = 30 + + def __init__(self): + self.nwp = {"ukv": self.NWP()} + self.gsp = self.GSP() + + def __init__(self): + self.input_data = self.InputData() + + return MockConfig() - result = process_and_combine_datasets( - dataset_dict=mock_dataset_dict, - config=mock_configuration, - t0=t0, - location=location, - target_key="gsp", - ) - assert isinstance(result, dict) - assert GSPBatchKey.gsp in result - assert NWPBatchKey.nwp in result - assert SatelliteBatchKey.satellite_actual in result +@pytest.fixture +def mock_location(): + """Create dummy location""" + return Location(id=12345, x=400000, y=500000) def test_merge_dicts(): - dicts = [{"a": 1, "b": 2}, {"b": 3, "c": 4}] - merged = merge_dicts(dicts) - assert merged == {"a": 1, "b": 3, "c": 4} + """Test merge_dicts function""" + dict1 = {"a": 1, "b": 2} + dict2 = {"c": 3, "d": 4} + dict3 = {"e": 5} + + result = merge_dicts([dict1, dict2, dict3]) + assert result == {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5} + + # Test key overwriting + dict4 = {"a": 10, "f": 6} + result = merge_dicts([dict1, dict4]) + assert result["a"] == 10 def test_fill_nans_in_arrays(): - batch = { - "a": np.array([1.0, np.nan, 3.0]), - "b": {"nested": np.array([np.nan, 5.0])}, + """Test the fill_nans_in_arrays function""" + array_with_nans = np.array([1.0, np.nan, 3.0, np.nan]) + nested_dict = { + "array1": array_with_nans, + "nested": { + "array2": np.array([np.nan, 2.0, np.nan, 4.0]) + }, + "string_key": "not_an_array" } - filled_batch = fill_nans_in_arrays(batch) - assert np.array_equal(filled_batch["a"], np.array([1.0, 0.0, 3.0])) - assert np.array_equal(filled_batch["b"]["nested"], np.array([0.0, 5.0])) + + result = fill_nans_in_arrays(nested_dict) + + assert not np.isnan(result["array1"]).any() + assert np.array_equal(result["array1"], np.array([1.0, 0.0, 3.0, 0.0])) + assert not np.isnan(result["nested"]["array2"]).any() + assert np.array_equal(result["nested"]["array2"], np.array([0.0, 2.0, 0.0, 4.0])) + assert result["string_key"] == "not_an_array" def test_compute(): - data = xr.DataArray( - np.random.rand(10), dims=["time"], coords={"time": pd.date_range("2023-01-01", periods=10)} + """Test the compute function""" + da = xr.DataArray(np.random.rand(5, 5)) + + # Create nested dictionary + nested_dict = { + "array1": da, + "nested": { + "array2": da + } + } + + result = compute(nested_dict) + + # Ensure function applied - check if data is no longer lazy array and determine structural alterations + # Check that result is an xarray DataArray + assert isinstance(result["array1"], xr.DataArray) + assert isinstance(result["nested"]["array2"], xr.DataArray) + + # Check data is no longer lazy object + assert isinstance(result["array1"].data, np.ndarray) + assert isinstance(result["nested"]["array2"].data, np.ndarray) + + # Check for NaN + assert not np.isnan(result["array1"].data).any() + assert not np.isnan(result["nested"]["array2"].data).any() + + +# TO DO - Update the below to include satellite and finalise testing procedure +# Currently for NWP only - awaiting confirmation +@pytest.mark.parametrize("t0_str", ["10:00", "11:00", "12:00"]) +def test_full_pipeline(da_nwp_like, mock_config, mock_location, mock_constants, t0_str): + """Test full pipeline considering time slice selection and then process and combine""" + t0 = pd.Timestamp(f"2024-01-02 {t0_str}") + + # Obtain NWP data slice + nwp_sample = select_time_slice_nwp( + da_nwp_like, + t0, + sample_period_duration=pd.Timedelta(minutes=mock_config.input_data.nwp["ukv"].time_resolution_minutes), + interval_start=pd.Timedelta(minutes=mock_config.input_data.nwp["ukv"].interval_start_minutes), + interval_end=pd.Timedelta(minutes=mock_config.input_data.nwp["ukv"].interval_end_minutes), + dropout_timedeltas=None, + dropout_frac=0, + accum_channels=["dswrf"], + channel_dim_name="channel", ) - nested_dict = {"level1": {"level2": data}} - computed_dict = compute(nested_dict) - - assert computed_dict["level1"]["level2"].equals(data) + + # Prepare dataset dictionary + dataset_dict = { + "nwp": {"ukv": nwp_sample}, + } + + # Process data with main function + result = process_and_combine_datasets( + dataset_dict, + mock_config, + t0, + mock_location, + target_key='gsp' + ) + + # Verify results structure + assert NWPBatchKey.nwp in result + + # Check NWP data normalisation and NaN handling + nwp_data = result[NWPBatchKey.nwp]["ukv"] + assert isinstance(nwp_data['nwp'], np.ndarray) + assert not np.isnan(nwp_data['nwp']).any() From ac4de5ebdc720703b0bb8c31c0dc1c3232222216 Mon Sep 17 00:00:00 2001 From: Felix <137530077+felix-e-h-p@users.noreply.github.com> Date: Sun, 22 Dec 2024 20:30:13 +0000 Subject: [PATCH 16/20] DS_Store gone --- tests/torch_datasets/.DS_Store | Bin 6148 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 tests/torch_datasets/.DS_Store diff --git a/tests/torch_datasets/.DS_Store b/tests/torch_datasets/.DS_Store deleted file mode 100644 index 330862ab50d7f8d728b6eb1091d509d39c3300c5..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKu};H441I);| z>+k4UV2MYph&6ZhfC_hfUSrCAzFswLT{UgR3Zzd?iRbSww2uc>Uo21Gu*E&~hE;A^ zt)+j7CRY&`tTIgnX#s~j6+7nYsxvFW+qXPR4iyXp1HnKr5DYj2yt7rt2MI$51HnKr z@Wp_h4+V!}>DVW_&b7 From abccd15acaf15450719ecc2564ba09adff238ef2 Mon Sep 17 00:00:00 2001 From: Felix Peretz Date: Mon, 23 Dec 2024 11:25:01 +0000 Subject: [PATCH 17/20] Updated process_and_combine.py --- ocf_data_sampler/torch_datasets/process_and_combine.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ocf_data_sampler/torch_datasets/process_and_combine.py b/ocf_data_sampler/torch_datasets/process_and_combine.py index 8b59da7..79a76ed 100644 --- a/ocf_data_sampler/torch_datasets/process_and_combine.py +++ b/ocf_data_sampler/torch_datasets/process_and_combine.py @@ -51,10 +51,9 @@ def process_and_combine_datasets( # Standardise da_sat = dataset_dict["sat"] da_sat = (da_sat - SAT_MEANS) / SAT_STDS + # Convert to NumpyBatch - sat_numpy_modalities = convert_satellite_to_numpy_batch(da_sat) - # Combine the Satellite into NumpyBatch - numpy_modalities.append({SatelliteBatchKey.satellite_actual: sat_numpy_modalities}) + numpy_modalities.append(convert_satellite_to_numpy_batch(da_sat)) gsp_config = config.input_data.gsp From 2aea746dfbbdaf83152e2fb27adeb911a8fb1726 Mon Sep 17 00:00:00 2001 From: Felix Peretz Date: Mon, 23 Dec 2024 15:34:02 +0000 Subject: [PATCH 18/20] Updated constants.py, process_and_combine.py with required changes - reinstated testing logic also --- ocf_data_sampler/constants.py | 12 +- .../torch_datasets/process_and_combine.py | 10 +- .../test_process_and_combine.py | 258 +++++++++--------- 3 files changed, 138 insertions(+), 142 deletions(-) diff --git a/ocf_data_sampler/constants.py b/ocf_data_sampler/constants.py index 1aa4ff3..7616d96 100644 --- a/ocf_data_sampler/constants.py +++ b/ocf_data_sampler/constants.py @@ -139,7 +139,7 @@ def __getitem__(self, key): # ------ Satellite # RSS Mean and std values from randomised 20% of 2020 imagery -SAT_STD = { +RSS_STD = { "HRV": 0.11405209, "IR_016": 0.21462157, "IR_039": 0.04618041, @@ -154,7 +154,7 @@ def __getitem__(self, key): "WV_073": 0.12924142, } -SAT_MEAN = { +RSS_MEAN = { "HRV": 0.09298719, "IR_016": 0.17594202, "IR_039": 0.86167645, @@ -169,9 +169,5 @@ def __getitem__(self, key): "WV_073": 0.62479186, } -SAT_STD = _to_data_array(SAT_STD) -SAT_MEAN = _to_data_array(SAT_MEAN) - -# SatStatDict wrapper not needed due to singular provider - direct assignment of meand and std -SAT_STDS = SAT_STD -SAT_MEANS = SAT_MEAN +RSS_STD = _to_data_array(RSS_STD) +RSS_MEAN = _to_data_array(RSS_MEAN) diff --git a/ocf_data_sampler/torch_datasets/process_and_combine.py b/ocf_data_sampler/torch_datasets/process_and_combine.py index 41dc381..0eb7261 100644 --- a/ocf_data_sampler/torch_datasets/process_and_combine.py +++ b/ocf_data_sampler/torch_datasets/process_and_combine.py @@ -4,7 +4,7 @@ from typing import Tuple from ocf_data_sampler.config import Configuration -from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, SAT_MEANS, SAT_STDS +from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD from ocf_data_sampler.numpy_batch import ( convert_nwp_to_numpy_batch, convert_satellite_to_numpy_batch, @@ -14,7 +14,6 @@ from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey from ocf_data_sampler.numpy_batch.satellite import SatelliteBatchKey - from ocf_data_sampler.select.geospatial import osgb_to_lon_lat from ocf_data_sampler.select.location import Location from ocf_data_sampler.utils import minutes @@ -50,7 +49,7 @@ def process_and_combine_datasets( if "sat" in dataset_dict: # Standardise da_sat = dataset_dict["sat"] - da_sat = (da_sat - SAT_MEANS) / SAT_STDS + da_sat = (da_sat - RSS_MEAN) / RSS_STD # Convert to NumpyBatch numpy_modalities.append(convert_satellite_to_numpy_batch(da_sat)) @@ -99,6 +98,7 @@ def process_and_combine_datasets( return combined_sample + def process_and_combine_site_sample_dict( dataset_dict: dict, config: Configuration, @@ -125,8 +125,9 @@ def process_and_combine_site_sample_dict( data_arrays.append((f"nwp-{provider}", da_nwp)) if "sat" in dataset_dict: - # TODO add some satellite normalisation + # Satellite normalisation added da_sat = dataset_dict["sat"] + da_sat = (da_sat - RSS_MEAN) / RSS_STD data_arrays.append(("satellite", da_sat)) if "site" in dataset_dict: @@ -149,6 +150,7 @@ def merge_dicts(list_of_dicts: list[dict]) -> dict: combined_dict.update(d) return combined_dict + def merge_arrays(normalised_data_arrays: list[Tuple[str, xr.DataArray]]) -> xr.Dataset: """ Combine a list of DataArrays into a single Dataset with unique naming conventions. diff --git a/tests/torch_datasets/test_process_and_combine.py b/tests/torch_datasets/test_process_and_combine.py index 76d9ab9..7102d1b 100644 --- a/tests/torch_datasets/test_process_and_combine.py +++ b/tests/torch_datasets/test_process_and_combine.py @@ -1,116 +1,123 @@ import pytest +import tempfile + import numpy as np import pandas as pd import xarray as xr +from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration from ocf_data_sampler.config import Configuration -from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey -from ocf_data_sampler.numpy_batch.satellite import SatelliteBatchKey from ocf_data_sampler.select.location import Location -from ocf_data_sampler.select.select_time_slice import select_time_slice, select_time_slice_nwp -from ocf_data_sampler.utils import minutes +from ocf_data_sampler.numpy_batch import NWPBatchKey, GSPBatchKey, SatelliteBatchKey +from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset from ocf_data_sampler.torch_datasets.process_and_combine import ( process_and_combine_datasets, + process_and_combine_site_sample_dict, merge_dicts, fill_nans_in_arrays, compute, ) -NWP_FREQ = pd.Timedelta("3h") +# Currently leaving here for reference purpose - not strictly needed +def test_pvnet(pvnet_config_filename): + + # Create dataset object + dataset = PVNetUKRegionalDataset(pvnet_config_filename) + assert len(dataset.locations) == 317 # no of GSPs not including the National level + # NB. I have not checked this value is in fact correct, but it does seem to stay constant + assert len(dataset.valid_t0_times) == 39 + assert len(dataset) == 317*39 -@pytest.fixture(scope="module") -def da_sat_like(): - """Create dummy satellite data""" - x = np.arange(-100, 100) - y = np.arange(-100, 100) - datetimes = pd.date_range("2024-01-02 00:00", "2024-01-03 00:00", freq="5min") + # Generate a sample + sample = dataset[0] - da_sat = xr.DataArray( - np.random.normal(size=(len(datetimes), len(x), len(y))), - coords=dict( - time_utc=(["time_utc"], datetimes), - x_geostationary=(["x_geostationary"], x), - y_geostationary=(["y_geostationary"], y), - ) + assert isinstance(sample, dict) + + for key in [ + NWPBatchKey.nwp, SatelliteBatchKey.satellite_actual, GSPBatchKey.gsp, + GSPBatchKey.solar_azimuth, GSPBatchKey.solar_elevation, + ]: + assert key in sample + + for nwp_source in ["ukv"]: + assert nwp_source in sample[NWPBatchKey.nwp] + + # check the shape of the data is correct + # 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels + assert sample[SatelliteBatchKey.satellite_actual].shape == (7, 1, 2, 2) + # 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels + assert sample[NWPBatchKey.nwp]["ukv"][NWPBatchKey.nwp].shape == (4, 1, 2, 2) + # 3 hours of 30 minute data (inclusive) + assert sample[GSPBatchKey.gsp].shape == (7,) + # Solar angles have same shape as GSP data + assert sample[GSPBatchKey.solar_azimuth].shape == (7,) + assert sample[GSPBatchKey.solar_elevation].shape == (7,) + + +# Currently leaving here for reference purpose - not strictly needed +def test_pvnet_no_gsp(pvnet_config_filename): + + # load config + config = load_yaml_configuration(pvnet_config_filename) + # remove gsp + config.input_data.gsp.zarr_path = '' + + # save temp config file + with tempfile.NamedTemporaryFile() as temp_config_file: + save_yaml_configuration(config, temp_config_file.name) + # Create dataset object + dataset = PVNetUKRegionalDataset(temp_config_file.name) + + # Generate a sample + _ = dataset[0] + + +def test_process_and_combine_datasets(pvnet_config_filename): + + # Load in config for function and define location + config = load_yaml_configuration(pvnet_config_filename) + t0 = pd.Timestamp("2024-01-01") + location = Location(coordinate_system="osgb", x=1234, y=5678, id=1) + + nwp_data = xr.DataArray( + np.random.rand(4, 2, 2, 2), + dims=["time_utc", "channel", "y", "x"], + coords={ + "time_utc": pd.date_range("2024-01-01", periods=4, freq="h"), + "channel": ["t2m", "dswrf"], + "step": ("time_utc", pd.timedelta_range(start='0h', periods=4, freq='h')), + "init_time_utc": pd.Timestamp("2024-01-01") + } ) - return da_sat - - -@pytest.fixture(scope="module") -def da_nwp_like(): - """Create dummy NWP data""" - x = np.arange(-100, 100) - y = np.arange(-100, 100) - datetimes = pd.date_range("2024-01-02 00:00", "2024-01-03 00:00", freq=NWP_FREQ) - steps = pd.timedelta_range("0h", "16h", freq="1h") - channels = ["t", "dswrf"] - - da_nwp = xr.DataArray( - np.random.normal(size=(len(datetimes), len(steps), len(channels), len(x), len(y))), - coords=dict( - init_time_utc=(["init_time_utc"], datetimes), - step=(["step"], steps), - channel=(["channel"], channels), - x_osgb=(["x_osgb"], x), - y_osgb=(["y_osgb"], y), - ) + + sat_data = xr.DataArray( + np.random.rand(7, 1, 2, 2), + dims=["time_utc", "channel", "y", "x"], + coords={ + "time_utc": pd.date_range("2024-01-01", periods=7, freq="5min"), + "channel": ["HRV"], + "x_geostationary": (["y", "x"], np.array([[1, 2], [1, 2]])), + "y_geostationary": (["y", "x"], np.array([[1, 1], [2, 2]])) + } ) - return da_nwp - - -@pytest.fixture -def mock_constants(monkeypatch): - """Creation of dummy constants used in normalisation process""" - mock_nwp_means = {"ukv": { - "t": 10.0, - "dswrf": 50.0 - }} - mock_nwp_stds = {"ukv": { - "t": 2.0, - "dswrf": 10.0 - }} - mock_sat_means = 100.0 - mock_sat_stds = 20.0 - - monkeypatch.setattr("ocf_data_sampler.constants.NWP_MEANS", mock_nwp_means) - monkeypatch.setattr("ocf_data_sampler.constants.NWP_STDS", mock_nwp_stds) - monkeypatch.setattr("ocf_data_sampler.constants.SAT_MEANS", mock_sat_means) - monkeypatch.setattr("ocf_data_sampler.constants.SAT_STDS", mock_sat_stds) - - -@pytest.fixture -def mock_config(): - """Specify dummy configuration""" - class MockConfig: - class InputData: - class NWP: - provider = "ukv" - interval_start_minutes = -360 - interval_end_minutes = 180 - time_resolution_minutes = 60 - - class GSP: - interval_start_minutes = -120 - interval_end_minutes = 120 - time_resolution_minutes = 30 - - def __init__(self): - self.nwp = {"ukv": self.NWP()} - self.gsp = self.GSP() - - def __init__(self): - self.input_data = self.InputData() - - return MockConfig() + # Combine as dict + dataset_dict = { + "nwp": {"ukv": nwp_data}, + "sat": sat_data + } + + # Call relevant function + result = process_and_combine_datasets(dataset_dict, config, t0, location) -@pytest.fixture -def mock_location(): - """Create dummy location""" - return Location(id=12345, x=400000, y=500000) + # Assert result is dicr - check and validate + assert isinstance(result, dict) + assert NWPBatchKey.nwp in result + assert result[SatelliteBatchKey.satellite_actual].shape == (7, 1, 2, 2) + assert result[NWPBatchKey.nwp]["ukv"][NWPBatchKey.nwp].shape == (4, 1, 2, 2) def test_merge_dicts(): @@ -176,44 +183,35 @@ def test_compute(): assert not np.isnan(result["nested"]["array2"].data).any() -# TO DO - Update the below to include satellite and finalise testing procedure -# Currently for NWP only - awaiting confirmation -@pytest.mark.parametrize("t0_str", ["10:00", "11:00", "12:00"]) -def test_full_pipeline(da_nwp_like, mock_config, mock_location, mock_constants, t0_str): - """Test full pipeline considering time slice selection and then process and combine""" - t0 = pd.Timestamp(f"2024-01-02 {t0_str}") - - # Obtain NWP data slice - nwp_sample = select_time_slice_nwp( - da_nwp_like, - t0, - sample_period_duration=pd.Timedelta(minutes=mock_config.input_data.nwp["ukv"].time_resolution_minutes), - interval_start=pd.Timedelta(minutes=mock_config.input_data.nwp["ukv"].interval_start_minutes), - interval_end=pd.Timedelta(minutes=mock_config.input_data.nwp["ukv"].interval_end_minutes), - dropout_timedeltas=None, - dropout_frac=0, - accum_channels=["dswrf"], - channel_dim_name="channel", - ) - - # Prepare dataset dictionary - dataset_dict = { - "nwp": {"ukv": nwp_sample}, +def test_process_and_combine_site_sample_dict(pvnet_config_filename): + # Load config + config = load_yaml_configuration(pvnet_config_filename) + + # Specify minimal structure for testing + raw_nwp_values = np.random.rand(4, 1, 2, 2) # Single channel + site_dict = { + "nwp": { + "ukv": xr.DataArray( + raw_nwp_values, + dims=["time_utc", "channel", "y", "x"], + coords={ + "time_utc": pd.date_range("2024-01-01", periods=4, freq="h"), + "channel": ["dswrf"], # Single channel + }, + ) + } } - - # Process data with main function - result = process_and_combine_datasets( - dataset_dict, - mock_config, - t0, - mock_location, - target_key='gsp' - ) - - # Verify results structure - assert NWPBatchKey.nwp in result - - # Check NWP data normalisation and NaN handling - nwp_data = result[NWPBatchKey.nwp]["ukv"] - assert isinstance(nwp_data['nwp'], np.ndarray) - assert not np.isnan(nwp_data['nwp']).any() + print(f"Input site_dict: {site_dict}") + + # Call function + result = process_and_combine_site_sample_dict(site_dict, config) + + # Assert to validate output structure + assert isinstance(result, xr.Dataset), "Result should be an xarray.Dataset" + assert len(result.data_vars) > 0, "Dataset should contain data variables" + + # Validate variable via assertion and shape of such + expected_variable = "nwp-ukv" + assert expected_variable in result.data_vars, f"Expected variable '{expected_variable}' not found" + nwp_result = result[expected_variable] + assert nwp_result.shape == (4, 1, 2, 2), f"Unexpected shape for '{expected_variable}': {nwp_result.shape}" From a2b9b8cc0cf5d1fc092c159beaf3410d6378c032 Mon Sep 17 00:00:00 2001 From: Felix Peretz Date: Mon, 23 Dec 2024 16:52:43 +0000 Subject: [PATCH 19/20] Removal of reference tests --- .../test_process_and_combine.py | 55 ------------------- 1 file changed, 55 deletions(-) diff --git a/tests/torch_datasets/test_process_and_combine.py b/tests/torch_datasets/test_process_and_combine.py index 7102d1b..04f124e 100644 --- a/tests/torch_datasets/test_process_and_combine.py +++ b/tests/torch_datasets/test_process_and_combine.py @@ -20,61 +20,6 @@ ) -# Currently leaving here for reference purpose - not strictly needed -def test_pvnet(pvnet_config_filename): - - # Create dataset object - dataset = PVNetUKRegionalDataset(pvnet_config_filename) - - assert len(dataset.locations) == 317 # no of GSPs not including the National level - # NB. I have not checked this value is in fact correct, but it does seem to stay constant - assert len(dataset.valid_t0_times) == 39 - assert len(dataset) == 317*39 - - # Generate a sample - sample = dataset[0] - - assert isinstance(sample, dict) - - for key in [ - NWPBatchKey.nwp, SatelliteBatchKey.satellite_actual, GSPBatchKey.gsp, - GSPBatchKey.solar_azimuth, GSPBatchKey.solar_elevation, - ]: - assert key in sample - - for nwp_source in ["ukv"]: - assert nwp_source in sample[NWPBatchKey.nwp] - - # check the shape of the data is correct - # 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels - assert sample[SatelliteBatchKey.satellite_actual].shape == (7, 1, 2, 2) - # 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels - assert sample[NWPBatchKey.nwp]["ukv"][NWPBatchKey.nwp].shape == (4, 1, 2, 2) - # 3 hours of 30 minute data (inclusive) - assert sample[GSPBatchKey.gsp].shape == (7,) - # Solar angles have same shape as GSP data - assert sample[GSPBatchKey.solar_azimuth].shape == (7,) - assert sample[GSPBatchKey.solar_elevation].shape == (7,) - - -# Currently leaving here for reference purpose - not strictly needed -def test_pvnet_no_gsp(pvnet_config_filename): - - # load config - config = load_yaml_configuration(pvnet_config_filename) - # remove gsp - config.input_data.gsp.zarr_path = '' - - # save temp config file - with tempfile.NamedTemporaryFile() as temp_config_file: - save_yaml_configuration(config, temp_config_file.name) - # Create dataset object - dataset = PVNetUKRegionalDataset(temp_config_file.name) - - # Generate a sample - _ = dataset[0] - - def test_process_and_combine_datasets(pvnet_config_filename): # Load in config for function and define location From 94c9cafc3781ff565f7bd65be23b3282f8ea052d Mon Sep 17 00:00:00 2001 From: Felix Peretz Date: Mon, 23 Dec 2024 19:16:27 +0000 Subject: [PATCH 20/20] Updated timestamps and test compute --- .../test_process_and_combine.py | 41 ++++++++++--------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/tests/torch_datasets/test_process_and_combine.py b/tests/torch_datasets/test_process_and_combine.py index 04f124e..1d01449 100644 --- a/tests/torch_datasets/test_process_and_combine.py +++ b/tests/torch_datasets/test_process_and_combine.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd import xarray as xr +import dask.array as da from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration from ocf_data_sampler.config import Configuration @@ -24,17 +25,17 @@ def test_process_and_combine_datasets(pvnet_config_filename): # Load in config for function and define location config = load_yaml_configuration(pvnet_config_filename) - t0 = pd.Timestamp("2024-01-01") + t0 = pd.Timestamp("2024-01-01 00:00") location = Location(coordinate_system="osgb", x=1234, y=5678, id=1) nwp_data = xr.DataArray( np.random.rand(4, 2, 2, 2), dims=["time_utc", "channel", "y", "x"], coords={ - "time_utc": pd.date_range("2024-01-01", periods=4, freq="h"), + "time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"), "channel": ["t2m", "dswrf"], "step": ("time_utc", pd.timedelta_range(start='0h', periods=4, freq='h')), - "init_time_utc": pd.Timestamp("2024-01-01") + "init_time_utc": pd.Timestamp("2024-01-01 00:00") } ) @@ -42,7 +43,7 @@ def test_process_and_combine_datasets(pvnet_config_filename): np.random.rand(7, 1, 2, 2), dims=["time_utc", "channel", "y", "x"], coords={ - "time_utc": pd.date_range("2024-01-01", periods=7, freq="5min"), + "time_utc": pd.date_range("2024-01-01 00:00", periods=7, freq="5min"), "channel": ["HRV"], "x_geostationary": (["y", "x"], np.array([[1, 2], [1, 2]])), "y_geostationary": (["y", "x"], np.array([[1, 1], [2, 2]])) @@ -58,7 +59,7 @@ def test_process_and_combine_datasets(pvnet_config_filename): # Call relevant function result = process_and_combine_datasets(dataset_dict, config, t0, location) - # Assert result is dicr - check and validate + # Assert result is dict - check and validate assert isinstance(result, dict) assert NWPBatchKey.nwp in result assert result[SatelliteBatchKey.satellite_actual].shape == (7, 1, 2, 2) @@ -101,29 +102,31 @@ def test_fill_nans_in_arrays(): def test_compute(): - """Test the compute function""" - da = xr.DataArray(np.random.rand(5, 5)) + """Test compute function with dask array""" + da_dask = xr.DataArray(da.random.random((5, 5))) - # Create nested dictionary + # Create a nested dictionary with dask array nested_dict = { - "array1": da, + "array1": da_dask, "nested": { - "array2": da + "array2": da_dask } } - + + # Ensure initial data is lazy - i.e. not yet computed + assert not isinstance(nested_dict["array1"].data, np.ndarray) + assert not isinstance(nested_dict["nested"]["array2"].data, np.ndarray) + + # Call the compute function result = compute(nested_dict) - - # Ensure function applied - check if data is no longer lazy array and determine structural alterations - # Check that result is an xarray DataArray + + # Assert that the result is an xarray DataArray and no longer lazy assert isinstance(result["array1"], xr.DataArray) assert isinstance(result["nested"]["array2"], xr.DataArray) - - # Check data is no longer lazy object assert isinstance(result["array1"].data, np.ndarray) assert isinstance(result["nested"]["array2"].data, np.ndarray) - - # Check for NaN + + # Ensure there no NaN values in computed data assert not np.isnan(result["array1"].data).any() assert not np.isnan(result["nested"]["array2"].data).any() @@ -140,7 +143,7 @@ def test_process_and_combine_site_sample_dict(pvnet_config_filename): raw_nwp_values, dims=["time_utc", "channel", "y", "x"], coords={ - "time_utc": pd.date_range("2024-01-01", periods=4, freq="h"), + "time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"), "channel": ["dswrf"], # Single channel }, )