Skip to content

Commit

Permalink
FIX: Better separability of dask (it has its own module now): don't…
Browse files Browse the repository at this point in the history
… create a client if the user doesn't specify it (as it is not required anymore in `Lock`). This should remove the force-use of `dask`.
  • Loading branch information
remi-braun committed Dec 3, 2024
1 parent 10e7da7 commit 29724e3
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 76 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## 1.43.5 (2024-mm-dd)

- FIX: Fix the ability to save COGs with any dtype with Dask, with the workaround described [here](https://github.com/opendatacube/odc-geo/issues/189#issuecomment-2513450481) (don't compute statistics for problematic dtypes)
- FIX: Better separability of `dask` (it has its own module now): don't create a client if the user doesn't specify it (as it is not required anymore in `Lock`). This should remove the force-use of `dask`.

## 1.43.4 (2024-11-28)

Expand Down
6 changes: 2 additions & 4 deletions CI/SCRIPTS/script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from enum import unique
from functools import wraps

from sertit import AnyPath, unistra
from sertit import AnyPath, dask, unistra
from sertit.misc import ListEnum

CI_SERTIT_S3 = "CI_SERTIT_USE_S3"
Expand Down Expand Up @@ -72,9 +72,7 @@ def dask_env(function):
def dask_env_wrapper(*_args, **_kwargs):
"""S3 environment wrapper"""
try:
from dask.distributed import Client, LocalCluster

with LocalCluster() as cluster, Client(cluster):
with dask.get_or_create_dask_client():
print("Using DASK")
return function(*_args, **_kwargs)
except ImportError:
Expand Down
76 changes: 76 additions & 0 deletions sertit/dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import logging
from contextlib import contextmanager

import psutil

from sertit import logs

LOGGER = logging.getLogger(logs.SU_NAME)


@contextmanager
def get_or_create_dask_client(processes=False):
"""
Return default Dask client or create a local cluster and linked client if not existing
Returns:
"""

try:
from dask.distributed import Client, get_client

ram_info = psutil.virtual_memory()
available_ram = ram_info.available / 1024 / 1024 / 1024
available_ram = 0.9 * available_ram

n_workers = 1
memory_limit = f"{available_ram}Gb"
if available_ram >= 16:
n_workers = available_ram // 16
memory_limit = f"{16}Gb"
try:
# Return default client
yield get_client()
except ValueError:
if processes:
# Create a local cluster and return client
LOGGER.warning(
f"Init local cluster with {n_workers} workers and {memory_limit} per worker"
)
yield Client(
n_workers=int(n_workers),
threads_per_worker=4,
memory_limit=memory_limit,
)
else:
# Create a local cluster (threaded)
LOGGER.warning("Init local cluster (threaded)")
yield Client(
processes=processes,
)

except ModuleNotFoundError:
LOGGER.warning(
"Can't import 'dask'. If you experiment out of memory issue, consider installing 'dask'."
)

return None


def get_dask_lock(name):
"""
Get a dask lock with given name. This lock uses the default client if existing;
or create a local cluster (get_or_create_dask_client) otherwise.
Args:
name: The name of the lock
Returns:
"""

try:
from dask.distributed import Lock

return Lock(name)
except ModuleNotFoundError:
LOGGER.warning(
"Can't import 'dask'. If you experiment out of memory issue, consider installing 'dask'."
)
return None
74 changes: 2 additions & 72 deletions sertit/rasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

import geopandas as gpd
import numpy as np
import psutil
import xarray as xr
from shapely.geometry import Polygon

Expand All @@ -40,7 +39,7 @@
"Please install 'rioxarray' to use the 'rasters' package."
) from ex

from sertit import geometry, logs, misc, path, rasters_rio, vectors
from sertit import dask, geometry, logs, misc, path, rasters_rio, vectors
from sertit.types import AnyPathStrType, AnyPathType, AnyRasterType, AnyXrDataStructure

MAX_CORES = rasters_rio.MAX_CORES
Expand Down Expand Up @@ -97,75 +96,6 @@ def get_nodata_value_from_xr(xds: AnyXrDataStructure) -> float:
return nodata


def get_or_create_dask_client(processes=False):
"""
Return default Dask client or create a local cluster and linked client if not existing
Returns:
"""

try:
from dask.distributed import Client, get_client # noqa

ram_info = psutil.virtual_memory()
available_ram = ram_info.available / 1024 / 1024 / 1024
available_ram = 0.9 * available_ram

n_workers = 1
memory_limit = f"{available_ram}Gb"
if available_ram >= 16:
n_workers = available_ram // 16
memory_limit = f"{16}Gb"
try:
# Return default client
return get_client() # noqa
except ValueError:
if processes:
# Create a local cluster and return client
LOGGER.warning(
f"Init local cluster with {n_workers} workers and {memory_limit} per worker"
)
return Client(
n_workers=int(n_workers),
threads_per_worker=4,
memory_limit=memory_limit,
)
else:
# Create a local cluster (threaded)
LOGGER.warning("Init local cluster (threaded)")
return Client(
processes=processes,
)

except ModuleNotFoundError:
LOGGER.warning(
"Can't import dask. If you experiment out of memory issue, consider installing dask."
)

return None


def get_dask_lock(name):
"""
Get a dask lock with given name. This lock uses the default client if existing;
or create a local cluster (get_or_create_dask_client) otherwise.
Args:
name: The name of the lock
Returns:
"""

try:
client = get_or_create_dask_client()
from dask.diagnostics import ProgressBar # noqa
from dask.distributed import Client, Lock, get_client # noqa

return Lock(name, client=client) # noqa
except ModuleNotFoundError:
LOGGER.warning(
"Can't import dask. If you experiment out of memory issue, consider installing dask."
)
return None


def get_nodata_value_from_dtype(dtype) -> float:
"""
Get default nodata value from any given dtype.
Expand Down Expand Up @@ -1157,7 +1087,7 @@ def write(

else:
# Get default client's lock
kwargs["lock"] = kwargs.get("lock", get_dask_lock("rio"))
kwargs["lock"] = kwargs.get("lock", dask.get_dask_lock("rio"))

# Set tiles by default
kwargs["tiled"] = kwargs.get("tiled", True)
Expand Down

0 comments on commit 29724e3

Please sign in to comment.