From 8dbdfb20050992837063e4a8be3dc5d9372f4df7 Mon Sep 17 00:00:00 2001 From: Bodong Yang <86948717+Bodong-Yang@users.noreply.github.com> Date: Fri, 29 Nov 2024 17:47:09 +0900 Subject: [PATCH] fix(backport v3.8.x): backport PR#395: retry_task_map: fix mem leak when network is totally cut off (#436) * fix: retry_task_map: fix mem leak when network is totally cut off (#395) Previously, when handling failed task, the retry_task_map first releases the semaphore, and then reschedules the failed task directly without acquire the semaphore again. This results in the concurrent tasks semaphore being escaped when there are a lot of failed tasks. This is especially critical when the network is totally cut off, the failed tasks are pilling up and new tasks are still being scheduled, resulting in memory leaks. This PR fixes the above problem by fixing the task_done_cb, now when handling and re-scheduling the failed tasks, the semaphore will be kept. semaphore will only be released when the task is done successfully, or the thread-pool shutdowns. Other refinements to the retry_task_map includes: 1. watchdog: when watchdog func failed, watchdog thread will try to drain the pending workitem queue. 2. fut_gen: when thread-pool shutdowns, drain the finished futures queue. 3. dispatcher: directly exits when thread-pool shutdowns. 4. watchdog: refine the watchdog thread implementation. --- .github/workflows/test.yaml | 2 + src/otaclient/app/configs.py | 2 +- src/otaclient_common/retry_task_map.py | 177 +++++++++++++++++-------- 3 files changed, 125 insertions(+), 56 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 7315f953e..5a6c7c625 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -4,9 +4,11 @@ on: pull_request: branches: - main + - v* push: branches: - main + - v* paths: - "src/**" - "tests/**" diff --git a/src/otaclient/app/configs.py b/src/otaclient/app/configs.py index 9a76ab8b7..94f4c191e 100644 --- a/src/otaclient/app/configs.py +++ b/src/otaclient/app/configs.py @@ -97,7 +97,7 @@ class BaseConfig(_InternalSettings): "otaclient": INFO, "otaclient_api": INFO, "otaclient_common": INFO, - "otaproxy": INFO, + "ota_proxy": INFO, } LOG_FORMAT = ( "[%(asctime)s][%(levelname)s]-%(name)s:%(funcName)s:%(lineno)d,%(message)s" diff --git a/src/otaclient_common/retry_task_map.py b/src/otaclient_common/retry_task_map.py index 37b0e5cf7..414076476 100644 --- a/src/otaclient_common/retry_task_map.py +++ b/src/otaclient_common/retry_task_map.py @@ -24,7 +24,7 @@ from concurrent.futures import Future, ThreadPoolExecutor from functools import partial from queue import Empty, SimpleQueue -from typing import Any, Callable, Generator, Iterable, Optional +from typing import TYPE_CHECKING, Any, Callable, Generator, Iterable, Optional from otaclient_common.typing import RT, T @@ -35,7 +35,7 @@ class TasksEnsureFailed(Exception): """Exception for tasks ensuring failed.""" -class ThreadPoolExecutorWithRetry(ThreadPoolExecutor): +class _ThreadPoolExecutorWithRetry(ThreadPoolExecutor): def __init__( self, @@ -48,21 +48,6 @@ def __init__( initializer: Callable[..., Any] | None = None, initargs: tuple = (), ) -> None: - """Initialize a ThreadPoolExecutorWithRetry instance. - - Args: - max_concurrent (int): Limit the number pending scheduled tasks. - max_workers (Optional[int], optional): Max number of worker threads in the pool. Defaults to None. - max_total_retry (Optional[int], optional): Max total retry counts before abort. Defaults to None. - thread_name_prefix (str, optional): Defaults to "". - watchdog_func (Optional[Callable]): A custom func to be called on watchdog thread, when - this func raises exception, the watchdog will interrupt the tasks execution. Defaults to None. - watchdog_check_interval (int): Defaults to 3(seconds). - initializer (Callable[..., Any] | None): The same param passed through to ThreadPoolExecutor. - Defaults to None. - initargs (tuple): The same param passed through to ThreadPoolExecutor. - Defaults to (). - """ self._start_lock, self._started = threading.Lock(), False self._total_task_num = 0 """ @@ -78,6 +63,13 @@ def __init__( self._concurrent_semaphore = threading.Semaphore(max_concurrent) self._fut_queue: SimpleQueue[Future[Any]] = SimpleQueue() + self._watchdog_check_interval = watchdog_check_interval + self._checker_funcs: list[Callable[[], Any]] = [] + if isinstance(max_total_retry, int) and max_total_retry > 0: + self._checker_funcs.append(partial(self._max_retry_check, max_total_retry)) + if callable(watchdog_func): + self._checker_funcs.append(watchdog_func) + super().__init__( max_workers=max_workers, thread_name_prefix=thread_name_prefix, @@ -85,48 +77,53 @@ def __init__( initargs=initargs, ) - if max_total_retry or callable(watchdog_func): - threading.Thread( - target=self._watchdog, - args=(max_total_retry, watchdog_func, watchdog_check_interval), - daemon=True, - ).start() + def _max_retry_check(self, max_total_retry: int) -> None: + if self._retry_count > max_total_retry: + raise TasksEnsureFailed("exceed max retry count, abort") def _watchdog( self, - max_retry: int | None, - watchdog_func: Callable[..., Any] | None, + _checker_funcs: list[Callable[[], Any]], interval: int, ) -> None: """Watchdog will shutdown the threadpool on certain conditions being met.""" - while not self._shutdown and not concurrent_fut_thread._shutdown: - if max_retry and self._retry_count > max_retry: - logger.warning(f"exceed {max_retry=}, abort") - return self.shutdown(wait=True) - - if callable(watchdog_func): - try: - watchdog_func() - except Exception as e: - logger.warning(f"custom watchdog func failed: {e!r}, abort") - return self.shutdown(wait=True) + while not (self._shutdown or self._broken or concurrent_fut_thread._shutdown): time.sleep(interval) + try: + for _func in _checker_funcs: + _func() + except Exception as e: + logger.warning( + f"watchdog failed: {e!r}, shutdown the pool and draining the workitem queue on shutdown.." + ) + self.shutdown(wait=False) + # drain the worker queues to cancel all the futs + with contextlib.suppress(Empty): + while True: + self._work_queue.get_nowait() def _task_done_cb( self, fut: Future[Any], /, *, item: T, func: Callable[[T], Any] ) -> None: - self._concurrent_semaphore.release() # always release se first + if self._shutdown or self._broken or concurrent_fut_thread._shutdown: + self._concurrent_semaphore.release() # on shutdown, always release a se + return # on shutdown, no need to put done fut into fut_queue self._fut_queue.put_nowait(fut) # ------ on task failed ------ # if fut.exception(): self._retry_count = next(self._retry_counter) - with contextlib.suppress(Exception): # on threadpool shutdown + try: # try to re-schedule the failed task self.submit(func, item).add_done_callback( partial(self._task_done_cb, item=item, func=func) ) + except Exception: # if re-schedule doesn't happen, release se + self._concurrent_semaphore.release() + else: # release semaphore when succeeded + self._concurrent_semaphore.release() def _fut_gen(self, interval: int) -> Generator[Future[Any], Any, None]: + """Generator which yields the done future, controlled by the caller.""" finished_tasks = 0 while finished_tasks == 0 or finished_tasks != self._total_task_num: if self._total_task_num < 0: @@ -134,9 +131,12 @@ def _fut_gen(self, interval: int) -> Generator[Future[Any], Any, None]: if self._shutdown or self._broken or concurrent_fut_thread._shutdown: logger.warning( - f"failed to ensure all tasks, {finished_tasks=}, {self._total_task_num=}" + f"dispatcher exits on threadpool shutdown, {finished_tasks=}, {self._total_task_num=}" ) - raise TasksEnsureFailed + with contextlib.suppress(Empty): + while True: # drain the _fut_queue + self._fut_queue.get_nowait() + raise TasksEnsureFailed # raise exc to upper caller try: done_fut = self._fut_queue.get_nowait() @@ -153,20 +153,6 @@ def ensure_tasks( *, ensure_tasks_pull_interval: int = 1, ) -> Generator[Future[RT], None, None]: - """Ensure all the items in are processed by in the pool. - - Args: - func (Callable[[T], RT]): The function to take the item from . - iterable (Iterable[T]): The iterable of items to be processed by . - - Raises: - ValueError: If the pool is shutdown or broken, or this method has already - being called once. - TasksEnsureFailed: If failed to ensure all the tasks are finished. - - Yields: - The Future instance of each processed tasks. - """ with self._start_lock: if self._started: try: @@ -175,11 +161,26 @@ def ensure_tasks( del self, func, iterable self._started = True + if self._checker_funcs: + threading.Thread( + target=self._watchdog, + args=(self._checker_funcs, self._watchdog_check_interval), + daemon=True, + ).start() + # ------ dispatch tasks from iterable ------ # def _dispatcher() -> None: _tasks_count = -1 # means no task is scheduled try: for _tasks_count, item in enumerate(iterable, start=1): + if ( + self._shutdown + or self._broken + or concurrent_fut_thread._shutdown + ): + logger.warning("threadpool is closing, exit") + return # directly exit on shutdown + self._concurrent_semaphore.acquire() fut = self.submit(func, item) fut.add_done_callback( @@ -187,7 +188,7 @@ def _dispatcher() -> None: ) except Exception as e: logger.error(f"tasks dispatcher failed: {e!r}, abort") - self.shutdown(wait=True) + self.shutdown(wait=False) return self._total_task_num = _tasks_count @@ -203,3 +204,69 @@ def _dispatcher() -> None: # a generator so that the first fut will be dispatched before # we start to get from fut_queue. return self._fut_gen(ensure_tasks_pull_interval) + + +# only expose APIs we want to exposed +if TYPE_CHECKING: + + class ThreadPoolExecutorWithRetry: + + def __init__( + self, + max_concurrent: int, + max_workers: Optional[int] = None, + max_total_retry: Optional[int] = None, + thread_name_prefix: str = "", + watchdog_func: Optional[Callable] = None, + watchdog_check_interval: int = 3, # seconds + initializer: Callable[..., Any] | None = None, + initargs: tuple = (), + ) -> None: + """Initialize a ThreadPoolExecutorWithRetry instance. + + Args: + max_concurrent (int): Limit the number pending scheduled tasks. + max_workers (Optional[int], optional): Max number of worker threads in the pool. Defaults to None. + max_total_retry (Optional[int], optional): Max total retry counts before abort. Defaults to None. + thread_name_prefix (str, optional): Defaults to "". + watchdog_func (Optional[Callable]): A custom func to be called on watchdog thread, when + this func raises exception, the watchdog will interrupt the tasks execution. Defaults to None. + watchdog_check_interval (int): Defaults to 3(seconds). + initializer (Callable[..., Any] | None): The same param passed through to ThreadPoolExecutor. + Defaults to None. + initargs (tuple): The same param passed through to ThreadPoolExecutor. + Defaults to (). + """ + raise NotImplementedError + + def ensure_tasks( + self, + func: Callable[[T], RT], + iterable: Iterable[T], + *, + ensure_tasks_pull_interval: int = 1, + ) -> Generator[Future[RT], None, None]: + """Ensure all the items in are processed by in the pool. + + Args: + func (Callable[[T], RT]): The function to take the item from . + iterable (Iterable[T]): The iterable of items to be processed by . + + Raises: + ValueError: If the pool is shutdown or broken, or this method has already + being called once. + TasksEnsureFailed: If failed to ensure all the tasks are finished. + + Yields: + The Future instance of each processed tasks. + """ + raise NotImplementedError + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + raise NotImplementedError + +else: + ThreadPoolExecutorWithRetry = _ThreadPoolExecutorWithRetry