diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 86e49c315d..d16310489d 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -837,3 +837,10 @@ def build_alibi_bias( 'multiquery_attention': MultiQueryAttention, 'grouped_query_attention': GroupedQueryAttention } + +try: + import transformer_engine.pytorch as te + + ATTN_CLASS_REGISTRY['te_multihead_attention'] = te.MultiheadAttention +except: + pass diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index c077ccb535..754cca9709 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -89,14 +89,29 @@ def __init__( } self.norm_1 = norm_class(d_model, device=device) - self.attn = attn_class( - d_model=d_model, - n_heads=n_heads, - fc_type=fc_type, - device=device, - **attn_config_subset_for_attn_class, - bias=not no_bias, - ) + + self.use_te_attn = False + if fc_type == 'te' and attn_config[ + 'attn_type'] == 'te_multihead_attention': + self.use_te_attn = True + self.attn = attn_class( + hidden_size=d_model, + num_attention_heads=n_heads, + num_gqa_groups=attn_config['kv_n_heads'], + fuse_qkv_params=True, + qkv_weight_interleaved=False, + input_layernorm=False, + bias=not no_bias, + ) + else: + self.attn = attn_class( + d_model=d_model, + n_heads=n_heads, + fc_type=fc_type, + device=device, + **attn_config_subset_for_attn_class, + bias=not no_bias, + ) self.norm_2 = None if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm', False): @@ -126,16 +141,30 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ torch.Tensor, torch.Tensor]]]: a = self.norm_1(x) - b, attn_weights, past_key_value = self.attn( - a, - past_key_value=past_key_value, - attn_bias=attn_bias, - rotary_emb_w_meta_info=rotary_emb_w_meta_info, - attention_mask=attention_mask, - is_causal=is_causal, - needs_weights=output_attentions, - attention_mask_in_length=attention_mask_in_length, - ) + if self.use_te_attn: + assert rotary_emb_w_meta_info is None, 'rotary embeddings not supported with TE attn' + assert output_attentions is False, 'output_attentions not supported with TE attn' + assert past_key_value is None, 'past_key_value not supported with TE attn' + assert attention_mask_in_length is None, 'attention_mask_in_length not supported with TE attn' + b = self.attn(a, + core_attention_bias=attn_bias, + attention_mask=attention_mask, + checkpoint_core_attention=False, + attn_mask_type='causal') + attn_weights = None + past_key_value = None + else: + b, attn_weights, past_key_value = self.attn( + a, + past_key_value=past_key_value, + attn_bias=attn_bias, + rotary_emb_w_meta_info=rotary_emb_w_meta_info, + attention_mask=attention_mask, + is_causal=is_causal, + needs_weights=output_attentions, + attention_mask_in_length=attention_mask_in_length, + ) + x = x + self.resid_attn_dropout(b) m = x if self.norm_2 is not None: