Skip to content

Commit

Permalink
force reset
Browse files Browse the repository at this point in the history
  • Loading branch information
weak-kajuma committed Dec 21, 2024
1 parent 4660c6e commit b4ff5f3
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 34 deletions.
149 changes: 122 additions & 27 deletions src/transformers/models/diffllama/modeling_diffllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,17 @@
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
Expand Down Expand Up @@ -224,7 +226,7 @@ def forward(
if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value
return attn_output, attn_weights


class DiffLlamaFlashAttention2(DiffLlamaAttention):
Expand Down Expand Up @@ -374,7 +376,7 @@ def forward(
if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value
return attn_output, attn_weights


class DiffLlamaSdpaAttention(DiffLlamaAttention):
Expand Down Expand Up @@ -477,27 +479,27 @@ def forward(
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)

return attn_output, None, past_key_value
return attn_output, None


class DiffLlamaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
def __init__(self, hidden_size, eps=1e-6):
"""
DiffLlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
output = self._norm(x.float())
# Llama does x.to(float16) * w whilst DiffLlama is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = output * (1.0 + self.weight.float())
return output.type_as(x)
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)

def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


DIFFLLAMA_ATTENTION_CLASSES = {
Expand All @@ -511,7 +513,9 @@ class DiffLlamaDecoderLayer(nn.Module):
def __init__(self, config: DiffLlamaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size

self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)

self.mlp = DiffLlamaMLP(config)
self.input_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand All @@ -525,6 +529,7 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand All @@ -542,6 +547,9 @@ def forward(
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
Expand All @@ -559,6 +567,7 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
Expand Down Expand Up @@ -625,6 +634,93 @@ def _init_weights(self, module):
module.weight.data[module.padding_idx].zero_()


class DiffLlamaRotaryEmbedding(nn.Module):
def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[DiffLlamaConfig] = None,
):
super().__init__()
# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`DiffLlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.46"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings

self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq

def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)

# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling

return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


DIFFLLAMA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Expand Down Expand Up @@ -722,6 +818,7 @@ def __init__(self, config: DiffLlamaConfig):
[DiffLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = DiffLlamaRotaryEmbedding(config=config)

self.gradient_checkpointing = False
if getattr(config, "pretraining_tp", 1) != 1:
Expand Down Expand Up @@ -749,6 +846,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand All @@ -770,9 +868,9 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids)

# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False # noqa: F841
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True # noqa: F841
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
Expand All @@ -788,22 +886,16 @@ def forward(
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)

# embed positions
hidden_states = inputs_embeds

# normalized
# DiffLlama downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
# See https://github.com/huggingface/transformers/pull/29402
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)

# decoder layers
all_hidden_states = () if output_hidden_states else None
Expand All @@ -824,6 +916,7 @@ def forward(
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
Expand All @@ -834,6 +927,8 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)

hidden_states = layer_outputs[0]
Expand Down
33 changes: 26 additions & 7 deletions src/transformers/models/diffllama/modular_diffllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from ...cache_utils import Cache, StaticCache
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
is_flash_attn_greater_or_equal_2_10,
logging,
Expand All @@ -37,8 +36,6 @@
LlamaForTokenClassification,
LlamaModel,
LlamaPreTrainedModel,
LlamaRMSNorm,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
repeat_kv,
)
Expand Down Expand Up @@ -163,7 +160,7 @@ def forward(
if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value
return attn_output, attn_weights


class DiffLlamaFlashAttention2(DiffLlamaAttention):
Expand Down Expand Up @@ -313,7 +310,7 @@ def forward(
if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value
return attn_output, attn_weights


class DiffLlamaSdpaAttention(DiffLlamaAttention):
Expand Down Expand Up @@ -416,7 +413,29 @@ def forward(
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)

return attn_output, None, past_key_value
return attn_output, None


DIFFLLAMA_ATTENTION_CLASSES = {
"eager": DiffLlamaAttention,
"flash_attention_2": DiffLlamaFlashAttention2,
"sdpa": DiffLlamaSdpaAttention,
}


class DiffLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: DiffLlamaConfig, layer_idx: int):
super().__init__(config, layer_idx)

self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)


class DiffLlamaPreTrainedModel(LlamaPreTrainedModel):
pass


class DiffLlamaModel(LlamaModel):
pass


class DiffLlamaForCausalLM(GemmaForCausalLM):
Expand All @@ -437,7 +456,7 @@ class DiffLlamaForTokenClassification(LlamaForTokenClassification):

__all__ = [
"DiffLlamaPreTrainedModel",
"DiffLlamaModel",
"DiffLlamaModel", # noqa: F822
"DiffLlamaForCausalLM",
"DiffLlamaForSequenceClassification",
"DiffLlamaForQuestionAnswering",
Expand Down

0 comments on commit b4ff5f3

Please sign in to comment.