diff --git a/xpublish_wms/grid.py b/xpublish_wms/grid.py index af7d7e0..c6b2949 100644 --- a/xpublish_wms/grid.py +++ b/xpublish_wms/grid.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Optional, Tuple +from typing import Any, Optional, Tuple, Union import cartopy.geodesic import cf_xarray # noqa @@ -91,7 +91,9 @@ def select_by_elevation( da = da.cf.sel({"vertical": elevation}, method="nearest") return da - def mask(self, da: xr.DataArray | xr.Dataset) -> xr.DataArray | xr.Dataset: + def mask( + self, da: Union[xr.DataArray, xr.Dataset] + ) -> Union[xr.DataArray, xr.Dataset]: """Mask the given data array""" return da @@ -166,7 +168,9 @@ def render_method(self) -> RenderMethod: def crs(self) -> str: return "EPSG:4326" - def mask(self, da: xr.DataArray | xr.Dataset) -> xr.DataArray | xr.Dataset: + def mask( + self, da: Union[xr.DataArray, xr.Dataset] + ) -> Union[xr.DataArray, xr.Dataset]: mask = self.ds[f'mask_{da.cf["latitude"].name.split("_")[1]}'] mask = mask.cf.isel(time=0).squeeze(drop=True).cf.drop_vars("time") mask[:-1, :] = mask[:-1, :].where(mask[1:, :] == 1, 0) @@ -529,7 +533,9 @@ def select_by_elevation( else: return self._grid.select_by_elevation(da, elevation) - def mask(self, da: xr.DataArray) -> xr.DataArray: + def mask( + self, da: Union[xr.DataArray, xr.Dataset] + ) -> Union[xr.DataArray, xr.Dataset]: if self._grid is None: return None else: