From 7ed72706f7d3f726c8222648d8ddf5588c1611c6 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Fri, 27 Oct 2023 11:34:30 -0700 Subject: [PATCH] [python] Update PyProcess to support trtllm (#1228) --- .../java/ai/djl/python/engine/PyProcess.java | 48 ++++++++++++------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java b/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java index d98d9bc0f..5a9e2f813 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java @@ -49,6 +49,7 @@ class PyProcess { private volatile boolean started; // NOPMD private AtomicInteger restartCount; private CompletableFuture restartFuture; + private boolean trtLlmMode; private static AtomicInteger counter = new AtomicInteger(0); @@ -68,6 +69,8 @@ class PyProcess { connections = Collections.singletonList(new Connection(pyEnv, port, -1)); } restartCount = new AtomicInteger(0); + // TODO: avoid using this hack when TRT-LLM improve its behavior + trtLlmMode = "trtllm".equals(model.getProperty("rolling_batch")); } Output predict(Input inputs, int timeout, boolean initialLoad) { @@ -80,32 +83,41 @@ Output predict(Input inputs, int timeout, boolean initialLoad) { } List> futures = new ArrayList<>(connections.size()); - for (Connection conn : connections) { - futures.add(conn.send(inputs)); + if (initialLoad || !trtLlmMode) { + for (Connection conn : connections) { + futures.add(conn.send(inputs)); + } + } else { + futures.add(connections.get(0).send(inputs)); } Output output = null; - for (CompletableFuture future : futures) { - output = future.get(timeout, TimeUnit.SECONDS); - if (initialLoad) { - int code = output.getCode(); - if (code >= 300) { - if (code == 507) { - throw new EngineException("OOM"); - } - if (pyEnv.isFailOnInitialize()) { - throw new EngineException( - "Failed to initialize model: " + output.getMessage()); - } - logger.warn("Model doesn't support initialize: {}", output.getMessage()); - } else { - logger.info("Model [{}] initialized.", model.getName()); + if (trtLlmMode) { + output = futures.get(0).get(timeout, TimeUnit.SECONDS); + } else { + for (CompletableFuture future : futures) { + output = future.get(timeout, TimeUnit.SECONDS); + } + } + + if (initialLoad && output != null) { + int code = output.getCode(); + if (code >= 300) { + if (code == 507) { + throw new EngineException("OOM"); } + if (pyEnv.isFailOnInitialize()) { + throw new EngineException( + "Failed to initialize model: " + output.getMessage()); + } + logger.warn("Model doesn't support initialize: {}", output.getMessage()); + } else { + logger.info("Model [{}] initialized.", model.getName()); } } return output; - } catch (Exception e) { + } catch (Throwable e) { // use Throwable to workaround spotbug false alarm logger.debug("predict[init={}] exception: {}", initialLoad, e.getClass().getName()); stopPythonProcess(!initialLoad); if (!initialLoad) {