Skip to content

Commit

Permalink
add fused_rms_norm, fused_rope
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Oct 24, 2023
1 parent 5359c0e commit 0ff9fa9
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 47 deletions.
8 changes: 6 additions & 2 deletions paddlenlp/transformers/qwen/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def __init__(
rotary_emb_base=10000,
use_dynamic_ntk=True,
use_logn_attn=True,
use_flash_attn="auto",
use_flash_attention=False,
use_fused_rms_norm=False,
use_fused_rope=False,
intermediate_size=22016,
no_bias=True,
tie_word_embeddings=False,
Expand All @@ -60,7 +62,9 @@ def __init__(
self.rotary_emb_base = rotary_emb_base
self.use_dynamic_ntk = use_dynamic_ntk
self.use_logn_attn = use_logn_attn
self.use_flash_attn = use_flash_attn
self.use_flash_attention = use_flash_attention
self.use_fused_rms_norm = use_fused_rms_norm
self.use_fused_rope = use_fused_rope
self.no_bias = no_bias

super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
100 changes: 55 additions & 45 deletions paddlenlp/transformers/qwen/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from paddle import Tensor, nn
from paddle.distributed import fleet
from paddle.distributed.fleet.utils import recompute
from paddle.utils import try_import

from paddlenlp.transformers.model_outputs import (
BaseModelOutputWithPast,
Expand All @@ -41,6 +42,8 @@
except:
flash_attention = None

from paddle.incubate.nn.functional import fused_rotary_position_embedding


def get_triangle_upper_mask(x, mask=None):
if mask is not None:
Expand Down Expand Up @@ -142,7 +145,7 @@ def _attn(self, query, key, value, attention_mask=None):
bsz, q_len, num_heads, head_dim = query.shape
_, kv_seq_len, _, _ = value.shape

if self.config.use_flash_attn and flash_attention is not None:
if self.config.use_flash_attention and flash_attention is not None:
# Flash Attention now ignore attention mask
# Current Flash Attention doesn't support attn maskt
# Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
Expand Down Expand Up @@ -224,7 +227,7 @@ def forward(
self._ntk_cached = ntk_alpha
else:
ntk_alpha = self._ntk_cached
rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha)
rotary_pos_emb = self.rotary_emb(value, kv_seq_len, ntk_alpha=ntk_alpha)

if rotary_pos_emb is not None:
if isinstance(rotary_pos_emb, tuple):
Expand All @@ -233,13 +236,19 @@ def forward(
rotary_pos_emb = (rotary_pos_emb,) * 2

if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
cur_len = query.shape[1]
q_pos_emb = q_pos_emb[:, -cur_len:, :, :]
k_pos_emb = k_pos_emb[:, -cur_len:, :, :]
query = apply_rotary_pos_emb(query, q_pos_emb)
key = apply_rotary_pos_emb(key, k_pos_emb)
cos, sin = rotary_pos_emb
if self.config.use_fused_rope:
query, key, _ = fused_rotary_position_embedding(
query,
key,
v=None,
sin=sin,
cos=cos,
position_ids=None,
use_neox_rotary_style=False,
)
else:
query, key = apply_rotary_pos_emb(query, key, cos, sin)

if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1]
Expand Down Expand Up @@ -312,17 +321,9 @@ def forward(self, hidden_states):
class QWenBlock(nn.Layer):
def __init__(self, config):
super().__init__()
hidden_size = config.hidden_size

self.ln_1 = RMSNorm(
hidden_size,
eps=config.layer_norm_epsilon,
)
self.ln_1 = RMSNorm(config)
self.attn = QWenAttention(config)
self.ln_2 = RMSNorm(
hidden_size,
eps=config.layer_norm_epsilon,
)
self.ln_2 = RMSNorm(config)

self.mlp = QWenMLP(config)

Expand Down Expand Up @@ -537,10 +538,7 @@ def __init__(self, config):
for i in range(config.num_hidden_layers)
]
)
self.ln_f = RMSNorm(
self.embed_dim,
eps=config.layer_norm_epsilon,
)
self.ln_f = RMSNorm(config)

def get_input_embeddings(self):
return self.wte
Expand Down Expand Up @@ -870,12 +868,11 @@ def __init__(self, dim, base=10000):
super().__init__()
self.dim = dim
self.base = base
self.inv_freq = 1.0 / (base ** (paddle.arange(0, dim, 2, dtype=paddle.float32) / dim))
self._rotary_pos_emb_cache = None
self.inv_freq = 1.0 / (self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim))
self._seq_len_cached = 0
self._ntk_alpha_cached = 1.0

def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0):
def update_cos_sin_cache(self, max_seq_len, offset=0, ntk_alpha=1.0):
seqlen = max_seq_len + offset
if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
Expand All @@ -885,36 +882,46 @@ def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0):
seq = paddle.arange(self._seq_len_cached)
freqs = paddle.outer(seq.astype(self.inv_freq.dtype), self.inv_freq)
emb = paddle.concat([freqs, freqs], axis=-1)
self.cos_cached = emb.cos()[None, :, None, :]
self.sin_cached = emb.sin()[None, :, None, :]

def forward(self, x, max_seq_len, offset=0, ntk_alpha=1.0):
self.update_cos_sin_cache(max_seq_len, offset, ntk_alpha)
cos = self.cos_cached[:, offset : offset + max_seq_len, :, ...]
sin = self.sin_cached[:, offset : offset + max_seq_len, :, ...]
return (
cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
)

self._rotary_pos_emb_cache = emb.unsqueeze([0, 2])

def forward(self, max_seq_len, offset=0, ntk_alpha=1.0):
self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha)
return self._rotary_pos_emb_cache[:, offset : offset + max_seq_len]

def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return paddle.concat([-x2, x1], axis=-1)

def _rotate_half(x):

x = x.reshape(x.shape[:-1] + [2, -1])
x1, x2 = x.unbind(axis=-2)
return paddle.concat([-x2, x1], axis=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim]
sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


def apply_rotary_pos_emb(t, freqs):
rot_dim = freqs.shape[-1]
t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
t_ = t_.astype(paddle.float32)
t_pass_ = t_pass_.astype(paddle.float32)
t_ = (t_ * freqs.cos()) + (_rotate_half(t_) * freqs.sin())
return paddle.concat([t_, t_pass_], axis=-1).astype(t.dtype)
def rms_norm_fused(x_in, w, eps):
fused_ln = try_import("fused_ln")
return fused_ln.fused_rms_norm(x_in, w, eps)[0]


class RMSNorm(nn.Layer):
def __init__(self, dim: int, eps: float = 1e-6):
def __init__(self, config):
super().__init__()
self.eps = eps
self.config = config
self.eps = config.layer_norm_epsilon
self.weight = paddle.create_parameter(
shape=[dim],
shape=[config.hidden_size],
dtype=paddle.get_default_dtype(),
default_initializer=nn.initializer.Constant(1.0),
)
Expand All @@ -923,5 +930,8 @@ def _norm(self, x):
return x * paddle.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
if self.config.use_fused_rms_norm:
return rms_norm_fused(x, self.weight, self.eps)

output = self._norm(x.astype(paddle.float32)).astype(x.dtype)
return output * self.weight

0 comments on commit 0ff9fa9

Please sign in to comment.