diff --git a/sdt_dask/clients/aws/fargate.py b/sdt_dask/clients/aws/fargate.py index c45a8f4d..841fc526 100644 --- a/sdt_dask/clients/aws/fargate.py +++ b/sdt_dask/clients/aws/fargate.py @@ -14,22 +14,23 @@ finally: class Fargate(Clients): - """ - Fargate Class for Dask on AWS Fargate - This class simplifies the process of setting up a Fargate cluster and - connecting a Dask client to it, enabling distributed execution - using AWS Fargate. - - Requires: - - dask - - dask_cloudprovider - - **Important:** Ensure you have appropriate IAM permissions to manage - AWS Fargate resources. - """ - def __init__(self): - pass + def __init__(self, + image: str = "", + tags: dict = {}, # optional + vpc: str = "", + region_name: str = "", + environment: dict = {}, + n_workers: int = 10, + threads_per_worker: int = 2 + ): + self.image = image + self.tags = tags + self.vpc = vpc + self.region_name = region_name + self.environment = environment + self.n_workers = n_workers + self.threads_per_worker = threads_per_worker def _check_versions(self): data = self.client.get_versions(check=True) scheduler_pkgs = data['scheduler']['packages'] @@ -40,50 +41,29 @@ def _check_versions(self): msg = 'Please Update the client version to match the Scheduler version' raise EnvironmentError(f"{c_pkg} version Mismatch:\n\tScheduler: {s_ver} vs Client: {c_ver}\n{msg}") - def init_client(self, - image: str = "", - tags: dict = {}, # optional - vpc: str = "", - region_name: str = "", - environment: dict = {}, - n_workers: int = 10, - threads_per_worker: int = 2 - ) -> Client: - """ - Initializes a Dask Client instance that leverages AWS Fargate for distributed execution. - - Args: - image (str, required): Docker image to use for the Fargate tasks. Defaults to "". - tags (dict, optional): Dictionary of tags to associate with the Fargate cluster. Defaults to an empty dictionary. - vpc (str, required): VPC ID to launch the Fargate cluster in. Defaults to "". - region_name (str, required): AWS region to launch the Fargate cluster in. Defaults to "". - environment (dict, required): Environment variables to set for the Fargate tasks. Defaults to an empty dictionary. - n_workers (int, optional): Number of worker nodes in the Fargate cluster. Defaults to 10. - threads_per_worker (int, optional): Number of threads per worker in the Fargate cluster. Defaults to 2. - - Returns: - Client: The initialized Dask client object connected to the Fargate cluster. - """ - print("[i] Initilializing Fargate Cluster ...") + def init_client(self) -> tuple: + try: + print("[i] Initilializing Fargate Cluster ...") - cluster = FargateCluster( - tags = tags, - image = image, - vpc = vpc, - region_name = region_name, - environment = environment, - n_workers = n_workers, - worker_nthreads = threads_per_worker - ) + cluster = FargateCluster( + tags = self.tags, + image = self.image, + vpc = self.vpc, + region_name = self.region_name, + environment = self.environment, + n_workers = self.n_workers, + worker_nthreads = self.threads_per_worker + ) - print("[i] Initialized Fargate Cluster") - print("[i] Initilializing Dask Client ...") + print("[i] Initialized Fargate Cluster") + print("[i] Initilializing Dask Client ...") - self.client = Client(cluster) + self.client = Client(cluster) - self._check_versions() + self._check_versions() - print(f"[>] Dask Dashboard: {self.client.dashboard_link}") + print(f"[>] Dask Dashboard: {self.client.dashboard_link}") - return self.client - \ No newline at end of file + return self.client, cluster + except Exception as e: + raise Exception(e) \ No newline at end of file diff --git a/sdt_dask/clients/local.py b/sdt_dask/clients/local.py index 27ac700a..ce140340 100644 --- a/sdt_dask/clients/local.py +++ b/sdt_dask/clients/local.py @@ -1,21 +1,8 @@ """ -Local Client for Dask Distributed Computing -============================================ - -This module provides a class for initializing a Dask client optimized for local execution. -It retrieves system information and configures the client based on resource availability. - -Classes: --------- -Local - Manages the creation of a Dask client with local configuration. - -Functions: ----------- -None +TODO: Change documentation to sphinx format """ try: - import os, platform, psutil + import os, platform, psutil, dask.config from dask.distributed import Client from sdt_dask.clients.clients import Clients @@ -26,100 +13,64 @@ raise ModuleNotFoundError(f"{error}\n[!] Check or reinstall the following packages\n{packages}") finally: - """ - Initializes a Dask client for local execution with resource-aware configuration. - """ + class Local(Clients): - """ - Initializes class attributes. - """ - def __init__(self): - pass - - """ - Retrieves system information for client configuration. + def __init__(self, n_workers: int = 2, threads_per_worker: int = 2, memory_per_worker: int = 5, verbose: bool = False): + self.verbose = verbose + self.n_workers = n_workers + self.threads_per_worker = threads_per_worker + self.memory_per_worker = memory_per_worker + self.dask_config = dask.config - Attributes: - ----------- - self.system: str - The operating system name (e.g., "windows", "linux"). - self.cpu_count: int - The number of CPU cores available on the system. - self.memory: int - The total system memory in GB. - """ - def _get_variables(self): + def _get_sys_var(self): self.system = platform.system().lower() self.cpu_count = os.cpu_count() self.memory = int((psutil.virtual_memory().total / (1024.**3))) - """ - Checks if the specified worker configuration is compatible with system resources. + def _config_init(self): + tmp_dir = dask.config.get('temporary_directory') + if not tmp_dir: + self.dask_config.set({'distributed.worker.memory.spill': False}) + self.dask_config.set({'distributed.worker.memory.pause': False}) + self.dask_config.set({'distributed.worker.memory.target': 0.8}) - Raises: - ------- - Exception: - If the configuration exceeds available resources. - """ def _check(self): - if self.workers * self.threads_per_worker > self.cpu_count: + self._get_sys_var() + # workers and threads need to be less than cpu core count + # memory per worker >= 5 GB but total memory use should be less than the system memory available + if self.n_workers * self.threads_per_worker > self.cpu_count: raise Exception(f"workers and threads exceed local resources, {self.cpu_count} cores present") - elif self.memory_per_worker < 5: - raise Exception(f"memory per worker too small, minimum memory size per worker 5 GB") - - """ - Initializes a Dask client with local configuration. - - Args: - ----- - n_workers: int, optional - The number of Dask workers to create (default: 2). - threads_per_worker: int, optional - The number of threads to use per worker (default: 2). - memory_per_worker: int, optional - The memory limit for each worker in GB (default: 5). - verbose: bool, optional - If True, prints system and client configuration information. - - Returns: - -------- - Client: - The initialized Dask client object. - """ - def init_client(self, n_workers: int = 2, threads_per_worker: int = 2, memory_per_worker: int = 5, verbose: bool = False) -> Client: - self._get_variables() - - - self.workers = n_workers - self.threads_per_worker = threads_per_worker - self.memory_per_worker = memory_per_worker - memory_spill_fraction = False + if self.n_workers * self.memory_per_worker > self.memory: + self.dask_config.set({'distributed.worker.memory.spill': True}) + print(f"[!] memory per worker exceeds system memory ({self.memory} GB), activating memory spill fraction\n") + def init_client(self) -> Client: + self._config_init() self._check() - - if self.workers * self.memory_per_worker > self.memory: - print(f"[!] memory per worker exceeds system memory ({self.memory} GB), activating memory spill fraction\n") - memory_spill_fraction = 0.8 if self.system == "windows": - self.client = Client(processes=False, - memory_spill_fraction=memory_spill_fraction, - memory_pause_fraction=False, - memory_target_fraction=0.8, # 0.8 - n_workers=self.workers, + self.client = Client(processes=False, + n_workers=self.n_workers, threads_per_worker=self.threads_per_worker, memory_limit=f"{self.memory_per_worker:.2f}GiB" ) else: - self.client = Client(memory_limit=f"{self.memory}GB") + self.client = Client(processes=True, + n_workers=self.n_workers, + threads_per_worker=self.threads_per_worker, + memory_limit=f"{self.memory_per_worker:.2f}GiB" + ) - if verbose: + if self.verbose: print(f"[i] System: {self.system}") print(f"[i] CPU Count: {self.cpu_count}") - print(f"[i] Memory: {self.memory}") - print(f"[i] Workers: {self.workers}") + print(f"[i] System Memory: {self.memory}") + print(f"[i] Workers: {self.n_workers}") print(f"[i] Threads per Worker: {self.threads_per_worker}") print(f"[i] Memory per Worker: {self.memory_per_worker}") + print(f"[i] Dask worker config:") + for key, value in self.dask_config.get('distributed.worker').items(): + print(f"{key} : {value}") print(f"\n[>] Dask Dashboard: {self.client.dashboard_link}\n")