Skip to content

Commit

Permalink
convert modular
Browse files Browse the repository at this point in the history
  • Loading branch information
weak-kajuma committed Dec 21, 2024
1 parent 484a493 commit 0ce2023
Showing 1 changed file with 29 additions and 91 deletions.
120 changes: 29 additions & 91 deletions src/transformers/models/diffllama/modeling_diffllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
LossKwargs,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
Expand All @@ -62,15 +63,17 @@
class DiffLlamaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]

def forward(self, hidden_state):
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj


def rotate_half(x):
Expand Down Expand Up @@ -529,37 +532,15 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)

# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
Expand All @@ -579,13 +560,9 @@ def forward(
hidden_states = residual + hidden_states

outputs = (hidden_states,)

if output_attentions:
outputs += (self_attn_weights,)

if use_cache:
outputs += (present_key_value,)

return outputs


Expand Down Expand Up @@ -637,40 +614,18 @@ def _init_weights(self, module):
class DiffLlamaRotaryEmbedding(nn.Module):
def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
config: DiffLlamaConfig,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[DiffLlamaConfig] = None,
):
super().__init__()
# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`DiffLlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.46"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings

self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
Expand Down Expand Up @@ -819,10 +774,7 @@ def __init__(self, config: DiffLlamaConfig):
)
self.norm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = DiffLlamaRotaryEmbedding(config=config)

self.gradient_checkpointing = False
if getattr(config, "pretraining_tp", 1) != 1:
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")

# Initialize weights and apply final processing
self.post_init()
Expand All @@ -839,7 +791,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
Expand Down Expand Up @@ -867,31 +819,22 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)

hidden_states = inputs_embeds

# create position embeddings to be shared across the decoder layers
Expand All @@ -900,7 +843,6 @@ def forward(
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
Expand Down Expand Up @@ -933,9 +875,6 @@ def forward(

hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]

if output_attentions:
all_self_attns += (layer_outputs[1],)

Expand All @@ -945,18 +884,13 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()

if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
output = BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
return output if return_dict else output.to_tuple()

def _update_causal_mask(
self,
Expand Down Expand Up @@ -1080,6 +1014,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask


class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...


class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
Expand Down Expand Up @@ -1127,7 +1064,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1177,6 +1114,7 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)

hidden_states = outputs[0]
Expand All @@ -1185,7 +1123,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down

0 comments on commit 0ce2023

Please sign in to comment.