Skip to content

Commit

Permalink
Client Configurations and Exception handling
Browse files Browse the repository at this point in the history
  • Loading branch information
nimishy committed Mar 28, 2024
1 parent e841c4e commit 34114c7
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 143 deletions.
92 changes: 36 additions & 56 deletions sdt_dask/clients/aws/fargate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,23 @@

finally:
class Fargate(Clients):
"""
Fargate Class for Dask on AWS Fargate

This class simplifies the process of setting up a Fargate cluster and
connecting a Dask client to it, enabling distributed execution
using AWS Fargate.
Requires:
- dask
- dask_cloudprovider
**Important:** Ensure you have appropriate IAM permissions to manage
AWS Fargate resources.
"""
def __init__(self):
pass
def __init__(self,
image: str = "",
tags: dict = {}, # optional
vpc: str = "",
region_name: str = "",
environment: dict = {},
n_workers: int = 10,
threads_per_worker: int = 2
):
self.image = image
self.tags = tags
self.vpc = vpc
self.region_name = region_name
self.environment = environment
self.n_workers = n_workers
self.threads_per_worker = threads_per_worker
def _check_versions(self):
data = self.client.get_versions(check=True)
scheduler_pkgs = data['scheduler']['packages']
Expand All @@ -40,50 +41,29 @@ def _check_versions(self):
msg = 'Please Update the client version to match the Scheduler version'
raise EnvironmentError(f"{c_pkg} version Mismatch:\n\tScheduler: {s_ver} vs Client: {c_ver}\n{msg}")

def init_client(self,
image: str = "",
tags: dict = {}, # optional
vpc: str = "",
region_name: str = "",
environment: dict = {},
n_workers: int = 10,
threads_per_worker: int = 2
) -> Client:
"""
Initializes a Dask Client instance that leverages AWS Fargate for distributed execution.
Args:
image (str, required): Docker image to use for the Fargate tasks. Defaults to "".
tags (dict, optional): Dictionary of tags to associate with the Fargate cluster. Defaults to an empty dictionary.
vpc (str, required): VPC ID to launch the Fargate cluster in. Defaults to "".
region_name (str, required): AWS region to launch the Fargate cluster in. Defaults to "".
environment (dict, required): Environment variables to set for the Fargate tasks. Defaults to an empty dictionary.
n_workers (int, optional): Number of worker nodes in the Fargate cluster. Defaults to 10.
threads_per_worker (int, optional): Number of threads per worker in the Fargate cluster. Defaults to 2.
Returns:
Client: The initialized Dask client object connected to the Fargate cluster.
"""
print("[i] Initilializing Fargate Cluster ...")
def init_client(self) -> tuple:
try:
print("[i] Initilializing Fargate Cluster ...")

cluster = FargateCluster(
tags = tags,
image = image,
vpc = vpc,
region_name = region_name,
environment = environment,
n_workers = n_workers,
worker_nthreads = threads_per_worker
)
cluster = FargateCluster(
tags = self.tags,
image = self.image,
vpc = self.vpc,
region_name = self.region_name,
environment = self.environment,
n_workers = self.n_workers,
worker_nthreads = self.threads_per_worker
)

print("[i] Initialized Fargate Cluster")
print("[i] Initilializing Dask Client ...")
print("[i] Initialized Fargate Cluster")
print("[i] Initilializing Dask Client ...")

self.client = Client(cluster)
self.client = Client(cluster)

self._check_versions()
self._check_versions()

print(f"[>] Dask Dashboard: {self.client.dashboard_link}")
print(f"[>] Dask Dashboard: {self.client.dashboard_link}")

return self.client

return self.client, cluster
except Exception as e:
raise Exception(e)
125 changes: 38 additions & 87 deletions sdt_dask/clients/local.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,8 @@
"""
Local Client for Dask Distributed Computing
============================================
This module provides a class for initializing a Dask client optimized for local execution.
It retrieves system information and configures the client based on resource availability.
Classes:
--------
Local
Manages the creation of a Dask client with local configuration.
Functions:
----------
None
TODO: Change documentation to sphinx format
"""
try:
import os, platform, psutil
import os, platform, psutil, dask.config
from dask.distributed import Client
from sdt_dask.clients.clients import Clients

Expand All @@ -26,100 +13,64 @@
raise ModuleNotFoundError(f"{error}\n[!] Check or reinstall the following packages\n{packages}")

finally:
"""
Initializes a Dask client for local execution with resource-aware configuration.
"""

class Local(Clients):
"""
Initializes class attributes.
"""
def __init__(self):
pass

"""
Retrieves system information for client configuration.
def __init__(self, n_workers: int = 2, threads_per_worker: int = 2, memory_per_worker: int = 5, verbose: bool = False):
self.verbose = verbose
self.n_workers = n_workers
self.threads_per_worker = threads_per_worker
self.memory_per_worker = memory_per_worker
self.dask_config = dask.config

Attributes:
-----------
self.system: str
The operating system name (e.g., "windows", "linux").
self.cpu_count: int
The number of CPU cores available on the system.
self.memory: int
The total system memory in GB.
"""
def _get_variables(self):
def _get_sys_var(self):
self.system = platform.system().lower()
self.cpu_count = os.cpu_count()
self.memory = int((psutil.virtual_memory().total / (1024.**3)))

"""
Checks if the specified worker configuration is compatible with system resources.
def _config_init(self):
tmp_dir = dask.config.get('temporary_directory')
if not tmp_dir:
self.dask_config.set({'distributed.worker.memory.spill': False})
self.dask_config.set({'distributed.worker.memory.pause': False})
self.dask_config.set({'distributed.worker.memory.target': 0.8})

Raises:
-------
Exception:
If the configuration exceeds available resources.
"""
def _check(self):
if self.workers * self.threads_per_worker > self.cpu_count:
self._get_sys_var()
# workers and threads need to be less than cpu core count
# memory per worker >= 5 GB but total memory use should be less than the system memory available
if self.n_workers * self.threads_per_worker > self.cpu_count:
raise Exception(f"workers and threads exceed local resources, {self.cpu_count} cores present")
elif self.memory_per_worker < 5:
raise Exception(f"memory per worker too small, minimum memory size per worker 5 GB")

"""
Initializes a Dask client with local configuration.
Args:
-----
n_workers: int, optional
The number of Dask workers to create (default: 2).
threads_per_worker: int, optional
The number of threads to use per worker (default: 2).
memory_per_worker: int, optional
The memory limit for each worker in GB (default: 5).
verbose: bool, optional
If True, prints system and client configuration information.
Returns:
--------
Client:
The initialized Dask client object.
"""
def init_client(self, n_workers: int = 2, threads_per_worker: int = 2, memory_per_worker: int = 5, verbose: bool = False) -> Client:
self._get_variables()


self.workers = n_workers
self.threads_per_worker = threads_per_worker
self.memory_per_worker = memory_per_worker
memory_spill_fraction = False
if self.n_workers * self.memory_per_worker > self.memory:
self.dask_config.set({'distributed.worker.memory.spill': True})
print(f"[!] memory per worker exceeds system memory ({self.memory} GB), activating memory spill fraction\n")

def init_client(self) -> Client:
self._config_init()
self._check()

if self.workers * self.memory_per_worker > self.memory:
print(f"[!] memory per worker exceeds system memory ({self.memory} GB), activating memory spill fraction\n")
memory_spill_fraction = 0.8

if self.system == "windows":
self.client = Client(processes=False,
memory_spill_fraction=memory_spill_fraction,
memory_pause_fraction=False,
memory_target_fraction=0.8, # 0.8
n_workers=self.workers,
self.client = Client(processes=False,
n_workers=self.n_workers,
threads_per_worker=self.threads_per_worker,
memory_limit=f"{self.memory_per_worker:.2f}GiB"
)
else:
self.client = Client(memory_limit=f"{self.memory}GB")
self.client = Client(processes=True,
n_workers=self.n_workers,
threads_per_worker=self.threads_per_worker,
memory_limit=f"{self.memory_per_worker:.2f}GiB"
)

if verbose:
if self.verbose:
print(f"[i] System: {self.system}")
print(f"[i] CPU Count: {self.cpu_count}")
print(f"[i] Memory: {self.memory}")
print(f"[i] Workers: {self.workers}")
print(f"[i] System Memory: {self.memory}")
print(f"[i] Workers: {self.n_workers}")
print(f"[i] Threads per Worker: {self.threads_per_worker}")
print(f"[i] Memory per Worker: {self.memory_per_worker}")
print(f"[i] Dask worker config:")
for key, value in self.dask_config.get('distributed.worker').items():
print(f"{key} : {value}")

print(f"\n[>] Dask Dashboard: {self.client.dashboard_link}\n")

Expand Down

0 comments on commit 34114c7

Please sign in to comment.