diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 7315f953e..5a6c7c625 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -4,9 +4,11 @@ on: pull_request: branches: - main + - v* push: branches: - main + - v* paths: - "src/**" - "tests/**" diff --git a/src/otaclient/app/configs.py b/src/otaclient/app/configs.py index 9a76ab8b7..94f4c191e 100644 --- a/src/otaclient/app/configs.py +++ b/src/otaclient/app/configs.py @@ -97,7 +97,7 @@ class BaseConfig(_InternalSettings): "otaclient": INFO, "otaclient_api": INFO, "otaclient_common": INFO, - "otaproxy": INFO, + "ota_proxy": INFO, } LOG_FORMAT = ( "[%(asctime)s][%(levelname)s]-%(name)s:%(funcName)s:%(lineno)d,%(message)s" diff --git a/src/otaclient_common/retry_task_map.py b/src/otaclient_common/retry_task_map.py index 37b0e5cf7..414076476 100644 --- a/src/otaclient_common/retry_task_map.py +++ b/src/otaclient_common/retry_task_map.py @@ -24,7 +24,7 @@ from concurrent.futures import Future, ThreadPoolExecutor from functools import partial from queue import Empty, SimpleQueue -from typing import Any, Callable, Generator, Iterable, Optional +from typing import TYPE_CHECKING, Any, Callable, Generator, Iterable, Optional from otaclient_common.typing import RT, T @@ -35,7 +35,7 @@ class TasksEnsureFailed(Exception): """Exception for tasks ensuring failed.""" -class ThreadPoolExecutorWithRetry(ThreadPoolExecutor): +class _ThreadPoolExecutorWithRetry(ThreadPoolExecutor): def __init__( self, @@ -48,21 +48,6 @@ def __init__( initializer: Callable[..., Any] | None = None, initargs: tuple = (), ) -> None: - """Initialize a ThreadPoolExecutorWithRetry instance. - - Args: - max_concurrent (int): Limit the number pending scheduled tasks. - max_workers (Optional[int], optional): Max number of worker threads in the pool. Defaults to None. - max_total_retry (Optional[int], optional): Max total retry counts before abort. Defaults to None. - thread_name_prefix (str, optional): Defaults to "". - watchdog_func (Optional[Callable]): A custom func to be called on watchdog thread, when - this func raises exception, the watchdog will interrupt the tasks execution. Defaults to None. - watchdog_check_interval (int): Defaults to 3(seconds). - initializer (Callable[..., Any] | None): The same param passed through to ThreadPoolExecutor. - Defaults to None. - initargs (tuple): The same param passed through to ThreadPoolExecutor. - Defaults to (). - """ self._start_lock, self._started = threading.Lock(), False self._total_task_num = 0 """ @@ -78,6 +63,13 @@ def __init__( self._concurrent_semaphore = threading.Semaphore(max_concurrent) self._fut_queue: SimpleQueue[Future[Any]] = SimpleQueue() + self._watchdog_check_interval = watchdog_check_interval + self._checker_funcs: list[Callable[[], Any]] = [] + if isinstance(max_total_retry, int) and max_total_retry > 0: + self._checker_funcs.append(partial(self._max_retry_check, max_total_retry)) + if callable(watchdog_func): + self._checker_funcs.append(watchdog_func) + super().__init__( max_workers=max_workers, thread_name_prefix=thread_name_prefix, @@ -85,48 +77,53 @@ def __init__( initargs=initargs, ) - if max_total_retry or callable(watchdog_func): - threading.Thread( - target=self._watchdog, - args=(max_total_retry, watchdog_func, watchdog_check_interval), - daemon=True, - ).start() + def _max_retry_check(self, max_total_retry: int) -> None: + if self._retry_count > max_total_retry: + raise TasksEnsureFailed("exceed max retry count, abort") def _watchdog( self, - max_retry: int | None, - watchdog_func: Callable[..., Any] | None, + _checker_funcs: list[Callable[[], Any]], interval: int, ) -> None: """Watchdog will shutdown the threadpool on certain conditions being met.""" - while not self._shutdown and not concurrent_fut_thread._shutdown: - if max_retry and self._retry_count > max_retry: - logger.warning(f"exceed {max_retry=}, abort") - return self.shutdown(wait=True) - - if callable(watchdog_func): - try: - watchdog_func() - except Exception as e: - logger.warning(f"custom watchdog func failed: {e!r}, abort") - return self.shutdown(wait=True) + while not (self._shutdown or self._broken or concurrent_fut_thread._shutdown): time.sleep(interval) + try: + for _func in _checker_funcs: + _func() + except Exception as e: + logger.warning( + f"watchdog failed: {e!r}, shutdown the pool and draining the workitem queue on shutdown.." + ) + self.shutdown(wait=False) + # drain the worker queues to cancel all the futs + with contextlib.suppress(Empty): + while True: + self._work_queue.get_nowait() def _task_done_cb( self, fut: Future[Any], /, *, item: T, func: Callable[[T], Any] ) -> None: - self._concurrent_semaphore.release() # always release se first + if self._shutdown or self._broken or concurrent_fut_thread._shutdown: + self._concurrent_semaphore.release() # on shutdown, always release a se + return # on shutdown, no need to put done fut into fut_queue self._fut_queue.put_nowait(fut) # ------ on task failed ------ # if fut.exception(): self._retry_count = next(self._retry_counter) - with contextlib.suppress(Exception): # on threadpool shutdown + try: # try to re-schedule the failed task self.submit(func, item).add_done_callback( partial(self._task_done_cb, item=item, func=func) ) + except Exception: # if re-schedule doesn't happen, release se + self._concurrent_semaphore.release() + else: # release semaphore when succeeded + self._concurrent_semaphore.release() def _fut_gen(self, interval: int) -> Generator[Future[Any], Any, None]: + """Generator which yields the done future, controlled by the caller.""" finished_tasks = 0 while finished_tasks == 0 or finished_tasks != self._total_task_num: if self._total_task_num < 0: @@ -134,9 +131,12 @@ def _fut_gen(self, interval: int) -> Generator[Future[Any], Any, None]: if self._shutdown or self._broken or concurrent_fut_thread._shutdown: logger.warning( - f"failed to ensure all tasks, {finished_tasks=}, {self._total_task_num=}" + f"dispatcher exits on threadpool shutdown, {finished_tasks=}, {self._total_task_num=}" ) - raise TasksEnsureFailed + with contextlib.suppress(Empty): + while True: # drain the _fut_queue + self._fut_queue.get_nowait() + raise TasksEnsureFailed # raise exc to upper caller try: done_fut = self._fut_queue.get_nowait() @@ -153,20 +153,6 @@ def ensure_tasks( *, ensure_tasks_pull_interval: int = 1, ) -> Generator[Future[RT], None, None]: - """Ensure all the items in are processed by in the pool. - - Args: - func (Callable[[T], RT]): The function to take the item from . - iterable (Iterable[T]): The iterable of items to be processed by . - - Raises: - ValueError: If the pool is shutdown or broken, or this method has already - being called once. - TasksEnsureFailed: If failed to ensure all the tasks are finished. - - Yields: - The Future instance of each processed tasks. - """ with self._start_lock: if self._started: try: @@ -175,11 +161,26 @@ def ensure_tasks( del self, func, iterable self._started = True + if self._checker_funcs: + threading.Thread( + target=self._watchdog, + args=(self._checker_funcs, self._watchdog_check_interval), + daemon=True, + ).start() + # ------ dispatch tasks from iterable ------ # def _dispatcher() -> None: _tasks_count = -1 # means no task is scheduled try: for _tasks_count, item in enumerate(iterable, start=1): + if ( + self._shutdown + or self._broken + or concurrent_fut_thread._shutdown + ): + logger.warning("threadpool is closing, exit") + return # directly exit on shutdown + self._concurrent_semaphore.acquire() fut = self.submit(func, item) fut.add_done_callback( @@ -187,7 +188,7 @@ def _dispatcher() -> None: ) except Exception as e: logger.error(f"tasks dispatcher failed: {e!r}, abort") - self.shutdown(wait=True) + self.shutdown(wait=False) return self._total_task_num = _tasks_count @@ -203,3 +204,69 @@ def _dispatcher() -> None: # a generator so that the first fut will be dispatched before # we start to get from fut_queue. return self._fut_gen(ensure_tasks_pull_interval) + + +# only expose APIs we want to exposed +if TYPE_CHECKING: + + class ThreadPoolExecutorWithRetry: + + def __init__( + self, + max_concurrent: int, + max_workers: Optional[int] = None, + max_total_retry: Optional[int] = None, + thread_name_prefix: str = "", + watchdog_func: Optional[Callable] = None, + watchdog_check_interval: int = 3, # seconds + initializer: Callable[..., Any] | None = None, + initargs: tuple = (), + ) -> None: + """Initialize a ThreadPoolExecutorWithRetry instance. + + Args: + max_concurrent (int): Limit the number pending scheduled tasks. + max_workers (Optional[int], optional): Max number of worker threads in the pool. Defaults to None. + max_total_retry (Optional[int], optional): Max total retry counts before abort. Defaults to None. + thread_name_prefix (str, optional): Defaults to "". + watchdog_func (Optional[Callable]): A custom func to be called on watchdog thread, when + this func raises exception, the watchdog will interrupt the tasks execution. Defaults to None. + watchdog_check_interval (int): Defaults to 3(seconds). + initializer (Callable[..., Any] | None): The same param passed through to ThreadPoolExecutor. + Defaults to None. + initargs (tuple): The same param passed through to ThreadPoolExecutor. + Defaults to (). + """ + raise NotImplementedError + + def ensure_tasks( + self, + func: Callable[[T], RT], + iterable: Iterable[T], + *, + ensure_tasks_pull_interval: int = 1, + ) -> Generator[Future[RT], None, None]: + """Ensure all the items in are processed by in the pool. + + Args: + func (Callable[[T], RT]): The function to take the item from . + iterable (Iterable[T]): The iterable of items to be processed by . + + Raises: + ValueError: If the pool is shutdown or broken, or this method has already + being called once. + TasksEnsureFailed: If failed to ensure all the tasks are finished. + + Yields: + The Future instance of each processed tasks. + """ + raise NotImplementedError + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + raise NotImplementedError + +else: + ThreadPoolExecutorWithRetry = _ThreadPoolExecutorWithRetry