diff --git a/ocf_data_sampler/select/geospatial.py b/ocf_data_sampler/select/geospatial.py new file mode 100644 index 0000000..8137f16 --- /dev/null +++ b/ocf_data_sampler/select/geospatial.py @@ -0,0 +1,118 @@ +"""Geospatial functions""" + +from numbers import Number +from typing import Union + +import numpy as np +import pyproj +import xarray as xr + +# OSGB is also called "OSGB 1936 / British National Grid -- United +# Kingdom Ordnance Survey". OSGB is used in many UK electricity +# system maps, and is used by the UK Met Office UKV model. OSGB is a +# Transverse Mercator projection, using 'easting' and 'northing' +# coordinates which are in meters. See https://epsg.io/27700 +OSGB36 = 27700 + +# WGS84 is short for "World Geodetic System 1984", used in GPS. Uses +# latitude and longitude. +WGS84 = 4326 + + +_osgb_to_lon_lat = pyproj.Transformer.from_crs( + crs_from=OSGB36, crs_to=WGS84, always_xy=True +).transform +_lon_lat_to_osgb = pyproj.Transformer.from_crs( + crs_from=WGS84, crs_to=OSGB36, always_xy=True +).transform + + +def osgb_to_lon_lat( + x: Union[Number, np.ndarray], y: Union[Number, np.ndarray] +) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]: + """Change OSGB coordinates to lon, lat. + + Args: + x: osgb east-west + y: osgb north-south + Return: 2-tuple of longitude (east-west), latitude (north-south) + """ + return _osgb_to_lon_lat(xx=x, yy=y) + + +def lon_lat_to_osgb( + x: Union[Number, np.ndarray], + y: Union[Number, np.ndarray], +) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]: + """Change lon-lat coordinates to OSGB. + + Args: + x: longitude east-west + y: latitude north-south + + Return: 2-tuple of OSGB x, y + """ + return _lon_lat_to_osgb(xx=x, yy=y) + + +def osgb_to_geostationary_area_coords( + x: Union[Number, np.ndarray], + y: Union[Number, np.ndarray], + xr_data: xr.DataArray, +) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]: + """Loads geostationary area and transformation from OSGB to geostationary coords + + Args: + x: osgb east-west + y: osgb north-south + xr_data: xarray object with geostationary area + + Returns: + Geostationary coords: x, y + """ + # Only load these if using geostationary projection + import pyresample + + area_definition_yaml = xr_data.attrs["area"] + + geostationary_area_definition = pyresample.area_config.load_area_from_string( + area_definition_yaml + ) + geostationary_crs = geostationary_area_definition.crs + osgb_to_geostationary = pyproj.Transformer.from_crs( + crs_from=OSGB36, crs_to=geostationary_crs, always_xy=True + ).transform + return osgb_to_geostationary(xx=x, yy=y) + + +def _coord_priority(available_coords): + if "longitude" in available_coords: + return "lon_lat", "longitude", "latitude" + elif "x_geostationary" in available_coords: + return "geostationary", "x_geostationary", "y_geostationary" + elif "x_osgb" in available_coords: + return "osgb", "x_osgb", "y_osgb" + else: + raise ValueError(f"Unrecognized coordinate system: {available_coords}") + + +def spatial_coord_type(ds: xr.DataArray): + """Searches the data array to determine the kind of spatial coordinates present. + + This search has a preference for the dimension coordinates of the xarray object. + + Args: + ds: Dataset with spatial coords + + Returns: + str: The kind of the coordinate system + x_coord: Name of the x-coordinate + y_coord: Name of the y-coordinate + """ + if isinstance(ds, xr.DataArray): + # Search dimension coords of dataarray + coords = _coord_priority(ds.xindexes) + else: + raise ValueError(f"Unrecognized input type: {type(ds)}") + + return coords diff --git a/ocf_data_sampler/select/location.py b/ocf_data_sampler/select/location.py new file mode 100644 index 0000000..9cfa9cf --- /dev/null +++ b/ocf_data_sampler/select/location.py @@ -0,0 +1,62 @@ +"""location""" + +from typing import Optional + +import numpy as np +from pydantic import BaseModel, Field, model_validator + + +allowed_coordinate_systems =["osgb", "lon_lat", "geostationary", "idx"] + +class Location(BaseModel): + """Represent a spatial location.""" + + coordinate_system: Optional[str] = "osgb" # ["osgb", "lon_lat", "geostationary", "idx"] + x: float + y: float + id: Optional[int] = Field(None) + + @model_validator(mode='after') + def validate_coordinate_system(self): + """Validate 'coordinate_system'""" + if self.coordinate_system not in allowed_coordinate_systems: + raise ValueError(f"coordinate_system = {self.coordinate_system} is not in {allowed_coordinate_systems}") + return self + + @model_validator(mode='after') + def validate_x(self): + """Validate 'x'""" + min_x: float + max_x: float + + co = self.coordinate_system + if co == "osgb": + min_x, max_x = -103976.3, 652897.98 + if co == "lon_lat": + min_x, max_x = -180, 180 + if co == "geostationary": + min_x, max_x = -5568748.275756836, 5567248.074173927 + if co == "idx": + min_x, max_x = 0, np.inf + if self.x < min_x or self.x > max_x: + raise ValueError(f"x = {self.x} must be within {[min_x, max_x]} for {co} coordinate system") + return self + + @model_validator(mode='after') + def validate_y(self): + """Validate 'y'""" + min_y: float + max_y: float + + co = self.coordinate_system + if co == "osgb": + min_y, max_y = -16703.87, 1199851.44 + if co == "lon_lat": + min_y, max_y = -90, 90 + if co == "geostationary": + min_y, max_y = 1393687.2151494026, 5570748.323202133 + if co == "idx": + min_y, max_y = 0, np.inf + if self.y < min_y or self.y > max_y: + raise ValueError(f"y = {self.y} must be within {[min_y, max_y]} for {co} coordinate system") + return self diff --git a/ocf_data_sampler/select/select_spatial_slice.py b/ocf_data_sampler/select/select_spatial_slice.py index 0e2f47e..9dad8ea 100644 --- a/ocf_data_sampler/select/select_spatial_slice.py +++ b/ocf_data_sampler/select/select_spatial_slice.py @@ -5,15 +5,14 @@ import numpy as np import xarray as xr -from ocf_datapipes.utils import Location -from ocf_datapipes.utils.geospatial import ( - lon_lat_to_geostationary_area_coords, +from ocf_data_sampler.select.location import Location +from ocf_data_sampler.select.geospatial import ( lon_lat_to_osgb, osgb_to_geostationary_area_coords, osgb_to_lon_lat, spatial_coord_type, ) -from ocf_datapipes.utils.utils import searchsorted + logger = logging.getLogger(__name__) @@ -45,9 +44,6 @@ def convert_coords_to_match_xarray( if from_coords == "osgb": x, y = osgb_to_geostationary_area_coords(x, y, da) - elif from_coords == "lon_lat": - x, y = lon_lat_to_geostationary_area_coords(x, y, da) - elif target_coords == "lon_lat": if from_coords == "osgb": x, y = osgb_to_lon_lat(x, y) @@ -130,13 +126,8 @@ def _get_idx_of_pixel_closest_to_poi_geostationary( f"{y} is not in the interval {da[y_dim].min().values}: {da[y_dim].max().values}" # Get the index into x and y nearest to x_center_geostationary and y_center_geostationary: - x_index_at_center = searchsorted( - da[x_dim].values, center_geostationary.x, assume_ascending=True - ) - - y_index_at_center = searchsorted( - da[y_dim].values, center_geostationary.y, assume_ascending=True - ) + x_index_at_center = np.searchsorted(da[x_dim].values, center_geostationary.x) + y_index_at_center = np.searchsorted(da[y_dim].values, center_geostationary.y) return Location(x=x_index_at_center, y=y_index_at_center, coordinate_system="idx") diff --git a/tests/select/test_location.py b/tests/select/test_location.py new file mode 100644 index 0000000..763f372 --- /dev/null +++ b/tests/select/test_location.py @@ -0,0 +1,67 @@ +from ocf_data_sampler.select.location import Location +import pytest + + +def test_make_valid_location_object_with_default_coordinate_system(): + x, y = -1000.5, 50000 + location = Location(x=x, y=y) + assert location.x == x, "location.x value not set correctly" + assert location.y == y, "location.x value not set correctly" + assert ( + location.coordinate_system == "osgb" + ), "location.coordinate_system value not set correctly" + + +def test_make_valid_location_object_with_osgb_coordinate_system(): + x, y, coordinate_system = 1.2, 22.9, "osgb" + location = Location(x=x, y=y, coordinate_system=coordinate_system) + assert location.x == x, "location.x value not set correctly" + assert location.y == y, "location.x value not set correctly" + assert ( + location.coordinate_system == coordinate_system + ), "location.coordinate_system value not set correctly" + + +def test_make_valid_location_object_with_lon_lat_coordinate_system(): + x, y, coordinate_system = 1.2, 1.2, "lon_lat" + location = Location(x=x, y=y, coordinate_system=coordinate_system) + assert location.x == x, "location.x value not set correctly" + assert location.y == y, "location.x value not set correctly" + assert ( + location.coordinate_system == coordinate_system + ), "location.coordinate_system value not set correctly" + + +def test_make_invalid_location_object_with_invalid_osgb_x(): + x, y, coordinate_system = 10000000, 1.2, "osgb" + with pytest.raises(ValueError) as err: + _ = Location(x=x, y=y, coordinate_system=coordinate_system) + assert err.typename == "ValidationError" + + +def test_make_invalid_location_object_with_invalid_osgb_y(): + x, y, coordinate_system = 2.5, 10000000, "osgb" + with pytest.raises(ValueError) as err: + _ = Location(x=x, y=y, coordinate_system=coordinate_system) + assert err.typename == "ValidationError" + + +def test_make_invalid_location_object_with_invalid_lon_lat_x(): + x, y, coordinate_system = 200, 1.2, "lon_lat" + with pytest.raises(ValueError) as err: + _ = Location(x=x, y=y, coordinate_system=coordinate_system) + assert err.typename == "ValidationError" + + +def test_make_invalid_location_object_with_invalid_lon_lat_y(): + x, y, coordinate_system = 2.5, -200, "lon_lat" + with pytest.raises(ValueError) as err: + _ = Location(x=x, y=y, coordinate_system=coordinate_system) + assert err.typename == "ValidationError" + + +def test_make_invalid_location_object_with_invalid_coordinate_system(): + x, y, coordinate_system = 2.5, 1000, "abcd" + with pytest.raises(ValueError) as err: + _ = Location(x=x, y=y, coordinate_system=coordinate_system) + assert err.typename == "ValidationError"