Skip to content

Commit

Permalink
[fix] prevent requests being sent to python model until model is full… (
Browse files Browse the repository at this point in the history
  • Loading branch information
siddvenk authored Aug 28, 2024
1 parent 9fe4d68 commit 0aab4c2
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.engine.EngineException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
Expand Down Expand Up @@ -63,10 +64,13 @@ public PyPredictor(
@Override
@SuppressWarnings("unchecked")
public List<O> batchPredict(List<I> inputs) throws TranslateException {
if (process.isStopped()) {
if (!process.isReady()) {
// TODO: wait for restart
throw new TranslateException("Backend Python process is stopped.");
}
if (process.isModelUnrecoverable()) {
throw new EngineException("Backend Python process is unrecoverable.");
}
Object first = inputs.get(0);
if (first instanceof Input) {
int size = inputs.size();
Expand Down
14 changes: 12 additions & 2 deletions engines/python/src/main/java/ai/djl/python/engine/PyProcess.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class PyProcess {
private List<Connection> connections;
private CountDownLatch latch;
private volatile boolean started; // NOPMD
private volatile boolean modelLoaded; // NOPMD
private volatile boolean modelUnrecoverable; // NOPMD
private AtomicInteger restartCount;
private CompletableFuture<Void> restartFuture;
private boolean trtLlmMode;
Expand Down Expand Up @@ -123,6 +125,8 @@ Output predict(Input inputs, int timeout, boolean initialLoad) {
if (!initialLoad) {
logger.info("Restart python process ...");
restartFuture = CompletableFuture.runAsync(this::startPythonProcess);
} else {
modelUnrecoverable = true;
}
if (e instanceof EngineException) {
throw (EngineException) e;
Expand All @@ -133,6 +137,7 @@ Output predict(Input inputs, int timeout, boolean initialLoad) {

synchronized void startPythonProcess() {
try {
modelLoaded = false;
int id = restartCount.get();
int port = connections.get(0).getPort();
logger.info("Start process: {} - retry: {}", port, id);
Expand Down Expand Up @@ -168,6 +173,7 @@ synchronized void startPythonProcess() {
Input init = new Input();
init.setProperties(pyEnv.getInitParameters());
predict(init, pyEnv.getModelLoadingTimeout(), true);
modelLoaded = true;
} catch (EngineException e) {
started = false;
throw e;
Expand Down Expand Up @@ -233,8 +239,12 @@ void setStarted(boolean started, int id) {
}
}

boolean isStopped() {
return !started;
boolean isReady() {
return started && modelLoaded;
}

boolean isModelUnrecoverable() {
return modelUnrecoverable;
}

static final class ReaderThread extends Thread {
Expand Down
9 changes: 9 additions & 0 deletions wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,15 @@ public Status getStatus() {
// SIGKILL (9 + 128)
System.exit(137); // NOPMD
}
boolean isHealthCheckOverrideEnabled =
Boolean.parseBoolean(
Utils.getEnvOrSystemProperty("SERVING_HEALTH_CHECK_OVERRIDE"));
if (isHealthCheckOverrideEnabled) {
logger.error(
"SERVING_HEALTH_CHECK_OVERRIDE is enabled. At least 1 model worker"
+ " has exhausted all retries. Not marking model as failed");
return Status.READY;
}

return Status.FAILED;
}
Expand Down

0 comments on commit 0aab4c2

Please sign in to comment.