From 3b2db1cd9f6bf80e4bf82882bda52ba81f3d6371 Mon Sep 17 00:00:00 2001 From: Bazire Date: Mon, 23 Oct 2023 19:36:23 +0200 Subject: [PATCH] Re-working process worker --- giskard/utils/worker_pool.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/giskard/utils/worker_pool.py b/giskard/utils/worker_pool.py index 039808c84c..fc0a0c2f55 100644 --- a/giskard/utils/worker_pool.py +++ b/giskard/utils/worker_pool.py @@ -9,7 +9,7 @@ from dataclasses import dataclass, field from enum import Enum from io import StringIO -from multiprocessing import Process, Queue, SimpleQueue, cpu_count, get_context +from multiprocessing import Process, Queue, cpu_count, get_context from multiprocessing.context import SpawnContext, SpawnProcess from multiprocessing.managers import SyncManager from queue import Empty, Full @@ -92,9 +92,7 @@ class GiskardResult: exception: Any = None -def _process_worker( - tasks_queue: Queue[Optional[GiskardTask]], tasks_results: Queue[GiskardResult], running_process: Dict[str, str] -): +def _process_worker(tasks_queue: Queue, tasks_results: Queue, running_process: Dict[str, str]): pid = os.getpid() LOGGER.info("Process %s started", pid) @@ -159,12 +157,12 @@ def __init__(self, nb_workers: Optional[int] = None, name: Optional[str] = None) # Mapping of the running tasks and worker pids self.with_timeout_tasks: List[TimeoutData] = [] # Queue with tasks to run - self.pending_tasks_queue: Queue[GiskardTask] = self._mp_context.Queue() + self.pending_tasks_queue: Queue = self._mp_context.Queue() # Queue with tasks to be consumed asap # As in ProcessPool, add one more to avoid idling process - self.running_tasks_queue: Queue[Optional[GiskardTask]] = self._mp_context.Queue(maxsize=self._nb_workers + 1) + self.running_tasks_queue: Queue = self._mp_context.Queue(maxsize=self._nb_workers + 1) # Queue with results to notify - self.tasks_results: Queue[GiskardResult] = self._mp_context.Queue() + self.tasks_results: Queue = self._mp_context.Queue() # Mapping task_id with future self.futures_mapping: Dict[str, Future] = dict() LOGGER.debug("Starting threads for the WorkerPoolExecutor")