Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lora merge #2962

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions demos/common/export_models/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
import argparse
import os
from openvino_tokenizers import convert_tokenizer, connect_models
from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM
from peft import PeftModel
#from peft.utils import PeftConfig
import torch
import jinja2
import json
import shutil
Expand All @@ -38,12 +41,14 @@ def add_common_arguments(parser):
subparsers = parser.add_subparsers(help='subcommand help', required=True, dest='task')
parser_text = subparsers.add_parser('text_generation', help='export model for chat and completion endpoints')
add_common_arguments(parser_text)
parser_text.add_argument('--kv_cache_precision', default=None, choices=["u8"], help='u8 or empty (model default). Reduced kv cache precision to u8 lowers the cache size consumption.', dest='kv_cache_precision')
parser_text.add_argument('--kv_cache_precision', default=None, choices=["u8", "fp32"], help='u8 or empty (model default). Reduced kv cache precision to u8 lowers the cache size consumption.', dest='kv_cache_precision')
parser_text.add_argument('--enable_prefix_caching', action='store_true', help='This algorithm is used to cache the prompt tokens.', dest='enable_prefix_caching')
parser_text.add_argument('--disable_dynamic_split_fuse', action='store_false', help='The maximum number of tokens that can be batched together.', dest='dynamic_split_fuse')
parser_text.add_argument('--max_num_batched_tokens', default=None, help='empty or integer. The maximum number of tokens that can be batched together.', dest='max_num_batched_tokens')
parser_text.add_argument('--max_num_seqs', default=None, help='256 by default. The maximum number of sequences that can be processed together.', dest='max_num_seqs')
parser_text.add_argument('--cache_size', default=10, type=int, help='cache size in GB', dest='cache_size')
parser_text.add_argument('--adapter',action='append', help='lora adapter in HF or a local folder with the adapter', dest='adapter')
parser_text.add_argument('--tokenizer', default=None, help='alternative tokenizer for the adapter', dest='tokenizer')
parser_embeddings = subparsers.add_parser('embeddings', help='export model for embeddings endpoint')
add_common_arguments(parser_embeddings)
parser_embeddings.add_argument('--skip_normalize', default=True, action='store_false', help='Skip normalize the embeddings.', dest='normalize')
Expand Down Expand Up @@ -244,22 +249,43 @@ def add_servable_to_config(config_path, mediapipe_name, base_path):
json.dump(config_data, config_file, indent=4)
print("Added servable to config file", config_path)

def export_text_generation_model(model_repository_path, source_model, model_name, precision, task_parameters, config_file_path):
def export_text_generation_model(model_repository_path, source_model, model_name, precision, task_parameters, config_file_path, adapter, adapter_tokenizer):
model_path = "./"
if os.path.isfile(os.path.join(source_model, 'openvino_model.xml')):
print("OV model is source folder. Skipping conversion.")
model_path = source_model
else: # assume HF model name or local pytorch model folder
llm_model_path = os.path.join(model_repository_path, model_name)
print("Exporting LLM model to ", llm_model_path)
tmp_folder = None
if not os.path.isdir(llm_model_path) or args['overwrite_models']:
optimum_command = "optimum-cli export openvino --disable-convert-tokenizer --model {} --weight-format {} --trust-remote-code {}".format(source_model, precision, llm_model_path)
if adapter is not None:
if len(adapter) > 1 and adapter_tokenizer is not None:
raise ValueError("Only one adapter can be used with a custom tokenizer")
if adapter_tokenizer is None:
adapter_tokenizer = source_model
tmp_folder = tempfile.mkdtemp()
print("Loading model with adapter")
HFmodel = AutoModelForCausalLM.from_pretrained(source_model, trust_remote_code=True)
for adapteri in adapter:
print("Loading adapter", adapteri)
HFmodel.resize_token_embeddings(len(AutoTokenizer.from_pretrained(adapter_tokenizer)), mean_resizing=False)
HFmodel = PeftModel.from_pretrained(HFmodel, adapteri)
print("Merging model with adapters")
HFmodel = HFmodel.merge_and_unload()
HFmodel.save_pretrained(tmp_folder)
tokenizer = AutoTokenizer.from_pretrained(adapter_tokenizer, trust_remote_code=True)
tokenizer.save_pretrained(tmp_folder)
source_model = tmp_folder
print("Exporting LLM model to ", llm_model_path)
optimum_command = "optimum-cli export openvino --task text-generation-with-past --disable-convert-tokenizer --model {} --weight-format {} --trust-remote-code {}".format(source_model, precision, llm_model_path)
if os.system(optimum_command):
raise ValueError("Failed to export llm model", source_model)
print("Exporting tokenizer to ", llm_model_path)
convert_tokenizer_command = "convert_tokenizer --utf8_replace_mode replace --with-detokenizer --skip-special-tokens --streaming-detokenizer -o {} {}".format(llm_model_path, source_model)
if (os.system(convert_tokenizer_command)):
raise ValueError("Failed to export tokenizer model", source_model)
if adapter is not None:
shutil.rmtree(tmp_folder)
os.makedirs(os.path.join(model_repository_path, model_name), exist_ok=True)
gtemplate = jinja2.Environment(loader=jinja2.BaseLoader).from_string(text_generation_graph_template)
graph_content = gtemplate.render(tokenizer_model="{}_tokenizer_model".format(model_name), embeddings_model="{}_embeddings_model".format(model_name), model_path=model_path, **task_parameters)
Expand Down Expand Up @@ -368,7 +394,7 @@ def export_rerank_model(model_repository_path, source_model, model_name, precisi
print("template params:",template_parameters)

if args['task'] == 'text_generation':
export_text_generation_model(args['model_repository_path'], args['source_model'], args['model_name'], args['precision'], template_parameters, args['config_file_path'])
export_text_generation_model(args['model_repository_path'], args['source_model'], args['model_name'], args['precision'], template_parameters, args['config_file_path'], args['adapter'], args['tokenizer'])

elif args['task'] == 'embeddings':
export_embeddings_model(args['model_repository_path'], args['source_model'], args['model_name'], args['precision'], template_parameters, str(args['version']), args['config_file_path'])
Expand Down
1 change: 1 addition & 0 deletions demos/common/export_models/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ sentence_transformers==3.1.1
openai
transformers<4.48
einops
peft>=0.14.0
4 changes: 4 additions & 0 deletions demos/continuous_batching/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,10 @@ Check this simple [text generation scaling demo](https://github.com/openvinotool
Check the [guide of using lm-evaluation-harness](https://github.com/openvinotoolkit/model_server/blob/main/demos/continuous_batching/accuracy/README.md)


## Using LoRA adapters with LLM models

Check this guide [using lora adapter for text generation](./lora_adapters/README.md)

## References
- [Chat Completions API](../../docs/model_server_rest_api_chat.md)
- [Completions API](../../docs/model_server_rest_api_completions.md)
Expand Down
124 changes: 124 additions & 0 deletions demos/continuous_batching/lora_adapters/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Using LoRA adapters for text generations {#ovms_demos_continuous_batching_lora}

[LoRA adapters](https://arxiv.org/pdf/2106.09685) can be used to efficiently fine-tune LLM models. There are two methods for employing the adapters for serving:
- merging the base model with the adapters and exporting the combined final model
- adding the adapters in runtime in the server deployment along with the base model

## Merging adapters with the main model

In this scenario, the base model is merged with one or more adapters. It can be done using `Peft` python library. Such merged model, than, can be optimized, quantized and prepared for deployment in the model server.

![merged](merged.png)

The clients will be calling the final shared model name.

Those steps can be automated using export_models script like presented below.

Install python dependencies:
```console
git clone https://github.com/openvinotoolkit/model_server.git
cd model_server
pip3 install -U -r demos/common/export_models/requirements.txt
cd demos/continuous_batching/lora_adapters
```

Export base model and an adapter into a merge model. When targeted on CPU:

```console
python export_model.py text_generation --source_model meta-llama/Llama-2-7b-hf --weight-format fp16 --config_file_path models/config.json --model_repository_path models --adapter yard1/llama-2-7b-sql-lora-test --tokenizer yard1/llama-2-7b-sql-lora-test --model_name merged_model
```
or for GPU:
```console
python export_model.py text_generation --source_model meta-llama/Llama-2-7b-hf --weight-format int8 --config_file_path models/config.json --model_repository_path models --adapter yard1/llama-2-7b-sql-lora-test --tokenizer yard1/llama-2-7b-sql-lora-test --model_name merged_model --target_device GPU --overwrite_models
```

For comparing the results, let's export also the base model alone:
```console
python export_model.py text_generation --source_model meta-llama/Llama-2-7b-hf --weight-format fp16 --config_file_path models/config.json --model_repository_path models --model_name base_model
```

> **Note:** `tokenizer` parameter is needed only the the adapter is using different tokenizer from the base model.

Such exported models can be used for deployment in serving.

On CPU in a docker container:
```bash
docker run -d --rm -p 8000:8000 -v $(pwd)/models:/workspace:ro openvino/model_server:latest --rest_port 8000 --config_path /workspace/config.json
```

On GPU in a docker container:
```bash
docker run -d --rm -p 8000:8000 --device /dev/dri --group-add=$(stat -c "%g" /dev/dri/render* | head -n 1) -v $(pwd)/models:/workspace:ro openvino/model_server:latest-gpu --rest_port 8000 --config_path /workspace/config.json
```

On baremetal after installation of the binary package:
```console
ovms --rest_port 8000 --config_path models/config.json
```

Now, we can test the merge model from the client:

```console
curl http://localhost:8000/v3/completions -H "Content-Type: application/json" \
-d '{
"model": "merged_model",
"prompt": "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]"
}' | jq
```
```json
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"text": " Write a SQL query to answer the question based on the table schema. context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR) question: name the icao for lilongwe international airport [/assistant] SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' \n\n</s>"
}
],
"created": 1736933735,
"model": "merged_model",
"object": "text_completion",
"usage": {
"prompt_tokens": 64,
"completion_tokens": 82,
"total_tokens": 146
}
}

```
The results are different when calling the base model:

```console
curl http://localhost:8000/v3/completions -H "Content-Type: application/json" \
-d '{
"model": "base_model",
"prompt": "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]"
}' | jq
```
```json
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"text": "\n\n Answer: lilongwe\n\n[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for mwanza international airport [/user] [assistant]\n\n Answer: mwanza\n\n[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for namibia [/user] [assistant]\n\n Answer: namibia\n\n[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for"
}
],
"created": 1736933826,
"model": "base_model",
"object": "text_completion",
"usage": {
"prompt_tokens": 64,
"completion_tokens": 200,
"total_tokens": 264
}
}

```
> **Note:** The results might diverge for every call especially for temperature > 0. Be aware that the adapter above is for testing purposes.


## Adding the adapters in runtime

TBD
59 changes: 59 additions & 0 deletions demos/continuous_batching/lora_adapters/hf_compare_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
device = "cpu"
# Make prompts
prompt = [
'''"[user] Write a SQL query to answer the question based on the table schema.\n
\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n
\n question: Name the ICAO for lilongwe international airport [/user] [assistant]''']

# Load Models
base_model = "meta-llama/Llama-2-7b-hf"
peft_adapter = "yard1/llama-2-7b-sql-lora-test"

tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(base_model)


def generate_base(model, prompt, tokenizer):
print("Generating results")
tokens = tokenizer(prompt, return_tensors='pt').to(device)
res = model.generate(**tokens, max_new_tokens=100)
res_sentences = [tokenizer.decode(i) for i in res]
print("Results:",res_sentences)

def merge_models(model, adapter):
print("Merging model with adapter")
adapter_tokenizer = AutoTokenizer.from_pretrained(adapter)
model.resize_token_embeddings(len(adapter_tokenizer), mean_resizing=False)
model = PeftModel.from_pretrained(model, adapter)
model = model.eval()
model = model.to(device)
return model, adapter_tokenizer

print("BASE MODEL")
generate_base(model, prompt, tokenizer)
model, adapter_tokenizer = merge_models(model, peft_adapter)
print("MERGED MODEL")
generate_base(model, prompt, adapter_tokenizer)



Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.