Skip to content

Commit

Permalink
Merge pull request #57 from openclimatefix/issue/location
Browse files Browse the repository at this point in the history
copy files over ocf_datapipes
  • Loading branch information
peterdudfield authored Oct 1, 2024
2 parents 4073bd4 + 24d8162 commit a708e9f
Show file tree
Hide file tree
Showing 4 changed files with 252 additions and 14 deletions.
118 changes: 118 additions & 0 deletions ocf_data_sampler/select/geospatial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Geospatial functions"""

from numbers import Number
from typing import Union

import numpy as np
import pyproj
import xarray as xr

# OSGB is also called "OSGB 1936 / British National Grid -- United
# Kingdom Ordnance Survey". OSGB is used in many UK electricity
# system maps, and is used by the UK Met Office UKV model. OSGB is a
# Transverse Mercator projection, using 'easting' and 'northing'
# coordinates which are in meters. See https://epsg.io/27700
OSGB36 = 27700

# WGS84 is short for "World Geodetic System 1984", used in GPS. Uses
# latitude and longitude.
WGS84 = 4326


_osgb_to_lon_lat = pyproj.Transformer.from_crs(
crs_from=OSGB36, crs_to=WGS84, always_xy=True
).transform
_lon_lat_to_osgb = pyproj.Transformer.from_crs(
crs_from=WGS84, crs_to=OSGB36, always_xy=True
).transform


def osgb_to_lon_lat(
x: Union[Number, np.ndarray], y: Union[Number, np.ndarray]
) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
"""Change OSGB coordinates to lon, lat.
Args:
x: osgb east-west
y: osgb north-south
Return: 2-tuple of longitude (east-west), latitude (north-south)
"""
return _osgb_to_lon_lat(xx=x, yy=y)


def lon_lat_to_osgb(
x: Union[Number, np.ndarray],
y: Union[Number, np.ndarray],
) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
"""Change lon-lat coordinates to OSGB.
Args:
x: longitude east-west
y: latitude north-south
Return: 2-tuple of OSGB x, y
"""
return _lon_lat_to_osgb(xx=x, yy=y)


def osgb_to_geostationary_area_coords(
x: Union[Number, np.ndarray],
y: Union[Number, np.ndarray],
xr_data: xr.DataArray,
) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
"""Loads geostationary area and transformation from OSGB to geostationary coords
Args:
x: osgb east-west
y: osgb north-south
xr_data: xarray object with geostationary area
Returns:
Geostationary coords: x, y
"""
# Only load these if using geostationary projection
import pyresample

area_definition_yaml = xr_data.attrs["area"]

geostationary_area_definition = pyresample.area_config.load_area_from_string(
area_definition_yaml
)
geostationary_crs = geostationary_area_definition.crs
osgb_to_geostationary = pyproj.Transformer.from_crs(
crs_from=OSGB36, crs_to=geostationary_crs, always_xy=True
).transform
return osgb_to_geostationary(xx=x, yy=y)


def _coord_priority(available_coords):
if "longitude" in available_coords:
return "lon_lat", "longitude", "latitude"
elif "x_geostationary" in available_coords:
return "geostationary", "x_geostationary", "y_geostationary"
elif "x_osgb" in available_coords:
return "osgb", "x_osgb", "y_osgb"
else:
raise ValueError(f"Unrecognized coordinate system: {available_coords}")


def spatial_coord_type(ds: xr.DataArray):
"""Searches the data array to determine the kind of spatial coordinates present.
This search has a preference for the dimension coordinates of the xarray object.
Args:
ds: Dataset with spatial coords
Returns:
str: The kind of the coordinate system
x_coord: Name of the x-coordinate
y_coord: Name of the y-coordinate
"""
if isinstance(ds, xr.DataArray):
# Search dimension coords of dataarray
coords = _coord_priority(ds.xindexes)
else:
raise ValueError(f"Unrecognized input type: {type(ds)}")

return coords
62 changes: 62 additions & 0 deletions ocf_data_sampler/select/location.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""location"""

from typing import Optional

import numpy as np
from pydantic import BaseModel, Field, model_validator


allowed_coordinate_systems =["osgb", "lon_lat", "geostationary", "idx"]

class Location(BaseModel):
"""Represent a spatial location."""

coordinate_system: Optional[str] = "osgb" # ["osgb", "lon_lat", "geostationary", "idx"]
x: float
y: float
id: Optional[int] = Field(None)

@model_validator(mode='after')
def validate_coordinate_system(self):
"""Validate 'coordinate_system'"""
if self.coordinate_system not in allowed_coordinate_systems:
raise ValueError(f"coordinate_system = {self.coordinate_system} is not in {allowed_coordinate_systems}")
return self

@model_validator(mode='after')
def validate_x(self):
"""Validate 'x'"""
min_x: float
max_x: float

co = self.coordinate_system
if co == "osgb":
min_x, max_x = -103976.3, 652897.98
if co == "lon_lat":
min_x, max_x = -180, 180
if co == "geostationary":
min_x, max_x = -5568748.275756836, 5567248.074173927
if co == "idx":
min_x, max_x = 0, np.inf
if self.x < min_x or self.x > max_x:
raise ValueError(f"x = {self.x} must be within {[min_x, max_x]} for {co} coordinate system")
return self

@model_validator(mode='after')
def validate_y(self):
"""Validate 'y'"""
min_y: float
max_y: float

co = self.coordinate_system
if co == "osgb":
min_y, max_y = -16703.87, 1199851.44
if co == "lon_lat":
min_y, max_y = -90, 90
if co == "geostationary":
min_y, max_y = 1393687.2151494026, 5570748.323202133
if co == "idx":
min_y, max_y = 0, np.inf
if self.y < min_y or self.y > max_y:
raise ValueError(f"y = {self.y} must be within {[min_y, max_y]} for {co} coordinate system")
return self
19 changes: 5 additions & 14 deletions ocf_data_sampler/select/select_spatial_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
import numpy as np
import xarray as xr

from ocf_datapipes.utils import Location
from ocf_datapipes.utils.geospatial import (
lon_lat_to_geostationary_area_coords,
from ocf_data_sampler.select.location import Location
from ocf_data_sampler.select.geospatial import (
lon_lat_to_osgb,
osgb_to_geostationary_area_coords,
osgb_to_lon_lat,
spatial_coord_type,
)
from ocf_datapipes.utils.utils import searchsorted


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -45,9 +44,6 @@ def convert_coords_to_match_xarray(
if from_coords == "osgb":
x, y = osgb_to_geostationary_area_coords(x, y, da)

elif from_coords == "lon_lat":
x, y = lon_lat_to_geostationary_area_coords(x, y, da)

elif target_coords == "lon_lat":
if from_coords == "osgb":
x, y = osgb_to_lon_lat(x, y)
Expand Down Expand Up @@ -130,13 +126,8 @@ def _get_idx_of_pixel_closest_to_poi_geostationary(
f"{y} is not in the interval {da[y_dim].min().values}: {da[y_dim].max().values}"

# Get the index into x and y nearest to x_center_geostationary and y_center_geostationary:
x_index_at_center = searchsorted(
da[x_dim].values, center_geostationary.x, assume_ascending=True
)

y_index_at_center = searchsorted(
da[y_dim].values, center_geostationary.y, assume_ascending=True
)
x_index_at_center = np.searchsorted(da[x_dim].values, center_geostationary.x)
y_index_at_center = np.searchsorted(da[y_dim].values, center_geostationary.y)

return Location(x=x_index_at_center, y=y_index_at_center, coordinate_system="idx")

Expand Down
67 changes: 67 additions & 0 deletions tests/select/test_location.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from ocf_data_sampler.select.location import Location
import pytest


def test_make_valid_location_object_with_default_coordinate_system():
x, y = -1000.5, 50000
location = Location(x=x, y=y)
assert location.x == x, "location.x value not set correctly"
assert location.y == y, "location.x value not set correctly"
assert (
location.coordinate_system == "osgb"
), "location.coordinate_system value not set correctly"


def test_make_valid_location_object_with_osgb_coordinate_system():
x, y, coordinate_system = 1.2, 22.9, "osgb"
location = Location(x=x, y=y, coordinate_system=coordinate_system)
assert location.x == x, "location.x value not set correctly"
assert location.y == y, "location.x value not set correctly"
assert (
location.coordinate_system == coordinate_system
), "location.coordinate_system value not set correctly"


def test_make_valid_location_object_with_lon_lat_coordinate_system():
x, y, coordinate_system = 1.2, 1.2, "lon_lat"
location = Location(x=x, y=y, coordinate_system=coordinate_system)
assert location.x == x, "location.x value not set correctly"
assert location.y == y, "location.x value not set correctly"
assert (
location.coordinate_system == coordinate_system
), "location.coordinate_system value not set correctly"


def test_make_invalid_location_object_with_invalid_osgb_x():
x, y, coordinate_system = 10000000, 1.2, "osgb"
with pytest.raises(ValueError) as err:
_ = Location(x=x, y=y, coordinate_system=coordinate_system)
assert err.typename == "ValidationError"


def test_make_invalid_location_object_with_invalid_osgb_y():
x, y, coordinate_system = 2.5, 10000000, "osgb"
with pytest.raises(ValueError) as err:
_ = Location(x=x, y=y, coordinate_system=coordinate_system)
assert err.typename == "ValidationError"


def test_make_invalid_location_object_with_invalid_lon_lat_x():
x, y, coordinate_system = 200, 1.2, "lon_lat"
with pytest.raises(ValueError) as err:
_ = Location(x=x, y=y, coordinate_system=coordinate_system)
assert err.typename == "ValidationError"


def test_make_invalid_location_object_with_invalid_lon_lat_y():
x, y, coordinate_system = 2.5, -200, "lon_lat"
with pytest.raises(ValueError) as err:
_ = Location(x=x, y=y, coordinate_system=coordinate_system)
assert err.typename == "ValidationError"


def test_make_invalid_location_object_with_invalid_coordinate_system():
x, y, coordinate_system = 2.5, 1000, "abcd"
with pytest.raises(ValueError) as err:
_ = Location(x=x, y=y, coordinate_system=coordinate_system)
assert err.typename == "ValidationError"

0 comments on commit a708e9f

Please sign in to comment.