Skip to content

Commit

Permalink
Use startup time in async worker thread instead of worker timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
mreso committed Sep 13, 2024
1 parent 15952d0 commit 86f81d4
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
public class AsyncWorkerThread extends WorkerThread {
// protected ConcurrentHashMap requestsInBackend;
protected static final Logger logger = LoggerFactory.getLogger(AsyncWorkerThread.class);
protected static final long MODEL_LOAD_TIMEOUT = 10L;
protected static final long WORKER_TIMEOUT = 2L;

protected boolean loadingFinished;
protected CountDownLatch latch;
Expand All @@ -53,6 +53,7 @@ public AsyncWorkerThread(
@Override
public void run() {
responseTimeout = model.getResponseTimeout();
startupTimeout = model.getStartupTimeout();
Thread thread = Thread.currentThread();
thread.setName(getWorkerName());
currentThread.set(thread);
Expand Down Expand Up @@ -80,11 +81,11 @@ public void run() {

if (loadingFinished == false) {
latch = new CountDownLatch(1);
if (!latch.await(MODEL_LOAD_TIMEOUT, TimeUnit.MINUTES)) {
if (!latch.await(startupTimeout, TimeUnit.SECONDS)) {
throw new WorkerInitializationException(
"Worker did not load the model within"
+ MODEL_LOAD_TIMEOUT
+ " mins");
"Worker did not load the model within "
+ startupTimeout
+ " seconds");
}
}

Expand All @@ -99,7 +100,7 @@ public void run() {
logger.debug("Shutting down the thread .. Scaling down.");
} else {
logger.debug(
"Backend worker monitoring thread interrupted or backend worker process died., responseTimeout:"
"Backend worker monitoring thread interrupted or backend worker process died. responseTimeout:"
+ responseTimeout
+ "sec",
e);
Expand Down
10 changes: 9 additions & 1 deletion ts/llm_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def get_model_config(args):
"batchSize": 1,
"maxBatchDelay": 100,
"responseTimeout": 1200,
"startupTimeout": args.startup_timeout,
"deviceType": "gpu",
"asyncCommunication": True,
"parallelLevel": torch.cuda.device_count() if torch.cuda.is_available else 1,
Expand Down Expand Up @@ -142,7 +143,7 @@ def main(args):
parser.add_argument(
"--vllm_engine.max_num_seqs",
type=int,
default=16,
default=256,
help="Max sequences in vllm engine",
)

Expand All @@ -160,6 +161,13 @@ def main(args):
help="Cache dir",
)

parser.add_argument(
"--startup_timeout",
type=int,
default=1200,
help="Model startup timeout in seconds",
)

args = parser.parse_args()

main(args)

0 comments on commit 86f81d4

Please sign in to comment.