From 07c11cf4d4b9a913fa52142fe134849f1e25e393 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 10 Oct 2024 21:11:56 +0800 Subject: [PATCH] [Bugfix] Fix lm_head weights tying with lora for llama (#9227) --- .../model_executor/layers/vocab_parallel_embedding.py | 11 ++++++++++- vllm/model_executor/models/llama.py | 3 ++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index ef6d401be2070..b448557af13b3 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -443,7 +443,7 @@ def __init__(self, super().__init__(num_embeddings, embedding_dim, params_dtype, org_num_embeddings, padding_size, quant_config, prefix) - + self.quant_config = quant_config if bias: self.bias = Parameter( torch.empty(self.num_embeddings_per_partition, @@ -455,6 +455,15 @@ def __init__(self, else: self.register_parameter("bias", None) + def tie_weights(self, embed_tokens: VocabParallelEmbedding): + """Tie the weights with word embeddings.""" + # GGUF quantized embed_tokens. + if self.quant_config and self.quant_config.get_name() == "gguf": + return embed_tokens + else: + self.weight = embed_tokens.weight + return self + def forward(self, input_): del input_ raise RuntimeError("LMHead's weights should be used in the sampler.") diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 0589b581ff236..2a79a9edf2111 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -524,7 +524,8 @@ def __init__( quant_config=quant_config, ) if config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens + self.lm_head = self.lm_head.tie_weights( + self.model.embed_tokens) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,