Skip to content

Commit

Permalink
fix: add single use tests and fixes (#156)
Browse files Browse the repository at this point in the history
* fix: add single use tests and fixes

* other_test
  • Loading branch information
chamini2 authored Sep 23, 2024
1 parent a431995 commit daa8adb
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 27 deletions.
82 changes: 63 additions & 19 deletions src/isolate/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,8 @@ def abort_with_msg(
return None

def cancel_tasks(self):
for task in self.background_tasks.values():
tasks_copy = self.background_tasks.copy()
for task in tasks_copy.values():
task.cancel()


Expand Down Expand Up @@ -534,6 +535,16 @@ class SingleTaskInterceptor(ServerBoundInterceptor):
"""Sets server to terminate after the first Submit/Run task."""

_done: bool = False
_task_id: str | None = None

def __init__(self):
def terminate(request: Any, context: grpc.ServicerContext) -> Any:
context.abort(
grpc.StatusCode.RESOURCE_EXHAUSTED,
"Server has already served one Run/Submit task.",
)

self._terminator = grpc.unary_unary_rpc_method_handler(terminate)

def intercept_service(self, continuation, handler_call_details):
handler = continuation(handler_call_details)
Expand All @@ -542,29 +553,62 @@ def intercept_service(self, continuation, handler_call_details):
is_run = handler_call_details.method == "/Isolate/Run"
is_new_task = is_submit or is_run

if is_new_task and self._done:
raise grpc.RpcError(
grpc.StatusCode.UNAVAILABLE,
"Server has already served one Run/Submit task.",
)
elif is_new_task:
self._done = True
else:
if not is_new_task:
# Let other requests like List/Cancel/etc pass through
return continuation(handler_call_details)
return handler

if self._done:
# Fail the request if the server has already served or is serving
# a Run/Submit task.
return self._terminator

self._done = True

def wrapper(method_impl):
@functools.wraps(method_impl)
def _wrapper(request, context):
def _stop():
if is_submit:
# Wait for the task to finish
while self.server.servicer.background_tasks:
def _wrapper(request: Any, context: grpc.ServicerContext) -> Any:
def termination() -> None:
if is_run:
print("Stopping server since run is finished")
# Stop the server after the Run task is finished
self.server.stop(grace=0.1)

elif is_submit:
# Wait until the task_id is assigned
while self._task_id is None:
time.sleep(0.1)
self.server.stop(grace=0.1)

context.add_callback(_stop)
return method_impl(request, context)
# Get the task from the background tasks
task = self.servicer.background_tasks.get(self._task_id)

if task is not None:
# Wait until the task future is assigned
tries = 0
while task.future is None:
time.sleep(0.1)
tries += 1
if tries > 100:
raise RuntimeError(
"Task future was not assigned in time."
)

def _stop(*args):
# Small sleep to make sure the cancellation is processed
time.sleep(0.1)
print("Stopping server since the task is finished")
self.server.stop(grace=0.1)

# Add a callback which will stop the server
# after the task is finished
task.future.add_done_callback(_stop)

context.add_callback(termination)
res = method_impl(request, context)

if is_submit:
self._task_id = cast(definitions.SubmitResponse, res).task_id

return res

return _wrapper

Expand Down Expand Up @@ -598,7 +642,7 @@ def main(argv: list[str] | None = None) -> None:
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=options.num_workers),
options=get_default_options(),
interceptors=interceptors,
interceptors=interceptors, # type: ignore
)

for interceptor in interceptors:
Expand Down
115 changes: 107 additions & 8 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, List, Optional, cast
from typing import Any, Iterator, List, Optional, cast

import grpc
import pytest
Expand All @@ -15,7 +15,12 @@
from isolate.server import definitions, health
from isolate.server.health_server import HealthServicer
from isolate.server.interface import from_grpc, to_serialized_object
from isolate.server.server import BridgeManager, IsolateServicer
from isolate.server.server import (
BridgeManager,
IsolateServicer,
ServerBoundInterceptor,
SingleTaskInterceptor,
)

REPO_DIR = Path(__file__).parent.parent
assert (
Expand All @@ -34,14 +39,32 @@ class Stubs:
health_stub: health.HealthStub


@pytest.fixture
def interceptors():
return []


@contextmanager
def make_server(tmp_path):
def make_server(
tmp_path: Path, interceptors: Optional[List[ServerBoundInterceptor]] = None
) -> Iterator[Stubs]:
interceptors = interceptors or []
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=1), options=get_default_options()
futures.ThreadPoolExecutor(max_workers=1),
options=get_default_options(),
interceptors=interceptors, # type: ignore
)

for interceptor in interceptors:
interceptor.register_server(server)

test_settings = IsolateSettings(cache_dir=tmp_path / "cache")
with BridgeManager() as bridge:
servicer = IsolateServicer(bridge, test_settings)

for interceptor in interceptors:
interceptor.register_servicer(servicer)

definitions.register_isolate(servicer, server)
health.register_health(HealthServicer(), server)
host, port = "localhost", server.add_insecure_port("[::]:0")
Expand Down Expand Up @@ -69,14 +92,14 @@ def make_server(tmp_path):


@pytest.fixture
def stub(tmp_path):
with make_server(tmp_path) as stubs:
def stub(tmp_path, interceptors):
with make_server(tmp_path, interceptors) as stubs:
yield stubs.isolate_stub


@pytest.fixture
def health_stub(tmp_path):
with make_server(tmp_path) as stubs:
def health_stub(tmp_path, interceptors):
with make_server(tmp_path, interceptors) as stubs:
yield stubs.health_stub


Expand Down Expand Up @@ -719,3 +742,79 @@ def test_server_submit_server(
stub.Cancel(definitions.CancelRequest(task_id=task_id))

assert not list(stub.List(definitions.ListRequest()).tasks)


@pytest.mark.parametrize(
"interceptors",
[
[SingleTaskInterceptor()],
],
)
def test_server_single_use_submit(
stub: definitions.IsolateStub,
monkeypatch: Any,
) -> None:
import time

inherit_from_local(monkeypatch)

request = definitions.SubmitRequest(function=prepare_request(myserver))
task_id = stub.Submit(request).task_id

tasks = [task.task_id for task in stub.List(definitions.ListRequest()).tasks]
assert task_id in tasks

# Now try to Submit again
with pytest.raises(grpc.RpcError) as exc_info:
stub.Submit(request)
assert exc_info.value.code() == grpc.StatusCode.RESOURCE_EXHAUSTED

# And try to Run a task
with pytest.raises(grpc.RpcError) as exc_info:
run_request(stub, prepare_request(myserver))
assert exc_info.value.code() == grpc.StatusCode.RESOURCE_EXHAUSTED

stub.Cancel(definitions.CancelRequest(task_id=task_id))
time.sleep(1)

with pytest.raises(grpc.RpcError) as exc_info:
stub.List(definitions.ListRequest())

# Server should be shutting down
assert exc_info.value.code() == grpc.StatusCode.UNAVAILABLE


@pytest.mark.parametrize(
"interceptors",
[
[SingleTaskInterceptor()],
],
)
def test_server_single_use_run(
stub: definitions.IsolateStub,
monkeypatch: Any,
) -> None:
import time

inherit_from_local(monkeypatch)

run_function(stub, check_machine)
time.sleep(1)

# Now try to Submit again
with pytest.raises(grpc.RpcError) as exc_info:
submit_request = definitions.SubmitRequest(function=prepare_request(myserver))
stub.Submit(submit_request)

assert exc_info.value.code() == grpc.StatusCode.UNAVAILABLE

# And try to Run a task
with pytest.raises(grpc.RpcError) as exc_info:
run_function(stub, check_machine)

assert exc_info.value.code() == grpc.StatusCode.UNAVAILABLE

with pytest.raises(grpc.RpcError) as exc_info:
stub.List(definitions.ListRequest())

assert exc_info.value.code() == grpc.StatusCode.UNAVAILABLE

0 comments on commit daa8adb

Please sign in to comment.