Skip to content

Commit

Permalink
[serving] Separate download draft model from downloadS3() (#1693)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Mar 28, 2024
1 parent c3ecb2c commit 3e948d8
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -501,11 +501,12 @@ public void initialize() throws IOException, ModelException {
downloadModel();
loadServingProperties();
downloadS3();
eventManager.onModelDownloaded(this, downloadDir);
downloadDraftModel();

long duration = (System.nanoTime() - begin) / 1000;
Metric metric = new Metric("DownloadModel", duration, Unit.MICROSECONDS, dimension);
MODEL_METRIC.info("{}", metric);

eventManager.onModelDownloaded(this, downloadDir);
if (LmiUtils.needConvert(this)) {
eventManager.onModelConverting(this, "trtllm");
begin = System.nanoTime();
Expand Down Expand Up @@ -1040,7 +1041,7 @@ private Path downloadS3ToDownloadDir(String s3Url) throws IOException, ModelExce
Files.createDirectories(parent);
Path tmp = Files.createTempDirectory(parent, "tmp");
try {
downloadS3(s3Url, tmp.toAbsolutePath().toString());
runS3cmd(s3Url, tmp.toAbsolutePath().toString());
Utils.moveQuietly(tmp, downloadModelDir);
logger.info("{}: Download completed! Files saved to {}", uid, downloadModelDir);
} finally {
Expand All @@ -1052,12 +1053,6 @@ private Path downloadS3ToDownloadDir(String s3Url) throws IOException, ModelExce

void downloadS3() throws ModelException, IOException {
String modelId = prop.getProperty("option.model_id");
String draftModelId = prop.getProperty("option.speculative_draft_model");
if (draftModelId != null && draftModelId.startsWith("s3://")) {
Path draftDownloadDir = downloadS3ToDownloadDir(draftModelId);
prop.setProperty(
"option.speculative_draft_model", draftDownloadDir.toAbsolutePath().toString());
}
if (modelId == null) {
return;
}
Expand All @@ -1071,7 +1066,7 @@ void downloadS3() throws ModelException, IOException {
}
}

private void downloadS3(String src, String dest) throws ModelException {
private void runS3cmd(String src, String dest) throws ModelException {
try {
String[] commands;
if (Files.exists(Paths.get("/opt/djl/bin/s5cmd"))) {
Expand Down Expand Up @@ -1107,6 +1102,15 @@ private void downloadS3(String src, String dest) throws ModelException {
}
}

private void downloadDraftModel() throws ModelException, IOException {
String draftModelId = prop.getProperty("option.speculative_draft_model");
if (draftModelId != null && draftModelId.startsWith("s3://")) {
Path draftDownloadDir = downloadS3ToDownloadDir(draftModelId);
prop.setProperty(
"option.speculative_draft_model", draftDownloadDir.toAbsolutePath().toString());
}
}

private static int intValue(Properties prop, String key, int defValue) {
String value = prop.getProperty(key);
if (value == null) {
Expand Down

0 comments on commit 3e948d8

Please sign in to comment.