Skip to content

Commit

Permalink
[lora] Add load option to LoRA adapter API
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Nov 8, 2024
1 parent e96a1c3 commit d63adb8
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 56 deletions.
101 changes: 63 additions & 38 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,38 @@ def _streaming_inference(self, batch: List, request_input: RequestInput,
self.hf_configs.device, **parameters))
return outputs

def add_lora(self, lora_name: str, lora_alias: str, lora_path: str):
if not is_rolling_batch_enabled(self.hf_configs.rolling_batch):
raise NotImplementedError(
"LoRA adapter API is only supported for rolling batch.")

loaded = self.rolling_batch.add_lora(lora_name, lora_path)
if not loaded:
raise RuntimeError(f"Failed to load LoRA adapter {lora_alias}")
return loaded

def remove_lora(self, lora_name: str, lora_alias: str):
if not is_rolling_batch_enabled(self.hf_configs.rolling_batch):
raise NotImplementedError(
"LoRA adapter API is only supported for rolling batch.")

removed = self.rolling_batch.remove_lora(lora_name)
if not removed:
logging.info(
f"Remove LoRA adapter {lora_alias} returned false, the adapter may have already been evicted."
)
return removed

def pin_lora(self, lora_name: str, lora_alias: str):
if not is_rolling_batch_enabled(self.hf_configs.rolling_batch):
raise NotImplementedError(
"LoRA adapter API is only supported for rolling batch.")

pinned = self.rolling_batch.pin_lora(lora_name)
if not pinned:
raise RuntimeError(f"Failed to pin LoRA adapter {lora_alias}")
return pinned

def get_pipeline(self, task: str, model_id_or_path: str, kwargs):
# define tokenizer or feature extractor as kwargs to load it the pipeline correctly
if task in {
Expand Down Expand Up @@ -496,41 +528,36 @@ def register_adapter(inputs: Input):
"""
Registers lora adapter with the model.
"""
if not is_rolling_batch_enabled(_service.hf_configs.rolling_batch):
raise NotImplementedError(
"LoRA adapter API is only supported for rolling batch.")

adapter_name = inputs.get_property("name")
adapter_alias = inputs.get_property("alias") or adapter_name
adapter_path = inputs.get_property("src")
adapter_load = inputs.get_as_string(
"load").lower() == "true" if inputs.contains_key("load") else True
adapter_pin = inputs.get_as_string(
"pin").lower() == "true" if inputs.contains_key("pin") else False

added = False
loaded = False
try:
if not os.path.exists(adapter_path):
raise ValueError(
f"Only local LoRA models are supported. {adapter_path} is not a valid path"
)

added = _service.rolling_batch.add_lora(adapter_name, adapter_path)
if not added:
raise RuntimeError(
f"Failed to register LoRA adapter {adapter_alias}")
if adapter_load:
loaded = _service.add_lora(adapter_name, adapter_alias,
adapter_path)

if adapter_pin:
pinned = _service.rolling_batch.pin_lora(adapter_name)
if not pinned:
raise RuntimeError(
f"Failed to pin LoRA adapter {adapter_alias}")

if not adapter_load:
raise RuntimeError("Need to set load=true to set pin=true")
_service.pin_lora(adapter_name, adapter_alias)
_service.adapter_registry[adapter_name] = inputs
except Exception as e:
if added:
if loaded:
logging.info(
f"LoRA adapter {adapter_alias} was successfully added, but failed to pin, removing ..."
)
_service.rolling_batch.remove_lora(adapter_name)
_service.remove_lora(adapter_name, adapter_alias)
if any(msg in str(e)
for msg in ("No free lora slots",
"greater than the number of GPU LoRA slots")):
Expand All @@ -546,13 +573,11 @@ def update_adapter(inputs: Input):
"""
Updates lora adapter with the model.
"""
if not is_rolling_batch_enabled(_service.hf_configs.rolling_batch):
raise NotImplementedError(
"LoRA adapter API is only supported for rolling batch.")

adapter_name = inputs.get_property("name")
adapter_alias = inputs.get_property("alias") or adapter_name
adapter_path = inputs.get_property("src")
adapter_load = inputs.get_as_string(
"load").lower() == "true" if inputs.contains_key("load") else True
adapter_pin = inputs.get_as_string(
"pin").lower() == "true" if inputs.contains_key("pin") else False

Expand All @@ -566,17 +591,26 @@ def update_adapter(inputs: Input):
raise NotImplementedError(
f"Updating adapter path is not supported.")

old_adapter_pin = inputs.get_as_string(
"pin").lower() == "true" if inputs.contains_key("pin") else False
old_adapter_load = old_adapter.get_as_string("load").lower(
) == "true" if old_adapter.contains_key("load") else True
if old_adapter_load != adapter_load:
if adapter_load:
_service.add_lora(adapter_name, adapter_alias, adapter_path)
else:
if adapter_pin:
raise RuntimeError("Need to set pin=false to set load=false")
_service.remove_lora(adapter_name, adapter_alias)

old_adapter_pin = old_adapter.get_as_string("pin").lower(
) == "true" if old_adapter.contains_key("pin") else False
if old_adapter_pin and not adapter_pin:
raise NotImplementedError(f"Unpin adapter is not supported.")

if adapter_pin:
pinned = _service.rolling_batch.pin_lora(adapter_name)
if not pinned:
raise RuntimeError(
f"Failed to pin LoRA adapter {adapter_alias}")

if old_adapter_pin != adapter_pin:
if adapter_pin:
if not adapter_load:
raise RuntimeError("Need to set load=true to set pin=true")
_service.pin_lora(adapter_name, adapter_alias)
_service.adapter_registry[adapter_name] = inputs
except Exception as e:
if any(msg in str(e)
Expand All @@ -593,23 +627,14 @@ def unregister_adapter(inputs: Input):
"""
Unregisters lora adapter from the model.
"""
if not is_rolling_batch_enabled(_service.hf_configs.rolling_batch):
raise NotImplementedError(
"LoRA adapter API is only supported for rolling batch.")

adapter_name = inputs.get_property("name")
adapter_alias = inputs.get_property("alias") or adapter_name

if adapter_name not in _service.adapter_registry:
raise ValueError(f"Adapter {adapter_alias} not registered.")

try:
removed = _service.rolling_batch.remove_lora(adapter_name)
if not removed:
logging.info(
f"Remove LoRA adapter {adapter_alias} returned false, the adapter may have already been evicted."
)

_service.remove_lora(adapter_name, adapter_alias)
del _service.adapter_registry[adapter_name]
except Exception as e:
return Output().error("remove_adapter_error", message=str(e))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,5 +283,5 @@ def pin_lora(self, lora_name):
# 2) An adapter is not evicted, call add_lora() is not necessary.
# But since whether an adapter is evicted is not exposed outside of engine,
# and add_lora() in this case will take negligible time, we will still call add_lora().
self.engine.add_lora(lora_request)
return self.engine.pin_lora(lora_request.lora_int_id)
return self.engine.add_lora(lora_request) and self.engine.pin_lora(
lora_request.lora_int_id)
Original file line number Diff line number Diff line change
Expand Up @@ -198,5 +198,5 @@ def pin_lora(self, lora_name):
# 2) An adapter is not evicted, call add_lora() is not necessary.
# But since whether an adapter is evicted is not exposed outside of engine,
# and add_lora() in this case will take negligible time, we will still call add_lora().
self.engine.add_lora(lora_request)
return self.engine.pin_lora(lora_request.lora_int_id)
return self.engine.add_lora(lora_request) and self.engine.pin_lora(
lora_request.lora_int_id)
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ private void handleListAdapters(
for (int i = pagination.getPageToken(); i < pagination.getLast(); ++i) {
String adapterName = keys.get(i);
Adapter<Input, Output> adapter = modelInfo.getAdapter(adapterName);
list.addAdapter(adapter.getName(), adapter.getSrc(), adapter.isPin());
list.addAdapter(adapter.getName(), adapter.getSrc(), adapter.isLoad(), adapter.isPin());
}

NettyUtils.sendJsonResponse(ctx, list);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
public class DescribeAdapterResponse {
private String name;
private String src;
private boolean load;
private boolean pin;

/**
Expand All @@ -30,6 +31,7 @@ public class DescribeAdapterResponse {
public DescribeAdapterResponse(Adapter<Input, Output> adapter) {
this.name = adapter.getName();
this.src = adapter.getSrc();
this.load = adapter.isLoad();
this.pin = adapter.isPin();
}

Expand All @@ -51,6 +53,15 @@ public String getSrc() {
return src;
}

/**
* Returns whether to load the adapter weights.
*
* @return whether to load the adapter weights
*/
public boolean isLoad() {
return load;
}

/**
* Returns whether to pin the adapter.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,24 @@ public List<AdapterItem> getAdapters() {
*
* @param name the adapter name
* @param src the adapter source
* @param load whether to load the adapter weights
* @param pin whether to pin the adapter
*/
public void addAdapter(String name, String src, boolean pin) {
adapters.add(new AdapterItem(name, src, pin));
public void addAdapter(String name, String src, boolean load, boolean pin) {
adapters.add(new AdapterItem(name, src, load, pin));
}

/** A class that holds the adapter response. */
public static final class AdapterItem {
private String name;
private String src;
private boolean load;
private boolean pin;

private AdapterItem(String name, String src, boolean pin) {
private AdapterItem(String name, String src, boolean load, boolean pin) {
this.name = name;
this.src = src;
this.load = load;
this.pin = pin;
}

Expand Down
63 changes: 53 additions & 10 deletions serving/src/test/java/ai/djl/serving/ModelServerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1002,12 +1002,25 @@ private void testRegisterAdapter(Channel channel, boolean registerModel, boolean
testAdapterMissing();

String strModelPrefix = modelPrefix ? "/models/adaptecho" : "";
url = strModelPrefix + "/adapters?name=adaptable&src=src&echooption=opt";
String adapterName = "adaptable";
url = strModelPrefix + "/adapters?name=" + adapterName + "&src=src&echooption=opt";
request(channel, HttpMethod.POST, url);
assertHttpOk();

StatusResponse statusResp = JsonUtils.GSON.fromJson(result, StatusResponse.class);
assertEquals(statusResp.getStatus(), "Adapter adaptable registered");
assertEquals(statusResp.getStatus(), "Adapter " + adapterName + " registered");

// Assert adapter registered
url = strModelPrefix + "/adapters/" + adapterName;
request(channel, HttpMethod.GET, url);
assertHttpOk();

DescribeAdapterResponse resp =
JsonUtils.GSON.fromJson(result, DescribeAdapterResponse.class);
assertEquals(resp.getName(), adapterName);
assertEquals(resp.getSrc(), "src");
assertTrue(resp.isLoad());
assertFalse(resp.isPin());
}

private void testRegisterAdapterConflict() throws InterruptedException {
Expand Down Expand Up @@ -1070,17 +1083,27 @@ private void testRegisterAdapterHandlerError() throws InterruptedException {
private void testUpdateAdapter(Channel channel, boolean modelPrefix)
throws InterruptedException {
logTestFunction();

String adapterName = "adaptable";
String strModelPrefix = modelPrefix ? "/models/adaptecho" : "";
String url = strModelPrefix + "/adapters/adaptable/update?src=src1";
String url = strModelPrefix + "/adapters/" + adapterName + "/update?pin=true";
request(channel, HttpMethod.POST, url);
assertHttpOk();

url = strModelPrefix + "/adapters/adaptable/update?src=src";
request(channel, HttpMethod.POST, url);
StatusResponse statusResp = JsonUtils.GSON.fromJson(result, StatusResponse.class);
assertEquals(statusResp.getStatus(), "Adapter " + adapterName + " updated");

// Assert adapter updated
url = strModelPrefix + "/adapters/" + adapterName;
request(channel, HttpMethod.GET, url);
assertHttpOk();

StatusResponse statusResp = JsonUtils.GSON.fromJson(result, StatusResponse.class);
assertEquals(statusResp.getStatus(), "Adapter adaptable updated");
DescribeAdapterResponse resp =
JsonUtils.GSON.fromJson(result, DescribeAdapterResponse.class);
assertEquals(resp.getName(), adapterName);
assertEquals(resp.getSrc(), "src");
assertTrue(resp.isLoad());
assertTrue(resp.isPin());
}

private void testUpdateAdapterModelNotFound() throws InterruptedException {
Expand Down Expand Up @@ -1117,7 +1140,11 @@ private void testUpdateAdapterHandlerError() throws InterruptedException {
String modelName = "adaptecho";
String adapterName = "adaptable";
String strModelPrefix = "/models/" + modelName;
String url = strModelPrefix + "/adapters/" + adapterName + "/update?src=src1&error=true";
String url =
strModelPrefix
+ "/adapters/"
+ adapterName
+ "/update?src=src1&load=false&error=true";
request(channel, HttpMethod.POST, url);
channel.closeFuture().sync();
channel.close().sync();
Expand All @@ -1135,6 +1162,8 @@ private void testUpdateAdapterHandlerError() throws InterruptedException {
JsonUtils.GSON.fromJson(result, DescribeAdapterResponse.class);
assertEquals(resp.getName(), adapterName);
assertEquals(resp.getSrc(), "src");
assertTrue(resp.isLoad());
assertFalse(resp.isPin());
}

private void testAdapterMissing() throws InterruptedException {
Expand Down Expand Up @@ -1290,6 +1319,8 @@ private void testDescribeAdapter(Channel channel, boolean modelPrefix)
JsonUtils.GSON.fromJson(result, DescribeAdapterResponse.class);
assertEquals(resp.getName(), "adaptable");
assertEquals(resp.getSrc(), "src");
assertTrue(resp.isLoad());
assertTrue(resp.isPin());
}

private void testDescribeAdapterModelNotFound() throws InterruptedException {
Expand Down Expand Up @@ -1337,12 +1368,24 @@ private void testUnregisterAdapter(Channel channel, boolean modelPrefix)
throws InterruptedException {
logTestFunction();
String strModelPrefix = modelPrefix ? "/models/adaptecho" : "";
String url = strModelPrefix + "/adapters/adaptable";
String adapterName = "adaptable";
String url = strModelPrefix + "/adapters/" + adapterName;
request(channel, HttpMethod.DELETE, url);
assertHttpOk();

StatusResponse statusResp = JsonUtils.GSON.fromJson(result, StatusResponse.class);
assertEquals(statusResp.getStatus(), "Adapter adaptable unregistered");
assertEquals(statusResp.getStatus(), "Adapter " + adapterName + " unregistered");

// Assert adapter unregistered
channel = connect(Connector.ConnectorType.MANAGEMENT);
assertNotNull(channel);

url = strModelPrefix + "/adapters";
request(channel, HttpMethod.GET, url);
assertHttpOk();

ListAdaptersResponse resp = JsonUtils.GSON.fromJson(result, ListAdaptersResponse.class);
assertFalse(resp.getAdapters().stream().anyMatch(a -> adapterName.equals(a.getName())));
}

private void testUnregisterAdapterModelNotFound() throws InterruptedException {
Expand Down
9 changes: 9 additions & 0 deletions wlm/src/main/java/ai/djl/serving/wlm/Adapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,15 @@ public void setOptions(Map<String, String> options) {
this.options = options;
}

/**
* Returns whether to load the adapter weights.
*
* @return whether to load the adapter weights
*/
public boolean isLoad() {
return Boolean.parseBoolean(options.getOrDefault("load", "true"));
}

/**
* Returns whether to pin the adapter.
*
Expand Down

0 comments on commit d63adb8

Please sign in to comment.