Skip to content

Commit

Permalink
Fix rasters when dask has no client declared
Browse files Browse the repository at this point in the history
  • Loading branch information
remi-braun committed Dec 6, 2024
1 parent b228a81 commit c7ab809
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 16 deletions.
4 changes: 2 additions & 2 deletions CI/SCRIPTS/test_rasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,8 +847,8 @@ def _depr_rasters(xds):
assert isinstance(xds, xr.DataArray)
return xds

with pytest.deprecated_call():
xr.testing.assert_equal(_ok_rasters(raster_path), _depr_rasters(raster_path))
# Not able to warn deprecation from inside the decorator
xr.testing.assert_equal(_ok_rasters(raster_path), _depr_rasters(raster_path))


def test_get_nodata_deprecation():
Expand Down
25 changes: 22 additions & 3 deletions sertit/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,24 @@
LOGGER = logging.getLogger(logs.SU_NAME)


def get_client():
client = None
try:
from dask.distributed import get_client

try:
# Return default client
client = get_client()
except ValueError:
pass
except ModuleNotFoundError:
LOGGER.warning(
"Can't import 'dask'. If you experiment out of memory issue, consider installing 'dask'."
)

return client


@contextmanager
def get_or_create_dask_client(processes=False):
"""
Expand Down Expand Up @@ -72,13 +90,14 @@ def get_dask_lock(name):
name: The name of the lock
Returns:
"""

lock = None
try:
from dask.distributed import Lock

return Lock(name)
if get_client():
lock = Lock(name)
except ModuleNotFoundError:
LOGGER.warning(
"Can't import 'dask'. If you experiment out of memory issue, consider installing 'dask'."
)
return None
return lock
33 changes: 22 additions & 11 deletions sertit/rasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ def wrapper(any_raster_type: AnyRasterType, *args, **kwargs) -> Any:
if any_raster_type is None:
raise ValueError("'any_raster_type' shouldn't be None!")

default_chunks = True if dask.get_client() is not None else None

# By default, try with the input fct
try:
out = function(any_raster_type, *args, **kwargs)
Expand Down Expand Up @@ -227,7 +229,7 @@ def wrapper(any_raster_type: AnyRasterType, *args, **kwargs) -> Any:
any_raster_type,
masked=True,
default_name=ds.name,
chunks=kwargs.pop("chunks", True),
chunks=kwargs.pop("chunks", default_chunks),
) as xds:
out = function(xds, *args, **kwargs)
else:
Expand All @@ -244,7 +246,7 @@ def wrapper(any_raster_type: AnyRasterType, *args, **kwargs) -> Any:
any_raster_type,
masked=True,
default_name=name,
chunks=kwargs.pop("chunks", True),
chunks=kwargs.pop("chunks", default_chunks),
) as xds:
out = function(xds, *args, **kwargs)

Expand Down Expand Up @@ -1905,15 +1907,24 @@ def hillshade(
# replace xarray-spatial fct with GDAL compatible one
from functools import partial

_func = partial(
_run_hillshade,
az_rad=azimuth * DEG_2_RAD,
alt_rad=(90 - zenith) * DEG_2_RAD,
res=np.abs(xds.rio.resolution()),
)
out = xds.data.map_overlap(
_func, depth=(1, 1), boundary=np.nan, meta=np.array(())
)
try:
_func = partial(
_run_hillshade,
az_rad=azimuth * DEG_2_RAD,
alt_rad=(90 - zenith) * DEG_2_RAD,
res=np.abs(xds.rio.resolution()),
)
out = xds.data.map_overlap(
_func, depth=(1, 1), boundary=np.nan, meta=np.array(())
)
except AttributeError:
# Without dask
out = _run_hillshade(
xds.data,
az_rad=azimuth * DEG_2_RAD,
alt_rad=(90 - zenith) * DEG_2_RAD,
res=np.abs(xds.rio.resolution()),
)

xds = xds.copy(data=out).rename(kwargs.get("name", "hillshade"))

Expand Down

0 comments on commit c7ab809

Please sign in to comment.