Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Satellite normalisation #97

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,6 @@ dmypy.json

# Pyre type checker
.pyre/

# OS
.DS_Store
38 changes: 38 additions & 0 deletions ocf_data_sampler/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __getitem__(self, key):
f"Values for {key} not yet available in ocf-data-sampler {list(self.keys())}"
)


# ------ UKV
# Means and std computed WITH version_7 and higher, MetOffice values
UKV_STD = {
Expand All @@ -49,6 +50,7 @@ def __getitem__(self, key):
"prmsl": 1252.71790539,
"prate": 0.00021497,
}

UKV_MEAN = {
"cdcb": 1412.26599062,
"lcc": 50.08362643,
Expand Down Expand Up @@ -97,6 +99,7 @@ def __getitem__(self, key):
"diff_duvrs": 81605.25,
"diff_sr": 818950.6875,
}

ECMWF_MEAN = {
"dlwrf": 27187026.0,
"dswrf": 11458988.0,
Expand Down Expand Up @@ -133,3 +136,38 @@ def __getitem__(self, key):
ecmwf=ECMWF_MEAN,
)

# ------ Satellite
# RSS Mean and std values from randomised 20% of 2020 imagery

RSS_STD = {
"HRV": 0.11405209,
"IR_016": 0.21462157,
"IR_039": 0.04618041,
"IR_087": 0.06687243,
"IR_097": 0.0468558,
"IR_108": 0.17482725,
"IR_120": 0.06115861,
"IR_134": 0.04492306,
"VIS006": 0.12184761,
"VIS008": 0.13090034,
"WV_062": 0.16111417,
"WV_073": 0.12924142,
}

RSS_MEAN = {
"HRV": 0.09298719,
"IR_016": 0.17594202,
"IR_039": 0.86167645,
"IR_087": 0.7719318,
"IR_097": 0.8014212,
"IR_108": 0.71254843,
"IR_120": 0.89058584,
"IR_134": 0.944365,
"VIS006": 0.09633306,
"VIS008": 0.11426069,
"WV_062": 0.7359355,
"WV_073": 0.62479186,
}

RSS_STD = _to_data_array(RSS_STD)
RSS_MEAN = _to_data_array(RSS_MEAN)
1 change: 0 additions & 1 deletion ocf_data_sampler/numpy_batch/nwp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Convert NWP to NumpyBatch"""

import pandas as pd
import xarray as xr

Expand Down
3 changes: 2 additions & 1 deletion ocf_data_sampler/numpy_batch/satellite.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class SatelliteBatchKey:

def convert_satellite_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> dict:
"""Convert from Xarray to NumpyBatch"""

example = {
SatelliteBatchKey.satellite_actual: da.values,
SatelliteBatchKey.time_utc: da.time_utc.values.astype(float),
Expand All @@ -27,4 +28,4 @@ def convert_satellite_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None
if t0_idx is not None:
example[SatelliteBatchKey.t0_idx] = t0_idx

return example
return example
16 changes: 12 additions & 4 deletions ocf_data_sampler/torch_datasets/process_and_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Tuple

from ocf_data_sampler.config import Configuration
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
from ocf_data_sampler.numpy_batch import (
convert_nwp_to_numpy_batch,
convert_satellite_to_numpy_batch,
Expand All @@ -13,6 +13,7 @@
)
from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey
from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
from ocf_data_sampler.numpy_batch.satellite import SatelliteBatchKey
from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
from ocf_data_sampler.select.location import Location
from ocf_data_sampler.utils import minutes
Expand All @@ -25,8 +26,8 @@ def process_and_combine_datasets(
location: Location,
target_key: str = 'gsp'
) -> dict:
"""Normalize and convert data to numpy arrays"""

"""Normalise and convert data to numpy arrays"""
felix-e-h-p marked this conversation as resolved.
Show resolved Hide resolved
numpy_modalities = []

if "nwp" in dataset_dict:
Expand All @@ -37,19 +38,23 @@ def process_and_combine_datasets(
# Standardise
provider = config.input_data.nwp[nwp_key].provider
da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]

# Convert to NumpyBatch
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(da_nwp)

# Combine the NWPs into NumpyBatch
numpy_modalities.append({NWPBatchKey.nwp: nwp_numpy_modalities})


if "sat" in dataset_dict:
# Satellite is already in the range [0-1] so no need to standardise
# Standardise
da_sat = dataset_dict["sat"]
da_sat = (da_sat - RSS_MEAN) / RSS_STD

# Convert to NumpyBatch
numpy_modalities.append(convert_satellite_to_numpy_batch(da_sat))


gsp_config = config.input_data.gsp

if "gsp" in dataset_dict:
Expand Down Expand Up @@ -93,6 +98,7 @@ def process_and_combine_datasets(

return combined_sample


def process_and_combine_site_sample_dict(
dataset_dict: dict,
config: Configuration,
Expand All @@ -119,8 +125,9 @@ def process_and_combine_site_sample_dict(
data_arrays.append((f"nwp-{provider}", da_nwp))

if "sat" in dataset_dict:
# TODO add some satellite normalisation
# Satellite normalisation added
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for this

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@felix-e-h-p this is picky but can we change this comment to just # Standardise or something. I don't think "Satellite normalisation added" makes sense out of context now the TODO is gone

da_sat = dataset_dict["sat"]
da_sat = (da_sat - RSS_MEAN) / RSS_STD
data_arrays.append(("satellite", da_sat))

if "site" in dataset_dict:
Expand All @@ -143,6 +150,7 @@ def merge_dicts(list_of_dicts: list[dict]) -> dict:
combined_dict.update(d)
return combined_dict


def merge_arrays(normalised_data_arrays: list[Tuple[str, xr.DataArray]]) -> xr.Dataset:
"""
Combine a list of DataArrays into a single Dataset with unique naming conventions.
Expand Down
165 changes: 165 additions & 0 deletions tests/torch_datasets/test_process_and_combine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import pytest
import tempfile

import numpy as np
import pandas as pd
import xarray as xr
import dask.array as da

from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
from ocf_data_sampler.config import Configuration
from ocf_data_sampler.select.location import Location
from ocf_data_sampler.numpy_batch import NWPBatchKey, GSPBatchKey, SatelliteBatchKey
from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset

from ocf_data_sampler.torch_datasets.process_and_combine import (
process_and_combine_datasets,
process_and_combine_site_sample_dict,
merge_dicts,
fill_nans_in_arrays,
compute,
)


def test_process_and_combine_datasets(pvnet_config_filename):

# Load in config for function and define location
config = load_yaml_configuration(pvnet_config_filename)
t0 = pd.Timestamp("2024-01-01 00:00")
location = Location(coordinate_system="osgb", x=1234, y=5678, id=1)

nwp_data = xr.DataArray(
np.random.rand(4, 2, 2, 2),
dims=["time_utc", "channel", "y", "x"],
coords={
"time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"),
"channel": ["t2m", "dswrf"],
"step": ("time_utc", pd.timedelta_range(start='0h', periods=4, freq='h')),
"init_time_utc": pd.Timestamp("2024-01-01 00:00")
}
)

sat_data = xr.DataArray(
np.random.rand(7, 1, 2, 2),
dims=["time_utc", "channel", "y", "x"],
coords={
"time_utc": pd.date_range("2024-01-01 00:00", periods=7, freq="5min"),
"channel": ["HRV"],
"x_geostationary": (["y", "x"], np.array([[1, 2], [1, 2]])),
"y_geostationary": (["y", "x"], np.array([[1, 1], [2, 2]]))
}
)

# Combine as dict
dataset_dict = {
"nwp": {"ukv": nwp_data},
"sat": sat_data
}

# Call relevant function
result = process_and_combine_datasets(dataset_dict, config, t0, location)

# Assert result is dict - check and validate
assert isinstance(result, dict)
assert NWPBatchKey.nwp in result
assert result[SatelliteBatchKey.satellite_actual].shape == (7, 1, 2, 2)
assert result[NWPBatchKey.nwp]["ukv"][NWPBatchKey.nwp].shape == (4, 1, 2, 2)


def test_merge_dicts():
"""Test merge_dicts function"""
dict1 = {"a": 1, "b": 2}
dict2 = {"c": 3, "d": 4}
dict3 = {"e": 5}

result = merge_dicts([dict1, dict2, dict3])
assert result == {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}

# Test key overwriting
dict4 = {"a": 10, "f": 6}
result = merge_dicts([dict1, dict4])
assert result["a"] == 10


def test_fill_nans_in_arrays():
felix-e-h-p marked this conversation as resolved.
Show resolved Hide resolved
"""Test the fill_nans_in_arrays function"""
array_with_nans = np.array([1.0, np.nan, 3.0, np.nan])
nested_dict = {
"array1": array_with_nans,
"nested": {
"array2": np.array([np.nan, 2.0, np.nan, 4.0])
},
"string_key": "not_an_array"
}

result = fill_nans_in_arrays(nested_dict)

assert not np.isnan(result["array1"]).any()
assert np.array_equal(result["array1"], np.array([1.0, 0.0, 3.0, 0.0]))
assert not np.isnan(result["nested"]["array2"]).any()
assert np.array_equal(result["nested"]["array2"], np.array([0.0, 2.0, 0.0, 4.0]))
assert result["string_key"] == "not_an_array"


def test_compute():
"""Test compute function with dask array"""
da_dask = xr.DataArray(da.random.random((5, 5)))

# Create a nested dictionary with dask array
nested_dict = {
"array1": da_dask,
"nested": {
"array2": da_dask
}
}

# Ensure initial data is lazy - i.e. not yet computed
assert not isinstance(nested_dict["array1"].data, np.ndarray)
assert not isinstance(nested_dict["nested"]["array2"].data, np.ndarray)

# Call the compute function
result = compute(nested_dict)

# Assert that the result is an xarray DataArray and no longer lazy
assert isinstance(result["array1"], xr.DataArray)
assert isinstance(result["nested"]["array2"], xr.DataArray)
assert isinstance(result["array1"].data, np.ndarray)
felix-e-h-p marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(result["nested"]["array2"].data, np.ndarray)

# Ensure there no NaN values in computed data
assert not np.isnan(result["array1"].data).any()
assert not np.isnan(result["nested"]["array2"].data).any()


def test_process_and_combine_site_sample_dict(pvnet_config_filename):
# Load config
config = load_yaml_configuration(pvnet_config_filename)

# Specify minimal structure for testing
raw_nwp_values = np.random.rand(4, 1, 2, 2) # Single channel
site_dict = {
"nwp": {
"ukv": xr.DataArray(
raw_nwp_values,
dims=["time_utc", "channel", "y", "x"],
coords={
"time_utc": pd.date_range("2024-01-01 00:00", periods=4, freq="h"),
"channel": ["dswrf"], # Single channel
},
)
}
}
print(f"Input site_dict: {site_dict}")

# Call function
result = process_and_combine_site_sample_dict(site_dict, config)

# Assert to validate output structure
assert isinstance(result, xr.Dataset), "Result should be an xarray.Dataset"
assert len(result.data_vars) > 0, "Dataset should contain data variables"

# Validate variable via assertion and shape of such
expected_variable = "nwp-ukv"
assert expected_variable in result.data_vars, f"Expected variable '{expected_variable}' not found"
nwp_result = result[expected_variable]
assert nwp_result.shape == (4, 1, 2, 2), f"Unexpected shape for '{expected_variable}': {nwp_result.shape}"
Loading