Skip to content

Commit

Permalink
[Bugfix] Fix LoRA weight sharding (vllm-project#10450)
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
  • Loading branch information
jeejeelee and DarkLight1337 authored Nov 24, 2024
1 parent 17d8fc1 commit 1700c54
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 168 deletions.
13 changes: 9 additions & 4 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ steps:
source_file_dependencies:
- vllm/lora
- tests/lora
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore lora/test_long_context.py lora/test_chatglm3_tp.py lora/test_llama_tp.py
parallelism: 4

- label: "PyTorch Fullgraph Smoke Test" # 9min
Expand Down Expand Up @@ -475,18 +475,23 @@ steps:
- pytest -v -s distributed/test_pp_cudagraph.py
- pytest -v -s distributed/test_pipeline_parallel.py

- label: LoRA Long Context (Distributed) # 11min
# This test runs llama 13B, so it is required to run on 4 GPUs.
- label: LoRA TP Test (Distributed)
num_gpus: 4
soft_fail: true
source_file_dependencies:
- vllm/lora
- tests/lora/test_long_context
- tests/lora
commands:
# FIXIT: find out which code initialize cuda before running the test
# before the fix, we need to use spawn to test it
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
# This test runs llama 13B, so it is required to run on 4 GPUs.
- pytest -v -s -x lora/test_long_context.py
# There is some Tensor Parallelism related processing logic in LoRA that
# requires multi-GPU testing for validation.
- pytest -v -s -x lora/test_chatglm3_tp.py
- pytest -v -s -x lora/test_llama_tp.py


- label: Weight Loading Multiple GPU Test # 33min
working_dir: "/vllm-workspace/tests"
Expand Down
63 changes: 53 additions & 10 deletions tests/lora/test_chatglm3.py → tests/lora/test_chatglm3_tp.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from typing import List

import vllm
from tests.utils import fork_new_process_for_each_test
from vllm.lora.request import LoRARequest

from ..utils import multi_gpu_test

MODEL_PATH = "THUDM/chatglm3-6b"

PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501

EXPECTED_LORA_OUTPUT = [
"SELECT count(*) FROM singer",
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501
"SELECT name , country , age FROM singer ORDER BY age",
]


def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [
Expand All @@ -20,7 +29,6 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
"Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501
),
]
print(prompts)
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32)
outputs = llm.generate(
prompts,
Expand All @@ -37,23 +45,58 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@fork_new_process_for_each_test
def test_chatglm3_lora(chatglm3_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=1,
trust_remote_code=True)

expected_lora_output = [
"SELECT count(*) FROM singer",
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501
"SELECT name , country , age FROM singer ORDER BY age",
]
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
output2 = do_sample(llm, chatglm3_lora_files, lora_id=2)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]


@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_chatglm3_lora_tp4(chatglm3_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=False)

output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
output2 = do_sample(llm, chatglm3_lora_files, lora_id=2)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]


@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=True)
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
for i in range(len(expected_lora_output)):
assert output1[i] == expected_lora_output[i]
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
output2 = do_sample(llm, chatglm3_lora_files, lora_id=2)
for i in range(len(expected_lora_output)):
assert output2[i] == expected_lora_output[i]
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]
146 changes: 0 additions & 146 deletions tests/lora/test_llama.py

This file was deleted.

Loading

0 comments on commit 1700c54

Please sign in to comment.