Skip to content

Commit

Permalink
add core_attn
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Oct 25, 2023
1 parent 1689eaf commit c6338d8
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
2 changes: 2 additions & 0 deletions paddlenlp/transformers/qwen/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
max_position_embeddings=8192,
scale_attn_weights=True,
use_cache=True,
recompute_granularity="full",
kv_channels=128,
rotary_pct=1.0,
rotary_emb_base=10000,
Expand All @@ -59,6 +60,7 @@ def __init__(
self.initializer_range = initializer_range
self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache
self.recompute_granularity = recompute_granularity
self.max_position_embeddings = max_position_embeddings
self.kv_channels = kv_channels
self.rotary_pct = rotary_pct
Expand Down
14 changes: 11 additions & 3 deletions paddlenlp/transformers/qwen/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,15 @@ def __init__(self, config):

self.config = config
self.seq_length = config.seq_length

self.hidden_size = config.hidden_size
self.split_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)

self.scale_attn_weights = True
self.enable_recompute = False
self.recompute_granularity = config.recompute_granularity

self.projection_size = config.kv_channels * config.num_attention_heads

Expand Down Expand Up @@ -282,7 +283,13 @@ def forward(
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
query = query * logn_tensor.expand(query.shape)

attn_output, attn_weight = self._attn(query, key, value, attention_mask)
has_gradient = not (query.stop_gradient and key.stop_gradient and value.stop_gradient)
if self.enable_recompute and self.training and has_gradient and self.recompute_granularity == "core_attn":
attn_output, attn_weight = recompute(
self._attn, query, key, value, attention_mask, use_reentrant=self.config.recompute_use_reentrant
)
else:
attn_output, attn_weight = self._attn(query, key, value, attention_mask)
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)

attn_output = self.c_proj(context_layer)
Expand Down Expand Up @@ -537,6 +544,7 @@ def __init__(self, config):
self.num_hidden_layers = config.num_hidden_layers
self.embed_dim = config.hidden_size
self.enable_recompute = False
self.recompute_granularity = config.recompute_granularity

if config.tensor_parallel_degree > 1:
self.wte = mpu.VocabParallelEmbedding(
Expand Down Expand Up @@ -688,7 +696,7 @@ def forward(
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if self.enable_recompute and self.training and has_gradient:
if self.enable_recompute and self.training and has_gradient and self.recompute_granularity == "full":
outputs = self.recompute_training(
block,
hidden_states,
Expand Down

0 comments on commit c6338d8

Please sign in to comment.