diff --git a/src/isolate/server/server.py b/src/isolate/server/server.py index 054567f..0b649ee 100644 --- a/src/isolate/server/server.py +++ b/src/isolate/server/server.py @@ -464,7 +464,8 @@ def abort_with_msg( return None def cancel_tasks(self): - for task in self.background_tasks.values(): + tasks_copy = self.background_tasks.copy() + for task in tasks_copy.values(): task.cancel() @@ -534,6 +535,16 @@ class SingleTaskInterceptor(ServerBoundInterceptor): """Sets server to terminate after the first Submit/Run task.""" _done: bool = False + _task_id: str | None = None + + def __init__(self): + def terminate(request: Any, context: grpc.ServicerContext) -> Any: + context.abort( + grpc.StatusCode.RESOURCE_EXHAUSTED, + "Server has already served one Run/Submit task.", + ) + + self._terminator = grpc.unary_unary_rpc_method_handler(terminate) def intercept_service(self, continuation, handler_call_details): handler = continuation(handler_call_details) @@ -542,29 +553,62 @@ def intercept_service(self, continuation, handler_call_details): is_run = handler_call_details.method == "/Isolate/Run" is_new_task = is_submit or is_run - if is_new_task and self._done: - raise grpc.RpcError( - grpc.StatusCode.UNAVAILABLE, - "Server has already served one Run/Submit task.", - ) - elif is_new_task: - self._done = True - else: + if not is_new_task: # Let other requests like List/Cancel/etc pass through - return continuation(handler_call_details) + return handler + + if self._done: + # Fail the request if the server has already served or is serving + # a Run/Submit task. + return self._terminator + + self._done = True def wrapper(method_impl): @functools.wraps(method_impl) - def _wrapper(request, context): - def _stop(): - if is_submit: - # Wait for the task to finish - while self.server.servicer.background_tasks: + def _wrapper(request: Any, context: grpc.ServicerContext) -> Any: + def termination() -> None: + if is_run: + print("Stopping server since run is finished") + # Stop the server after the Run task is finished + self.server.stop(grace=0.1) + + elif is_submit: + # Wait until the task_id is assigned + while self._task_id is None: time.sleep(0.1) - self.server.stop(grace=0.1) - context.add_callback(_stop) - return method_impl(request, context) + # Get the task from the background tasks + task = self.servicer.background_tasks.get(self._task_id) + + if task is not None: + # Wait until the task future is assigned + tries = 0 + while task.future is None: + time.sleep(0.1) + tries += 1 + if tries > 100: + raise RuntimeError( + "Task future was not assigned in time." + ) + + def _stop(*args): + # Small sleep to make sure the cancellation is processed + time.sleep(0.1) + print("Stopping server since the task is finished") + self.server.stop(grace=0.1) + + # Add a callback which will stop the server + # after the task is finished + task.future.add_done_callback(_stop) + + context.add_callback(termination) + res = method_impl(request, context) + + if is_submit: + self._task_id = cast(definitions.SubmitResponse, res).task_id + + return res return _wrapper @@ -598,7 +642,7 @@ def main(argv: list[str] | None = None) -> None: server = grpc.server( futures.ThreadPoolExecutor(max_workers=options.num_workers), options=get_default_options(), - interceptors=interceptors, + interceptors=interceptors, # type: ignore ) for interceptor in interceptors: diff --git a/tests/test_server.py b/tests/test_server.py index 778fb75..40fadda 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import Any, List, Optional, cast +from typing import Any, Iterator, List, Optional, cast import grpc import pytest @@ -15,7 +15,12 @@ from isolate.server import definitions, health from isolate.server.health_server import HealthServicer from isolate.server.interface import from_grpc, to_serialized_object -from isolate.server.server import BridgeManager, IsolateServicer +from isolate.server.server import ( + BridgeManager, + IsolateServicer, + ServerBoundInterceptor, + SingleTaskInterceptor, +) REPO_DIR = Path(__file__).parent.parent assert ( @@ -34,14 +39,31 @@ class Stubs: health_stub: health.HealthStub +@pytest.fixture +def interceptors(): + return [] + + @contextmanager -def make_server(tmp_path): +def make_server( + tmp_path: Path, interceptors: List[ServerBoundInterceptor] +) -> Iterator[Stubs]: server = grpc.server( - futures.ThreadPoolExecutor(max_workers=1), options=get_default_options() + futures.ThreadPoolExecutor(max_workers=1), + options=get_default_options(), + interceptors=interceptors, # type: ignore ) + + for interceptor in interceptors: + interceptor.register_server(server) + test_settings = IsolateSettings(cache_dir=tmp_path / "cache") with BridgeManager() as bridge: servicer = IsolateServicer(bridge, test_settings) + + for interceptor in interceptors: + interceptor.register_servicer(servicer) + definitions.register_isolate(servicer, server) health.register_health(HealthServicer(), server) host, port = "localhost", server.add_insecure_port("[::]:0") @@ -69,14 +91,14 @@ def make_server(tmp_path): @pytest.fixture -def stub(tmp_path): - with make_server(tmp_path) as stubs: +def stub(tmp_path, interceptors): + with make_server(tmp_path, interceptors) as stubs: yield stubs.isolate_stub @pytest.fixture -def health_stub(tmp_path): - with make_server(tmp_path) as stubs: +def health_stub(tmp_path, interceptors): + with make_server(tmp_path, interceptors) as stubs: yield stubs.health_stub @@ -719,3 +741,39 @@ def test_server_submit_server( stub.Cancel(definitions.CancelRequest(task_id=task_id)) assert not list(stub.List(definitions.ListRequest()).tasks) + + +@pytest.mark.parametrize( + "interceptors", + [ + [SingleTaskInterceptor()], + ], +) +def test_server_single_use_submit( + stub: definitions.IsolateStub, + monkeypatch: Any, +) -> None: + inherit_from_local(monkeypatch) + + request = definitions.SubmitRequest(function=prepare_request(myserver)) + task_id = stub.Submit(request).task_id + + tasks = [task.task_id for task in stub.List(definitions.ListRequest()).tasks] + assert task_id in tasks + + # Now try to Submit again + with pytest.raises(grpc.RpcError) as exc_info: + stub.Submit(request) + assert exc_info.value.code() == grpc.StatusCode.RESOURCE_EXHAUSTED + + # And try to Run a task + with pytest.raises(grpc.RpcError) as exc_info: + run_request(stub, prepare_request(myserver)) + assert exc_info.value.code() == grpc.StatusCode.RESOURCE_EXHAUSTED + + stub.Cancel(definitions.CancelRequest(task_id=task_id)) + + with pytest.raises(grpc.RpcError) as exc_info: + stub.List(definitions.ListRequest()) + # Server should be shutting down + assert exc_info.value.code() == grpc.StatusCode.UNAVAILABLE