From 2ba8f8d5a46f2b32849018163f533efb83ed2900 Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Fri, 8 Nov 2024 14:14:17 -0800 Subject: [PATCH] [lora] Add load option to LoRA adapter API --- .../python/setup/djl_python/huggingface.py | 102 +++++++++++------- .../rolling_batch/lmi_dist_rolling_batch.py | 4 +- .../rolling_batch/vllm_rolling_batch.py | 4 +- .../http/AdapterManagementRequestHandler.java | 2 +- .../serving/http/DescribeAdapterResponse.java | 11 ++ .../http/list/ListAdaptersResponse.java | 9 +- .../java/ai/djl/serving/ModelServerTest.java | 63 +++++++++-- .../main/java/ai/djl/serving/wlm/Adapter.java | 9 ++ 8 files changed, 148 insertions(+), 56 deletions(-) diff --git a/engines/python/setup/djl_python/huggingface.py b/engines/python/setup/djl_python/huggingface.py index 7600074241..ff69120336 100644 --- a/engines/python/setup/djl_python/huggingface.py +++ b/engines/python/setup/djl_python/huggingface.py @@ -302,6 +302,38 @@ def _streaming_inference(self, batch: List, request_input: RequestInput, self.hf_configs.device, **parameters)) return outputs + def add_lora(self, lora_name: str, lora_alias: str, lora_path: str): + if not is_rolling_batch_enabled(self.hf_configs.rolling_batch): + raise NotImplementedError( + "LoRA adapter API is only supported for rolling batch.") + + loaded = self.rolling_batch.add_lora(lora_name, lora_path) + if not loaded: + raise RuntimeError(f"Failed to load LoRA adapter {lora_alias}") + return loaded + + def remove_lora(self, lora_name: str, lora_alias: str): + if not is_rolling_batch_enabled(self.hf_configs.rolling_batch): + raise NotImplementedError( + "LoRA adapter API is only supported for rolling batch.") + + removed = self.rolling_batch.remove_lora(lora_name) + if not removed: + logging.info( + f"Remove LoRA adapter {lora_alias} returned false, the adapter may have already been evicted." + ) + return removed + + def pin_lora(self, lora_name: str, lora_alias: str): + if not is_rolling_batch_enabled(self.hf_configs.rolling_batch): + raise NotImplementedError( + "LoRA adapter API is only supported for rolling batch.") + + pinned = self.rolling_batch.pin_lora(lora_name) + if not pinned: + raise RuntimeError(f"Failed to pin LoRA adapter {lora_alias}") + return pinned + def get_pipeline(self, task: str, model_id_or_path: str, kwargs): # define tokenizer or feature extractor as kwargs to load it the pipeline correctly if task in { @@ -496,41 +528,36 @@ def register_adapter(inputs: Input): """ Registers lora adapter with the model. """ - if not is_rolling_batch_enabled(_service.hf_configs.rolling_batch): - raise NotImplementedError( - "LoRA adapter API is only supported for rolling batch.") - adapter_name = inputs.get_property("name") adapter_alias = inputs.get_property("alias") or adapter_name adapter_path = inputs.get_property("src") + adapter_load = inputs.get_as_string( + "load").lower() == "true" if inputs.contains_key("load") else True adapter_pin = inputs.get_as_string( "pin").lower() == "true" if inputs.contains_key("pin") else False - added = False + loaded = False try: if not os.path.exists(adapter_path): raise ValueError( f"Only local LoRA models are supported. {adapter_path} is not a valid path" ) - added = _service.rolling_batch.add_lora(adapter_name, adapter_path) - if not added: - raise RuntimeError( - f"Failed to register LoRA adapter {adapter_alias}") + if adapter_load: + loaded = _service.add_lora(adapter_name, adapter_alias, + adapter_path) if adapter_pin: - pinned = _service.rolling_batch.pin_lora(adapter_name) - if not pinned: - raise RuntimeError( - f"Failed to pin LoRA adapter {adapter_alias}") - + if not adapter_load: + raise RuntimeError("Need to set load=true to set pin=true") + _service.pin_lora(adapter_name, adapter_alias) _service.adapter_registry[adapter_name] = inputs except Exception as e: - if added: + if loaded: logging.info( f"LoRA adapter {adapter_alias} was successfully added, but failed to pin, removing ..." ) - _service.rolling_batch.remove_lora(adapter_name) + _service.remove_lora(adapter_name, adapter_alias) if any(msg in str(e) for msg in ("No free lora slots", "greater than the number of GPU LoRA slots")): @@ -546,13 +573,11 @@ def update_adapter(inputs: Input): """ Updates lora adapter with the model. """ - if not is_rolling_batch_enabled(_service.hf_configs.rolling_batch): - raise NotImplementedError( - "LoRA adapter API is only supported for rolling batch.") - adapter_name = inputs.get_property("name") adapter_alias = inputs.get_property("alias") or adapter_name adapter_path = inputs.get_property("src") + adapter_load = inputs.get_as_string( + "load").lower() == "true" if inputs.contains_key("load") else True adapter_pin = inputs.get_as_string( "pin").lower() == "true" if inputs.contains_key("pin") else False @@ -566,17 +591,27 @@ def update_adapter(inputs: Input): raise NotImplementedError( f"Updating adapter path is not supported.") - old_adapter_pin = inputs.get_as_string( - "pin").lower() == "true" if inputs.contains_key("pin") else False + old_adapter_load = old_adapter.get_as_string("load").lower( + ) == "true" if old_adapter.contains_key("load") else True + if old_adapter_load != adapter_load: + if adapter_load: + _service.add_lora(adapter_name, adapter_alias, adapter_path) + else: + if adapter_pin: + raise RuntimeError( + "Need to set pin=false to set load=false") + _service.remove_lora(adapter_name, adapter_alias) + + old_adapter_pin = old_adapter.get_as_string("pin").lower( + ) == "true" if old_adapter.contains_key("pin") else False if old_adapter_pin and not adapter_pin: raise NotImplementedError(f"Unpin adapter is not supported.") - if adapter_pin: - pinned = _service.rolling_batch.pin_lora(adapter_name) - if not pinned: - raise RuntimeError( - f"Failed to pin LoRA adapter {adapter_alias}") - + if old_adapter_pin != adapter_pin: + if adapter_pin: + if not adapter_load: + raise RuntimeError("Need to set load=true to set pin=true") + _service.pin_lora(adapter_name, adapter_alias) _service.adapter_registry[adapter_name] = inputs except Exception as e: if any(msg in str(e) @@ -593,10 +628,6 @@ def unregister_adapter(inputs: Input): """ Unregisters lora adapter from the model. """ - if not is_rolling_batch_enabled(_service.hf_configs.rolling_batch): - raise NotImplementedError( - "LoRA adapter API is only supported for rolling batch.") - adapter_name = inputs.get_property("name") adapter_alias = inputs.get_property("alias") or adapter_name @@ -604,12 +635,7 @@ def unregister_adapter(inputs: Input): raise ValueError(f"Adapter {adapter_alias} not registered.") try: - removed = _service.rolling_batch.remove_lora(adapter_name) - if not removed: - logging.info( - f"Remove LoRA adapter {adapter_alias} returned false, the adapter may have already been evicted." - ) - + _service.remove_lora(adapter_name, adapter_alias) del _service.adapter_registry[adapter_name] except Exception as e: return Output().error("remove_adapter_error", message=str(e)) diff --git a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py index 613f1a5def..32d3a02635 100644 --- a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py @@ -283,5 +283,5 @@ def pin_lora(self, lora_name): # 2) An adapter is not evicted, call add_lora() is not necessary. # But since whether an adapter is evicted is not exposed outside of engine, # and add_lora() in this case will take negligible time, we will still call add_lora(). - self.engine.add_lora(lora_request) - return self.engine.pin_lora(lora_request.lora_int_id) + return self.engine.add_lora(lora_request) and self.engine.pin_lora( + lora_request.lora_int_id) diff --git a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py index 96dc3bf73c..f5ca38dd92 100644 --- a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py @@ -198,5 +198,5 @@ def pin_lora(self, lora_name): # 2) An adapter is not evicted, call add_lora() is not necessary. # But since whether an adapter is evicted is not exposed outside of engine, # and add_lora() in this case will take negligible time, we will still call add_lora(). - self.engine.add_lora(lora_request) - return self.engine.pin_lora(lora_request.lora_int_id) + return self.engine.add_lora(lora_request) and self.engine.pin_lora( + lora_request.lora_int_id) diff --git a/serving/src/main/java/ai/djl/serving/http/AdapterManagementRequestHandler.java b/serving/src/main/java/ai/djl/serving/http/AdapterManagementRequestHandler.java index 8530a2562e..f2b8acecd7 100644 --- a/serving/src/main/java/ai/djl/serving/http/AdapterManagementRequestHandler.java +++ b/serving/src/main/java/ai/djl/serving/http/AdapterManagementRequestHandler.java @@ -168,7 +168,7 @@ private void handleListAdapters( for (int i = pagination.getPageToken(); i < pagination.getLast(); ++i) { String adapterName = keys.get(i); Adapter adapter = modelInfo.getAdapter(adapterName); - list.addAdapter(adapter.getName(), adapter.getSrc(), adapter.isPin()); + list.addAdapter(adapter.getName(), adapter.getSrc(), adapter.isLoad(), adapter.isPin()); } NettyUtils.sendJsonResponse(ctx, list); diff --git a/serving/src/main/java/ai/djl/serving/http/DescribeAdapterResponse.java b/serving/src/main/java/ai/djl/serving/http/DescribeAdapterResponse.java index 3020b19a0d..94fe4ef631 100644 --- a/serving/src/main/java/ai/djl/serving/http/DescribeAdapterResponse.java +++ b/serving/src/main/java/ai/djl/serving/http/DescribeAdapterResponse.java @@ -20,6 +20,7 @@ public class DescribeAdapterResponse { private String name; private String src; + private boolean load; private boolean pin; /** @@ -30,6 +31,7 @@ public class DescribeAdapterResponse { public DescribeAdapterResponse(Adapter adapter) { this.name = adapter.getName(); this.src = adapter.getSrc(); + this.load = adapter.isLoad(); this.pin = adapter.isPin(); } @@ -51,6 +53,15 @@ public String getSrc() { return src; } + /** + * Returns whether to load the adapter weights. + * + * @return whether to load the adapter weights + */ + public boolean isLoad() { + return load; + } + /** * Returns whether to pin the adapter. * diff --git a/serving/src/main/java/ai/djl/serving/http/list/ListAdaptersResponse.java b/serving/src/main/java/ai/djl/serving/http/list/ListAdaptersResponse.java index abde232572..35678fdc37 100644 --- a/serving/src/main/java/ai/djl/serving/http/list/ListAdaptersResponse.java +++ b/serving/src/main/java/ai/djl/serving/http/list/ListAdaptersResponse.java @@ -58,21 +58,24 @@ public List getAdapters() { * * @param name the adapter name * @param src the adapter source + * @param load whether to load the adapter weights * @param pin whether to pin the adapter */ - public void addAdapter(String name, String src, boolean pin) { - adapters.add(new AdapterItem(name, src, pin)); + public void addAdapter(String name, String src, boolean load, boolean pin) { + adapters.add(new AdapterItem(name, src, load, pin)); } /** A class that holds the adapter response. */ public static final class AdapterItem { private String name; private String src; + private boolean load; private boolean pin; - private AdapterItem(String name, String src, boolean pin) { + private AdapterItem(String name, String src, boolean load, boolean pin) { this.name = name; this.src = src; + this.load = load; this.pin = pin; } diff --git a/serving/src/test/java/ai/djl/serving/ModelServerTest.java b/serving/src/test/java/ai/djl/serving/ModelServerTest.java index 2dfebb4d1e..29ac2add28 100644 --- a/serving/src/test/java/ai/djl/serving/ModelServerTest.java +++ b/serving/src/test/java/ai/djl/serving/ModelServerTest.java @@ -1002,12 +1002,25 @@ private void testRegisterAdapter(Channel channel, boolean registerModel, boolean testAdapterMissing(); String strModelPrefix = modelPrefix ? "/models/adaptecho" : ""; - url = strModelPrefix + "/adapters?name=adaptable&src=src&echooption=opt"; + String adapterName = "adaptable"; + url = strModelPrefix + "/adapters?name=" + adapterName + "&src=src&echooption=opt"; request(channel, HttpMethod.POST, url); assertHttpOk(); StatusResponse statusResp = JsonUtils.GSON.fromJson(result, StatusResponse.class); - assertEquals(statusResp.getStatus(), "Adapter adaptable registered"); + assertEquals(statusResp.getStatus(), "Adapter " + adapterName + " registered"); + + // Assert adapter registered + url = strModelPrefix + "/adapters/" + adapterName; + request(channel, HttpMethod.GET, url); + assertHttpOk(); + + DescribeAdapterResponse resp = + JsonUtils.GSON.fromJson(result, DescribeAdapterResponse.class); + assertEquals(resp.getName(), adapterName); + assertEquals(resp.getSrc(), "src"); + assertTrue(resp.isLoad()); + assertFalse(resp.isPin()); } private void testRegisterAdapterConflict() throws InterruptedException { @@ -1070,17 +1083,27 @@ private void testRegisterAdapterHandlerError() throws InterruptedException { private void testUpdateAdapter(Channel channel, boolean modelPrefix) throws InterruptedException { logTestFunction(); + + String adapterName = "adaptable"; String strModelPrefix = modelPrefix ? "/models/adaptecho" : ""; - String url = strModelPrefix + "/adapters/adaptable/update?src=src1"; + String url = strModelPrefix + "/adapters/" + adapterName + "/update?pin=true"; request(channel, HttpMethod.POST, url); assertHttpOk(); - url = strModelPrefix + "/adapters/adaptable/update?src=src"; - request(channel, HttpMethod.POST, url); + StatusResponse statusResp = JsonUtils.GSON.fromJson(result, StatusResponse.class); + assertEquals(statusResp.getStatus(), "Adapter " + adapterName + " updated"); + + // Assert adapter updated + url = strModelPrefix + "/adapters/" + adapterName; + request(channel, HttpMethod.GET, url); assertHttpOk(); - StatusResponse statusResp = JsonUtils.GSON.fromJson(result, StatusResponse.class); - assertEquals(statusResp.getStatus(), "Adapter adaptable updated"); + DescribeAdapterResponse resp = + JsonUtils.GSON.fromJson(result, DescribeAdapterResponse.class); + assertEquals(resp.getName(), adapterName); + assertEquals(resp.getSrc(), "src"); + assertTrue(resp.isLoad()); + assertTrue(resp.isPin()); } private void testUpdateAdapterModelNotFound() throws InterruptedException { @@ -1117,7 +1140,11 @@ private void testUpdateAdapterHandlerError() throws InterruptedException { String modelName = "adaptecho"; String adapterName = "adaptable"; String strModelPrefix = "/models/" + modelName; - String url = strModelPrefix + "/adapters/" + adapterName + "/update?src=src1&error=true"; + String url = + strModelPrefix + + "/adapters/" + + adapterName + + "/update?src=src1&load=false&error=true"; request(channel, HttpMethod.POST, url); channel.closeFuture().sync(); channel.close().sync(); @@ -1135,6 +1162,8 @@ private void testUpdateAdapterHandlerError() throws InterruptedException { JsonUtils.GSON.fromJson(result, DescribeAdapterResponse.class); assertEquals(resp.getName(), adapterName); assertEquals(resp.getSrc(), "src"); + assertTrue(resp.isLoad()); + assertFalse(resp.isPin()); } private void testAdapterMissing() throws InterruptedException { @@ -1290,6 +1319,8 @@ private void testDescribeAdapter(Channel channel, boolean modelPrefix) JsonUtils.GSON.fromJson(result, DescribeAdapterResponse.class); assertEquals(resp.getName(), "adaptable"); assertEquals(resp.getSrc(), "src"); + assertTrue(resp.isLoad()); + assertTrue(resp.isPin()); } private void testDescribeAdapterModelNotFound() throws InterruptedException { @@ -1337,12 +1368,24 @@ private void testUnregisterAdapter(Channel channel, boolean modelPrefix) throws InterruptedException { logTestFunction(); String strModelPrefix = modelPrefix ? "/models/adaptecho" : ""; - String url = strModelPrefix + "/adapters/adaptable"; + String adapterName = "adaptable"; + String url = strModelPrefix + "/adapters/" + adapterName; request(channel, HttpMethod.DELETE, url); assertHttpOk(); StatusResponse statusResp = JsonUtils.GSON.fromJson(result, StatusResponse.class); - assertEquals(statusResp.getStatus(), "Adapter adaptable unregistered"); + assertEquals(statusResp.getStatus(), "Adapter " + adapterName + " unregistered"); + + // Assert adapter unregistered + channel = connect(Connector.ConnectorType.MANAGEMENT); + assertNotNull(channel); + + url = strModelPrefix + "/adapters"; + request(channel, HttpMethod.GET, url); + assertHttpOk(); + + ListAdaptersResponse resp = JsonUtils.GSON.fromJson(result, ListAdaptersResponse.class); + assertFalse(resp.getAdapters().stream().anyMatch(a -> adapterName.equals(a.getName()))); } private void testUnregisterAdapterModelNotFound() throws InterruptedException { diff --git a/wlm/src/main/java/ai/djl/serving/wlm/Adapter.java b/wlm/src/main/java/ai/djl/serving/wlm/Adapter.java index a1ace56e8b..6cb19ef38e 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/Adapter.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/Adapter.java @@ -192,6 +192,15 @@ public void setOptions(Map options) { this.options = options; } + /** + * Returns whether to load the adapter weights. + * + * @return whether to load the adapter weights + */ + public boolean isLoad() { + return Boolean.parseBoolean(options.getOrDefault("load", "true")); + } + /** * Returns whether to pin the adapter. *