diff --git a/engines/python/setup/djl_python/huggingface.py b/engines/python/setup/djl_python/huggingface.py
index 7600074241..ff69120336 100644
--- a/engines/python/setup/djl_python/huggingface.py
+++ b/engines/python/setup/djl_python/huggingface.py
@@ -302,6 +302,38 @@ def _streaming_inference(self, batch: List, request_input: RequestInput,
self.hf_configs.device, **parameters))
return outputs
+ def add_lora(self, lora_name: str, lora_alias: str, lora_path: str):
+ if not is_rolling_batch_enabled(self.hf_configs.rolling_batch):
+ raise NotImplementedError(
+ "LoRA adapter API is only supported for rolling batch.")
+
+ loaded = self.rolling_batch.add_lora(lora_name, lora_path)
+ if not loaded:
+ raise RuntimeError(f"Failed to load LoRA adapter {lora_alias}")
+ return loaded
+
+ def remove_lora(self, lora_name: str, lora_alias: str):
+ if not is_rolling_batch_enabled(self.hf_configs.rolling_batch):
+ raise NotImplementedError(
+ "LoRA adapter API is only supported for rolling batch.")
+
+ removed = self.rolling_batch.remove_lora(lora_name)
+ if not removed:
+ logging.info(
+ f"Remove LoRA adapter {lora_alias} returned false, the adapter may have already been evicted."
+ )
+ return removed
+
+ def pin_lora(self, lora_name: str, lora_alias: str):
+ if not is_rolling_batch_enabled(self.hf_configs.rolling_batch):
+ raise NotImplementedError(
+ "LoRA adapter API is only supported for rolling batch.")
+
+ pinned = self.rolling_batch.pin_lora(lora_name)
+ if not pinned:
+ raise RuntimeError(f"Failed to pin LoRA adapter {lora_alias}")
+ return pinned
+
def get_pipeline(self, task: str, model_id_or_path: str, kwargs):
# define tokenizer or feature extractor as kwargs to load it the pipeline correctly
if task in {
@@ -496,41 +528,36 @@ def register_adapter(inputs: Input):
"""
Registers lora adapter with the model.
"""
- if not is_rolling_batch_enabled(_service.hf_configs.rolling_batch):
- raise NotImplementedError(
- "LoRA adapter API is only supported for rolling batch.")
-
adapter_name = inputs.get_property("name")
adapter_alias = inputs.get_property("alias") or adapter_name
adapter_path = inputs.get_property("src")
+ adapter_load = inputs.get_as_string(
+ "load").lower() == "true" if inputs.contains_key("load") else True
adapter_pin = inputs.get_as_string(
"pin").lower() == "true" if inputs.contains_key("pin") else False
- added = False
+ loaded = False
try:
if not os.path.exists(adapter_path):
raise ValueError(
f"Only local LoRA models are supported. {adapter_path} is not a valid path"
)
- added = _service.rolling_batch.add_lora(adapter_name, adapter_path)
- if not added:
- raise RuntimeError(
- f"Failed to register LoRA adapter {adapter_alias}")
+ if adapter_load:
+ loaded = _service.add_lora(adapter_name, adapter_alias,
+ adapter_path)
if adapter_pin:
- pinned = _service.rolling_batch.pin_lora(adapter_name)
- if not pinned:
- raise RuntimeError(
- f"Failed to pin LoRA adapter {adapter_alias}")
-
+ if not adapter_load:
+ raise RuntimeError("Need to set load=true to set pin=true")
+ _service.pin_lora(adapter_name, adapter_alias)
_service.adapter_registry[adapter_name] = inputs
except Exception as e:
- if added:
+ if loaded:
logging.info(
f"LoRA adapter {adapter_alias} was successfully added, but failed to pin, removing ..."
)
- _service.rolling_batch.remove_lora(adapter_name)
+ _service.remove_lora(adapter_name, adapter_alias)
if any(msg in str(e)
for msg in ("No free lora slots",
"greater than the number of GPU LoRA slots")):
@@ -546,13 +573,11 @@ def update_adapter(inputs: Input):
"""
Updates lora adapter with the model.
"""
- if not is_rolling_batch_enabled(_service.hf_configs.rolling_batch):
- raise NotImplementedError(
- "LoRA adapter API is only supported for rolling batch.")
-
adapter_name = inputs.get_property("name")
adapter_alias = inputs.get_property("alias") or adapter_name
adapter_path = inputs.get_property("src")
+ adapter_load = inputs.get_as_string(
+ "load").lower() == "true" if inputs.contains_key("load") else True
adapter_pin = inputs.get_as_string(
"pin").lower() == "true" if inputs.contains_key("pin") else False
@@ -566,17 +591,27 @@ def update_adapter(inputs: Input):
raise NotImplementedError(
f"Updating adapter path is not supported.")
- old_adapter_pin = inputs.get_as_string(
- "pin").lower() == "true" if inputs.contains_key("pin") else False
+ old_adapter_load = old_adapter.get_as_string("load").lower(
+ ) == "true" if old_adapter.contains_key("load") else True
+ if old_adapter_load != adapter_load:
+ if adapter_load:
+ _service.add_lora(adapter_name, adapter_alias, adapter_path)
+ else:
+ if adapter_pin:
+ raise RuntimeError(
+ "Need to set pin=false to set load=false")
+ _service.remove_lora(adapter_name, adapter_alias)
+
+ old_adapter_pin = old_adapter.get_as_string("pin").lower(
+ ) == "true" if old_adapter.contains_key("pin") else False
if old_adapter_pin and not adapter_pin:
raise NotImplementedError(f"Unpin adapter is not supported.")
- if adapter_pin:
- pinned = _service.rolling_batch.pin_lora(adapter_name)
- if not pinned:
- raise RuntimeError(
- f"Failed to pin LoRA adapter {adapter_alias}")
-
+ if old_adapter_pin != adapter_pin:
+ if adapter_pin:
+ if not adapter_load:
+ raise RuntimeError("Need to set load=true to set pin=true")
+ _service.pin_lora(adapter_name, adapter_alias)
_service.adapter_registry[adapter_name] = inputs
except Exception as e:
if any(msg in str(e)
@@ -593,10 +628,6 @@ def unregister_adapter(inputs: Input):
"""
Unregisters lora adapter from the model.
"""
- if not is_rolling_batch_enabled(_service.hf_configs.rolling_batch):
- raise NotImplementedError(
- "LoRA adapter API is only supported for rolling batch.")
-
adapter_name = inputs.get_property("name")
adapter_alias = inputs.get_property("alias") or adapter_name
@@ -604,12 +635,7 @@ def unregister_adapter(inputs: Input):
raise ValueError(f"Adapter {adapter_alias} not registered.")
try:
- removed = _service.rolling_batch.remove_lora(adapter_name)
- if not removed:
- logging.info(
- f"Remove LoRA adapter {adapter_alias} returned false, the adapter may have already been evicted."
- )
-
+ _service.remove_lora(adapter_name, adapter_alias)
del _service.adapter_registry[adapter_name]
except Exception as e:
return Output().error("remove_adapter_error", message=str(e))
diff --git a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py
index 613f1a5def..32d3a02635 100644
--- a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py
+++ b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py
@@ -283,5 +283,5 @@ def pin_lora(self, lora_name):
# 2) An adapter is not evicted, call add_lora() is not necessary.
# But since whether an adapter is evicted is not exposed outside of engine,
# and add_lora() in this case will take negligible time, we will still call add_lora().
- self.engine.add_lora(lora_request)
- return self.engine.pin_lora(lora_request.lora_int_id)
+ return self.engine.add_lora(lora_request) and self.engine.pin_lora(
+ lora_request.lora_int_id)
diff --git a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py
index 96dc3bf73c..f5ca38dd92 100644
--- a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py
+++ b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py
@@ -198,5 +198,5 @@ def pin_lora(self, lora_name):
# 2) An adapter is not evicted, call add_lora() is not necessary.
# But since whether an adapter is evicted is not exposed outside of engine,
# and add_lora() in this case will take negligible time, we will still call add_lora().
- self.engine.add_lora(lora_request)
- return self.engine.pin_lora(lora_request.lora_int_id)
+ return self.engine.add_lora(lora_request) and self.engine.pin_lora(
+ lora_request.lora_int_id)
diff --git a/serving/src/main/java/ai/djl/serving/http/AdapterManagementRequestHandler.java b/serving/src/main/java/ai/djl/serving/http/AdapterManagementRequestHandler.java
index 8530a2562e..f2b8acecd7 100644
--- a/serving/src/main/java/ai/djl/serving/http/AdapterManagementRequestHandler.java
+++ b/serving/src/main/java/ai/djl/serving/http/AdapterManagementRequestHandler.java
@@ -168,7 +168,7 @@ private void handleListAdapters(
for (int i = pagination.getPageToken(); i < pagination.getLast(); ++i) {
String adapterName = keys.get(i);
Adapter adapter = modelInfo.getAdapter(adapterName);
- list.addAdapter(adapter.getName(), adapter.getSrc(), adapter.isPin());
+ list.addAdapter(adapter.getName(), adapter.getSrc(), adapter.isLoad(), adapter.isPin());
}
NettyUtils.sendJsonResponse(ctx, list);
diff --git a/serving/src/main/java/ai/djl/serving/http/DescribeAdapterResponse.java b/serving/src/main/java/ai/djl/serving/http/DescribeAdapterResponse.java
index 3020b19a0d..94fe4ef631 100644
--- a/serving/src/main/java/ai/djl/serving/http/DescribeAdapterResponse.java
+++ b/serving/src/main/java/ai/djl/serving/http/DescribeAdapterResponse.java
@@ -20,6 +20,7 @@
public class DescribeAdapterResponse {
private String name;
private String src;
+ private boolean load;
private boolean pin;
/**
@@ -30,6 +31,7 @@ public class DescribeAdapterResponse {
public DescribeAdapterResponse(Adapter adapter) {
this.name = adapter.getName();
this.src = adapter.getSrc();
+ this.load = adapter.isLoad();
this.pin = adapter.isPin();
}
@@ -51,6 +53,15 @@ public String getSrc() {
return src;
}
+ /**
+ * Returns whether to load the adapter weights.
+ *
+ * @return whether to load the adapter weights
+ */
+ public boolean isLoad() {
+ return load;
+ }
+
/**
* Returns whether to pin the adapter.
*
diff --git a/serving/src/main/java/ai/djl/serving/http/list/ListAdaptersResponse.java b/serving/src/main/java/ai/djl/serving/http/list/ListAdaptersResponse.java
index abde232572..35678fdc37 100644
--- a/serving/src/main/java/ai/djl/serving/http/list/ListAdaptersResponse.java
+++ b/serving/src/main/java/ai/djl/serving/http/list/ListAdaptersResponse.java
@@ -58,21 +58,24 @@ public List getAdapters() {
*
* @param name the adapter name
* @param src the adapter source
+ * @param load whether to load the adapter weights
* @param pin whether to pin the adapter
*/
- public void addAdapter(String name, String src, boolean pin) {
- adapters.add(new AdapterItem(name, src, pin));
+ public void addAdapter(String name, String src, boolean load, boolean pin) {
+ adapters.add(new AdapterItem(name, src, load, pin));
}
/** A class that holds the adapter response. */
public static final class AdapterItem {
private String name;
private String src;
+ private boolean load;
private boolean pin;
- private AdapterItem(String name, String src, boolean pin) {
+ private AdapterItem(String name, String src, boolean load, boolean pin) {
this.name = name;
this.src = src;
+ this.load = load;
this.pin = pin;
}
diff --git a/serving/src/test/java/ai/djl/serving/ModelServerTest.java b/serving/src/test/java/ai/djl/serving/ModelServerTest.java
index 2dfebb4d1e..29ac2add28 100644
--- a/serving/src/test/java/ai/djl/serving/ModelServerTest.java
+++ b/serving/src/test/java/ai/djl/serving/ModelServerTest.java
@@ -1002,12 +1002,25 @@ private void testRegisterAdapter(Channel channel, boolean registerModel, boolean
testAdapterMissing();
String strModelPrefix = modelPrefix ? "/models/adaptecho" : "";
- url = strModelPrefix + "/adapters?name=adaptable&src=src&echooption=opt";
+ String adapterName = "adaptable";
+ url = strModelPrefix + "/adapters?name=" + adapterName + "&src=src&echooption=opt";
request(channel, HttpMethod.POST, url);
assertHttpOk();
StatusResponse statusResp = JsonUtils.GSON.fromJson(result, StatusResponse.class);
- assertEquals(statusResp.getStatus(), "Adapter adaptable registered");
+ assertEquals(statusResp.getStatus(), "Adapter " + adapterName + " registered");
+
+ // Assert adapter registered
+ url = strModelPrefix + "/adapters/" + adapterName;
+ request(channel, HttpMethod.GET, url);
+ assertHttpOk();
+
+ DescribeAdapterResponse resp =
+ JsonUtils.GSON.fromJson(result, DescribeAdapterResponse.class);
+ assertEquals(resp.getName(), adapterName);
+ assertEquals(resp.getSrc(), "src");
+ assertTrue(resp.isLoad());
+ assertFalse(resp.isPin());
}
private void testRegisterAdapterConflict() throws InterruptedException {
@@ -1070,17 +1083,27 @@ private void testRegisterAdapterHandlerError() throws InterruptedException {
private void testUpdateAdapter(Channel channel, boolean modelPrefix)
throws InterruptedException {
logTestFunction();
+
+ String adapterName = "adaptable";
String strModelPrefix = modelPrefix ? "/models/adaptecho" : "";
- String url = strModelPrefix + "/adapters/adaptable/update?src=src1";
+ String url = strModelPrefix + "/adapters/" + adapterName + "/update?pin=true";
request(channel, HttpMethod.POST, url);
assertHttpOk();
- url = strModelPrefix + "/adapters/adaptable/update?src=src";
- request(channel, HttpMethod.POST, url);
+ StatusResponse statusResp = JsonUtils.GSON.fromJson(result, StatusResponse.class);
+ assertEquals(statusResp.getStatus(), "Adapter " + adapterName + " updated");
+
+ // Assert adapter updated
+ url = strModelPrefix + "/adapters/" + adapterName;
+ request(channel, HttpMethod.GET, url);
assertHttpOk();
- StatusResponse statusResp = JsonUtils.GSON.fromJson(result, StatusResponse.class);
- assertEquals(statusResp.getStatus(), "Adapter adaptable updated");
+ DescribeAdapterResponse resp =
+ JsonUtils.GSON.fromJson(result, DescribeAdapterResponse.class);
+ assertEquals(resp.getName(), adapterName);
+ assertEquals(resp.getSrc(), "src");
+ assertTrue(resp.isLoad());
+ assertTrue(resp.isPin());
}
private void testUpdateAdapterModelNotFound() throws InterruptedException {
@@ -1117,7 +1140,11 @@ private void testUpdateAdapterHandlerError() throws InterruptedException {
String modelName = "adaptecho";
String adapterName = "adaptable";
String strModelPrefix = "/models/" + modelName;
- String url = strModelPrefix + "/adapters/" + adapterName + "/update?src=src1&error=true";
+ String url =
+ strModelPrefix
+ + "/adapters/"
+ + adapterName
+ + "/update?src=src1&load=false&error=true";
request(channel, HttpMethod.POST, url);
channel.closeFuture().sync();
channel.close().sync();
@@ -1135,6 +1162,8 @@ private void testUpdateAdapterHandlerError() throws InterruptedException {
JsonUtils.GSON.fromJson(result, DescribeAdapterResponse.class);
assertEquals(resp.getName(), adapterName);
assertEquals(resp.getSrc(), "src");
+ assertTrue(resp.isLoad());
+ assertFalse(resp.isPin());
}
private void testAdapterMissing() throws InterruptedException {
@@ -1290,6 +1319,8 @@ private void testDescribeAdapter(Channel channel, boolean modelPrefix)
JsonUtils.GSON.fromJson(result, DescribeAdapterResponse.class);
assertEquals(resp.getName(), "adaptable");
assertEquals(resp.getSrc(), "src");
+ assertTrue(resp.isLoad());
+ assertTrue(resp.isPin());
}
private void testDescribeAdapterModelNotFound() throws InterruptedException {
@@ -1337,12 +1368,24 @@ private void testUnregisterAdapter(Channel channel, boolean modelPrefix)
throws InterruptedException {
logTestFunction();
String strModelPrefix = modelPrefix ? "/models/adaptecho" : "";
- String url = strModelPrefix + "/adapters/adaptable";
+ String adapterName = "adaptable";
+ String url = strModelPrefix + "/adapters/" + adapterName;
request(channel, HttpMethod.DELETE, url);
assertHttpOk();
StatusResponse statusResp = JsonUtils.GSON.fromJson(result, StatusResponse.class);
- assertEquals(statusResp.getStatus(), "Adapter adaptable unregistered");
+ assertEquals(statusResp.getStatus(), "Adapter " + adapterName + " unregistered");
+
+ // Assert adapter unregistered
+ channel = connect(Connector.ConnectorType.MANAGEMENT);
+ assertNotNull(channel);
+
+ url = strModelPrefix + "/adapters";
+ request(channel, HttpMethod.GET, url);
+ assertHttpOk();
+
+ ListAdaptersResponse resp = JsonUtils.GSON.fromJson(result, ListAdaptersResponse.class);
+ assertFalse(resp.getAdapters().stream().anyMatch(a -> adapterName.equals(a.getName())));
}
private void testUnregisterAdapterModelNotFound() throws InterruptedException {
diff --git a/wlm/src/main/java/ai/djl/serving/wlm/Adapter.java b/wlm/src/main/java/ai/djl/serving/wlm/Adapter.java
index a1ace56e8b..6cb19ef38e 100644
--- a/wlm/src/main/java/ai/djl/serving/wlm/Adapter.java
+++ b/wlm/src/main/java/ai/djl/serving/wlm/Adapter.java
@@ -192,6 +192,15 @@ public void setOptions(Map options) {
this.options = options;
}
+ /**
+ * Returns whether to load the adapter weights.
+ *
+ * @return whether to load the adapter weights
+ */
+ public boolean isLoad() {
+ return Boolean.parseBoolean(options.getOrDefault("load", "true"));
+ }
+
/**
* Returns whether to pin the adapter.
*