Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Satellite normalisation #97

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,6 @@ dmypy.json

# Pyre type checker
.pyre/

# OS
.DS_Store
38 changes: 38 additions & 0 deletions ocf_data_sampler/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __getitem__(self, key):
f"Values for {key} not yet available in ocf-data-sampler {list(self.keys())}"
)


# ------ UKV
# Means and std computed WITH version_7 and higher, MetOffice values
UKV_STD = {
Expand All @@ -49,6 +50,7 @@ def __getitem__(self, key):
"prmsl": 1252.71790539,
"prate": 0.00021497,
}

UKV_MEAN = {
"cdcb": 1412.26599062,
"lcc": 50.08362643,
Expand Down Expand Up @@ -97,6 +99,7 @@ def __getitem__(self, key):
"diff_duvrs": 81605.25,
"diff_sr": 818950.6875,
}

ECMWF_MEAN = {
"dlwrf": 27187026.0,
"dswrf": 11458988.0,
Expand Down Expand Up @@ -133,3 +136,38 @@ def __getitem__(self, key):
ecmwf=ECMWF_MEAN,
)

# ------ Satellite
# RSS Mean and std values from randomised 20% of 2020 imagery

RSS_STD = {
"HRV": 0.11405209,
"IR_016": 0.21462157,
"IR_039": 0.04618041,
"IR_087": 0.06687243,
"IR_097": 0.0468558,
"IR_108": 0.17482725,
"IR_120": 0.06115861,
"IR_134": 0.04492306,
"VIS006": 0.12184761,
"VIS008": 0.13090034,
"WV_062": 0.16111417,
"WV_073": 0.12924142,
}

RSS_MEAN = {
"HRV": 0.09298719,
"IR_016": 0.17594202,
"IR_039": 0.86167645,
"IR_087": 0.7719318,
"IR_097": 0.8014212,
"IR_108": 0.71254843,
"IR_120": 0.89058584,
"IR_134": 0.944365,
"VIS006": 0.09633306,
"VIS008": 0.11426069,
"WV_062": 0.7359355,
"WV_073": 0.62479186,
}

RSS_STD = _to_data_array(RSS_STD)
RSS_MEAN = _to_data_array(RSS_MEAN)
1 change: 0 additions & 1 deletion ocf_data_sampler/numpy_batch/nwp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Convert NWP to NumpyBatch"""

import pandas as pd
import xarray as xr

Expand Down
3 changes: 2 additions & 1 deletion ocf_data_sampler/numpy_batch/satellite.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class SatelliteBatchKey:

def convert_satellite_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> dict:
"""Convert from Xarray to NumpyBatch"""

example = {
SatelliteBatchKey.satellite_actual: da.values,
SatelliteBatchKey.time_utc: da.time_utc.values.astype(float),
Expand All @@ -27,4 +28,4 @@ def convert_satellite_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None
if t0_idx is not None:
example[SatelliteBatchKey.t0_idx] = t0_idx

return example
return example
16 changes: 12 additions & 4 deletions ocf_data_sampler/torch_datasets/process_and_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Tuple

from ocf_data_sampler.config import Configuration
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
from ocf_data_sampler.numpy_batch import (
convert_nwp_to_numpy_batch,
convert_satellite_to_numpy_batch,
Expand All @@ -13,6 +13,7 @@
)
from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey
from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
from ocf_data_sampler.numpy_batch.satellite import SatelliteBatchKey
from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
from ocf_data_sampler.select.location import Location
from ocf_data_sampler.utils import minutes
Expand All @@ -25,8 +26,8 @@ def process_and_combine_datasets(
location: Location,
target_key: str = 'gsp'
) -> dict:
"""Normalize and convert data to numpy arrays"""

"""Normalise and convert data to numpy arrays"""
felix-e-h-p marked this conversation as resolved.
Show resolved Hide resolved
numpy_modalities = []

if "nwp" in dataset_dict:
Expand All @@ -37,19 +38,23 @@ def process_and_combine_datasets(
# Standardise
provider = config.input_data.nwp[nwp_key].provider
da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]

# Convert to NumpyBatch
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(da_nwp)

# Combine the NWPs into NumpyBatch
numpy_modalities.append({NWPBatchKey.nwp: nwp_numpy_modalities})


if "sat" in dataset_dict:
# Satellite is already in the range [0-1] so no need to standardise
# Standardise
da_sat = dataset_dict["sat"]
da_sat = (da_sat - RSS_MEAN) / RSS_STD

# Convert to NumpyBatch
numpy_modalities.append(convert_satellite_to_numpy_batch(da_sat))


gsp_config = config.input_data.gsp

if "gsp" in dataset_dict:
Expand Down Expand Up @@ -93,6 +98,7 @@ def process_and_combine_datasets(

return combined_sample


def process_and_combine_site_sample_dict(
dataset_dict: dict,
config: Configuration,
Expand All @@ -119,8 +125,9 @@ def process_and_combine_site_sample_dict(
data_arrays.append((f"nwp-{provider}", da_nwp))

if "sat" in dataset_dict:
# TODO add some satellite normalisation
# Satellite normalisation added
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for this

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@felix-e-h-p this is picky but can we change this comment to just # Standardise or something. I don't think "Satellite normalisation added" makes sense out of context now the TODO is gone

da_sat = dataset_dict["sat"]
da_sat = (da_sat - RSS_MEAN) / RSS_STD
data_arrays.append(("satellite", da_sat))

if "site" in dataset_dict:
Expand All @@ -143,6 +150,7 @@ def merge_dicts(list_of_dicts: list[dict]) -> dict:
combined_dict.update(d)
return combined_dict


def merge_arrays(normalised_data_arrays: list[Tuple[str, xr.DataArray]]) -> xr.Dataset:
"""
Combine a list of DataArrays into a single Dataset with unique naming conventions.
Expand Down
217 changes: 217 additions & 0 deletions tests/torch_datasets/test_process_and_combine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
import pytest
import tempfile

import numpy as np
import pandas as pd
import xarray as xr

from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
from ocf_data_sampler.config import Configuration
from ocf_data_sampler.select.location import Location
from ocf_data_sampler.numpy_batch import NWPBatchKey, GSPBatchKey, SatelliteBatchKey
from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset

from ocf_data_sampler.torch_datasets.process_and_combine import (
process_and_combine_datasets,
process_and_combine_site_sample_dict,
merge_dicts,
fill_nans_in_arrays,
compute,
)


# Currently leaving here for reference purpose - not strictly needed
felix-e-h-p marked this conversation as resolved.
Show resolved Hide resolved
def test_pvnet(pvnet_config_filename):

# Create dataset object
dataset = PVNetUKRegionalDataset(pvnet_config_filename)

assert len(dataset.locations) == 317 # no of GSPs not including the National level
# NB. I have not checked this value is in fact correct, but it does seem to stay constant
assert len(dataset.valid_t0_times) == 39
assert len(dataset) == 317*39

# Generate a sample
sample = dataset[0]

assert isinstance(sample, dict)

for key in [
NWPBatchKey.nwp, SatelliteBatchKey.satellite_actual, GSPBatchKey.gsp,
GSPBatchKey.solar_azimuth, GSPBatchKey.solar_elevation,
]:
assert key in sample

for nwp_source in ["ukv"]:
assert nwp_source in sample[NWPBatchKey.nwp]

# check the shape of the data is correct
# 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels
assert sample[SatelliteBatchKey.satellite_actual].shape == (7, 1, 2, 2)
# 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels
assert sample[NWPBatchKey.nwp]["ukv"][NWPBatchKey.nwp].shape == (4, 1, 2, 2)
# 3 hours of 30 minute data (inclusive)
assert sample[GSPBatchKey.gsp].shape == (7,)
# Solar angles have same shape as GSP data
assert sample[GSPBatchKey.solar_azimuth].shape == (7,)
assert sample[GSPBatchKey.solar_elevation].shape == (7,)


# Currently leaving here for reference purpose - not strictly needed
def test_pvnet_no_gsp(pvnet_config_filename):
felix-e-h-p marked this conversation as resolved.
Show resolved Hide resolved

# load config
config = load_yaml_configuration(pvnet_config_filename)
# remove gsp
config.input_data.gsp.zarr_path = ''

# save temp config file
with tempfile.NamedTemporaryFile() as temp_config_file:
save_yaml_configuration(config, temp_config_file.name)
# Create dataset object
dataset = PVNetUKRegionalDataset(temp_config_file.name)

# Generate a sample
_ = dataset[0]


def test_process_and_combine_datasets(pvnet_config_filename):

# Load in config for function and define location
config = load_yaml_configuration(pvnet_config_filename)
t0 = pd.Timestamp("2024-01-01")
felix-e-h-p marked this conversation as resolved.
Show resolved Hide resolved
location = Location(coordinate_system="osgb", x=1234, y=5678, id=1)

nwp_data = xr.DataArray(
np.random.rand(4, 2, 2, 2),
dims=["time_utc", "channel", "y", "x"],
coords={
"time_utc": pd.date_range("2024-01-01", periods=4, freq="h"),
felix-e-h-p marked this conversation as resolved.
Show resolved Hide resolved
"channel": ["t2m", "dswrf"],
"step": ("time_utc", pd.timedelta_range(start='0h', periods=4, freq='h')),
"init_time_utc": pd.Timestamp("2024-01-01")
}
)

sat_data = xr.DataArray(
np.random.rand(7, 1, 2, 2),
dims=["time_utc", "channel", "y", "x"],
coords={
"time_utc": pd.date_range("2024-01-01", periods=7, freq="5min"),
felix-e-h-p marked this conversation as resolved.
Show resolved Hide resolved
"channel": ["HRV"],
"x_geostationary": (["y", "x"], np.array([[1, 2], [1, 2]])),
"y_geostationary": (["y", "x"], np.array([[1, 1], [2, 2]]))
}
)

# Combine as dict
dataset_dict = {
"nwp": {"ukv": nwp_data},
"sat": sat_data
}

# Call relevant function
result = process_and_combine_datasets(dataset_dict, config, t0, location)

# Assert result is dicr - check and validate
assert isinstance(result, dict)
assert NWPBatchKey.nwp in result
assert result[SatelliteBatchKey.satellite_actual].shape == (7, 1, 2, 2)
assert result[NWPBatchKey.nwp]["ukv"][NWPBatchKey.nwp].shape == (4, 1, 2, 2)


def test_merge_dicts():
"""Test merge_dicts function"""
dict1 = {"a": 1, "b": 2}
dict2 = {"c": 3, "d": 4}
dict3 = {"e": 5}

result = merge_dicts([dict1, dict2, dict3])
assert result == {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}

# Test key overwriting
dict4 = {"a": 10, "f": 6}
result = merge_dicts([dict1, dict4])
assert result["a"] == 10


def test_fill_nans_in_arrays():
"""Test the fill_nans_in_arrays function"""
array_with_nans = np.array([1.0, np.nan, 3.0, np.nan])
nested_dict = {
"array1": array_with_nans,
"nested": {
"array2": np.array([np.nan, 2.0, np.nan, 4.0])
},
"string_key": "not_an_array"
}

result = fill_nans_in_arrays(nested_dict)

assert not np.isnan(result["array1"]).any()
assert np.array_equal(result["array1"], np.array([1.0, 0.0, 3.0, 0.0]))
assert not np.isnan(result["nested"]["array2"]).any()
assert np.array_equal(result["nested"]["array2"], np.array([0.0, 2.0, 0.0, 4.0]))
assert result["string_key"] == "not_an_array"


def test_compute():
"""Test the compute function"""
da = xr.DataArray(np.random.rand(5, 5))

# Create nested dictionary
nested_dict = {
"array1": da,
"nested": {
"array2": da
}
}

result = compute(nested_dict)

# Ensure function applied - check if data is no longer lazy array and determine structural alterations
# Check that result is an xarray DataArray
assert isinstance(result["array1"], xr.DataArray)
assert isinstance(result["nested"]["array2"], xr.DataArray)

# Check data is no longer lazy object
assert isinstance(result["array1"].data, np.ndarray)
felix-e-h-p marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(result["nested"]["array2"].data, np.ndarray)

# Check for NaN
assert not np.isnan(result["array1"].data).any()
assert not np.isnan(result["nested"]["array2"].data).any()


def test_process_and_combine_site_sample_dict(pvnet_config_filename):
# Load config
config = load_yaml_configuration(pvnet_config_filename)

# Specify minimal structure for testing
raw_nwp_values = np.random.rand(4, 1, 2, 2) # Single channel
site_dict = {
"nwp": {
"ukv": xr.DataArray(
raw_nwp_values,
dims=["time_utc", "channel", "y", "x"],
coords={
"time_utc": pd.date_range("2024-01-01", periods=4, freq="h"),
felix-e-h-p marked this conversation as resolved.
Show resolved Hide resolved
"channel": ["dswrf"], # Single channel
},
)
}
}
print(f"Input site_dict: {site_dict}")

# Call function
result = process_and_combine_site_sample_dict(site_dict, config)

# Assert to validate output structure
assert isinstance(result, xr.Dataset), "Result should be an xarray.Dataset"
assert len(result.data_vars) > 0, "Dataset should contain data variables"

# Validate variable via assertion and shape of such
expected_variable = "nwp-ukv"
assert expected_variable in result.data_vars, f"Expected variable '{expected_variable}' not found"
nwp_result = result[expected_variable]
assert nwp_result.shape == (4, 1, 2, 2), f"Unexpected shape for '{expected_variable}': {nwp_result.shape}"
Loading