forked from axolotl-ai-cloud/axolotl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
deepseekv2 liger support (axolotl-ai-cloud#1878)
* deepseekv2 liger support * add comment * add missing impl
- Loading branch information
Showing
2 changed files
with
153 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
""" | ||
DeepseekV2 model with LigerFusedLinearCrossEntropyLoss | ||
""" | ||
# pylint: disable=duplicate-code | ||
|
||
from typing import List, Optional, Tuple, Union | ||
|
||
import torch | ||
from liger_kernel.transformers.fused_linear_cross_entropy import ( | ||
LigerFusedLinearCrossEntropyLoss, | ||
) | ||
from torch.nn import CrossEntropyLoss | ||
from transformers.modeling_outputs import CausalLMOutputWithPast | ||
|
||
|
||
# @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) | ||
# @replace_return_docstrings( | ||
# output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC | ||
# ) | ||
def lce_forward( | ||
self, | ||
input_ids: torch.LongTensor = None, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
past_key_values: Optional[List[torch.FloatTensor]] = None, | ||
inputs_embeds: Optional[torch.FloatTensor] = None, | ||
labels: Optional[torch.LongTensor] = None, | ||
use_cache: Optional[bool] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
) -> Union[Tuple, CausalLMOutputWithPast]: | ||
r""" | ||
Args: | ||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | ||
Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., | ||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | ||
(masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. | ||
Returns: | ||
Example: | ||
```python | ||
>>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM | ||
>>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) | ||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) | ||
>>> prompt = "Hey, are you conscious? Can you talk to me?" | ||
>>> inputs = tokenizer(prompt, return_tensors="pt") | ||
>>> # Generate | ||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) | ||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | ||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." | ||
```""" | ||
output_attentions = ( | ||
output_attentions | ||
if output_attentions is not None | ||
else self.config.output_attentions | ||
) | ||
output_hidden_states = ( | ||
output_hidden_states | ||
if output_hidden_states is not None | ||
else self.config.output_hidden_states | ||
) | ||
return_dict = ( | ||
return_dict if return_dict is not None else self.config.use_return_dict | ||
) | ||
|
||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | ||
outputs = self.model( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
position_ids=position_ids, | ||
past_key_values=past_key_values, | ||
inputs_embeds=inputs_embeds, | ||
use_cache=use_cache, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
) | ||
|
||
hidden_states = outputs[0] | ||
|
||
loss = None | ||
logits = None | ||
|
||
if self.training: | ||
shift_hidden_states = hidden_states[..., :-1, :].contiguous() | ||
shift_labels = labels[..., 1:].contiguous() | ||
|
||
# flatten tokens | ||
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) | ||
shift_labels = shift_labels.view(-1) | ||
|
||
lce = LigerFusedLinearCrossEntropyLoss() | ||
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) | ||
else: | ||
logits = self.lm_head(hidden_states) | ||
logits = logits.float() | ||
|
||
loss = None | ||
if labels is not None: | ||
# Shift so that tokens < n predict n | ||
shift_logits = logits[..., :-1, :].contiguous() | ||
shift_labels = labels[..., 1:].contiguous() | ||
# Flatten the tokens | ||
loss_fct = CrossEntropyLoss() | ||
shift_logits = shift_logits.view(-1, self.config.vocab_size) | ||
shift_labels = shift_labels.view(-1) | ||
# Enable model parallelism | ||
shift_labels = shift_labels.to(shift_logits.device) | ||
loss = loss_fct(shift_logits, shift_labels) | ||
|
||
if not return_dict: | ||
output = (logits,) + outputs[1:] | ||
return (loss,) + output if loss is not None else output | ||
|
||
return CausalLMOutputWithPast( | ||
loss=loss, | ||
logits=logits, | ||
past_key_values=outputs.past_key_values, | ||
hidden_states=outputs.hidden_states, | ||
attentions=outputs.attentions, | ||
) |