Skip to content

Commit

Permalink
Merge branch 'main' into issue/visulaize
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdudfield committed Nov 29, 2024
2 parents 859546e + a74fbb4 commit 1b336d4
Show file tree
Hide file tree
Showing 10 changed files with 140 additions and 34 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[bumpversion]
commit = True
tag = True
current_version = 3.3.46
current_version = 3.3.52
message = Bump version: {current_version} → {new_version} [skip ci]

[bumpversion:file:setup.py]
Expand Down
5 changes: 5 additions & 0 deletions ocf_datapipes/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,11 @@ class Satellite(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
description="The temporal resolution (in minutes) of the data."
"Note that this needs to be divisible by 5.",
)
satellite_scaling_methods: Optional[List[str]] = Field(
["mean_std"],
description="There are few ways to scale the satellite data. "
"1. None, 2. mean_std, 3. min_max",
)


class HRVSatellite(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
Expand Down
2 changes: 1 addition & 1 deletion ocf_datapipes/load/nwp/nwp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
self.open_nwp = open_icon_eu
elif provider.lower() == "icon-global":
self.open_nwp = open_icon_global
elif provider.lower() == "ecmwf":
elif provider.lower() in ("ecmwf", "mo_global"): # same schema so using the same loader
self.open_nwp = open_ifs
elif provider.lower() == "gfs":
self.open_nwp = open_gfs
Expand Down
23 changes: 17 additions & 6 deletions ocf_datapipes/select/select_spatial_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,24 +103,35 @@ def _get_idx_of_pixel_closest_to_poi(

def _get_idx_of_pixel_closest_to_poi_geostationary(
xr_data: xr.DataArray,
center_osgb: Location,
center_coordinate: Location,
) -> Location:
"""
Return x and y index location of pixel at center of region of interest.
Args:
xr_data: Xarray dataset
center_osgb: Center in OSGB coordinates
center_coordinate: Central coordinate
Returns:
Location for the center pixel in geostationary coordinates
"""

xr_coords, xr_x_dim, xr_y_dim = spatial_coord_type(xr_data)
if center_coordinate.coordinate_system == "osgb":
x, y = osgb_to_geostationary_area_coords(
x=center_coordinate.x, y=center_coordinate.y, xr_data=xr_data
)
elif center_coordinate.coordinate_system == "lon_lat":
x, y = lon_lat_to_geostationary_area_coords(
x=center_coordinate.x, y=center_coordinate.y, xr_data=xr_data
)
else:
raise NotImplementedError(
f"Only 'osgb' and 'lon_lat' location coordinates are \
supported in conversion to geostationary \
- not '{center_coordinate.coordinate_system}'"
)

x, y = osgb_to_geostationary_area_coords(x=center_osgb.x, y=center_osgb.y, xr_data=xr_data)
center_geostationary = Location(x=x, y=y, coordinate_system="geostationary")

# Check that the requested point lies within the data
assert xr_data[xr_x_dim].min() < x < xr_data[xr_x_dim].max()
assert xr_data[xr_y_dim].min() < y < xr_data[xr_y_dim].max()
Expand Down Expand Up @@ -390,7 +401,7 @@ def select_spatial_slice_pixels(
if xr_coords == "geostationary":
center_idx: Location = _get_idx_of_pixel_closest_to_poi_geostationary(
xr_data=xr_data,
center_osgb=location,
center_coordinate=location,
)
else:
center_idx: Location = _get_idx_of_pixel_closest_to_poi(
Expand Down
29 changes: 8 additions & 21 deletions ocf_datapipes/training/pvnet_site.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
NWP_MEANS,
NWP_STDS,
RSS_MEAN,
RSS_RAW_MAX,
RSS_RAW_MIN,
RSS_STD,
)
from ocf_datapipes.utils.utils import (
Expand All @@ -36,29 +38,10 @@
xr.set_options(keep_attrs=True)
logger = logging.getLogger("pvnet_site_datapipe")

normalization_values = {
2019: 3185.0,
2020: 2678.0,
2021: 3196.0,
2022: 3575.0,
2023: 3773.0,
2024: 3773.0,
}


def normalize_pv(x: xr.DataArray):
"""Normalize PV data"""
# This is after the data has been temporally sliced, so have the year
return x / normalization_values[2024]

year = x.time_utc.dt.year

# Add the effective_capacity_mwp to the dataset, indexed on the time_utc
return (
x / normalization_values[year]
if year in normalization_values
else x / normalization_values[2024]
)
return x / x.nominal_capacity_wp


class DictDatasetIterDataPipe(IterDataPipe):
Expand Down Expand Up @@ -273,7 +256,11 @@ def construct_sliced_data_pipeline(
roi_height_pixels=conf_sat.satellite_image_size_pixels_height,
roi_width_pixels=conf_sat.satellite_image_size_pixels_width,
)
sat_datapipe = sat_datapipe.normalize(mean=RSS_MEAN, std=RSS_STD)
scaling_methods = conf_sat.satellite_scaling_methods
if "min_max" in scaling_methods:
sat_datapipe = sat_datapipe.normalize(min_values=RSS_RAW_MIN, max_values=RSS_RAW_MAX)
if "mean_std" in scaling_methods:
sat_datapipe = sat_datapipe.normalize(mean=RSS_MEAN, std=RSS_STD)

if "pv" in datapipes_dict:
# Recombine Sensor arrays - see function doc for further explanation
Expand Down
8 changes: 8 additions & 0 deletions ocf_datapipes/transform/xarray/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def __init__(
max_value: Optional[Union[int, float]] = None,
calculate_mean_std_from_example: bool = False,
normalize_fn: Optional[Callable] = None,
min_values: Optional[Union[xr.Dataset, xr.DataArray, np.ndarray]] = None,
max_values: Optional[Union[xr.Dataset, xr.DataArray, np.ndarray]] = None,
):
"""
Normalize the data with either given mean/std,
Expand All @@ -37,13 +39,17 @@ def __init__(
calculate_mean_std_from_example: Whether to calculate the
mean/std from the input data or not
normalize_fn: Callable function to apply to the data to normalize it
min_values: Min values for each channel
max_values: Max values for each channel
"""
self.source_datapipe = source_datapipe
self.mean = mean
self.std = std
self.max_value = max_value
self.calculate_mean_std_from_example = calculate_mean_std_from_example
self.normalize_fn = normalize_fn
self.min_values = min_values
self.max_values = max_values

def __iter__(self) -> Union[xr.Dataset, xr.DataArray]:
"""Normalize the data depending on the init arguments"""
Expand All @@ -61,6 +67,8 @@ def __iter__(self) -> Union[xr.Dataset, xr.DataArray]:
# For Topo data for example
xr_data -= xr_data.mean().item()
xr_data /= xr_data.std().item()
elif (self.min_values is not None) and (self.max_values is not None):
xr_data = (xr_data - self.min_values) / (self.max_values - self.min_values)
else:
try:
logger.debug(f"Normalizing by {self.normalize_fn}")
Expand Down
65 changes: 65 additions & 0 deletions ocf_datapipes/utils/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __getitem__(self, key):
"excarta",
"merra2",
"merra2_uk",
"mo_global",
]

# ------ UKV
Expand Down Expand Up @@ -131,6 +132,24 @@ def __getitem__(self, key):
UKV_STD = _to_data_array(UKV_STD)
UKV_MEAN = _to_data_array(UKV_MEAN)

# These were calculated from 200 random init times (step 0s) from the MO global data
MO_GLOBAL_INDIA_MEAN = {
"temperature_sl": 298.2,
"wind_u_component_10m": 0.5732,
"wind_v_component_10m": -0.2831,
}

MO_GLOBAL_INDIA_STD = {
"temperature_sl": 8.473,
"wind_u_component_10m": 2.599,
"wind_v_component_10m": 2.016,
}


MO_GLOBAL_VARIABLE_NAMES = tuple(MO_GLOBAL_INDIA_MEAN.keys())
MO_GLOBAL_INDIA_STD = _to_data_array(MO_GLOBAL_INDIA_STD)
MO_GLOBAL_INDIA_MEAN = _to_data_array(MO_GLOBAL_INDIA_MEAN)


# ------ GFS
GFS_STD = {
Expand Down Expand Up @@ -250,6 +269,10 @@ def __getitem__(self, key):
"v10": 0.02332865633070469,
"v100": -0.07577426731586456,
"v200": -0.1255049854516983,
"diff_dlwrf": 1340142.4,
"diff_dswrf": 820569.5,
"diff_duvrs": 94480.24,
"diff_sr": 814910.1,
}

INDIA_ECMWF_STD = {
Expand All @@ -270,6 +293,10 @@ def __getitem__(self, key):
"v10": 2.401158571243286,
"v100": 3.5278923511505127,
"v200": 3.974159002304077,
"diff_dlwrf": 292804.8,
"diff_dswrf": 1082344.9,
"diff_duvrs": 125904.18,
"diff_sr": 1088536.2,
}


Expand Down Expand Up @@ -347,6 +374,7 @@ def __getitem__(self, key):
excarta=EXCARTA_VARIABLE_NAMES,
merra2=MERRA2_VARIABLE_NAMES,
merra2_uk=UK_MERRA2_VARIABLE_NAMES,
mo_global=MO_GLOBAL_VARIABLE_NAMES,
)
NWP_STDS = NWPStatDict(
ukv=UKV_STD,
Expand All @@ -356,6 +384,7 @@ def __getitem__(self, key):
excarta=EXCARTA_STD,
merra2=MERRA2_STD,
merra2_uk=UK_MERRA2_STD,
mo_global=MO_GLOBAL_INDIA_STD,
)
NWP_MEANS = NWPStatDict(
ukv=UKV_MEAN,
Expand All @@ -365,6 +394,7 @@ def __getitem__(self, key):
excarta=EXCARTA_MEAN,
merra2=MERRA2_MEAN,
merra2_uk=UK_MERRA2_MEAN,
mo_global=MO_GLOBAL_INDIA_MEAN,
)

# --------------------------- SATELLITE ------------------------------
Expand Down Expand Up @@ -405,6 +435,41 @@ def __getitem__(self, key):
RSS_STD = _to_data_array(RSS_STD)
RSS_MEAN = _to_data_array(RSS_MEAN)

# normalizing from raw values

RSS_RAW_MIN = {
"IR_016": -2.5118103,
"IR_039": -64.83977,
"IR_087": 63.404694,
"IR_097": 2.844452,
"IR_108": 199.10002,
"IR_120": -17.254883,
"IR_134": -26.29155,
"VIS006": -1.1009827,
"VIS008": -2.4184198,
"WV_062": 199.57048,
"WV_073": 198.95093,
"HRV": -1.2278595,
}

RSS_RAW_MAX = {
"IR_016": 69.60857,
"IR_039": 339.15588,
"IR_087": 340.26526,
"IR_097": 317.86752,
"IR_108": 313.2767,
"IR_120": 315.99194,
"IR_134": 274.82297,
"VIS006": 93.786545,
"VIS008": 101.34922,
"WV_062": 249.91806,
"WV_073": 286.96323,
"HRV": 103.90016,
}

RSS_RAW_MIN = _to_data_array(RSS_RAW_MIN)
RSS_RAW_MAX = _to_data_array(RSS_RAW_MAX)


# --------------------------- SENSORS --------------------------------

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
torch>=2.0.0
torch>=2.0.0, <2.5.0
Cartopy>=0.20.3
xarray
zarr
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

setup(
name="ocf_datapipes",
version="3.3.46",
version="3.3.52",
license="MIT",
description="Pytorch Datapipes built for use in Open Climate Fix's forecasting work",
author="Jacob Bieker, Jack Kelly, Peter Dudfield, James Fulton",
Expand Down
36 changes: 33 additions & 3 deletions tests/select/test_select_spatial_slice.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import numpy as np
import xarray as xr
from ocf_datapipes.utils import Location

from ocf_datapipes.select import (
PickLocations,
SelectSpatialSliceMeters,
SelectSpatialSlicePixels,
)

from ocf_datapipes.select.select_spatial_slice import slice_spatial_pixel_window_from_xarray
from ocf_datapipes.select.select_spatial_slice import (
_get_idx_of_pixel_closest_to_poi_geostationary,
slice_spatial_pixel_window_from_xarray,
)
from ocf_datapipes.utils import Location


def test_slice_spatial_pixel_window_from_xarray_function():
Expand Down Expand Up @@ -158,3 +160,31 @@ def test_select_spatial_slice_meters_icon_global(passiv_datapipe, icon_global_da
# ICON global has roughly 13km spacing, so this should be around 7x7 grid
assert len(data.longitude) == 49
assert len(data.latitude) == 49


def test_get_idx_of_pixel_closest_to_poi_geostationary_lon_lat_location():
# Create dummy data
x = np.arange(5000000, -5000000, -5000)
y = np.arange(5000000, -5000000, -5000)[::-1]

xr_data = xr.Dataset(
data_vars=dict(
data=(["x_geostationary", "y_geostationary"], np.random.normal(size=(len(x), len(y)))),
),
coords=dict(
x_geostationary=(["x_geostationary"], x),
y_geostationary=(["y_geostationary"], y),
),
)
xr_data.attrs["area"] = (
"msg_seviri_iodc_3km:\n description: MSG SEVIRI Indian Ocean Data Coverage service area definition with\n 3 km resolution\n projection:\n proj: geos\n lon_0: 41.5\n h: 35785831\n x_0: 0\n y_0: 0\n a: 6378169\n rf: 295.488065897014\n no_defs: null\n type: crs\n shape:\n height: 3712\n width: 3712\n area_extent:\n lower_left_xy: [5000000, 5000000]\n upper_right_xy: [-5000000, -5000000]\n units: m\n"
)

center = Location(x=77.1, y=28.6, coordinate_system="lon_lat")

location_center_idx = _get_idx_of_pixel_closest_to_poi_geostationary(
xr_data=xr_data, center_coordinate=center
)

assert location_center_idx.coordinate_system == "idx"
assert location_center_idx.x == 2000

0 comments on commit 1b336d4

Please sign in to comment.