Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Async Load ( #1285

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.56rc3"
version = "0.9.56dev100"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
54 changes: 37 additions & 17 deletions truss/templates/server/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from functools import cached_property
from multiprocessing import Lock
from pathlib import Path
from threading import Thread
from typing import (
Any,
Callable,
Expand All @@ -39,6 +38,7 @@
from shared import dynamic_config_resolver, serialization
from shared.lazy_data_resolver import LazyDataResolver
from shared.secrets_resolver import SecretsResolver
from tenacity import AsyncRetrying, stop_after_attempt, wait_fixed

if sys.version_info >= (3, 9):
from typing import AsyncGenerator, Generator
Expand Down Expand Up @@ -338,23 +338,19 @@ def ready(self) -> bool:
def _model_file_name(self) -> str:
return self._config["model_class_filename"]

def start_load_thread(self):
# Don't retry failed loads.
if self._status == ModelWrapper.Status.NOT_READY:
thread = Thread(target=self.load)
thread.start()

def load(self):
async def load(self):
if self.ready:
return

# if we are already loading, block on acquiring the lock;
# this worker will return 503 while the worker with the lock is loading
with self._load_lock:
self._status = ModelWrapper.Status.LOADING
self._logger.info("Executing model.load()...")
try:
start_time = time.perf_counter()
self._load_impl()
await self.try_load()

self._status = ModelWrapper.Status.READY
self._logger.info(
f"Completed model.load() execution in {_elapsed_ms(start_time)} ms"
Expand All @@ -363,7 +359,15 @@ def load(self):
self._logger.exception("Exception while loading model")
self._status = ModelWrapper.Status.FAILED

def _load_impl(self):
async def start_load(self):
if self.should_load():
asyncio.create_task(self.load())

def should_load(self) -> bool:
# don't retry failed loads
return not self._status == ModelWrapper.Status.FAILED and not self.ready

def _initialize_model(self):
data_dir = Path("data")
data_dir.mkdir(exist_ok=True)

Expand Down Expand Up @@ -446,17 +450,33 @@ def _load_impl(self):

self._maybe_model_descriptor = ModelDescriptor.from_model(self._model)

async def try_load(self):
await to_thread.run_sync(self._initialize_model)

if self._maybe_model_descriptor.setup_environment:
self._initialize_environment_before_load()

if hasattr(self._model, "load"):
retry(
self._model.load,
NUM_LOAD_RETRIES,
self._logger.warning,
"Failed to load model.",
gap_seconds=1.0,
)
if inspect.iscoroutinefunction(self._model.load):
async for attempt in AsyncRetrying(
stop=stop_after_attempt(NUM_LOAD_RETRIES),
wait=wait_fixed(1),
before_sleep=lambda retry_state: self._logger.info(
f"Model load failed (attempt {retry_state.attempt_number})...retrying"
),
):
with attempt:
(await self._model.load(),)

else:
await to_thread.run_sync(
retry,
self._model.load,
NUM_LOAD_RETRIES,
self._logger.warn,
"Failed to load model.",
1.0,
)

def setup_polling_for_environment_updates(self):
self._poll_for_environment_updates_task = asyncio.create_task(
Expand Down
1 change: 1 addition & 0 deletions truss/templates/server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ psutil==5.9.4
python-json-logger==2.0.2
pyyaml==6.0.0
requests==2.31.0
tenacity==9.0.0
uvicorn==0.24.0
uvloop==0.19.0
aiofiles==24.1.0
5 changes: 3 additions & 2 deletions truss/templates/server/truss_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,15 +275,16 @@ def cleanup(self):
if INFERENCE_SERVER_FAILED_FILE.exists():
INFERENCE_SERVER_FAILED_FILE.unlink()

def on_startup(self):
async def on_startup(self):
"""
This method will be started inside the main process, so here is where
we want to setup our logging and model.
"""
self.cleanup()
if self._setup_json_logger:
setup_logging()
self._model.start_load_thread()

await self._model.start_load()
asyncio.create_task(self._shutdown_if_load_fails())
self._model.setup_polling_for_environment_updates()

Expand Down
17 changes: 11 additions & 6 deletions truss/tests/templates/server/test_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ async def test_model_wrapper_load_error_once(app_path):
config = yaml.safe_load((app_path / "config.yaml").read_text())
os.chdir(app_path)
model_wrapper = model_wrapper_class(config, sdk_trace.NoOpTracer())
model_wrapper.load()
await model_wrapper.load()
# Allow load thread to execute
time.sleep(1)
output = await model_wrapper.predict({}, MagicMock(spec=Request))
assert output == {}
assert model_wrapper._model.load_count == 2


def test_model_wrapper_load_error_more_than_allowed(app_path, helpers):
async def test_model_wrapper_load_error_more_than_allowed(app_path, helpers):
with helpers.env_var("NUM_LOAD_RETRIES_TRUSS", "0"):
if "model_wrapper" in sys.modules:
model_wrapper_module = sys.modules["model_wrapper"]
Expand All @@ -71,7 +71,7 @@ def test_model_wrapper_load_error_more_than_allowed(app_path, helpers):
config = yaml.safe_load((app_path / "config.yaml").read_text())
os.chdir(app_path)
model_wrapper = model_wrapper_class(config, sdk_trace.NoOpTracer())
model_wrapper.load()
await model_wrapper.load()
# Allow load thread to execute
time.sleep(1)
assert model_wrapper.load_failed
Expand Down Expand Up @@ -109,7 +109,8 @@ async def test_trt_llm_truss_init_extension(trt_llm_truss_container_fs, helpers)
model_wrapper_module, "_init_extension", return_value=mock_extension
) as mock_init_extension:
model_wrapper = model_wrapper_class(config, sdk_trace.NoOpTracer())
model_wrapper.load()
await model_wrapper.load()

called_with_specific_extension = any(
call_args[0][0] == "trt_llm"
for call_args in mock_init_extension.call_args_list
Expand Down Expand Up @@ -146,8 +147,10 @@ async def mock_predict(return_value, request):
model_wrapper_module, "_init_extension", return_value=mock_extension
):
model_wrapper = model_wrapper_class(config, sdk_trace.NoOpTracer())
model_wrapper.load()
await model_wrapper.load()

resp = await model_wrapper.predict({}, MagicMock(spec=Request))

mock_extension.load.assert_called()
mock_extension.model_args.assert_called()
assert mock_predict_called
Expand Down Expand Up @@ -183,8 +186,10 @@ async def mock_predict(return_value, request: Request):
model_wrapper_module, "_init_extension", return_value=mock_extension
):
model_wrapper = model_wrapper_class(config, sdk_trace.NoOpTracer())
model_wrapper.load()
await model_wrapper.load()

resp = await model_wrapper.predict({}, MagicMock(spec=Request))

mock_extension.load.assert_called()
mock_extension.model_override.assert_called()
assert mock_predict_called
Expand Down
66 changes: 66 additions & 0 deletions truss/tests/test_model_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,72 @@ def test_truss_with_error_stacktrace(test_data_path):
)


@pytest.mark.integration
def test_async_load_truss():
model = """
import asyncio

class Model:
async def load(self):
await asyncio.sleep(5)

def predict(self, request):
return {"a": "b"}
"""

config = "model_name: async-load-truss"

with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
truss_dir = Path(tmp_work_dir, "truss")

create_truss(truss_dir, config, textwrap.dedent(model))

tr = TrussHandle(truss_dir)
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=False)

truss_server_addr = "http://localhost:8090"

def _test_liveness_probe(expected_code):
live = requests.get(f"{truss_server_addr}/", timeout=1)
assert live.status_code == expected_code

def _test_readiness_probe(expected_code):
ready = requests.get(f"{truss_server_addr}/v1/models/model", timeout=1)
assert ready.status_code == expected_code

def _test_ping(expected_code):
ping = requests.get(f"{truss_server_addr}/ping", timeout=1)
assert ping.status_code == expected_code

def _test_predict(expected_code):
invocations = requests.post(
f"{truss_server_addr}/v1/models/model:predict", json={}, timeout=1
)
assert invocations.status_code == expected_code

SERVER_WARMUP_TIME = 3
LOAD_TEST_TIME = 2
LOAD_BUFFER_TIME = 5

# Sleep a few seconds to get the server some time to wake up
time.sleep(SERVER_WARMUP_TIME)

# The truss takes about 5 seconds to load.
# We want to make sure that it's not ready for that time.
for _ in range(LOAD_TEST_TIME):
_test_liveness_probe(200)
_test_readiness_probe(503)
_test_ping(503)
_test_predict(503)
time.sleep(1)

time.sleep(LOAD_BUFFER_TIME)
_test_liveness_probe(200)
_test_readiness_probe(200)
_test_ping(200)
_test_predict(200)


@pytest.mark.integration
def test_slow_truss(test_data_path):
with ensure_kill_all():
Expand Down