diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/AsyncWorkerThread.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/AsyncWorkerThread.java index 2dd1272b01..526f3b9232 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/AsyncWorkerThread.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/AsyncWorkerThread.java @@ -33,7 +33,7 @@ public class AsyncWorkerThread extends WorkerThread { // protected ConcurrentHashMap requestsInBackend; protected static final Logger logger = LoggerFactory.getLogger(AsyncWorkerThread.class); - protected static final long MODEL_LOAD_TIMEOUT = 10L; + protected static final long WORKER_TIMEOUT = 2L; protected boolean loadingFinished; protected CountDownLatch latch; @@ -53,6 +53,7 @@ public AsyncWorkerThread( @Override public void run() { responseTimeout = model.getResponseTimeout(); + startupTimeout = model.getStartupTimeout(); Thread thread = Thread.currentThread(); thread.setName(getWorkerName()); currentThread.set(thread); @@ -80,11 +81,11 @@ public void run() { if (loadingFinished == false) { latch = new CountDownLatch(1); - if (!latch.await(MODEL_LOAD_TIMEOUT, TimeUnit.MINUTES)) { + if (!latch.await(startupTimeout, TimeUnit.SECONDS)) { throw new WorkerInitializationException( - "Worker did not load the model within" - + MODEL_LOAD_TIMEOUT - + " mins"); + "Worker did not load the model within " + + startupTimeout + + " seconds"); } } @@ -99,7 +100,7 @@ public void run() { logger.debug("Shutting down the thread .. Scaling down."); } else { logger.debug( - "Backend worker monitoring thread interrupted or backend worker process died., responseTimeout:" + "Backend worker monitoring thread interrupted or backend worker process died. responseTimeout:" + responseTimeout + "sec", e); diff --git a/ts/llm_launcher.py b/ts/llm_launcher.py index 89248ce9f4..466497a5e6 100644 --- a/ts/llm_launcher.py +++ b/ts/llm_launcher.py @@ -24,6 +24,7 @@ def get_model_config(args): "batchSize": 1, "maxBatchDelay": 100, "responseTimeout": 1200, + "startupTimeout": args.startup_timeout, "deviceType": "gpu", "asyncCommunication": True, "parallelLevel": torch.cuda.device_count() if torch.cuda.is_available else 1, @@ -142,7 +143,7 @@ def main(args): parser.add_argument( "--vllm_engine.max_num_seqs", type=int, - default=16, + default=256, help="Max sequences in vllm engine", ) @@ -160,6 +161,13 @@ def main(args): help="Cache dir", ) + parser.add_argument( + "--startup_timeout", + type=int, + default=1200, + help="Model startup timeout in seconds", + ) + args = parser.parse_args() main(args)