diff --git a/giskard/utils/worker_pool.py b/giskard/utils/worker_pool.py index bbe8f82ebb..514794e120 100644 --- a/giskard/utils/worker_pool.py +++ b/giskard/utils/worker_pool.py @@ -2,10 +2,14 @@ import logging import os +import sys import time +import traceback from concurrent.futures import CancelledError, Executor, Future +from contextlib import redirect_stderr, redirect_stdout from dataclasses import dataclass, field from enum import Enum +from io import StringIO from multiprocessing import Process, Queue, SimpleQueue, cpu_count, get_context from multiprocessing.context import SpawnContext, SpawnProcess from multiprocessing.managers import SyncManager @@ -46,13 +50,19 @@ def _stop_processes(p_list: List[Process], timeout: float = 1) -> List[Optional[ return exit_codes -@dataclass +class GiskardFuture(Future): + def __init__(self) -> None: + super().__init__() + self.logs = "" + + +@dataclass(frozen=True) class TimeoutData: id: str end_time: float -@dataclass +@dataclass(frozen=True) class GiskardTask: callable: Callable args: Any @@ -60,11 +70,12 @@ class GiskardTask: id: str = field(default_factory=_generate_task_id) -@dataclass +@dataclass(frozen=True) class GiskardResult: id: str result: Any = None exception: Any = None + logs: str = None def _process_worker(tasks_queue: SimpleQueue, tasks_results: SimpleQueue, running_process: Dict[str, str]): @@ -74,18 +85,26 @@ def _process_worker(tasks_queue: SimpleQueue, tasks_results: SimpleQueue, runnin while True: # Blocking accessor, will wait for a task task: GiskardTask = tasks_queue.get() + # This is how we cleanly stop the workers if task is None: return - try: - LOGGER.debug("Doing task %", task.id) - running_process[task.id] = pid - result = task.callable(*task.args, **task.kwargs) - to_return = GiskardResult(id=task.id, result=result) - except BaseException as e: - to_return = GiskardResult(id=task.id, exception=str(e)) - finally: - running_process.pop(task.id) - tasks_results.put(to_return) + # Capture any log (stdout, stderr + root logger) + with redirect_stdout(StringIO()) as f: + with redirect_stderr(f): + handler = logging.StreamHandler(f) + logging.getLogger().addHandler(handler) + try: + LOGGER.debug("Doing task %", task.id) + running_process[task.id] = pid + result = task.callable(*task.args, **task.kwargs) + to_return = GiskardResult(id=task.id, result=result, logs=f.getvalue()) + except BaseException as e: + exception = "\n".join(traceback.format_exception(type(e), e, e.__traceback__)) + to_return = GiskardResult(id=task.id, exception=exception, logs=f.getvalue() + "\n" + exception) + finally: + running_process.pop(task.id) + tasks_results.put(to_return) + logging.getLogger().removeHandler(handler) # Note: See _on_queue_feeder_error @@ -160,7 +179,7 @@ def schedule( args=None, kwargs=None, timeout: Optional[float] = None, - ): + ) -> GiskardFuture: if args is None: args = [] if kwargs is None: @@ -168,7 +187,7 @@ def schedule( if self._state in FINAL_STATES: raise RuntimeError(f"Cannot submit when pool is {self._state.name}") task = GiskardTask(callable=fn, args=args, kwargs=kwargs) - res = Future() + res = GiskardFuture() self._futures_mapping[task.id] = res self._pending_tasks_queue.put(task) if timeout is not None: @@ -236,7 +255,8 @@ def _results_thread( except BaseException: pass continue - if result.result is not None: + future.logs = result.logs + if result.exception is None: future.set_result(result.result) else: # TODO(Bazire): improve to get Traceback diff --git a/tests/test_worker_pool.py b/tests/test_worker_pool.py index 0056d7adac..a8892c97e4 100644 --- a/tests/test_worker_pool.py +++ b/tests/test_worker_pool.py @@ -1,3 +1,5 @@ +import logging +import sys from multiprocessing.context import SpawnProcess from time import sleep @@ -41,6 +43,48 @@ def sleep_add_one(timer, value): return value + 1 +def print_stuff(): + print("stuff stdout") + print("other stuff", file=sys.stderr) + logging.getLogger().info("info log") + logging.getLogger("truc").error("toto") + logging.getLogger(__name__).warning("Warning") + return + + +def bugged_code(): + print("Before raising") + return 1 / 0 + + +@pytest.mark.concurrency +def test_handle_log(one_worker_pool: WorkerPoolExecutor): + future = one_worker_pool.submit(print_stuff) + assert future.result(timeout=5) is None + print(future.logs) + assert "stuff stdout" in future.logs + assert "other stuff" in future.logs + assert "info log" in future.logs + assert "toto" in future.logs + assert "Warning" in future.logs + + +@pytest.mark.concurrency +def test_handle_exception_log(one_worker_pool: WorkerPoolExecutor): + future = one_worker_pool.submit(bugged_code) + exception = future.exception(timeout=5) + assert exception is not None + print(exception) + assert "ZeroDivisionError: division by zero" in str(exception) + assert "in bugged_code" in str(exception) + assert "return 1 / 0" in str(exception) + print(future.logs) + assert "Before raising" in future.logs + assert "ZeroDivisionError: division by zero" in future.logs + assert "in bugged_code" in future.logs + assert "return 1 / 0" in future.logs + + @pytest.mark.concurrency def test_submit_one_task(one_worker_pool: WorkerPoolExecutor): future = one_worker_pool.submit(add_one, 1)