diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index c0fe013d096454..3bd1e0d867111d 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -43,6 +43,7 @@ from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, @@ -62,6 +63,7 @@ class DiffLlamaMLP(nn.Module): def __init__(self, config): super().__init__() + self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) @@ -69,8 +71,9 @@ def __init__(self, config): self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] - def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj def rotate_half(x): @@ -529,37 +532,15 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -579,13 +560,9 @@ def forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -637,40 +614,18 @@ def _init_weights(self, module): class DiffLlamaRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, + config: DiffLlamaConfig, device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[DiffLlamaConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`DiffLlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] @@ -819,10 +774,7 @@ def __init__(self, config: DiffLlamaConfig): ) self.norm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = DiffLlamaRotaryEmbedding(config=config) - self.gradient_checkpointing = False - if getattr(config, "pretraining_tp", 1) != 1: - logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Initialize weights and apply final processing self.post_init() @@ -839,7 +791,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -867,31 +819,22 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) + hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -900,7 +843,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: @@ -933,9 +875,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -945,18 +884,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -1080,6 +1014,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} @@ -1127,7 +1064,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, - **loss_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1177,6 +1114,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -1185,7 +1123,7 @@ def forward( loss = None if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:]