diff --git a/CHANGES.md b/CHANGES.md index d24ae43..7e0c6b1 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -3,6 +3,7 @@ ## 1.43.5 (2024-mm-dd) - FIX: Fix the ability to save COGs with any dtype with Dask, with the workaround described [here](https://github.com/opendatacube/odc-geo/issues/189#issuecomment-2513450481) (don't compute statistics for problematic dtypes) +- FIX: Better separability of `dask` (it has its own module now): don't create a client if the user doesn't specify it (as it is not required anymore in `Lock`). This should remove the force-use of `dask`. ## 1.43.4 (2024-11-28) diff --git a/CI/SCRIPTS/script_utils.py b/CI/SCRIPTS/script_utils.py index 87dee1e..59f67d4 100644 --- a/CI/SCRIPTS/script_utils.py +++ b/CI/SCRIPTS/script_utils.py @@ -19,7 +19,7 @@ from enum import unique from functools import wraps -from sertit import AnyPath, unistra +from sertit import AnyPath, dask, unistra from sertit.misc import ListEnum CI_SERTIT_S3 = "CI_SERTIT_USE_S3" @@ -72,9 +72,7 @@ def dask_env(function): def dask_env_wrapper(*_args, **_kwargs): """S3 environment wrapper""" try: - from dask.distributed import Client, LocalCluster - - with LocalCluster() as cluster, Client(cluster): + with dask.get_or_create_dask_client(): print("Using DASK") return function(*_args, **_kwargs) except ImportError: diff --git a/sertit/dask.py b/sertit/dask.py new file mode 100644 index 0000000..8615652 --- /dev/null +++ b/sertit/dask.py @@ -0,0 +1,76 @@ +import logging +from contextlib import contextmanager + +import psutil + +from sertit import logs + +LOGGER = logging.getLogger(logs.SU_NAME) + + +@contextmanager +def get_or_create_dask_client(processes=False): + """ + Return default Dask client or create a local cluster and linked client if not existing + Returns: + """ + + try: + from dask.distributed import Client, get_client + + ram_info = psutil.virtual_memory() + available_ram = ram_info.available / 1024 / 1024 / 1024 + available_ram = 0.9 * available_ram + + n_workers = 1 + memory_limit = f"{available_ram}Gb" + if available_ram >= 16: + n_workers = available_ram // 16 + memory_limit = f"{16}Gb" + try: + # Return default client + yield get_client() + except ValueError: + if processes: + # Create a local cluster and return client + LOGGER.warning( + f"Init local cluster with {n_workers} workers and {memory_limit} per worker" + ) + yield Client( + n_workers=int(n_workers), + threads_per_worker=4, + memory_limit=memory_limit, + ) + else: + # Create a local cluster (threaded) + LOGGER.warning("Init local cluster (threaded)") + yield Client( + processes=processes, + ) + + except ModuleNotFoundError: + LOGGER.warning( + "Can't import 'dask'. If you experiment out of memory issue, consider installing 'dask'." + ) + + return None + + +def get_dask_lock(name): + """ + Get a dask lock with given name. This lock uses the default client if existing; + or create a local cluster (get_or_create_dask_client) otherwise. + Args: + name: The name of the lock + Returns: + """ + + try: + from dask.distributed import Lock + + return Lock(name) + except ModuleNotFoundError: + LOGGER.warning( + "Can't import 'dask'. If you experiment out of memory issue, consider installing 'dask'." + ) + return None diff --git a/sertit/rasters.py b/sertit/rasters.py index 675f9c8..37000ce 100644 --- a/sertit/rasters.py +++ b/sertit/rasters.py @@ -25,7 +25,6 @@ import geopandas as gpd import numpy as np -import psutil import xarray as xr from shapely.geometry import Polygon @@ -40,7 +39,7 @@ "Please install 'rioxarray' to use the 'rasters' package." ) from ex -from sertit import geometry, logs, misc, path, rasters_rio, vectors +from sertit import dask, geometry, logs, misc, path, rasters_rio, vectors from sertit.types import AnyPathStrType, AnyPathType, AnyRasterType, AnyXrDataStructure MAX_CORES = rasters_rio.MAX_CORES @@ -97,75 +96,6 @@ def get_nodata_value_from_xr(xds: AnyXrDataStructure) -> float: return nodata -def get_or_create_dask_client(processes=False): - """ - Return default Dask client or create a local cluster and linked client if not existing - Returns: - """ - - try: - from dask.distributed import Client, get_client # noqa - - ram_info = psutil.virtual_memory() - available_ram = ram_info.available / 1024 / 1024 / 1024 - available_ram = 0.9 * available_ram - - n_workers = 1 - memory_limit = f"{available_ram}Gb" - if available_ram >= 16: - n_workers = available_ram // 16 - memory_limit = f"{16}Gb" - try: - # Return default client - return get_client() # noqa - except ValueError: - if processes: - # Create a local cluster and return client - LOGGER.warning( - f"Init local cluster with {n_workers} workers and {memory_limit} per worker" - ) - return Client( - n_workers=int(n_workers), - threads_per_worker=4, - memory_limit=memory_limit, - ) - else: - # Create a local cluster (threaded) - LOGGER.warning("Init local cluster (threaded)") - return Client( - processes=processes, - ) - - except ModuleNotFoundError: - LOGGER.warning( - "Can't import dask. If you experiment out of memory issue, consider installing dask." - ) - - return None - - -def get_dask_lock(name): - """ - Get a dask lock with given name. This lock uses the default client if existing; - or create a local cluster (get_or_create_dask_client) otherwise. - Args: - name: The name of the lock - Returns: - """ - - try: - client = get_or_create_dask_client() - from dask.diagnostics import ProgressBar # noqa - from dask.distributed import Client, Lock, get_client # noqa - - return Lock(name, client=client) # noqa - except ModuleNotFoundError: - LOGGER.warning( - "Can't import dask. If you experiment out of memory issue, consider installing dask." - ) - return None - - def get_nodata_value_from_dtype(dtype) -> float: """ Get default nodata value from any given dtype. @@ -1157,7 +1087,7 @@ def write( else: # Get default client's lock - kwargs["lock"] = kwargs.get("lock", get_dask_lock("rio")) + kwargs["lock"] = kwargs.get("lock", dask.get_dask_lock("rio")) # Set tiles by default kwargs["tiled"] = kwargs.get("tiled", True)