Skip to content

Commit

Permalink
[Bugfix] Fix lm_head weights tying with lora for llama (vllm-project#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Isotr0py authored Oct 10, 2024
1 parent f3a507f commit 07c11cf
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
11 changes: 10 additions & 1 deletion vllm/model_executor/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def __init__(self,
super().__init__(num_embeddings, embedding_dim, params_dtype,
org_num_embeddings, padding_size, quant_config,
prefix)

self.quant_config = quant_config
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,
Expand All @@ -455,6 +455,15 @@ def __init__(self,
else:
self.register_parameter("bias", None)

def tie_weights(self, embed_tokens: VocabParallelEmbedding):
"""Tie the weights with word embeddings."""
# GGUF quantized embed_tokens.
if self.quant_config and self.quant_config.get_name() == "gguf":
return embed_tokens
else:
self.weight = embed_tokens.weight
return self

def forward(self, input_):
del input_
raise RuntimeError("LMHead's weights should be used in the sampler.")
3 changes: 2 additions & 1 deletion vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,8 @@ def __init__(
quant_config=quant_config,
)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
self.lm_head = self.lm_head.tie_weights(
self.model.embed_tokens)

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
Expand Down

0 comments on commit 07c11cf

Please sign in to comment.