From 0ff9fa9a7ed0974f674a145f320d01bcf3e363a6 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Tue, 24 Oct 2023 21:08:32 +0800 Subject: [PATCH] add fused_rms_norm, fused_rope --- paddlenlp/transformers/qwen/configuration.py | 8 +- paddlenlp/transformers/qwen/modeling.py | 100 ++++++++++--------- 2 files changed, 61 insertions(+), 47 deletions(-) diff --git a/paddlenlp/transformers/qwen/configuration.py b/paddlenlp/transformers/qwen/configuration.py index 3fb6ac69d3de..72a0e703df44 100644 --- a/paddlenlp/transformers/qwen/configuration.py +++ b/paddlenlp/transformers/qwen/configuration.py @@ -37,7 +37,9 @@ def __init__( rotary_emb_base=10000, use_dynamic_ntk=True, use_logn_attn=True, - use_flash_attn="auto", + use_flash_attention=False, + use_fused_rms_norm=False, + use_fused_rope=False, intermediate_size=22016, no_bias=True, tie_word_embeddings=False, @@ -60,7 +62,9 @@ def __init__( self.rotary_emb_base = rotary_emb_base self.use_dynamic_ntk = use_dynamic_ntk self.use_logn_attn = use_logn_attn - self.use_flash_attn = use_flash_attn + self.use_flash_attention = use_flash_attention + self.use_fused_rms_norm = use_fused_rms_norm + self.use_fused_rope = use_fused_rope self.no_bias = no_bias super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) diff --git a/paddlenlp/transformers/qwen/modeling.py b/paddlenlp/transformers/qwen/modeling.py index bf74b6dbd448..25130da5f77a 100644 --- a/paddlenlp/transformers/qwen/modeling.py +++ b/paddlenlp/transformers/qwen/modeling.py @@ -22,6 +22,7 @@ from paddle import Tensor, nn from paddle.distributed import fleet from paddle.distributed.fleet.utils import recompute +from paddle.utils import try_import from paddlenlp.transformers.model_outputs import ( BaseModelOutputWithPast, @@ -41,6 +42,8 @@ except: flash_attention = None +from paddle.incubate.nn.functional import fused_rotary_position_embedding + def get_triangle_upper_mask(x, mask=None): if mask is not None: @@ -142,7 +145,7 @@ def _attn(self, query, key, value, attention_mask=None): bsz, q_len, num_heads, head_dim = query.shape _, kv_seq_len, _, _ = value.shape - if self.config.use_flash_attn and flash_attention is not None: + if self.config.use_flash_attention and flash_attention is not None: # Flash Attention now ignore attention mask # Current Flash Attention doesn't support attn maskt # Paddle Flash Attention input [ bz, seqlen, nhead, head_dim] @@ -224,7 +227,7 @@ def forward( self._ntk_cached = ntk_alpha else: ntk_alpha = self._ntk_cached - rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) + rotary_pos_emb = self.rotary_emb(value, kv_seq_len, ntk_alpha=ntk_alpha) if rotary_pos_emb is not None: if isinstance(rotary_pos_emb, tuple): @@ -233,13 +236,19 @@ def forward( rotary_pos_emb = (rotary_pos_emb,) * 2 if rotary_pos_emb is not None: - q_pos_emb, k_pos_emb = rotary_pos_emb - # Slice the pos emb for current inference - cur_len = query.shape[1] - q_pos_emb = q_pos_emb[:, -cur_len:, :, :] - k_pos_emb = k_pos_emb[:, -cur_len:, :, :] - query = apply_rotary_pos_emb(query, q_pos_emb) - key = apply_rotary_pos_emb(key, k_pos_emb) + cos, sin = rotary_pos_emb + if self.config.use_fused_rope: + query, key, _ = fused_rotary_position_embedding( + query, + key, + v=None, + sin=sin, + cos=cos, + position_ids=None, + use_neox_rotary_style=False, + ) + else: + query, key = apply_rotary_pos_emb(query, key, cos, sin) if layer_past is not None: past_key, past_value = layer_past[0], layer_past[1] @@ -312,17 +321,9 @@ def forward(self, hidden_states): class QWenBlock(nn.Layer): def __init__(self, config): super().__init__() - hidden_size = config.hidden_size - - self.ln_1 = RMSNorm( - hidden_size, - eps=config.layer_norm_epsilon, - ) + self.ln_1 = RMSNorm(config) self.attn = QWenAttention(config) - self.ln_2 = RMSNorm( - hidden_size, - eps=config.layer_norm_epsilon, - ) + self.ln_2 = RMSNorm(config) self.mlp = QWenMLP(config) @@ -537,10 +538,7 @@ def __init__(self, config): for i in range(config.num_hidden_layers) ] ) - self.ln_f = RMSNorm( - self.embed_dim, - eps=config.layer_norm_epsilon, - ) + self.ln_f = RMSNorm(config) def get_input_embeddings(self): return self.wte @@ -870,12 +868,11 @@ def __init__(self, dim, base=10000): super().__init__() self.dim = dim self.base = base - self.inv_freq = 1.0 / (base ** (paddle.arange(0, dim, 2, dtype=paddle.float32) / dim)) - self._rotary_pos_emb_cache = None + self.inv_freq = 1.0 / (self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim)) self._seq_len_cached = 0 self._ntk_alpha_cached = 1.0 - def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0): + def update_cos_sin_cache(self, max_seq_len, offset=0, ntk_alpha=1.0): seqlen = max_seq_len + offset if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached: base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) @@ -885,36 +882,46 @@ def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0): seq = paddle.arange(self._seq_len_cached) freqs = paddle.outer(seq.astype(self.inv_freq.dtype), self.inv_freq) emb = paddle.concat([freqs, freqs], axis=-1) + self.cos_cached = emb.cos()[None, :, None, :] + self.sin_cached = emb.sin()[None, :, None, :] + + def forward(self, x, max_seq_len, offset=0, ntk_alpha=1.0): + self.update_cos_sin_cache(max_seq_len, offset, ntk_alpha) + cos = self.cos_cached[:, offset : offset + max_seq_len, :, ...] + sin = self.sin_cached[:, offset : offset + max_seq_len, :, ...] + return ( + cos.cast(x.dtype) if cos.dtype != x.dtype else cos, + sin.cast(x.dtype) if sin.dtype != x.dtype else sin, + ) - self._rotary_pos_emb_cache = emb.unsqueeze([0, 2]) - - def forward(self, max_seq_len, offset=0, ntk_alpha=1.0): - self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha) - return self._rotary_pos_emb_cache[:, offset : offset + max_seq_len] +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return paddle.concat([-x2, x1], axis=-1) -def _rotate_half(x): - x = x.reshape(x.shape[:-1] + [2, -1]) - x1, x2 = x.unbind(axis=-2) - return paddle.concat([-x2, x1], axis=-1) +def apply_rotary_pos_emb(q, k, cos, sin): + cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] + sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed -def apply_rotary_pos_emb(t, freqs): - rot_dim = freqs.shape[-1] - t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:] - t_ = t_.astype(paddle.float32) - t_pass_ = t_pass_.astype(paddle.float32) - t_ = (t_ * freqs.cos()) + (_rotate_half(t_) * freqs.sin()) - return paddle.concat([t_, t_pass_], axis=-1).astype(t.dtype) +def rms_norm_fused(x_in, w, eps): + fused_ln = try_import("fused_ln") + return fused_ln.fused_rms_norm(x_in, w, eps)[0] class RMSNorm(nn.Layer): - def __init__(self, dim: int, eps: float = 1e-6): + def __init__(self, config): super().__init__() - self.eps = eps + self.config = config + self.eps = config.layer_norm_epsilon self.weight = paddle.create_parameter( - shape=[dim], + shape=[config.hidden_size], dtype=paddle.get_default_dtype(), default_initializer=nn.initializer.Constant(1.0), ) @@ -923,5 +930,8 @@ def _norm(self, x): return x * paddle.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): + if self.config.use_fused_rms_norm: + return rms_norm_fused(x, self.weight, self.eps) + output = self._norm(x.astype(paddle.float32)).astype(x.dtype) return output * self.weight