Skip to content

Commit

Permalink
[Model] Support GGUF models newly added in transformers 4.46.0 (vll…
Browse files Browse the repository at this point in the history
…m-project#9685)

Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
  • Loading branch information
Isotr0py and DarkLight1337 authored Jan 13, 2025
1 parent 9597a09 commit d14e98d
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 87 deletions.
22 changes: 8 additions & 14 deletions examples/offline_inference/gguf_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,20 @@
from vllm import LLM, SamplingParams


def run_gguf_inference(model_path):
PROMPT_TEMPLATE = "<|system|>\n{system_message}</s>\n<|user|>\n{prompt}</s>\n<|assistant|>\n" # noqa: E501
system_message = "You are a friendly chatbot who always responds in the style of a pirate." # noqa: E501
def run_gguf_inference(model_path, tokenizer):
# Sample prompts.
prompts = [
"How many helicopters can a human eat in one sitting?",
"What's the future of AI?",
]
prompts = [
PROMPT_TEMPLATE.format(system_message=system_message, prompt=prompt)
for prompt in prompts
]
prompts = [[{"role": "user", "content": prompt}] for prompt in prompts]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0, max_tokens=128)

# Create an LLM.
llm = LLM(model=model_path,
tokenizer="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
gpu_memory_utilization=0.95)
llm = LLM(model=model_path, tokenizer=tokenizer)

outputs = llm.generate(prompts, sampling_params)
outputs = llm.chat(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
Expand All @@ -32,7 +25,8 @@ def run_gguf_inference(model_path):


if __name__ == "__main__":
repo_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
filename = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
repo_id = "bartowski/Phi-3-medium-4k-instruct-GGUF"
filename = "Phi-3-medium-4k-instruct-IQ2_M.gguf"
tokenizer = "microsoft/Phi-3-medium-4k-instruct"
model = hf_hub_download(repo_id, filename=filename)
run_gguf_inference(model)
run_gguf_inference(model, tokenizer)
105 changes: 74 additions & 31 deletions tests/models/decoder_only/language/test_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,45 +4,90 @@
"""

import os
from typing import List, NamedTuple, Type

import pytest
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer

from tests.quantization.utils import is_quant_method_supported

from ....conftest import VllmRunner
from ...utils import check_logprobs_close

os.environ["TOKENIZERS_PARALLELISM"] = "true"

MAX_MODEL_LEN = 1024


class GGUFTestConfig(NamedTuple):
original_model: str
gguf_repo: str
gguf_filename: str

@property
def gguf_model(self):
return hf_hub_download(self.gguf_repo, filename=self.gguf_filename)


LLAMA_CONFIG = GGUFTestConfig(
original_model="meta-llama/Llama-3.2-1B-Instruct",
gguf_repo="bartowski/Llama-3.2-1B-Instruct-GGUF",
gguf_filename="Llama-3.2-1B-Instruct-IQ4_XS.gguf",
)

QWEN2_CONFIG = GGUFTestConfig(
original_model="Qwen/Qwen2.5-1.5B-Instruct",
gguf_repo="Qwen/Qwen2.5-1.5B-Instruct-GGUF",
gguf_filename="qwen2.5-1.5b-instruct-q6_k.gguf",
)

PHI3_CONFIG = GGUFTestConfig(
original_model="microsoft/Phi-3.5-mini-instruct",
gguf_repo="bartowski/Phi-3.5-mini-instruct-GGUF",
gguf_filename="Phi-3.5-mini-instruct-IQ4_XS.gguf",
)

GPT2_CONFIG = GGUFTestConfig(
original_model="openai-community/gpt2-large",
gguf_repo="QuantFactory/gpt2-large-GGUF",
gguf_filename="gpt2-large.Q4_K_M.gguf",
)

STABLELM_CONFIG = GGUFTestConfig(
original_model="stabilityai/stablelm-3b-4e1t",
gguf_repo="afrideva/stablelm-3b-4e1t-GGUF",
gguf_filename="stablelm-3b-4e1t.q4_k_m.gguf",
)

STARCODER_CONFIG = GGUFTestConfig(
original_model="bigcode/starcoder2-3b",
gguf_repo="QuantFactory/starcoder2-3b-GGUF",
gguf_filename="starcoder2-3b.Q6_K.gguf",
)

MODELS = [
LLAMA_CONFIG,
QWEN2_CONFIG,
PHI3_CONFIG,
GPT2_CONFIG,
STABLELM_CONFIG,
# STARCODER_CONFIG, # broken
]


@pytest.mark.skipif(not is_quant_method_supported("gguf"),
reason="gguf is not supported on this GPU type.")
@pytest.mark.parametrize(("original_model", "gguf_id", "gguf_path"), [
("meta-llama/Llama-3.2-1B-Instruct",
"bartowski/Llama-3.2-1B-Instruct-GGUF",
"Llama-3.2-1B-Instruct-Q4_K_M.gguf"),
("meta-llama/Llama-3.2-1B-Instruct",
"bartowski/Llama-3.2-1B-Instruct-GGUF",
"Llama-3.2-1B-Instruct-IQ4_XS.gguf"),
("Qwen/Qwen2-1.5B-Instruct", "Qwen/Qwen2-1.5B-Instruct-GGUF",
"qwen2-1_5b-instruct-q4_k_m.gguf"),
("Qwen/Qwen2-1.5B-Instruct", "legraphista/Qwen2-1.5B-Instruct-IMat-GGUF",
"Qwen2-1.5B-Instruct.IQ4_XS.gguf"),
])
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tp_size", [1, 2])
def test_models(
num_gpus_available,
vllm_runner,
example_prompts,
original_model,
gguf_id,
gguf_path,
num_gpus_available: int,
vllm_runner: Type[VllmRunner],
example_prompts: List[str],
model: GGUFTestConfig,
dtype: str,
max_tokens: int,
num_logprobs: int,
Expand All @@ -51,28 +96,26 @@ def test_models(
if num_gpus_available < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")

gguf_model = hf_hub_download(gguf_id, filename=gguf_path)

tokenizer = AutoTokenizer.from_pretrained(original_model)
messages = [[{
'role': 'user',
'content': prompt
}] for prompt in example_prompts]
example_prompts = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
tokenizer = AutoTokenizer.from_pretrained(model.original_model)
if tokenizer.chat_template is not None:
messages = [[{
'role': 'user',
'content': prompt
}] for prompt in example_prompts]
example_prompts = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)

# Run unquantized model.
with vllm_runner(model_name=original_model,
with vllm_runner(model_name=model.original_model,
dtype=dtype,
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=tp_size) as original_model:

original_outputs = original_model.generate_greedy_logprobs(
example_prompts[:-1], max_tokens, num_logprobs)

# Run gguf model.
with vllm_runner(model_name=gguf_model,
with vllm_runner(model_name=model.gguf_model,
tokenizer_name=model.original_model,
dtype=dtype,
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=tp_size) as gguf_model:
Expand Down
58 changes: 35 additions & 23 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,8 +447,14 @@ def weight_loader(self,
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
param.data[loaded_shard_id].copy_(loaded_weight)
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
if loaded_shard_id is not None:
param.data[loaded_shard_id].copy_(loaded_weight)
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
else:
param.shard_weight_type = {
i: loaded_weight.item()
for i, _ in enumerate(self.output_sizes)
}
return

if is_gguf_weight:
Expand All @@ -459,15 +465,15 @@ def weight_loader(self,
shard_size = loaded_weight.size(output_dim) // tp_size
start_idx = tp_rank * shard_size

loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)

param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
if len(param.data_container) == 2:
self.qweight = param.materialize_nested()
return
if loaded_shard_id is not None:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
if len(param.data_container) == 2:
self.qweight = param.materialize_nested()
return

param_data = param.data
output_dim = getattr(param, "output_dim", None)
Expand Down Expand Up @@ -811,10 +817,16 @@ def weight_loader(self,
# initialize GGUF param after we know the quantize type
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type and loaded_shard_id is not None:
if is_gguf_weight_type:
idx_map = {"q": 0, "k": 1, "v": 2}
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
if loaded_shard_id is not None:
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
else:
param.shard_weight_type = {
k: loaded_weight.item()
for k in idx_map
}
return

if is_gguf_weight:
Expand All @@ -825,15 +837,15 @@ def weight_loader(self,
shard_size = loaded_weight.size(output_dim) // tp_size
start_idx = tp_rank * shard_size

loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)

param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
if len(param.data_container) == 3:
self.qweight = param.materialize_nested()
return
if loaded_shard_id is not None:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
if len(param.data_container) == 3:
self.qweight = param.materialize_nested()
return

param_data = param.data
output_dim = getattr(param, "output_dim", None)
Expand Down
11 changes: 8 additions & 3 deletions vllm/model_executor/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
assert not config.scale_attn_by_inverse_layer_idx
assert not config.reorder_and_upcast_attn
self.embed_dim = config.hidden_size
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wte = VocabParallelEmbedding(config.vocab_size,
self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.wte")
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
Expand Down Expand Up @@ -259,7 +262,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.lm_head = self.transformer.wte
else:
self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size)
self.config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.lm_head")
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
Expand Down Expand Up @@ -304,7 +309,7 @@ def load_weights(self, weights: Iterable[Tuple[str,
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "lm_head.weight" in name:
if name.startswith("lm_head"):
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
continue
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ def __init__(
)

is_neox_style = True
if quant_config is not None and quant_config.get_name() == "gguf":
is_gguf = quant_config and quant_config.get_name() == "gguf"
if is_gguf and config.model_type == "llama":
is_neox_style = False

self.rotary_emb = get_rope(
Expand Down
Loading

0 comments on commit d14e98d

Please sign in to comment.