Skip to content

Commit

Permalink
[TrtLLM] Support JIT compilation and dynamic batch for TrtLLM python …
Browse files Browse the repository at this point in the history
…backend
  • Loading branch information
sindhuvahinis committed Mar 28, 2024
1 parent e3fb4f2 commit b3d7191
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ public void load(Path modelPath, String prefix, Map<String, ?> options) throws I
"Python engine does not support dynamic blocks");
}
String entryPoint = null;
boolean isTrtLlmBackend = false;
if (options != null) {
logger.debug("options in serving.properties for model: {}", modelName);
for (Map.Entry<String, ?> entry : options.entrySet()) {
Expand Down Expand Up @@ -121,9 +120,6 @@ public void load(Path modelPath, String prefix, Map<String, ?> options) throws I
case "entryPoint":
entryPoint = value;
break;
case "rolling_batch":
isTrtLlmBackend = "trtllm".equals(value);
break;
case "parallel_loading":
parallelLoading = Boolean.parseBoolean(value);
break;
Expand Down Expand Up @@ -158,6 +154,7 @@ public void load(Path modelPath, String prefix, Map<String, ?> options) throws I
entryPoint = Utils.getenv("DJL_ENTRY_POINT");
if (entryPoint == null) {
Path modelFile = findModelFile(prefix);
String features = Utils.getEnvOrSystemProperty("SERVING_FEATURES");
// find default entryPoint
String engineName = manager.getEngine().getEngineName();
if (modelFile != null) {
Expand All @@ -167,7 +164,7 @@ public void load(Path modelPath, String prefix, Map<String, ?> options) throws I
} else if ("nc".equals(manager.getDevice().getDeviceType())
&& pyEnv.getTensorParallelDegree() > 0) {
entryPoint = "djl_python.transformers_neuronx";
} else if (isTrtLlmBackend) {
} else if ("trtllm".equals(features)) {
entryPoint = "djl_python.tensorrt_llm";
} else if (pyEnv.getInitParameters().containsKey("model_id")) {
entryPoint = "djl_python.huggingface";
Expand Down
58 changes: 51 additions & 7 deletions wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,34 +56,44 @@ public final class LmiConfigRecommender {

private LmiConfigRecommender() {}

static void configure(Properties lmiProperties, LmiUtils.HuggingFaceModelConfig modelConfig) {
setRollingBatch(lmiProperties, modelConfig);
setEngine(lmiProperties);
static void configure(
ModelInfo<?, ?> modelInfo,
Properties lmiProperties,
LmiUtils.HuggingFaceModelConfig modelConfig) {
String features = Utils.getEnvOrSystemProperty("SERVING_FEATURES");
setDynamicBatch(lmiProperties, modelConfig, modelInfo, features);
setRollingBatch(lmiProperties, modelConfig, features);
setEngine(lmiProperties, modelConfig, features);
setTensorParallelDegree(lmiProperties);
}

private static void setRollingBatch(
Properties lmiProperties, LmiUtils.HuggingFaceModelConfig modelConfig) {
Properties lmiProperties,
LmiUtils.HuggingFaceModelConfig modelConfig,
String features) {
// If dynamic batch is enabled, we don't enable rolling batch.
if (Integer.parseInt(lmiProperties.getProperty("batch_size", "1")) > 1) {
lmiProperties.setProperty("option.rolling_batch", "disable");
return;
}
String rollingBatch = lmiProperties.getProperty("option.rolling_batch", "auto");
String features = Utils.getEnvOrSystemProperty("SERVING_FEATURES");
if (!"auto".equals(rollingBatch)) {
return;
} else if (!isTextGenerationModel(modelConfig)) {
// Non text-generation use-cases are not compatible with rolling batch
rollingBatch = "disable";
} else if (isVLLMEnabled(features) && isLmiDistEnabled(features)) {
rollingBatch = MODEL_TO_ROLLING_BATCH.getOrDefault(modelConfig.getModelType(), "auto");
} else if (LmiUtils.isTrtLLM(lmiProperties)) {
} else if (LmiUtils.isTrtLLMRollingBatch(lmiProperties)) {
rollingBatch = "trtllm";
}
lmiProperties.setProperty("option.rolling_batch", rollingBatch);
}

private static void setEngine(Properties lmiProperties) {
private static void setEngine(
Properties lmiProperties,
LmiUtils.HuggingFaceModelConfig modelConfig,
String features) {
if (lmiProperties.containsKey("engine")) {
return;
}
Expand All @@ -93,6 +103,11 @@ private static void setEngine(Properties lmiProperties) {
engine = "MPI";
lmiProperties.setProperty("option.mpi_mode", "true");
}
// TODO TrtLLM python backend: Change it once TrtLLM supports T5 with inflight batching.
if (isT5TrtLLM(modelConfig, features)) {
engine = "MPI";
lmiProperties.setProperty("option.mpi_mode", "true");
}
lmiProperties.setProperty("engine", engine);
}

Expand All @@ -107,6 +122,26 @@ private static void setTensorParallelDegree(Properties lmiProperties) {
lmiProperties.setProperty("option.tensor_parallel_degree", tpDegree);
}

private static void setDynamicBatch(
Properties lmiProperties,
LmiUtils.HuggingFaceModelConfig modelConfig,
ModelInfo<?, ?> modelInfo,
String features) {
// TODO TrtLLM python backend: Change it once TrtLLM supports T5 with inflight batching.
if (isT5TrtLLM(modelConfig, features)) {

// To do runtime compilation for TensorRT-LLM T5 model.
lmiProperties.setProperty("trtllm_python_backend", String.valueOf(true));
lmiProperties.setProperty("option.rolling_batch", "disable");

// We set batch_size only when customer did not provide it.
if (Integer.parseInt(lmiProperties.getProperty("batch_size", "0")) == 0) {
modelInfo.batchSize = 32;
lmiProperties.setProperty("batch_size", String.valueOf(32));
}
}
}

private static boolean isVLLMEnabled(String features) {
return features != null && features.contains("vllm");
}
Expand All @@ -115,6 +150,15 @@ private static boolean isLmiDistEnabled(String features) {
return features != null && features.contains("lmi-dist");
}

private static boolean isTrtLLMEnabled(String features) {
return features != null && features.contains("trtllm");
}

private static boolean isT5TrtLLM(
LmiUtils.HuggingFaceModelConfig modelConfig, String features) {
return isTrtLLMEnabled(features) && "t5".equals(modelConfig.getModelType());
}

private static boolean isTextGenerationModel(LmiUtils.HuggingFaceModelConfig modelConfig) {
for (String arch : modelConfig.getArchitectures()) {
boolean isTextGenerationModel =
Expand Down
47 changes: 30 additions & 17 deletions wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ static String inferLmiEngine(ModelInfo<?, ?> modelInfo) throws ModelException {
Properties prop = modelInfo.getProperties();
HuggingFaceModelConfig modelConfig = getHuggingFaceModelConfig(modelInfo);
if (modelConfig == null) {
String engineName = isTrtLLM(prop) ? "MPI" : "Python";
String engineName = isTrtLLMRollingBatch(prop) ? "MPI" : "Python";
logger.info("No config.json found, use {} engine.", engineName);
return engineName;
}
LmiConfigRecommender.configure(prop, modelConfig);
LmiConfigRecommender.configure(modelInfo, prop, modelConfig);
logger.info(
"Detected engine: {}, rolling_batch: {}, tensor_paralell_degre {}, for modelType:"
"Detected engine: {}, rolling_batch: {}, tensor_parallel_degree {}, for modelType:"
+ " {}",
prop.getProperty("engine"),
prop.getProperty("option.rolling_batch"),
Expand All @@ -70,7 +70,7 @@ static String inferLmiEngine(ModelInfo<?, ?> modelInfo) throws ModelException {
return prop.getProperty("engine");
}

static boolean isTrtLLM(Properties properties) {
static boolean isTrtLLMRollingBatch(Properties properties) {
String rollingBatch = properties.getProperty("option.rolling_batch");
if ("trtllm".equals(rollingBatch)) {
return true;
Expand All @@ -84,11 +84,12 @@ static boolean isTrtLLM(Properties properties) {
}

static boolean needConvert(ModelInfo<?, ?> info) {
return isTrtLLM(info.getProperties());
Properties properties = info.getProperties();
return isTrtLLMRollingBatch(info.getProperties())
|| properties.containsKey("trtllm_python_backend");
}

static void convertTrtLLM(ModelInfo<?, ?> info) throws IOException {
info.prop.put("option.rolling_batch", "trtllm");
Path trtRepo;
String modelId = null;
if (info.downloadDir != null) {
Expand All @@ -100,18 +101,30 @@ static void convertTrtLLM(ModelInfo<?, ?> info) throws IOException {
trtRepo = Paths.get(modelId);
}
}
if (!isValidTrtLlmModelRepo(trtRepo)) {
if (modelId == null) {
modelId = trtRepo.toString();
}
String tpDegree = info.prop.getProperty("option.tensor_parallel_degree");
if (tpDegree == null) {
tpDegree = Utils.getenv("TENSOR_PARALLEL_DEGREE", "max");
}
if ("max".equals(tpDegree)) {
tpDegree = String.valueOf(CudaUtils.getGpuCount());
}

if (modelId == null) {
modelId = trtRepo.toString();
}
String tpDegree = info.prop.getProperty("option.tensor_parallel_degree");
if (tpDegree == null) {
tpDegree = Utils.getenv("TENSOR_PARALLEL_DEGREE", "max");
}
if ("max".equals(tpDegree)) {
tpDegree = String.valueOf(CudaUtils.getGpuCount());
}

// TODO TrtLLM python backend: Change it once TrtLLM supports T5 with inflight batching.
if (info.prop.containsKey("trtllm_python_backend")) {
// Inflight batching support is not available for certain models like t5.
// Python backend models have different model repo format compared to C++ backend.
// And whether it is valid or not is checked in tensorrt_llm_toolkit. So it is not
// necessary to check here.
info.downloadDir = buildTrtLlmArtifacts(info.modelDir, modelId, tpDegree);
} else {
info.prop.put("option.rolling_batch", "trtllm");
if (!isValidTrtLlmModelRepo(trtRepo)) {
info.downloadDir = buildTrtLlmArtifacts(info.modelDir, modelId, tpDegree);
}
}
}

Expand Down

0 comments on commit b3d7191

Please sign in to comment.