Skip to content

Commit

Permalink
Adding logs capture in worker pool
Browse files Browse the repository at this point in the history
  • Loading branch information
Hartorn committed Oct 16, 2023
1 parent 6dfeae9 commit a3b8bde
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 16 deletions.
51 changes: 35 additions & 16 deletions giskard/utils/worker_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
import logging
import os
import time
import traceback
from concurrent.futures import CancelledError, Executor, Future
from contextlib import redirect_stderr, redirect_stdout
from dataclasses import dataclass, field
from enum import Enum
from io import StringIO
from multiprocessing import Process, Queue, SimpleQueue, cpu_count, get_context
from multiprocessing.context import SpawnContext, SpawnProcess
from multiprocessing.managers import SyncManager
Expand Down Expand Up @@ -46,25 +49,32 @@ def _stop_processes(p_list: List[Process], timeout: float = 1) -> List[Optional[
return exit_codes


@dataclass
class GiskardFuture(Future):
def __init__(self) -> None:
super().__init__()
self.logs = ""


@dataclass(frozen=True)
class TimeoutData:
id: str
end_time: float


@dataclass
@dataclass(frozen=True)
class GiskardTask:
callable: Callable
args: Any
kwargs: Any
id: str = field(default_factory=_generate_task_id)


@dataclass
@dataclass(frozen=True)
class GiskardResult:
id: str
result: Any = None
exception: Any = None
logs: str = None


def _process_worker(tasks_queue: SimpleQueue, tasks_results: SimpleQueue, running_process: Dict[str, str]):
Expand All @@ -74,18 +84,26 @@ def _process_worker(tasks_queue: SimpleQueue, tasks_results: SimpleQueue, runnin
while True:
# Blocking accessor, will wait for a task
task: GiskardTask = tasks_queue.get()
# This is how we cleanly stop the workers
if task is None:
return
try:
LOGGER.debug("Doing task %", task.id)
running_process[task.id] = pid
result = task.callable(*task.args, **task.kwargs)
to_return = GiskardResult(id=task.id, result=result)
except BaseException as e:
to_return = GiskardResult(id=task.id, exception=str(e))
finally:
running_process.pop(task.id)
tasks_results.put(to_return)
# Capture any log (stdout, stderr + root logger)
with redirect_stdout(StringIO()) as f:
with redirect_stderr(f):
handler = logging.StreamHandler(f)
logging.getLogger().addHandler(handler)
try:
LOGGER.debug("Doing task %", task.id)
running_process[task.id] = pid
result = task.callable(*task.args, **task.kwargs)
to_return = GiskardResult(id=task.id, result=result, logs=f.getvalue())
except BaseException as e:
exception = "\n".join(traceback.format_exception(type(e), e, e.__traceback__))
to_return = GiskardResult(id=task.id, exception=exception, logs=f.getvalue() + "\n" + exception)
finally:
running_process.pop(task.id)
tasks_results.put(to_return)
logging.getLogger().removeHandler(handler)


# Note: See _on_queue_feeder_error
Expand Down Expand Up @@ -160,15 +178,15 @@ def schedule(
args=None,
kwargs=None,
timeout: Optional[float] = None,
):
) -> GiskardFuture:
if args is None:
args = []
if kwargs is None:
kwargs = {}
if self._state in FINAL_STATES:
raise RuntimeError(f"Cannot submit when pool is {self._state.name}")
task = GiskardTask(callable=fn, args=args, kwargs=kwargs)
res = Future()
res = GiskardFuture()
self._futures_mapping[task.id] = res
self._pending_tasks_queue.put(task)
if timeout is not None:
Expand Down Expand Up @@ -236,7 +254,8 @@ def _results_thread(
except BaseException:
pass
continue
if result.result is not None:
future.logs = result.logs
if result.exception is None:
future.set_result(result.result)
else:
# TODO(Bazire): improve to get Traceback
Expand Down
44 changes: 44 additions & 0 deletions tests/test_worker_pool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging
import sys
from multiprocessing.context import SpawnProcess
from time import sleep

Expand Down Expand Up @@ -41,6 +43,48 @@ def sleep_add_one(timer, value):
return value + 1


def print_stuff():
print("stuff stdout")
print("other stuff", file=sys.stderr)
logging.getLogger().info("info log")
logging.getLogger("truc").error("toto")
logging.getLogger(__name__).warning("Warning")
return


def bugged_code():
print("Before raising")
return 1 / 0


@pytest.mark.concurrency
def test_handle_log(one_worker_pool: WorkerPoolExecutor):
future = one_worker_pool.submit(print_stuff)
assert future.result(timeout=5) is None
print(future.logs)
assert "stuff stdout" in future.logs
assert "other stuff" in future.logs
assert "info log" in future.logs
assert "toto" in future.logs
assert "Warning" in future.logs


@pytest.mark.concurrency
def test_handle_exception_log(one_worker_pool: WorkerPoolExecutor):
future = one_worker_pool.submit(bugged_code)
exception = future.exception(timeout=5)
assert exception is not None
print(exception)
assert "ZeroDivisionError: division by zero" in str(exception)
assert "in bugged_code" in str(exception)
assert "return 1 / 0" in str(exception)
print(future.logs)
assert "Before raising" in future.logs
assert "ZeroDivisionError: division by zero" in future.logs
assert "in bugged_code" in future.logs
assert "return 1 / 0" in future.logs


@pytest.mark.concurrency
def test_submit_one_task(one_worker_pool: WorkerPoolExecutor):
future = one_worker_pool.submit(add_one, 1)
Expand Down

0 comments on commit a3b8bde

Please sign in to comment.