Skip to content

Commit

Permalink
[python] Update PyProcess to support trtllm (#1228)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Oct 27, 2023
1 parent 61d413d commit 7ed7270
Showing 1 changed file with 30 additions and 18 deletions.
48 changes: 30 additions & 18 deletions engines/python/src/main/java/ai/djl/python/engine/PyProcess.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class PyProcess {
private volatile boolean started; // NOPMD
private AtomicInteger restartCount;
private CompletableFuture<Void> restartFuture;
private boolean trtLlmMode;

private static AtomicInteger counter = new AtomicInteger(0);

Expand All @@ -68,6 +69,8 @@ class PyProcess {
connections = Collections.singletonList(new Connection(pyEnv, port, -1));
}
restartCount = new AtomicInteger(0);
// TODO: avoid using this hack when TRT-LLM improve its behavior
trtLlmMode = "trtllm".equals(model.getProperty("rolling_batch"));
}

Output predict(Input inputs, int timeout, boolean initialLoad) {
Expand All @@ -80,32 +83,41 @@ Output predict(Input inputs, int timeout, boolean initialLoad) {
}

List<CompletableFuture<Output>> futures = new ArrayList<>(connections.size());
for (Connection conn : connections) {
futures.add(conn.send(inputs));
if (initialLoad || !trtLlmMode) {
for (Connection conn : connections) {
futures.add(conn.send(inputs));
}
} else {
futures.add(connections.get(0).send(inputs));
}

Output output = null;
for (CompletableFuture<Output> future : futures) {
output = future.get(timeout, TimeUnit.SECONDS);
if (initialLoad) {
int code = output.getCode();
if (code >= 300) {
if (code == 507) {
throw new EngineException("OOM");
}
if (pyEnv.isFailOnInitialize()) {
throw new EngineException(
"Failed to initialize model: " + output.getMessage());
}
logger.warn("Model doesn't support initialize: {}", output.getMessage());
} else {
logger.info("Model [{}] initialized.", model.getName());
if (trtLlmMode) {
output = futures.get(0).get(timeout, TimeUnit.SECONDS);
} else {
for (CompletableFuture<Output> future : futures) {
output = future.get(timeout, TimeUnit.SECONDS);
}
}

if (initialLoad && output != null) {
int code = output.getCode();
if (code >= 300) {
if (code == 507) {
throw new EngineException("OOM");
}
if (pyEnv.isFailOnInitialize()) {
throw new EngineException(
"Failed to initialize model: " + output.getMessage());
}
logger.warn("Model doesn't support initialize: {}", output.getMessage());
} else {
logger.info("Model [{}] initialized.", model.getName());
}
}

return output;
} catch (Exception e) {
} catch (Throwable e) { // use Throwable to workaround spotbug false alarm
logger.debug("predict[init={}] exception: {}", initialLoad, e.getClass().getName());
stopPythonProcess(!initialLoad);
if (!initialLoad) {
Expand Down

0 comments on commit 7ed7270

Please sign in to comment.