diff --git a/satip/geospatial.py b/satip/geospatial.py index 0968af76..e4aeb454 100644 --- a/satip/geospatial.py +++ b/satip/geospatial.py @@ -30,7 +30,7 @@ # Geographic bounds for various regions of interest, in order of min_lon, min_lat, max_lon, max_lat # (see https://satpy.readthedocs.io/en/stable/_modules/satpy/scene.html) -GEOGRAPHIC_BOUNDS = {"UK": (-18., 38., 12., 72.), "RSS": (-64, 16, 83, 69), "India": (60, 6, 97, 37)} +GEOGRAPHIC_BOUNDS = {"UK": (-17, 44, 11, 71), "RSS": (-64, 16, 83, 69), "India": (60, 6, 97, 37)} class Transformers: diff --git a/satip/utils.py b/satip/utils.py index 030da0c2..8a3487b6 100644 --- a/satip/utils.py +++ b/satip/utils.py @@ -254,7 +254,7 @@ def convert_scene_to_dataarray( log.debug("Starting scene conversion", memory=get_memory()) if area != "RSS": try: - scene = scene.crop(ll_bbox=GEOGRAPHIC_BOUNDS[area]) + scene = crop(scene=scene, bounds=GEOGRAPHIC_BOUNDS[area]) except NotImplementedError: # 15 minutely data by default doesn't work for some reason, have to resample it scene = scene.resample("msg_seviri_rss_1km" if band == "HRV" else "msg_seviri_rss_3km") @@ -410,6 +410,36 @@ def get_dataset_from_scene( gc.collect() log.debug("Saved HRV to NetCDF", memory=get_memory()) +def crop(scene: satpy.scene, bounds: list[float]) -> satpy.scene: + """Crop the satpy scene to given lon-lat box + + Args: + scene: The satpy scene object + bounds: The bounding box: [min_lon, min_lat, max_lon, max_lat] + """ + + # Get the lons and lats of the first channel - these are 2D arrays + lons, lats = scene[list(scene.keys())[0]].attrs["area"].get_lonlats() + + # Make mask of the lasts and lons + lon_mask = (lon_min < lons) & (lons < lon_max) + lat_mask = (lat_min < lats) & (lats < lat_max) + mask = lon_mask & lat_mask + + # Whether any of the columns need this row: 1D array + mask_i = mask.any(axis=1) + # Whether any of the rows need this column: 1D array + mask_j = mask.any(axis=0) + + # Find the min and max index where the mask is True along each dimension + i0 = mask_i.argmax() + i1 = len(mask_i) - mask_i[::-1].argmax() + j0 = mask_j.argmax() + j1 = len(mask_j) - mask_j[::-1].argmax() + + # return and slice the scene + return scene.slice(((slice(i0, i1), slice(j0, j1)))) + def data_quality_filter(ds: xr.Dataset, threshold_fraction: float = 0.9) -> bool: """