Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Satellite normalisation #97

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,6 @@ dmypy.json

# Pyre type checker
.pyre/

# OS
.DS_Store
42 changes: 42 additions & 0 deletions ocf_data_sampler/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __getitem__(self, key):
f"Values for {key} not yet available in ocf-data-sampler {list(self.keys())}"
)


# ------ UKV
# Means and std computed WITH version_7 and higher, MetOffice values
UKV_STD = {
Expand All @@ -49,6 +50,7 @@ def __getitem__(self, key):
"prmsl": 1252.71790539,
"prate": 0.00021497,
}

UKV_MEAN = {
"cdcb": 1412.26599062,
"lcc": 50.08362643,
Expand Down Expand Up @@ -97,6 +99,7 @@ def __getitem__(self, key):
"diff_duvrs": 81605.25,
"diff_sr": 818950.6875,
}

ECMWF_MEAN = {
"dlwrf": 27187026.0,
"dswrf": 11458988.0,
Expand Down Expand Up @@ -133,3 +136,42 @@ def __getitem__(self, key):
ecmwf=ECMWF_MEAN,
)

# ------ Satellite
# RSS Mean and std values from randomised 20% of 2020 imagery

SAT_STD = {
"HRV": 0.11405209,
"IR_016": 0.21462157,
"IR_039": 0.04618041,
"IR_087": 0.06687243,
"IR_097": 0.0468558,
"IR_108": 0.17482725,
"IR_120": 0.06115861,
"IR_134": 0.04492306,
"VIS006": 0.12184761,
"VIS008": 0.13090034,
"WV_062": 0.16111417,
"WV_073": 0.12924142,
}

SAT_MEAN = {
"HRV": 0.09298719,
"IR_016": 0.17594202,
"IR_039": 0.86167645,
"IR_087": 0.7719318,
"IR_097": 0.8014212,
"IR_108": 0.71254843,
"IR_120": 0.89058584,
"IR_134": 0.944365,
"VIS006": 0.09633306,
"VIS008": 0.11426069,
"WV_062": 0.7359355,
"WV_073": 0.62479186,
}

SAT_STD = _to_data_array(SAT_STD)
felix-e-h-p marked this conversation as resolved.
Show resolved Hide resolved
SAT_MEAN = _to_data_array(SAT_MEAN)

# SatStatDict wrapper not needed due to singular provider - direct assignment of meand and std
SAT_STDS = SAT_STD
felix-e-h-p marked this conversation as resolved.
Show resolved Hide resolved
SAT_MEANS = SAT_MEAN
1 change: 0 additions & 1 deletion ocf_data_sampler/numpy_batch/nwp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Convert NWP to NumpyBatch"""

import pandas as pd
import xarray as xr

Expand Down
3 changes: 2 additions & 1 deletion ocf_data_sampler/numpy_batch/satellite.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class SatelliteBatchKey:

def convert_satellite_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> dict:
"""Convert from Xarray to NumpyBatch"""

example = {
SatelliteBatchKey.satellite_actual: da.values,
SatelliteBatchKey.time_utc: da.time_utc.values.astype(float),
Expand All @@ -27,4 +28,4 @@ def convert_satellite_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None
if t0_idx is not None:
example[SatelliteBatchKey.t0_idx] = t0_idx

return example
return example
17 changes: 12 additions & 5 deletions ocf_data_sampler/torch_datasets/process_and_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import xarray as xr

from ocf_data_sampler.config import Configuration
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, SAT_MEANS, SAT_STDS
from ocf_data_sampler.numpy_batch import (
convert_nwp_to_numpy_batch,
convert_satellite_to_numpy_batch,
Expand All @@ -13,6 +13,8 @@
)
from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey
from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
from ocf_data_sampler.numpy_batch.satellite import SatelliteBatchKey

from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
from ocf_data_sampler.select.location import Location
from ocf_data_sampler.utils import minutes
Expand All @@ -25,8 +27,8 @@ def process_and_combine_datasets(
location: Location,
target_key: str = 'gsp'
) -> dict:
"""Normalize and convert data to numpy arrays"""

"""Normalise and convert data to numpy arrays"""
felix-e-h-p marked this conversation as resolved.
Show resolved Hide resolved
numpy_modalities = []

if "nwp" in dataset_dict:
Expand All @@ -37,18 +39,23 @@ def process_and_combine_datasets(
# Standardise
provider = config.input_data.nwp[nwp_key].provider
da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider]

# Convert to NumpyBatch
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(da_nwp)

# Combine the NWPs into NumpyBatch
numpy_modalities.append({NWPBatchKey.nwp: nwp_numpy_modalities})


if "sat" in dataset_dict:
# Satellite is already in the range [0-1] so no need to standardise
# Standardise
da_sat = dataset_dict["sat"]

da_sat = (da_sat - SAT_MEANS) / SAT_STDS
# Convert to NumpyBatch
numpy_modalities.append(convert_satellite_to_numpy_batch(da_sat))
sat_numpy_modalities = convert_satellite_to_numpy_batch(da_sat)
# Combine the Satellite into NumpyBatch
numpy_modalities.append({SatelliteBatchKey.satellite_actual: sat_numpy_modalities})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line, could be changed to back to numpy_modalities.append(convert_satellite_to_numpy_batch(da_sat), and then I think itll be fine



gsp_config = config.input_data.gsp

Expand Down
219 changes: 219 additions & 0 deletions tests/torch_datasets/test_process_and_combine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import pytest
import numpy as np
import pandas as pd
import xarray as xr

from ocf_data_sampler.config import Configuration
from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
from ocf_data_sampler.numpy_batch.satellite import SatelliteBatchKey
from ocf_data_sampler.select.location import Location
from ocf_data_sampler.select.select_time_slice import select_time_slice, select_time_slice_nwp
from ocf_data_sampler.utils import minutes

from ocf_data_sampler.torch_datasets.process_and_combine import (
process_and_combine_datasets,
merge_dicts,
fill_nans_in_arrays,
compute,
)


NWP_FREQ = pd.Timedelta("3h")


@pytest.fixture(scope="module")
def da_sat_like():
"""Create dummy satellite data"""
x = np.arange(-100, 100)
y = np.arange(-100, 100)
datetimes = pd.date_range("2024-01-02 00:00", "2024-01-03 00:00", freq="5min")

da_sat = xr.DataArray(
np.random.normal(size=(len(datetimes), len(x), len(y))),
coords=dict(
time_utc=(["time_utc"], datetimes),
x_geostationary=(["x_geostationary"], x),
y_geostationary=(["y_geostationary"], y),
)
)
return da_sat


@pytest.fixture(scope="module")
def da_nwp_like():
"""Create dummy NWP data"""
x = np.arange(-100, 100)
y = np.arange(-100, 100)
datetimes = pd.date_range("2024-01-02 00:00", "2024-01-03 00:00", freq=NWP_FREQ)
steps = pd.timedelta_range("0h", "16h", freq="1h")
channels = ["t", "dswrf"]

da_nwp = xr.DataArray(
np.random.normal(size=(len(datetimes), len(steps), len(channels), len(x), len(y))),
coords=dict(
init_time_utc=(["init_time_utc"], datetimes),
step=(["step"], steps),
channel=(["channel"], channels),
x_osgb=(["x_osgb"], x),
y_osgb=(["y_osgb"], y),
)
)
return da_nwp


@pytest.fixture
def mock_constants(monkeypatch):
felix-e-h-p marked this conversation as resolved.
Show resolved Hide resolved
"""Creation of dummy constants used in normalisation process"""
mock_nwp_means = {"ukv": {
"t": 10.0,
"dswrf": 50.0
}}
mock_nwp_stds = {"ukv": {
"t": 2.0,
"dswrf": 10.0
}}
mock_sat_means = 100.0
mock_sat_stds = 20.0

monkeypatch.setattr("ocf_data_sampler.constants.NWP_MEANS", mock_nwp_means)
monkeypatch.setattr("ocf_data_sampler.constants.NWP_STDS", mock_nwp_stds)
monkeypatch.setattr("ocf_data_sampler.constants.SAT_MEANS", mock_sat_means)
monkeypatch.setattr("ocf_data_sampler.constants.SAT_STDS", mock_sat_stds)


@pytest.fixture
def mock_config():
"""Specify dummy configuration"""
class MockConfig:
class InputData:
class NWP:
provider = "ukv"
interval_start_minutes = -360
interval_end_minutes = 180
time_resolution_minutes = 60

class GSP:
interval_start_minutes = -120
interval_end_minutes = 120
time_resolution_minutes = 30

def __init__(self):
self.nwp = {"ukv": self.NWP()}
self.gsp = self.GSP()

def __init__(self):
self.input_data = self.InputData()

return MockConfig()


@pytest.fixture
def mock_location():
felix-e-h-p marked this conversation as resolved.
Show resolved Hide resolved
"""Create dummy location"""
return Location(id=12345, x=400000, y=500000)


def test_merge_dicts():
"""Test merge_dicts function"""
dict1 = {"a": 1, "b": 2}
dict2 = {"c": 3, "d": 4}
dict3 = {"e": 5}

result = merge_dicts([dict1, dict2, dict3])
assert result == {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}

# Test key overwriting
dict4 = {"a": 10, "f": 6}
result = merge_dicts([dict1, dict4])
assert result["a"] == 10


def test_fill_nans_in_arrays():
felix-e-h-p marked this conversation as resolved.
Show resolved Hide resolved
"""Test the fill_nans_in_arrays function"""
array_with_nans = np.array([1.0, np.nan, 3.0, np.nan])
nested_dict = {
"array1": array_with_nans,
"nested": {
"array2": np.array([np.nan, 2.0, np.nan, 4.0])
},
"string_key": "not_an_array"
}

result = fill_nans_in_arrays(nested_dict)

assert not np.isnan(result["array1"]).any()
assert np.array_equal(result["array1"], np.array([1.0, 0.0, 3.0, 0.0]))
assert not np.isnan(result["nested"]["array2"]).any()
assert np.array_equal(result["nested"]["array2"], np.array([0.0, 2.0, 0.0, 4.0]))
assert result["string_key"] == "not_an_array"


def test_compute():
"""Test the compute function"""
da = xr.DataArray(np.random.rand(5, 5))

# Create nested dictionary
nested_dict = {
"array1": da,
"nested": {
"array2": da
}
}

result = compute(nested_dict)

# Ensure function applied - check if data is no longer lazy array and determine structural alterations
# Check that result is an xarray DataArray
assert isinstance(result["array1"], xr.DataArray)
assert isinstance(result["nested"]["array2"], xr.DataArray)

# Check data is no longer lazy object
assert isinstance(result["array1"].data, np.ndarray)
felix-e-h-p marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(result["nested"]["array2"].data, np.ndarray)

# Check for NaN
assert not np.isnan(result["array1"].data).any()
assert not np.isnan(result["nested"]["array2"].data).any()


# TO DO - Update the below to include satellite and finalise testing procedure
# Currently for NWP only - awaiting confirmation
@pytest.mark.parametrize("t0_str", ["10:00", "11:00", "12:00"])
def test_full_pipeline(da_nwp_like, mock_config, mock_location, mock_constants, t0_str):
"""Test full pipeline considering time slice selection and then process and combine"""
t0 = pd.Timestamp(f"2024-01-02 {t0_str}")

# Obtain NWP data slice
nwp_sample = select_time_slice_nwp(
da_nwp_like,
t0,
sample_period_duration=pd.Timedelta(minutes=mock_config.input_data.nwp["ukv"].time_resolution_minutes),
interval_start=pd.Timedelta(minutes=mock_config.input_data.nwp["ukv"].interval_start_minutes),
interval_end=pd.Timedelta(minutes=mock_config.input_data.nwp["ukv"].interval_end_minutes),
dropout_timedeltas=None,
dropout_frac=0,
accum_channels=["dswrf"],
channel_dim_name="channel",
)

# Prepare dataset dictionary
dataset_dict = {
"nwp": {"ukv": nwp_sample},
}

# Process data with main function
result = process_and_combine_datasets(
dataset_dict,
mock_config,
t0,
mock_location,
target_key='gsp'
)

# Verify results structure
assert NWPBatchKey.nwp in result

# Check NWP data normalisation and NaN handling
nwp_data = result[NWPBatchKey.nwp]["ukv"]
assert isinstance(nwp_data['nwp'], np.ndarray)
assert not np.isnan(nwp_data['nwp']).any()
Loading