Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lora] Add load option to LoRA adapter API #2536

Merged
merged 2 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 70 additions & 43 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,38 @@ def _streaming_inference(self, batch: List, request_input: RequestInput,
self.hf_configs.device, **parameters))
return outputs

def add_lora(self, lora_name: str, lora_alias: str, lora_path: str):
if not is_rolling_batch_enabled(self.hf_configs.rolling_batch):
raise NotImplementedError(
"LoRA adapter API is only supported for rolling batch.")

loaded = self.rolling_batch.add_lora(lora_name, lora_path)
if not loaded:
raise RuntimeError(f"Failed to load LoRA adapter {lora_alias}")
return loaded

def remove_lora(self, lora_name: str, lora_alias: str):
if not is_rolling_batch_enabled(self.hf_configs.rolling_batch):
raise NotImplementedError(
"LoRA adapter API is only supported for rolling batch.")

removed = self.rolling_batch.remove_lora(lora_name)
if not removed:
logging.info(
f"Remove LoRA adapter {lora_alias} returned false, the adapter may have already been evicted."
)
return removed

def pin_lora(self, lora_name: str, lora_alias: str):
if not is_rolling_batch_enabled(self.hf_configs.rolling_batch):
raise NotImplementedError(
"LoRA adapter API is only supported for rolling batch.")

pinned = self.rolling_batch.pin_lora(lora_name)
if not pinned:
raise RuntimeError(f"Failed to pin LoRA adapter {lora_alias}")
return pinned

def get_pipeline(self, task: str, model_id_or_path: str, kwargs):
# define tokenizer or feature extractor as kwargs to load it the pipeline correctly
if task in {
Expand Down Expand Up @@ -496,41 +528,38 @@ def register_adapter(inputs: Input):
"""
Registers lora adapter with the model.
"""
if not is_rolling_batch_enabled(_service.hf_configs.rolling_batch):
raise NotImplementedError(
"LoRA adapter API is only supported for rolling batch.")

adapter_name = inputs.get_property("name")
adapter_alias = inputs.get_property("alias") or adapter_name
adapter_path = inputs.get_property("src")
adapter_load = inputs.get_as_string(
"load").lower() == "true" if inputs.contains_key("load") else True
adapter_pin = inputs.get_as_string(
"pin").lower() == "true" if inputs.contains_key("pin") else False

added = False
loaded = False
try:
if not os.path.exists(adapter_path):
raise ValueError(
f"Only local LoRA models are supported. {adapter_path} is not a valid path"
)

added = _service.rolling_batch.add_lora(adapter_name, adapter_path)
if not added:
raise RuntimeError(
f"Failed to register LoRA adapter {adapter_alias}")
if not adapter_load and adapter_pin:
raise ValueError("Can not set load to false and pin to true")

if adapter_pin:
pinned = _service.rolling_batch.pin_lora(adapter_name)
if not pinned:
raise RuntimeError(
f"Failed to pin LoRA adapter {adapter_alias}")
if adapter_load:
loaded = _service.add_lora(adapter_name, adapter_alias,
adapter_path)

if adapter_pin:
_service.pin_lora(adapter_name, adapter_alias)
_service.adapter_registry[adapter_name] = inputs
except Exception as e:
if added:
logging.debug(f"Failed to register adapter: {e}", exc_info=True)
if loaded:
logging.info(
f"LoRA adapter {adapter_alias} was successfully added, but failed to pin, removing ..."
f"LoRA adapter {adapter_alias} was successfully loaded, but failed to pin, unloading ..."
)
_service.rolling_batch.remove_lora(adapter_name)
_service.remove_lora(adapter_name, adapter_alias)
if any(msg in str(e)
for msg in ("No free lora slots",
"greater than the number of GPU LoRA slots")):
Expand All @@ -546,39 +575,45 @@ def update_adapter(inputs: Input):
"""
Updates lora adapter with the model.
"""
if not is_rolling_batch_enabled(_service.hf_configs.rolling_batch):
raise NotImplementedError(
"LoRA adapter API is only supported for rolling batch.")

adapter_name = inputs.get_property("name")
adapter_alias = inputs.get_property("alias") or adapter_name
adapter_path = inputs.get_property("src")
adapter_load = inputs.get_as_string(
"load").lower() == "true" if inputs.contains_key("load") else True
adapter_pin = inputs.get_as_string(
"pin").lower() == "true" if inputs.contains_key("pin") else False

if adapter_name not in _service.adapter_registry:
raise ValueError(f"Adapter {adapter_alias} not registered.")

try:
if not adapter_load and adapter_pin:
raise ValueError("Can not set load to false and pin to true")

old_adapter = _service.adapter_registry[adapter_name]
if old_adapter.get_property("src") and old_adapter.get_property(
"src") != adapter_path:
old_adapter_path = old_adapter.get_property("src")
if adapter_path != old_adapter_path:
raise NotImplementedError(
f"Updating adapter path is not supported.")

old_adapter_pin = inputs.get_as_string(
"pin").lower() == "true" if inputs.contains_key("pin") else False
if old_adapter_pin and not adapter_pin:
raise NotImplementedError(f"Unpin adapter is not supported.")

if adapter_pin:
pinned = _service.rolling_batch.pin_lora(adapter_name)
if not pinned:
raise RuntimeError(
f"Failed to pin LoRA adapter {adapter_alias}")
old_adapter_load = old_adapter.get_as_string("load").lower(
) == "true" if old_adapter.contains_key("load") else True
if adapter_load != old_adapter_load:
if adapter_load:
_service.add_lora(adapter_name, adapter_alias, adapter_path)
else:
_service.remove_lora(adapter_name, adapter_alias)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if adapter_load is false, why are we removing the adapter? Should this just be a noop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to have an unload option.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so we support unloading only for unpinned adapters

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes


old_adapter_pin = old_adapter.get_as_string("pin").lower(
) == "true" if old_adapter.contains_key("pin") else False
if adapter_pin != old_adapter_pin:
if adapter_pin:
_service.pin_lora(adapter_name, adapter_alias)
else:
raise NotImplementedError(f"Unpin adapter is not supported.")
_service.adapter_registry[adapter_name] = inputs
except Exception as e:
logging.debug(f"Failed to update adapter: {e}", exc_info=True)
if any(msg in str(e)
for msg in ("No free lora slots",
"greater than the number of GPU LoRA slots")):
Expand All @@ -593,25 +628,17 @@ def unregister_adapter(inputs: Input):
"""
Unregisters lora adapter from the model.
"""
if not is_rolling_batch_enabled(_service.hf_configs.rolling_batch):
raise NotImplementedError(
"LoRA adapter API is only supported for rolling batch.")

adapter_name = inputs.get_property("name")
adapter_alias = inputs.get_property("alias") or adapter_name

if adapter_name not in _service.adapter_registry:
raise ValueError(f"Adapter {adapter_alias} not registered.")

try:
removed = _service.rolling_batch.remove_lora(adapter_name)
if not removed:
logging.info(
f"Remove LoRA adapter {adapter_alias} returned false, the adapter may have already been evicted."
)

_service.remove_lora(adapter_name, adapter_alias)
del _service.adapter_registry[adapter_name]
except Exception as e:
logging.debug(f"Failed to unregister adapter: {e}", exc_info=True)
return Output().error("remove_adapter_error", message=str(e))

logging.info(f"Unregistered adapter {adapter_alias} successfully")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,5 +283,5 @@ def pin_lora(self, lora_name):
# 2) An adapter is not evicted, call add_lora() is not necessary.
# But since whether an adapter is evicted is not exposed outside of engine,
# and add_lora() in this case will take negligible time, we will still call add_lora().
self.engine.add_lora(lora_request)
return self.engine.pin_lora(lora_request.lora_int_id)
loaded = self.engine.add_lora(lora_request)
return loaded and self.engine.pin_lora(lora_request.lora_int_id)
Original file line number Diff line number Diff line change
Expand Up @@ -198,5 +198,5 @@ def pin_lora(self, lora_name):
# 2) An adapter is not evicted, call add_lora() is not necessary.
# But since whether an adapter is evicted is not exposed outside of engine,
# and add_lora() in this case will take negligible time, we will still call add_lora().
self.engine.add_lora(lora_request)
return self.engine.pin_lora(lora_request.lora_int_id)
loaded = self.engine.add_lora(lora_request)
return loaded and self.engine.pin_lora(lora_request.lora_int_id)
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ private void handleListAdapters(
for (int i = pagination.getPageToken(); i < pagination.getLast(); ++i) {
String adapterName = keys.get(i);
Adapter<Input, Output> adapter = modelInfo.getAdapter(adapterName);
list.addAdapter(adapter.getName(), adapter.getSrc(), adapter.isPin());
list.addAdapter(adapter.getName(), adapter.getSrc(), adapter.isLoad(), adapter.isPin());
}

NettyUtils.sendJsonResponse(ctx, list);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
public class DescribeAdapterResponse {
private String name;
private String src;
private boolean load;
private boolean pin;

/**
Expand All @@ -30,6 +31,7 @@ public class DescribeAdapterResponse {
public DescribeAdapterResponse(Adapter<Input, Output> adapter) {
this.name = adapter.getName();
this.src = adapter.getSrc();
this.load = adapter.isLoad();
this.pin = adapter.isPin();
}

Expand All @@ -51,6 +53,15 @@ public String getSrc() {
return src;
}

/**
* Returns whether to load the adapter weights.
*
* @return whether to load the adapter weights
*/
public boolean isLoad() {
return load;
}

/**
* Returns whether to pin the adapter.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,24 @@ public List<AdapterItem> getAdapters() {
*
* @param name the adapter name
* @param src the adapter source
* @param load whether to load the adapter weights
* @param pin whether to pin the adapter
*/
public void addAdapter(String name, String src, boolean pin) {
adapters.add(new AdapterItem(name, src, pin));
public void addAdapter(String name, String src, boolean load, boolean pin) {
adapters.add(new AdapterItem(name, src, load, pin));
}

/** A class that holds the adapter response. */
public static final class AdapterItem {
private String name;
private String src;
private boolean load;
private boolean pin;

private AdapterItem(String name, String src, boolean pin) {
private AdapterItem(String name, String src, boolean load, boolean pin) {
this.name = name;
this.src = src;
this.load = load;
this.pin = pin;
}

Expand Down
63 changes: 53 additions & 10 deletions serving/src/test/java/ai/djl/serving/ModelServerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1002,12 +1002,25 @@ private void testRegisterAdapter(Channel channel, boolean registerModel, boolean
testAdapterMissing();

String strModelPrefix = modelPrefix ? "/models/adaptecho" : "";
url = strModelPrefix + "/adapters?name=adaptable&src=src&echooption=opt";
String adapterName = "adaptable";
url = strModelPrefix + "/adapters?name=" + adapterName + "&src=src&echooption=opt";
request(channel, HttpMethod.POST, url);
assertHttpOk();

StatusResponse statusResp = JsonUtils.GSON.fromJson(result, StatusResponse.class);
assertEquals(statusResp.getStatus(), "Adapter adaptable registered");
assertEquals(statusResp.getStatus(), "Adapter " + adapterName + " registered");

// Assert adapter registered
url = strModelPrefix + "/adapters/" + adapterName;
request(channel, HttpMethod.GET, url);
assertHttpOk();

DescribeAdapterResponse resp =
JsonUtils.GSON.fromJson(result, DescribeAdapterResponse.class);
assertEquals(resp.getName(), adapterName);
assertEquals(resp.getSrc(), "src");
assertTrue(resp.isLoad());
assertFalse(resp.isPin());
}

private void testRegisterAdapterConflict() throws InterruptedException {
Expand Down Expand Up @@ -1070,17 +1083,27 @@ private void testRegisterAdapterHandlerError() throws InterruptedException {
private void testUpdateAdapter(Channel channel, boolean modelPrefix)
throws InterruptedException {
logTestFunction();

String adapterName = "adaptable";
String strModelPrefix = modelPrefix ? "/models/adaptecho" : "";
String url = strModelPrefix + "/adapters/adaptable/update?src=src1";
String url = strModelPrefix + "/adapters/" + adapterName + "/update?pin=true";
request(channel, HttpMethod.POST, url);
assertHttpOk();

url = strModelPrefix + "/adapters/adaptable/update?src=src";
request(channel, HttpMethod.POST, url);
StatusResponse statusResp = JsonUtils.GSON.fromJson(result, StatusResponse.class);
assertEquals(statusResp.getStatus(), "Adapter " + adapterName + " updated");

// Assert adapter updated
url = strModelPrefix + "/adapters/" + adapterName;
request(channel, HttpMethod.GET, url);
assertHttpOk();

StatusResponse statusResp = JsonUtils.GSON.fromJson(result, StatusResponse.class);
assertEquals(statusResp.getStatus(), "Adapter adaptable updated");
DescribeAdapterResponse resp =
JsonUtils.GSON.fromJson(result, DescribeAdapterResponse.class);
assertEquals(resp.getName(), adapterName);
assertEquals(resp.getSrc(), "src");
assertTrue(resp.isLoad());
assertTrue(resp.isPin());
}

private void testUpdateAdapterModelNotFound() throws InterruptedException {
Expand Down Expand Up @@ -1117,7 +1140,11 @@ private void testUpdateAdapterHandlerError() throws InterruptedException {
String modelName = "adaptecho";
String adapterName = "adaptable";
String strModelPrefix = "/models/" + modelName;
String url = strModelPrefix + "/adapters/" + adapterName + "/update?src=src1&error=true";
String url =
strModelPrefix
+ "/adapters/"
+ adapterName
+ "/update?src=src1&load=false&error=true";
request(channel, HttpMethod.POST, url);
channel.closeFuture().sync();
channel.close().sync();
Expand All @@ -1135,6 +1162,8 @@ private void testUpdateAdapterHandlerError() throws InterruptedException {
JsonUtils.GSON.fromJson(result, DescribeAdapterResponse.class);
assertEquals(resp.getName(), adapterName);
assertEquals(resp.getSrc(), "src");
assertTrue(resp.isLoad());
assertFalse(resp.isPin());
}

private void testAdapterMissing() throws InterruptedException {
Expand Down Expand Up @@ -1290,6 +1319,8 @@ private void testDescribeAdapter(Channel channel, boolean modelPrefix)
JsonUtils.GSON.fromJson(result, DescribeAdapterResponse.class);
assertEquals(resp.getName(), "adaptable");
assertEquals(resp.getSrc(), "src");
assertTrue(resp.isLoad());
assertTrue(resp.isPin());
}

private void testDescribeAdapterModelNotFound() throws InterruptedException {
Expand Down Expand Up @@ -1337,12 +1368,24 @@ private void testUnregisterAdapter(Channel channel, boolean modelPrefix)
throws InterruptedException {
logTestFunction();
String strModelPrefix = modelPrefix ? "/models/adaptecho" : "";
String url = strModelPrefix + "/adapters/adaptable";
String adapterName = "adaptable";
String url = strModelPrefix + "/adapters/" + adapterName;
request(channel, HttpMethod.DELETE, url);
assertHttpOk();

StatusResponse statusResp = JsonUtils.GSON.fromJson(result, StatusResponse.class);
assertEquals(statusResp.getStatus(), "Adapter adaptable unregistered");
assertEquals(statusResp.getStatus(), "Adapter " + adapterName + " unregistered");

// Assert adapter unregistered
channel = connect(Connector.ConnectorType.MANAGEMENT);
assertNotNull(channel);

url = strModelPrefix + "/adapters";
request(channel, HttpMethod.GET, url);
assertHttpOk();

ListAdaptersResponse resp = JsonUtils.GSON.fromJson(result, ListAdaptersResponse.class);
assertFalse(resp.getAdapters().stream().anyMatch(a -> adapterName.equals(a.getName())));
}

private void testUnregisterAdapterModelNotFound() throws InterruptedException {
Expand Down
Loading
Loading