From ac563e6d6967e1c2499c8e677aab0e4f003b9481 Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Wed, 8 May 2024 16:19:17 -0700 Subject: [PATCH] Refactoring attention (#1182) * refactoring * adding back a function that got deleted by mistake * adding co-authors Co-Authored-By: Vitaliy Chiley Co-Authored-By: Cheng Li * adding co-authors Co-Authored-By: Vitaliy Chiley * adding co-authors Co-authored-by: Vitaliy Chiley Co-authored-by: Vitaliy Chiley * Update config_utils.py adding co-authors Co-authored-by: Vitaliy Chiley Co-authored-by: Vitaliy Chiley Co-authored-by: Cheng Li Co-authored-by: Cheng Li <@cli99> * lint Co-authored-by: Vitaliy Chiley Co-authored-by: Vitaliy Chiley * Adding_co_authors Co-authored-by: Vitaliy Chiley Co-authored-by: Vitaliy Chiley Co-authored-by: Cheng Li * Update llmfoundry/models/mpt/modeling_mpt.py Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * addressing comments * adding_co_authors Co-authored-by: Cheng Li * Update llmfoundry/utils/config_utils.py --------- Co-authored-by: Vitaliy Chiley Co-authored-by: Vitaliy Chiley Co-authored-by: Cheng Li Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> Co-authored-by: Cheng Li --- llmfoundry/models/layers/attention.py | 228 +++++++++++++-------- llmfoundry/models/layers/blocks.py | 66 +++--- llmfoundry/models/mpt/configuration_mpt.py | 6 +- llmfoundry/models/mpt/modeling_mpt.py | 69 +++++-- llmfoundry/models/utils/config_moe_args.py | 14 +- llmfoundry/utils/config_utils.py | 24 ++- 6 files changed, 268 insertions(+), 139 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 82fee68af6..4884b568fd 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -5,10 +5,9 @@ import math import warnings -from typing import Any, Optional +from typing import Any, Dict, Optional, Tuple import torch -import torch.nn as nn import transformers from einops import rearrange from packaging import version @@ -233,7 +232,6 @@ def flash_attn_fn( dropout_p: float = 0.0, training: bool = False, needs_weights: bool = False, - multiquery: bool = False, should_repeat_kv_for_gqa: Optional[bool] = True, sliding_window_size: int = -1, alibi_slopes: Optional[torch.Tensor] = None, @@ -506,6 +504,54 @@ def forward( flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: + query, key, value = self.get_qkv(x) + + if rotary_emb_w_meta_info is not None: + query, key, value = self._apply_rotary_embeddings( + rotary_emb_w_meta_info, + query, + key, + value, + ) + + extra_attn_kwargs = self.get_implementation_specific_args( + attention_mask, + alibi_slopes, + flash_attn_padding_info, + ) + + context, attn_weights, past_key_value = self.attn_fn( + query, + key, + value, + n_heads=self.n_heads, + kv_n_heads=self.kv_n_heads, + past_key_value=past_key_value, + softmax_scale=self.softmax_scale, + attn_bias=attn_bias, + is_causal=is_causal, + dropout_p=self.attn_dropout_p, + training=self.training, + needs_weights=needs_weights, + **extra_attn_kwargs, + ) + + return self.out_proj(context), attn_weights, past_key_value + + def get_qkv( + self, + x: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Computes and returns the query, key, and value tensors. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + query (torch.Tensor): The query tensor. + key (torch.Tensor): The key tensor. + value (torch.Tensor): The value tensor. + """ qkv = self.Wqkv(x) if self.clip_qkv: @@ -520,8 +566,6 @@ def forward( dim=2, ) - key_padding_mask = attention_mask - if self.qk_ln or self.qk_gn: # Applying layernorm to qk q_shape, k_shape = query.shape, key.shape @@ -533,97 +577,105 @@ def forward( query = self.q_ln(query).to(dtype).view(q_shape) key = self.k_ln(key).to(dtype).view(k_shape) - if rotary_emb_w_meta_info is not None: - rotary_emb = rotary_emb_w_meta_info['rotary_emb'] - seq_len = rotary_emb_w_meta_info['seq_len'] - offset_info = rotary_emb_w_meta_info['offset_info'] - bsz, seqlen = query.shape[:2] - query = query.view(bsz, seqlen, -1, self.head_dim) - key = key.view(bsz, seqlen, -1, self.head_dim) - - if rotary_emb_w_meta_info['impl'] == 'dail': - value = value.view(bsz, seqlen, -1, self.head_dim) - - kv = torch.stack([key, value], dim=2) - query, kv = rotary_emb( - query, - kv, - seqlen_offset=offset_info, - max_seqlen=seq_len, + return query, key, value + + def _apply_rotary_embeddings( + self, + rotary_emb_w_meta_info: Dict[str, Any], + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + rotary_emb = rotary_emb_w_meta_info['rotary_emb'] + seq_len = rotary_emb_w_meta_info['seq_len'] + offset_info = rotary_emb_w_meta_info['offset_info'] + bsz, seqlen = query.shape[:2] + query = query.view(bsz, seqlen, -1, self.head_dim) + key = key.view(bsz, seqlen, -1, self.head_dim) + + if rotary_emb_w_meta_info['impl'] == 'dail': + value = value.view(bsz, seqlen, -1, self.head_dim) + + kv = torch.stack([key, value], dim=2) + query, kv = rotary_emb( + query, + kv, + seqlen_offset=offset_info, + max_seqlen=seq_len, + ) + [key, value] = torch.unbind(kv, dim=2) + + value = value.view(bsz, seqlen, -1) + elif rotary_emb_w_meta_info['impl'] == 'hf': + if is_transformers_version_gte('4.38'): + (cos, sin) = rotary_emb( + x=value, + position_ids=offset_info, + ) + else: + (cos, sin) = rotary_emb(x=value, seq_len=seq_len) + if is_transformers_version_gte('4.38'): + query, key = apply_rotary_pos_emb( + q=query, + k=key, + cos=cos, + sin=sin, + position_ids=None, + unsqueeze_dim=2, + ) + elif is_transformers_version_gte('4.36'): + query, key = apply_rotary_pos_emb( + q=query, + k=key, + cos=cos, + sin=sin, + position_ids=offset_info, + unsqueeze_dim=2, + ) + else: + query = query.transpose(1, 2) + key = key.transpose(1, 2) + query, key = apply_rotary_pos_emb( + q=query, + k=key, + cos=cos, + sin=sin, + position_ids=offset_info, ) - [key, value] = torch.unbind(kv, dim=2) - - value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim) - elif rotary_emb_w_meta_info['impl'] == 'hf': - if is_transformers_version_gte('4.38'): - (cos, sin) = rotary_emb( - x=value, - position_ids=offset_info, - ) - else: - (cos, sin) = rotary_emb(x=value, seq_len=seq_len) - if is_transformers_version_gte('4.38'): - query, key = apply_rotary_pos_emb( - q=query, - k=key, - cos=cos, - sin=sin, - position_ids=None, - unsqueeze_dim=2, - ) - elif is_transformers_version_gte('4.36'): - query, key = apply_rotary_pos_emb( - q=query, - k=key, - cos=cos, - sin=sin, - position_ids=offset_info, - unsqueeze_dim=2, - ) - else: - query = query.transpose(1, 2) - key = key.transpose(1, 2) - query, key = apply_rotary_pos_emb( - q=query, - k=key, - cos=cos, - sin=sin, - position_ids=offset_info, - ) - query = query.transpose(1, 2) - key = key.transpose(1, 2) - - query = query.view(bsz, seqlen, self.d_model) - key = key.view(bsz, seqlen, self.kv_n_heads * self.head_dim) - - extra_attn_kwargs = {} + query = query.transpose(1, 2) + key = key.transpose(1, 2) + + query = query.view(bsz, seqlen, -1) + key = key.view(bsz, seqlen, -1) + return query, key, value + + def get_implementation_specific_args( + self, + attention_mask: Optional[torch.Tensor] = None, + alibi_slopes: Optional[torch.Tensor] = None, + flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, + ) -> dict[str, Any]: + """Returns attention implementation specific args. + + Args: + attention_mask (Optional[torch.Tensor]): The attention mask. + alibi_slopes (Optional[torch.Tensor]): The alibi slopes. + flash_attn_padding_info (Optional[dict[str, torch.Tensor]]): The padding information, only required for flash attention. + + Returns: + extra_attn_kwargs (dict[str, Any]): Implementation specific args. + """ if self.attn_impl == 'flash': - key_padding_mask = None extra_attn_kwargs = { 'should_repeat_kv_for_gqa': not is_flash_v2_installed(), 'sliding_window_size': self.sliding_window_size, 'alibi_slopes': alibi_slopes, 'flash_attn_padding_info': flash_attn_padding_info, + 'key_padding_mask': None, } - - context, attn_weights, past_key_value = self.attn_fn( - query, - key, - value, - self.n_heads, - self.kv_n_heads, - past_key_value=past_key_value, - softmax_scale=self.softmax_scale, - attn_bias=attn_bias, - key_padding_mask=key_padding_mask, - is_causal=is_causal, - dropout_p=self.attn_dropout_p, - training=self.training, - needs_weights=needs_weights, - **extra_attn_kwargs, - ) - - return self.out_proj(context), attn_weights, past_key_value + else: + extra_attn_kwargs = {'key_padding_mask': attention_mask} + return extra_attn_kwargs @attention_classes.register_class('multihead_attention') diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 494bdcdff1..3ff65fd8b3 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -3,7 +3,7 @@ """GPT Blocks used for the GPT Model.""" -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Set, Tuple import torch import torch.nn as nn @@ -88,6 +88,8 @@ def __init__( self.norm_attn_norm = FusedNormAttentionNorm( d_model=d_model, n_heads=n_heads, + args_to_exclude_in_attn_class=self. + args_to_exclude_in_attn_class, attn_config=attn_config, ffn_has_norm=ffn_has_norm, fc_type=fc_type, @@ -99,21 +101,10 @@ def __init__( else: assert isinstance(attn_config['attn_type'], str) # Necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs - args_to_exclude_in_attn_class = { - 'attn_type', - 'alibi', - 'attn_uses_sequence_id', - 'alibi_bias_max', - 'rope', - 'rope_theta', - 'rope_impl', - 'rope_dail_config', - 'rope_hf_config', - } attn_config_subset_for_attn_class = { k: v for k, v in attn_config.items() - if k not in args_to_exclude_in_attn_class + if k not in self.args_to_exclude_in_attn_class } self.norm_1 = build_norm( @@ -153,6 +144,20 @@ def __init__( self.resid_ffn_dropout = nn.Dropout(resid_pdrop) self.use_pad_tok_in_ffn = use_pad_tok_in_ffn + @property + def args_to_exclude_in_attn_class(self): + return { + 'attn_type', + 'alibi', + 'attn_uses_sequence_id', + 'alibi_bias_max', + 'rope', + 'rope_theta', + 'rope_impl', + 'rope_dail_config', + 'rope_hf_config', + } + def forward( self, x: torch.Tensor, @@ -196,6 +201,24 @@ def forward( if self.norm_2 is not None: m = self.norm_2(x) + n = self.apply_ffn(attention_mask, m) + x = x + self.resid_ffn_dropout(n) + return x, attn_weights, past_key_value + + def apply_ffn( + self, + attention_mask: Optional[torch.ByteTensor], + m: torch.Tensor, + ) -> torch.Tensor: + """Apply feed forward layers to the input. + + Args: + attention_mask (Optional[torch.ByteTensor]): The attention mask. + m (torch.Tensor): The input. + + Returns: + n (torch.Tensor): The output. + """ batch_size, seq_len = m.size()[:2] indices = None if not self.use_pad_tok_in_ffn: @@ -205,8 +228,7 @@ def forward( if not self.use_pad_tok_in_ffn: assert pad_input is not None n = pad_input(n, indices, batch_size, seq_len) - x = x + self.resid_ffn_dropout(n) - return x, attn_weights, past_key_value + return n class FusedNormAttentionNorm(nn.Module): @@ -215,6 +237,7 @@ def __init__( self, d_model: int, n_heads: int, + args_to_exclude_in_attn_class: Set[str], attn_config: Optional[Dict] = None, ffn_has_norm: bool = False, fc_type: str = 'torch', @@ -228,18 +251,7 @@ def __init__( assert attn_config is not None assert isinstance(attn_config['attn_type'], str) - # necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs - args_to_exclude_in_attn_class = { - 'attn_type', - 'alibi', - 'attn_uses_sequence_id', - 'alibi_bias_max', - 'rope', - 'rope_theta', - 'rope_impl', - 'rope_dail_config', - 'rope_hf_config', - } + # Necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs attn_config_subset_for_attn_class = { k: v for k, v in attn_config.items() diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 78653fabdc..a1716fa214 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -140,7 +140,9 @@ def __init__( self.n_heads = n_heads self.n_layers = n_layers self.expansion_ratio = expansion_ratio - self.max_seq_len = max_seq_len + if max_seq_len != int(max_seq_len): + raise ValueError('max_seq_len must be an integer') + self.max_seq_len = int(max_seq_len) self.vocab_size = vocab_size self.resid_pdrop = resid_pdrop self.emb_pdrop = emb_pdrop @@ -327,3 +329,5 @@ def _validate_config(self) -> None: raise ImportError( 'In order to set `use_pad_tok_in_ffn=False`, please install flash-attn==1.0.9 or flash-attn==2.3.6', ) + if (self.attn_config.get('seq_parallel_world_size', 1) or 1) > 1: + raise NotImplementedError('Sequence Parallelism is not supported.') diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 15f1440b47..8726879208 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -365,15 +365,9 @@ def __init__(self, config: MPTConfig): ) self.emb_drop = nn.Dropout(config.emb_pdrop) self.mb_args = None - block_args = config.to_dict() - if block_args['ffn_config']['ffn_type'] in ffns_with_megablocks: - block_args['ffn_config'] = config_moe_args( - block_args['ffn_config'], - config.d_model, - config.expansion_ratio, - config.n_layers, - ) - self.mb_args = block_args['ffn_config'].get('args') + self.shift_labels = True + + block_args = self.extract_block_args(config.to_dict()) self.blocks = nn.ModuleList([ MPTBlock( @@ -442,6 +436,18 @@ def __init__(self, config: MPTConfig): log.debug(self) log.debug(f'Using {self.config.init_config["name"]} initialization.') + def extract_block_args(self, block_args: Dict[str, Any]) -> Dict[str, Any]: + """Sets the block args.""" + if block_args['ffn_config']['ffn_type'] in ffns_with_megablocks: + block_args['ffn_config'] = config_moe_args( + block_args['ffn_config'], + block_args['d_model'], + block_args['expansion_ratio'], + block_args['n_layers'], + ) + self.mb_args = block_args['ffn_config'].get('args') + return block_args + def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]: return self.wte @@ -581,17 +587,17 @@ def forward( ) elif input_ids is not None: bsz = input_ids.size(0) - S = input_ids.size(1) x = self.wte(input_ids) input_device = input_ids.device elif inputs_embeds is not None: bsz = inputs_embeds.size(0) - S = inputs_embeds.size(1) x = inputs_embeds input_device = inputs_embeds.device else: raise ValueError('You must specify input_ids or inputs_embeds') + S = self.get_sequence_length(x) + assert ( S <= self.config.max_seq_len ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' @@ -744,6 +750,17 @@ def forward( attentions=all_self_attns, ) + def get_sequence_length(self, x: torch.Tensor) -> int: + """Returns the sequence length. + + Args: + x (torch.Tensor): The input Tensor. + + Returns: + S (int): The sequence length. + """ + return x.size(1) + # Param Initialization, needed for device='meta' fast initialization def param_init_fn(self, module: nn.Module) -> None: init_fn_name = self.config.init_config['name'] @@ -1084,7 +1101,7 @@ def __init__( use_logits=True, metrics=train_metrics, eval_metrics=eval_metrics, - shift_labels=True, + shift_labels=model.transformer.shift_labels, allow_embedding_resizing=True, ) @@ -1140,7 +1157,11 @@ def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: def loss(self, outputs: CausalLMOutputWithPast, batch: Mapping) -> Union[dict, torch.Tensor]: - targets = self.get_targets(batch) + if self.model.transformer.shift_labels: + targets = self.get_targets(batch) + else: + targets = batch['labels'] + losses = self.loss_fn( outputs.logits.view(-1, outputs.logits.size(-1)), targets.view(-1), @@ -1150,6 +1171,12 @@ def loss(self, outputs: CausalLMOutputWithPast, loss = losses.sum() else: loss = losses.sum() / (targets != self.loss_fn.ignore_index).sum() + if 'sample_weighing_factor' in batch: + if batch['sample_weighing_factor'].shape[0] > 1: + raise ValueError( + 'Sample weighing factor is not supported when batch["sample_weighing_factor"].shape[0] > 1.', + ) + loss = loss * batch['sample_weighing_factor'][0].item() if self.config.ffn_config['ffn_type'] in ffns_with_megablocks: # MegaBlocks MoE load balancing loss @@ -1187,9 +1214,19 @@ def flops_per_batch(self, batch: Mapping): params = self.n_active_params params_flops_per_token = 2 * params params_flops_per_seq = params_flops_per_token * msl - attn_flops_per_seq = ( + attn_flops_per_seq = self.get_attention_flops(msl) + return (params_flops_per_seq + attn_flops_per_seq) * 3 * bs + + def get_attention_flops(self, msl: int) -> int: + """Computes the attention flops for the batch. + + Args: + msl (int): The batch sequence length. + + Returns: + attn_flops (int): The attention flops. + """ + return ( self.model.config.n_layers * 2 * 2 * (self.model.config.d_model * (msl**2)) ) - - return (params_flops_per_seq + attn_flops_per_seq) * 3 * bs diff --git a/llmfoundry/models/utils/config_moe_args.py b/llmfoundry/models/utils/config_moe_args.py index 2d9a8cadd4..29f8d1bfcc 100644 --- a/llmfoundry/models/utils/config_moe_args.py +++ b/llmfoundry/models/utils/config_moe_args.py @@ -3,6 +3,7 @@ """Helper function to configure MPT with MoEs.""" +import inspect from typing import Union import torch @@ -143,7 +144,10 @@ def config_megablocks_moe_args( elif lbl_process_group == 'global_group': lbl_process_group = distributed.group.WORLD elif isinstance(lbl_process_group, int): - lbl_process_group = create_set_process_group(lbl_process_group) + if lbl_process_group > 1: + lbl_process_group = create_set_process_group(lbl_process_group) + else: + lbl_process_group = None elif lbl_process_group is not None: raise ValueError( f'Unknown {lbl_process_group=}. Options are: none | expert_group | global_group | .', @@ -153,6 +157,14 @@ def config_megablocks_moe_args( ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio) ffn_config.setdefault('ffn_hidden_size', ffn_hidden_size) + args_to_keep_in_ffn_config = inspect.signature( + megablocks.layers.arguments.Arguments, + ).parameters + + ffn_config = { + k: v for k, v in ffn_config.items() if k in args_to_keep_in_ffn_config + } + args = megablocks.layers.arguments.Arguments( hidden_size=d_model, **ffn_config, diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 211ed08d3e..9470ce2ac6 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -360,19 +360,29 @@ def get_hf_config_value(config: Union[dict, PretrainedConfig], key: str) -> Any: def calculate_batch_size_info( global_batch_size: int, - device_microbatch_size: Union[int, Literal['auto']], -) -> Tuple[int, Union[int, Literal['auto']], Union[int, Literal['auto']]]: - if global_batch_size % dist.get_world_size() != 0: + device_microbatch_size: Union[int, float, Literal['auto']], + data_replication_degree: int = 1, +) -> Tuple[Union[int, float], Union[int, float, Literal['auto']], Union[ + int, Literal['auto']]]: + if dist.get_world_size() % data_replication_degree != 0: raise ValueError( - f'Global batch size {global_batch_size} is not divisible by {dist.get_world_size()} ' + f'World size {dist.get_world_size()} is not divisible by data replication degree {data_replication_degree}.', + ) + if global_batch_size % ( + dist.get_world_size() // data_replication_degree + ) != 0: + raise ValueError( + f'Global batchsize {global_batch_size} is not divisible by {(dist.get_world_size() // data_replication_degree)=} ' + 'as a result, the batch size would be truncated, please adjust `global_batch_size` ' + f'to be divisible by world size, {dist.get_world_size()}.', ) - device_batch_size = global_batch_size // dist.get_world_size() + device_batch_size = global_batch_size / dist.get_world_size() + if device_batch_size == round(device_batch_size): + device_batch_size = round(device_batch_size) if device_microbatch_size == 'auto': device_grad_accum = 'auto' - elif isinstance(device_microbatch_size, int): + elif isinstance(device_microbatch_size, (int, float)): if device_microbatch_size > device_batch_size: log.warn( f'device_microbatch_size > device_batch_size, ' + @@ -390,9 +400,11 @@ def calculate_batch_size_info( # Coming soon: this conversion math will be done inside Composer Trainer def update_batch_size_info(cfg: Dict[str, Any]) -> Dict[str, Any]: + data_replication_degree = 1 device_train_batch_size, device_train_microbatch_size, device_train_grad_accum = calculate_batch_size_info( cfg['global_train_batch_size'], cfg['device_train_microbatch_size'], + data_replication_degree=data_replication_degree, ) cfg['n_gpus'] = dist.get_world_size() cfg['device_train_batch_size'] = device_train_batch_size