diff --git a/ocf_datapipes/select/select_spatial_slice.py b/ocf_datapipes/select/select_spatial_slice.py index 5cf883d03..df511f443 100644 --- a/ocf_datapipes/select/select_spatial_slice.py +++ b/ocf_datapipes/select/select_spatial_slice.py @@ -103,24 +103,35 @@ def _get_idx_of_pixel_closest_to_poi( def _get_idx_of_pixel_closest_to_poi_geostationary( xr_data: xr.DataArray, - center_osgb: Location, + center_coordinate: Location, ) -> Location: """ Return x and y index location of pixel at center of region of interest. Args: xr_data: Xarray dataset - center_osgb: Center in OSGB coordinates + center_coordinate: Central coordinate Returns: Location for the center pixel in geostationary coordinates """ - xr_coords, xr_x_dim, xr_y_dim = spatial_coord_type(xr_data) + if center_coordinate.coordinate_system == "osgb": + x, y = osgb_to_geostationary_area_coords( + x=center_coordinate.x, y=center_coordinate.y, xr_data=xr_data + ) + elif center_coordinate.coordinate_system == "lon_lat": + x, y = lon_lat_to_geostationary_area_coords( + x=center_coordinate.x, y=center_coordinate.y, xr_data=xr_data + ) + else: + raise NotImplementedError( + f"Only 'osgb' and 'lon_lat' location coordinates are \ + supported in conversion to geostationary \ + - not '{center_coordinate.coordinate_system}'" + ) - x, y = osgb_to_geostationary_area_coords(x=center_osgb.x, y=center_osgb.y, xr_data=xr_data) center_geostationary = Location(x=x, y=y, coordinate_system="geostationary") - # Check that the requested point lies within the data assert xr_data[xr_x_dim].min() < x < xr_data[xr_x_dim].max() assert xr_data[xr_y_dim].min() < y < xr_data[xr_y_dim].max() @@ -390,7 +401,7 @@ def select_spatial_slice_pixels( if xr_coords == "geostationary": center_idx: Location = _get_idx_of_pixel_closest_to_poi_geostationary( xr_data=xr_data, - center_osgb=location, + center_coordinate=location, ) else: center_idx: Location = _get_idx_of_pixel_closest_to_poi( diff --git a/tests/select/test_select_spatial_slice.py b/tests/select/test_select_spatial_slice.py index cdb3a89f2..4e24a58ed 100644 --- a/tests/select/test_select_spatial_slice.py +++ b/tests/select/test_select_spatial_slice.py @@ -1,14 +1,16 @@ import numpy as np import xarray as xr -from ocf_datapipes.utils import Location from ocf_datapipes.select import ( PickLocations, SelectSpatialSliceMeters, SelectSpatialSlicePixels, ) - -from ocf_datapipes.select.select_spatial_slice import slice_spatial_pixel_window_from_xarray +from ocf_datapipes.select.select_spatial_slice import ( + _get_idx_of_pixel_closest_to_poi_geostationary, + slice_spatial_pixel_window_from_xarray, +) +from ocf_datapipes.utils import Location def test_slice_spatial_pixel_window_from_xarray_function(): @@ -158,3 +160,31 @@ def test_select_spatial_slice_meters_icon_global(passiv_datapipe, icon_global_da # ICON global has roughly 13km spacing, so this should be around 7x7 grid assert len(data.longitude) == 49 assert len(data.latitude) == 49 + + +def test_get_idx_of_pixel_closest_to_poi_geostationary_lon_lat_location(): + # Create dummy data + x = np.arange(5000000, -5000000, -5000) + y = np.arange(5000000, -5000000, -5000)[::-1] + + xr_data = xr.Dataset( + data_vars=dict( + data=(["x_geostationary", "y_geostationary"], np.random.normal(size=(len(x), len(y)))), + ), + coords=dict( + x_geostationary=(["x_geostationary"], x), + y_geostationary=(["y_geostationary"], y), + ), + ) + xr_data.attrs["area"] = ( + "msg_seviri_iodc_3km:\n description: MSG SEVIRI Indian Ocean Data Coverage service area definition with\n 3 km resolution\n projection:\n proj: geos\n lon_0: 41.5\n h: 35785831\n x_0: 0\n y_0: 0\n a: 6378169\n rf: 295.488065897014\n no_defs: null\n type: crs\n shape:\n height: 3712\n width: 3712\n area_extent:\n lower_left_xy: [5000000, 5000000]\n upper_right_xy: [-5000000, -5000000]\n units: m\n" + ) + + center = Location(x=77.1, y=28.6, coordinate_system="lon_lat") + + location_center_idx = _get_idx_of_pixel_closest_to_poi_geostationary( + xr_data=xr_data, center_coordinate=center + ) + + assert location_center_idx.coordinate_system == "idx" + assert location_center_idx.x == 2000