From aa9509e9be67d79e4b6bceb12e9f0fc20f083e19 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 5 Oct 2023 22:00:25 +0000 Subject: [PATCH 001/106] .. --- llmfoundry/models/layers/attention.py | 34 +++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index bea6284fb5..f313210e48 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -17,6 +17,30 @@ from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY +def _apply_rotary_position_embeddings(rotation_matrix: torch.Tensor, + query: torch.Tensor, key: torch.Tensor): + rotation_matrix = rotation_matrix.unsqueeze(-2).unsqueeze(0) + # sin [batch_size, sequence_length, num_heads, embed_size_per_head//2] + # cos [batch_size, sequence_length, num_heads, embed_size_per_head//2] + sin, cos = rotation_matrix.chunk(2, dim=-1) + # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] + sin_pos = torch.stack([sin, sin], dim=-1).reshape_as(rotation_matrix) + # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] + cos_pos = torch.stack([cos, cos], dim=-1).reshape_as(rotation_matrix) + + # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2] + rotate_half_query_layer = torch.stack([-query[..., 1::2], query[..., ::2]], + dim=-1).reshape_as(query) + query = (query * cos_pos + rotate_half_query_layer * sin_pos).to(query) + + # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2] + rotate_half_key_layer = torch.stack([-key[..., 1::2], key[..., ::2]], + dim=-1).reshape_as(key) + key = (key * cos_pos + rotate_half_key_layer * sin_pos).to(key) + + return query, key + + def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool) -> bool: # disable causal when it is not needed @@ -521,6 +545,7 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + rotation_matrix: Optional[torch.Tensor] = None, is_causal: bool = True, needs_weights: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ @@ -539,6 +564,15 @@ def forward( dim=2, ) + if rotation_matrix is not None: + query = query.view(*(query.shape[:-1]), -1, self.head_dim) + key = key.view(*(key.shape[:-1]), -1, self.head_dim) + query, key = self.apply_rotary_position_embeddings( + rotation_matrix, query, key) + query = query.reshape(*(query.shape[:-2]), self.d_model) + key = key.reshape(*(key.shape[:-2]), + self.kv_n_heads * self.head_dim) + key_padding_mask = attention_mask if self.qk_ln: From 7354fcceb2aedccee09489be1875fc396fc56948 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 5 Oct 2023 15:23:11 -0700 Subject: [PATCH 002/106] .. --- llmfoundry/models/layers/attention.py | 2 +- llmfoundry/models/layers/blocks.py | 6 +++- llmfoundry/models/mpt/configuration_mpt.py | 4 +++ llmfoundry/models/mpt/modeling_mpt.py | 33 ++++++++++++++++++++++ 4 files changed, 43 insertions(+), 2 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index f313210e48..4a830c6f0c 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -16,7 +16,7 @@ from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY - +# Code taken from https://github.com/huggingface/transformers/blob/v4.33.3/src/transformers/models/roformer/modeling_roformer.py def _apply_rotary_position_embeddings(rotation_matrix: torch.Tensor, query: torch.Tensor, key: torch.Tensor): rotation_matrix = rotation_matrix.unsqueeze(-2).unsqueeze(0) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index a08ef6d77f..c3df3e65bd 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -41,6 +41,8 @@ def __init__( 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8, + 'rope': False, + 'rope_bf': 10000, } if ffn_config is None: @@ -58,7 +60,7 @@ def __init__( # necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs args_to_exclude_in_attn_class = { 'attn_type', 'prefix_lm', 'alibi', 'attn_uses_sequence_id', - 'alibi_bias_max' + 'alibi_bias_max', 'rope', 'rope_bf' } attn_config_subset_for_attn_class = { k: v @@ -94,6 +96,7 @@ def forward( x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, + rotation_matrix: Optional[torch.Tensor] = None, attention_mask: Optional[torch.ByteTensor] = None, is_causal: bool = True, output_attentions: bool = False, @@ -104,6 +107,7 @@ def forward( a, past_key_value=past_key_value, attn_bias=attn_bias, + rotation_matrix=rotation_matrix, attention_mask=attention_mask, is_causal=is_causal, needs_weights=output_attentions, diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 251e4f5caf..e535b9b44d 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -19,6 +19,8 @@ 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8, + 'rope': False, + 'rope_bf': 10000, } ffn_config_defaults: Dict = { @@ -94,6 +96,8 @@ def __init__( Defaults to ``False`` meaning any provided `sequence_id` will be ignored. alibi (bool): Whether to use the alibi bias instead of position embeddings. alibi_bias_max (int): The maximum value of the alibi bias. + rope (bool): Whether to use rotary positional embeddings. + rope_bf (int): The base frequency for rope. kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads. ffn_config (Dict): A dictionary used to configure the model's ffn module: ffn_type (str): type of ffn to use. Options: mptmlp, te_ln_mlp diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index cd162195b6..3b3e55a054 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -144,6 +144,15 @@ def __init__(self, config: MPTConfig): use_sequence_id=self.attn_uses_sequence_id, ) + self.alibi = config.attn_config['rope'] + self._rotation_matrix_initialized = False + self.rotation_matrix = None + assert (config.d_model % config.n_heads == 0) + self.rope_max_seq_len = config.max_seq_len + self.rope_head_dim = config.d_model//config.n_heads + self.rope_bf = config.attn_config['rope_bf'] + self.rotation_matrix_shape = (config.max_seq_len, self.rope_head_dim) + if config.no_bias: for module in self.modules(): if hasattr(module, 'bias') and isinstance( @@ -237,6 +246,27 @@ def _attn_bias( return attn_bias, None + # Code taken from https://github.com/huggingface/transformers/blob/v4.33.3/src/transformers/models/roformer/modeling_roformer.py + @torch.no_grad() + def _rotation_matrix(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + """ + Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in + the 2nd half of the vector. [head_dim // 2:] + """ + if not self._rotation_matrix_initialized: + self.rotation_matrix = None + if self.rope: + self.rotation_matrix = torch.empty(self.rotation_matrix_shape, device=device, dtype=dtype) + + theta = 1/(torch.pow(self.rope_bf, 2 * (torch.arange(self.rope_head_dim)//2) / self.rope_head_dim)) + position_enc = torch.outer(torch.arange(self.rope_max_seq_len), theta) + + sentinel = self.rope_head_dim // 2 if self.rope_head_dim % 2 == 0 else (self.rope_head_dim // 2) + 1 + self.rotation_matrix[:, 0:sentinel] = torch.FloatTensor(torch.sin(position_enc[:, 0::2])).to(self.rotation_matrix) + self.rotation_matrix[:, sentinel:] = torch.FloatTensor(torch.cos(position_enc[:, 1::2])).to(self.rotation_matrix) + self._rotation_matrix_initialized = True + return self.rotation_matrix + def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor) -> torch.Tensor: s_k, s_q = attn_bias.shape[-2:] @@ -421,6 +451,8 @@ def forward( sequence_id=sequence_id, ) + rotation_matrix = self._rotation_matrix(device=x.device, dtype=torch.float32) + # initialize the past key values cache if it should be used if use_cache and past_key_values is None: past_key_values = [() for _ in range(self.config.n_layers) @@ -438,6 +470,7 @@ def forward( x, past_key_value=past_key_value, attn_bias=attn_bias, + rotation_matrix=rotation_matrix, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions), From 6801142db7aed40945ca73c3e66951c2b81aa8b4 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 5 Oct 2023 15:38:54 -0700 Subject: [PATCH 003/106] .. --- llmfoundry/models/mpt/modeling_mpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 3b3e55a054..ae0b286956 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -264,7 +264,7 @@ def _rotation_matrix(self, device: torch.device, dtype: torch.dtype) -> torch.Te sentinel = self.rope_head_dim // 2 if self.rope_head_dim % 2 == 0 else (self.rope_head_dim // 2) + 1 self.rotation_matrix[:, 0:sentinel] = torch.FloatTensor(torch.sin(position_enc[:, 0::2])).to(self.rotation_matrix) self.rotation_matrix[:, sentinel:] = torch.FloatTensor(torch.cos(position_enc[:, 1::2])).to(self.rotation_matrix) - self._rotation_matrix_initialized = True + self._rotation_matrix_initialized = True return self.rotation_matrix def _apply_prefix_mask(self, attn_bias: torch.Tensor, From eff62709b0d9dec4760f1be12ac62997e0aef23e Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 5 Oct 2023 15:55:48 -0700 Subject: [PATCH 004/106] .. --- .../yamls/pretrain/mpt-125m-realistic.yaml | 177 ++++++++++++++++++ 1 file changed, 177 insertions(+) create mode 100644 scripts/train/yamls/pretrain/mpt-125m-realistic.yaml diff --git a/scripts/train/yamls/pretrain/mpt-125m-realistic.yaml b/scripts/train/yamls/pretrain/mpt-125m-realistic.yaml new file mode 100644 index 0000000000..2eb571aca3 --- /dev/null +++ b/scripts/train/yamls/pretrain/mpt-125m-realistic.yaml @@ -0,0 +1,177 @@ +global_seed: 17 +max_seq_len: 2048 # max tokens per sequence during training + +data_local: /root/my-copy-c4 +data_remote: oci://mosaicml-internal-dataset-c4/preconcat-gpt_neox/ # If blank, files must be present in data_local + +# Model +model: + name: mpt_causal_lm + init_device: meta + d_model: 768 + n_heads: 12 + n_layers: 12 + expansion_ratio: 4 + max_seq_len: ${max_seq_len} + vocab_size: 50368 # update for hero run with custom tokenizer + no_bias: true + attn_config: + alibi: false + attn_impl: triton + clip_qkv: 6 + attn_uses_sequence_id: true + pos_bias_type: rope # Keep alibi: true, and pos_bias_type can be one of alibi, rope, or... + rope_bf: 10000 # Use 10000 for original rope, 500000 for rope_abf (https://arxiv.org/pdf/2309.16039.pdf) + +# Tokenizer +tokenizer: + name: EleutherAI/gpt-neox-20b # default tokenizer used for MPT + kwargs: + model_max_length: ${max_seq_len} + +# Optimization +global_train_batch_size: 512 # ~1M tokens, update for hero run, must be divisible by gpu_num +max_duration: 100000ba # update for hero run, e.g. 100000ba ~= 100B tokens + +optimizer: + name: decoupled_lionw + lr: 0.0006 + betas: + - 0.9 + - 0.95 + weight_decay: 0.0006 + +scheduler: + name: cosine_with_warmup + t_warmup: 5000ba + alpha_f: 0.1 + +# System +seed: ${global_seed} +precision: amp_bf16 +device_train_microbatch_size: 16 # NOTE: Please ensure that ${gpus} is less than or equal to ${global_train_batch_size}/${device_train_microbatch_size}. + # 1. If you want to increase the number of gpus beyond this number, please decrease device_train_microbatch_size correspondingly to + # satisfy this equation. + # 2. If you want to increase the number of gpus beyond this number and keep device_train_microbatch_size the same (or increase it), + # you would need to increase global_train_batch_size. However, in that case other hyperparameters like lr would probably need to + # be changed and we cannot guarantee convergence. +device_eval_batch_size: 16 + +# FSDP +fsdp_config: + activation_checkpointing: false + activation_checkpointing_reentrant: false + activation_cpu_offload: false + limit_all_gathers: true + mixed_precision: PURE + sharding_strategy: SHARD_GRAD_OP + state_dict_type: full + verbose: false + +# Logging +eval_first: false +eval_interval: 900000ba +log_to_console: true +console_log_interval: 1ba +progress_bar: false +python_log_level: DEBUG +loggers: + wandb: {} + +# Checkpointing +autoresume: true +save_filename: ep{epoch}-ba{batch}/rank{rank}.pt +save_folder: oci://mosaicml-internal-checkpoints/shashank/rope-vs-alibi/mpt-125m # update for hero run +save_interval: 100000ba # update for hero run, e.g. 2000ba +save_num_checkpoints_to_keep: 1 + +# Algos and Callbacks +algorithms: + gradient_clipping: + clipping_threshold: 1 + clipping_type: norm + +callbacks: + # generate_callback: + # batch_log_interval: 10 # update for hero run, e.g. 2000 + # do_sample: true + # max_new_tokens: 100 + # prompts: + # - The quick brown fox jumps over + # - |- + # Vegan Banana Bread + # Instructions: + # 1. + # - The other day I was explaining what generative AI is to my five year old. + # temperature: 1 + # top_k: 50 + # top_p: 0.95 + # use_cache: true + lr_monitor: {} + memory_monitor: {} + # mono_ckpt_saver: + # batch_interval: ${save_interval} + # filename: ep{epoch}-ba{batch}/mono.pt + # save_folder: ${save_folder} + runtime_estimator: {} + scheduled_gc: + batch_interval: 2000 + speed_monitor: + window_size: 10 + +# Dataloaders +eos_token_id: 0 # update for hero run with custom tokenizer +num_canonical_nodes: 128 # update for hero run, must be codivisible by # physical nodes + +# Eval loader +eval_loader: + name: text + drop_last: false + num_workers: 8 + + # dataset yaml from https://github.com/mosaicml/llm-foundry/blob/main/scripts/train/yamls/pretrain/mpt-125m.yaml - + dataset: + eos_token_id: ${eos_token_id} + local: ${data_local} + remote: ${data_remote} + split: val + shuffle: false + max_seq_len: ${max_seq_len} + shuffle_seed: ${global_seed} + + + +# In-context-learning tasks + +## If you want to use one of our suites of tasks +# icl_tasks: eval/yamls/tasks_light.yaml + +## Or if you want to manually specify individual tasks +icl_tasks: +- + label: piqa + dataset_uri: eval/local_data/commonsense_reasoning/piqa.jsonl + num_fewshot: [0, 1, 5] + icl_task_type: multiple_choice + continuation_delimiter: 'Answer: ' +- + label: lambada + dataset_uri: eval/local_data/language_understanding/lambada_openai.jsonl + num_fewshot: [0] + icl_task_type: language_modeling + +# Train loader +train_loader: + name: text + drop_last: true + num_workers: 8 + + # dataset yaml from https://github.com/mosaicml/llm-foundry/blob/main/scripts/train/yamls/pretrain/mpt-125m.yaml - + dataset: + eos_token_id: ${eos_token_id} + local: ${data_local} + remote: ${data_remote} + split: train + shuffle: true + max_seq_len: ${max_seq_len} + shuffle_seed: ${global_seed} \ No newline at end of file From cdc6798d666aad97c33f58110a8f565d2f655e00 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 5 Oct 2023 15:57:04 -0700 Subject: [PATCH 005/106] .. --- scripts/train/yamls/pretrain/mpt-125m-realistic.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/train/yamls/pretrain/mpt-125m-realistic.yaml b/scripts/train/yamls/pretrain/mpt-125m-realistic.yaml index 2eb571aca3..5e69108c21 100644 --- a/scripts/train/yamls/pretrain/mpt-125m-realistic.yaml +++ b/scripts/train/yamls/pretrain/mpt-125m-realistic.yaml @@ -20,7 +20,7 @@ model: attn_impl: triton clip_qkv: 6 attn_uses_sequence_id: true - pos_bias_type: rope # Keep alibi: true, and pos_bias_type can be one of alibi, rope, or... + rope: true # Keep alibi: true, and pos_bias_type can be one of alibi, rope, or... rope_bf: 10000 # Use 10000 for original rope, 500000 for rope_abf (https://arxiv.org/pdf/2309.16039.pdf) # Tokenizer From a74afb42d92cbcf6279712486e0a994b075b9bc7 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 5 Oct 2023 16:00:43 -0700 Subject: [PATCH 006/106] .. --- llmfoundry/models/mpt/modeling_mpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index ae0b286956..2294498044 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -144,7 +144,7 @@ def __init__(self, config: MPTConfig): use_sequence_id=self.attn_uses_sequence_id, ) - self.alibi = config.attn_config['rope'] + self.rope = config.attn_config['rope'] self._rotation_matrix_initialized = False self.rotation_matrix = None assert (config.d_model % config.n_heads == 0) From 3c025859e3f01629f38bf3a3765f9e9d07971c65 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 5 Oct 2023 16:03:26 -0700 Subject: [PATCH 007/106] .. --- llmfoundry/models/layers/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 4a830c6f0c..427188350b 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -567,7 +567,7 @@ def forward( if rotation_matrix is not None: query = query.view(*(query.shape[:-1]), -1, self.head_dim) key = key.view(*(key.shape[:-1]), -1, self.head_dim) - query, key = self.apply_rotary_position_embeddings( + query, key = self._apply_rotary_position_embeddings( rotation_matrix, query, key) query = query.reshape(*(query.shape[:-2]), self.d_model) key = key.reshape(*(key.shape[:-2]), From 47f5af676d088ea532f8e5da4b127cbaf1788bac Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 5 Oct 2023 16:06:37 -0700 Subject: [PATCH 008/106] .. --- llmfoundry/models/layers/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 427188350b..6364ac9f8e 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -567,7 +567,7 @@ def forward( if rotation_matrix is not None: query = query.view(*(query.shape[:-1]), -1, self.head_dim) key = key.view(*(key.shape[:-1]), -1, self.head_dim) - query, key = self._apply_rotary_position_embeddings( + query, key = _apply_rotary_position_embeddings( rotation_matrix, query, key) query = query.reshape(*(query.shape[:-2]), self.d_model) key = key.reshape(*(key.shape[:-2]), From 3389f782e03bc92b2de044b48587b1221241ea14 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 5 Oct 2023 16:12:56 -0700 Subject: [PATCH 009/106] .. --- llmfoundry/models/layers/attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 6364ac9f8e..65296d0935 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -565,6 +565,7 @@ def forward( ) if rotation_matrix is not None: + breakpoint() query = query.view(*(query.shape[:-1]), -1, self.head_dim) key = key.view(*(key.shape[:-1]), -1, self.head_dim) query, key = _apply_rotary_position_embeddings( From c9f2154490b401714f3ebb6a2ef13092ca70b045 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 5 Oct 2023 16:14:54 -0700 Subject: [PATCH 010/106] .. --- llmfoundry/models/layers/attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 65296d0935..6364ac9f8e 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -565,7 +565,6 @@ def forward( ) if rotation_matrix is not None: - breakpoint() query = query.view(*(query.shape[:-1]), -1, self.head_dim) key = key.view(*(key.shape[:-1]), -1, self.head_dim) query, key = _apply_rotary_position_embeddings( From 9db76a8b9221bfb3530ed6d6a040dd1d700f6d4a Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 5 Oct 2023 16:26:46 -0700 Subject: [PATCH 011/106] .. --- llmfoundry/models/layers/attention.py | 1 + .../pretrain/mpt-125m-realistic-rope.yaml | 177 ++++++++++++++++++ .../yamls/pretrain/mpt-125m-realistic.yaml | 4 +- 3 files changed, 180 insertions(+), 2 deletions(-) create mode 100644 scripts/train/yamls/pretrain/mpt-125m-realistic-rope.yaml diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 6364ac9f8e..c5df05b384 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -564,6 +564,7 @@ def forward( dim=2, ) + breakpoint() if rotation_matrix is not None: query = query.view(*(query.shape[:-1]), -1, self.head_dim) key = key.view(*(key.shape[:-1]), -1, self.head_dim) diff --git a/scripts/train/yamls/pretrain/mpt-125m-realistic-rope.yaml b/scripts/train/yamls/pretrain/mpt-125m-realistic-rope.yaml new file mode 100644 index 0000000000..5e69108c21 --- /dev/null +++ b/scripts/train/yamls/pretrain/mpt-125m-realistic-rope.yaml @@ -0,0 +1,177 @@ +global_seed: 17 +max_seq_len: 2048 # max tokens per sequence during training + +data_local: /root/my-copy-c4 +data_remote: oci://mosaicml-internal-dataset-c4/preconcat-gpt_neox/ # If blank, files must be present in data_local + +# Model +model: + name: mpt_causal_lm + init_device: meta + d_model: 768 + n_heads: 12 + n_layers: 12 + expansion_ratio: 4 + max_seq_len: ${max_seq_len} + vocab_size: 50368 # update for hero run with custom tokenizer + no_bias: true + attn_config: + alibi: false + attn_impl: triton + clip_qkv: 6 + attn_uses_sequence_id: true + rope: true # Keep alibi: true, and pos_bias_type can be one of alibi, rope, or... + rope_bf: 10000 # Use 10000 for original rope, 500000 for rope_abf (https://arxiv.org/pdf/2309.16039.pdf) + +# Tokenizer +tokenizer: + name: EleutherAI/gpt-neox-20b # default tokenizer used for MPT + kwargs: + model_max_length: ${max_seq_len} + +# Optimization +global_train_batch_size: 512 # ~1M tokens, update for hero run, must be divisible by gpu_num +max_duration: 100000ba # update for hero run, e.g. 100000ba ~= 100B tokens + +optimizer: + name: decoupled_lionw + lr: 0.0006 + betas: + - 0.9 + - 0.95 + weight_decay: 0.0006 + +scheduler: + name: cosine_with_warmup + t_warmup: 5000ba + alpha_f: 0.1 + +# System +seed: ${global_seed} +precision: amp_bf16 +device_train_microbatch_size: 16 # NOTE: Please ensure that ${gpus} is less than or equal to ${global_train_batch_size}/${device_train_microbatch_size}. + # 1. If you want to increase the number of gpus beyond this number, please decrease device_train_microbatch_size correspondingly to + # satisfy this equation. + # 2. If you want to increase the number of gpus beyond this number and keep device_train_microbatch_size the same (or increase it), + # you would need to increase global_train_batch_size. However, in that case other hyperparameters like lr would probably need to + # be changed and we cannot guarantee convergence. +device_eval_batch_size: 16 + +# FSDP +fsdp_config: + activation_checkpointing: false + activation_checkpointing_reentrant: false + activation_cpu_offload: false + limit_all_gathers: true + mixed_precision: PURE + sharding_strategy: SHARD_GRAD_OP + state_dict_type: full + verbose: false + +# Logging +eval_first: false +eval_interval: 900000ba +log_to_console: true +console_log_interval: 1ba +progress_bar: false +python_log_level: DEBUG +loggers: + wandb: {} + +# Checkpointing +autoresume: true +save_filename: ep{epoch}-ba{batch}/rank{rank}.pt +save_folder: oci://mosaicml-internal-checkpoints/shashank/rope-vs-alibi/mpt-125m # update for hero run +save_interval: 100000ba # update for hero run, e.g. 2000ba +save_num_checkpoints_to_keep: 1 + +# Algos and Callbacks +algorithms: + gradient_clipping: + clipping_threshold: 1 + clipping_type: norm + +callbacks: + # generate_callback: + # batch_log_interval: 10 # update for hero run, e.g. 2000 + # do_sample: true + # max_new_tokens: 100 + # prompts: + # - The quick brown fox jumps over + # - |- + # Vegan Banana Bread + # Instructions: + # 1. + # - The other day I was explaining what generative AI is to my five year old. + # temperature: 1 + # top_k: 50 + # top_p: 0.95 + # use_cache: true + lr_monitor: {} + memory_monitor: {} + # mono_ckpt_saver: + # batch_interval: ${save_interval} + # filename: ep{epoch}-ba{batch}/mono.pt + # save_folder: ${save_folder} + runtime_estimator: {} + scheduled_gc: + batch_interval: 2000 + speed_monitor: + window_size: 10 + +# Dataloaders +eos_token_id: 0 # update for hero run with custom tokenizer +num_canonical_nodes: 128 # update for hero run, must be codivisible by # physical nodes + +# Eval loader +eval_loader: + name: text + drop_last: false + num_workers: 8 + + # dataset yaml from https://github.com/mosaicml/llm-foundry/blob/main/scripts/train/yamls/pretrain/mpt-125m.yaml - + dataset: + eos_token_id: ${eos_token_id} + local: ${data_local} + remote: ${data_remote} + split: val + shuffle: false + max_seq_len: ${max_seq_len} + shuffle_seed: ${global_seed} + + + +# In-context-learning tasks + +## If you want to use one of our suites of tasks +# icl_tasks: eval/yamls/tasks_light.yaml + +## Or if you want to manually specify individual tasks +icl_tasks: +- + label: piqa + dataset_uri: eval/local_data/commonsense_reasoning/piqa.jsonl + num_fewshot: [0, 1, 5] + icl_task_type: multiple_choice + continuation_delimiter: 'Answer: ' +- + label: lambada + dataset_uri: eval/local_data/language_understanding/lambada_openai.jsonl + num_fewshot: [0] + icl_task_type: language_modeling + +# Train loader +train_loader: + name: text + drop_last: true + num_workers: 8 + + # dataset yaml from https://github.com/mosaicml/llm-foundry/blob/main/scripts/train/yamls/pretrain/mpt-125m.yaml - + dataset: + eos_token_id: ${eos_token_id} + local: ${data_local} + remote: ${data_remote} + split: train + shuffle: true + max_seq_len: ${max_seq_len} + shuffle_seed: ${global_seed} \ No newline at end of file diff --git a/scripts/train/yamls/pretrain/mpt-125m-realistic.yaml b/scripts/train/yamls/pretrain/mpt-125m-realistic.yaml index 5e69108c21..d414c837c3 100644 --- a/scripts/train/yamls/pretrain/mpt-125m-realistic.yaml +++ b/scripts/train/yamls/pretrain/mpt-125m-realistic.yaml @@ -16,11 +16,11 @@ model: vocab_size: 50368 # update for hero run with custom tokenizer no_bias: true attn_config: - alibi: false + alibi: true attn_impl: triton clip_qkv: 6 attn_uses_sequence_id: true - rope: true # Keep alibi: true, and pos_bias_type can be one of alibi, rope, or... + rope: false # Keep alibi: true, and pos_bias_type can be one of alibi, rope, or... rope_bf: 10000 # Use 10000 for original rope, 500000 for rope_abf (https://arxiv.org/pdf/2309.16039.pdf) # Tokenizer From 722eb0c2b5c6411aa91bcea6ba189940bc7d1029 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 5 Oct 2023 16:33:40 -0700 Subject: [PATCH 012/106] .. --- llmfoundry/models/layers/attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index c5df05b384..6364ac9f8e 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -564,7 +564,6 @@ def forward( dim=2, ) - breakpoint() if rotation_matrix is not None: query = query.view(*(query.shape[:-1]), -1, self.head_dim) key = key.view(*(key.shape[:-1]), -1, self.head_dim) From 4eb9f17f5f7270661614d699f31e93571099baf5 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 6 Oct 2023 07:35:14 -0700 Subject: [PATCH 013/106] .. --- llmfoundry/models/layers/__init__.py | 2 + llmfoundry/models/layers/attention.py | 23 +++-- llmfoundry/models/layers/blocks.py | 3 + llmfoundry/models/layers/rotary_embedding.py | 56 ++++++++++++ llmfoundry/models/mpt/modeling_mpt.py | 89 +++++++++++++------- 5 files changed, 133 insertions(+), 40 deletions(-) create mode 100644 llmfoundry/models/layers/rotary_embedding.py diff --git a/llmfoundry/models/layers/__init__.py b/llmfoundry/models/layers/__init__.py index 68aa0fe7fe..fea55042d2 100644 --- a/llmfoundry/models/layers/__init__.py +++ b/llmfoundry/models/layers/__init__.py @@ -29,4 +29,6 @@ 'SharedEmbedding', 'FFN_CLASS_REGISTRY', 'build_ffn', + 'RotaryEmbedding', + 'apply_rotary_pos_emb', ] diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 6364ac9f8e..c4355b6f56 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -15,6 +15,7 @@ from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY +from llmfoundry.models.layers.rotary_embedding import apply_rotary_pos_emb # Code taken from https://github.com/huggingface/transformers/blob/v4.33.3/src/transformers/models/roformer/modeling_roformer.py def _apply_rotary_position_embeddings(rotation_matrix: torch.Tensor, @@ -546,6 +547,7 @@ def forward( attn_bias: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, rotation_matrix: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, is_causal: bool = True, needs_weights: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ @@ -564,14 +566,19 @@ def forward( dim=2, ) - if rotation_matrix is not None: - query = query.view(*(query.shape[:-1]), -1, self.head_dim) - key = key.view(*(key.shape[:-1]), -1, self.head_dim) - query, key = _apply_rotary_position_embeddings( - rotation_matrix, query, key) - query = query.reshape(*(query.shape[:-2]), self.d_model) - key = key.reshape(*(key.shape[:-2]), - self.kv_n_heads * self.head_dim) + # if rotation_matrix is not None: + # query = query.view(*(query.shape[:-1]), -1, self.head_dim) + # key = key.view(*(key.shape[:-1]), -1, self.head_dim) + # query, key = _apply_rotary_position_embeddings( + # rotation_matrix, query, key) + # query = query.reshape(*(query.shape[:-2]), self.d_model) + # key = key.reshape(*(key.shape[:-2]), + # self.kv_n_heads * self.head_dim) + + + if rotary_emb is not None: + (cos, sin, pos) = rotary_emb + query, key = apply_rotary_pos_emb(query, key, cos, sin, pos) key_padding_mask = attention_mask diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index c3df3e65bd..6dc11713ab 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -97,6 +97,8 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, rotation_matrix: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + pos: Optional[torch.Tensor] = None, attention_mask: Optional[torch.ByteTensor] = None, is_causal: bool = True, output_attentions: bool = False, @@ -108,6 +110,7 @@ def forward( past_key_value=past_key_value, attn_bias=attn_bias, rotation_matrix=rotation_matrix, + rotary_emb=rotary_emb, attention_mask=attention_mask, is_causal=is_causal, needs_weights=output_attentions, diff --git a/llmfoundry/models/layers/rotary_embedding.py b/llmfoundry/models/layers/rotary_embedding.py new file mode 100644 index 0000000000..c41af9bcc3 --- /dev/null +++ b/llmfoundry/models/layers/rotary_embedding.py @@ -0,0 +1,56 @@ +import torch +from torch import nn + +# Code taken from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py +class RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + def forward(self, device, dtype, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=device, dtype=dtype) + + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=dtype), + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed \ No newline at end of file diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 2294498044..947a5d8573 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -38,6 +38,7 @@ from llmfoundry.models.layers.ffn import MPTMLP as MPTMLP from llmfoundry.models.layers.ffn import build_ffn as build_ffn from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY +from llmfoundry.models.layers.rotary_embedding import RotaryEmbedding from llmfoundry.models.mpt.configuration_mpt import MPTConfig # NOTE: All utils are imported directly even if unused so that @@ -145,14 +146,18 @@ def __init__(self, config: MPTConfig): ) self.rope = config.attn_config['rope'] - self._rotation_matrix_initialized = False - self.rotation_matrix = None assert (config.d_model % config.n_heads == 0) - self.rope_max_seq_len = config.max_seq_len self.rope_head_dim = config.d_model//config.n_heads + self.rope_max_seq_len = config.max_seq_len self.rope_bf = config.attn_config['rope_bf'] + + self._rotation_matrix_initialized = False + self.rotation_matrix = None self.rotation_matrix_shape = (config.max_seq_len, self.rope_head_dim) + self._rotary_embedding_initialized = False + self.rotary_embedding = None + if config.no_bias: for module in self.modules(): if hasattr(module, 'bias') and isinstance( @@ -266,6 +271,22 @@ def _rotation_matrix(self, device: torch.device, dtype: torch.dtype) -> torch.Te self.rotation_matrix[:, sentinel:] = torch.FloatTensor(torch.cos(position_enc[:, 1::2])).to(self.rotation_matrix) self._rotation_matrix_initialized = True return self.rotation_matrix + + @torch.no_grad() + def _rotary_emb(self, device, dtype, seq_len) -> torch.Tensor: + if not self._rotary_embedding_initialized: + self.rotary_embedding= None + if self.rope: + self.rotary_embedding=RotaryEmbedding( + self.rope_head_dim, + max_position_embeddings=self.rope_max_seq_len, + base=self.rope_bf, + ) + + self._rotary_embedding_initialized = True + if self.rotary_embedding is None: + return None + return self.rotary_embedding(device, dtype, seq_len) def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor) -> torch.Tensor: @@ -392,42 +413,43 @@ def forward( ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' tok_emb = self.wte(input_ids) + past_position = 0 + if past_key_values is not None: + if len(past_key_values) != self.config.n_layers: + raise ValueError( + f'past_key_values must provide a past_key_value for each attention ' + + + f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).' + ) + # For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim). + # For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq). + # Here we shift position embedding using the `seq` dim of the past key + past_position = past_key_values[0][0].size(1) + if self.attn_impl == 'torch': + past_position = past_key_values[0][0].size(3) + + pos = torch.arange( + past_position, + S + past_position, + dtype=torch.long, + device=input_ids.device, + ).unsqueeze(0) + if attention_mask is not None: + # adjust the position indices to account for padding tokens + pos = torch.clamp( + pos - torch.cumsum((~attention_mask).to(torch.int32), + dim=1)[:, past_position:], + min=0, + ) + if self.learned_pos_emb: - past_position = 0 - if past_key_values is not None: - if len(past_key_values) != self.config.n_layers: - raise ValueError( - f'past_key_values must provide a past_key_value for each attention ' - + - f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).' - ) - # For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim). - # For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq). - # Here we shift position embedding using the `seq` dim of the past key - past_position = past_key_values[0][0].size(1) - if self.attn_impl == 'torch': - past_position = past_key_values[0][0].size(3) - if S + past_position > self.config.max_seq_len: raise ValueError( f'Cannot forward input with past sequence length {past_position} and current sequence length ' + f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.' ) - pos = torch.arange( - past_position, - S + past_position, - dtype=torch.long, - device=input_ids.device, - ).unsqueeze(0) - if attention_mask is not None: - # adjust the position indices to account for padding tokens - pos = torch.clamp( - pos - torch.cumsum((~attention_mask).to(torch.int32), - dim=1)[:, past_position:], - min=0, - ) - + pos_emb = self.wpe(pos) x = tok_emb + pos_emb else: @@ -452,6 +474,8 @@ def forward( ) rotation_matrix = self._rotation_matrix(device=x.device, dtype=torch.float32) + + (sin, cos) = self._rotary_emb(x.device, x.dtype, S) # initialize the past key values cache if it should be used if use_cache and past_key_values is None: @@ -471,6 +495,7 @@ def forward( past_key_value=past_key_value, attn_bias=attn_bias, rotation_matrix=rotation_matrix, + rotary_emb=(sin, cos, pos), attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions), From c675c22a92c72b3ca2a22514ed57546e1efae324 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 6 Oct 2023 18:29:09 +0000 Subject: [PATCH 014/106] .. --- llmfoundry/models/layers/attention.py | 35 +++++++----- llmfoundry/models/layers/rotary_embedding.py | 46 ++++++++++----- llmfoundry/models/mpt/modeling_mpt.py | 60 ++++++++++++-------- 3 files changed, 90 insertions(+), 51 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index c4355b6f56..9f7216935f 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -17,6 +17,7 @@ from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY from llmfoundry.models.layers.rotary_embedding import apply_rotary_pos_emb + # Code taken from https://github.com/huggingface/transformers/blob/v4.33.3/src/transformers/models/roformer/modeling_roformer.py def _apply_rotary_position_embeddings(rotation_matrix: torch.Tensor, query: torch.Tensor, key: torch.Tensor): @@ -566,19 +567,27 @@ def forward( dim=2, ) - # if rotation_matrix is not None: - # query = query.view(*(query.shape[:-1]), -1, self.head_dim) - # key = key.view(*(key.shape[:-1]), -1, self.head_dim) - # query, key = _apply_rotary_position_embeddings( - # rotation_matrix, query, key) - # query = query.reshape(*(query.shape[:-2]), self.d_model) - # key = key.reshape(*(key.shape[:-2]), - # self.kv_n_heads * self.head_dim) - - - if rotary_emb is not None: - (cos, sin, pos) = rotary_emb - query, key = apply_rotary_pos_emb(query, key, cos, sin, pos) + if rotation_matrix is not None: + query = query.view(*(query.shape[:-1]), -1, self.head_dim) + key = key.view(*(key.shape[:-1]), -1, self.head_dim) + + if False: # roformer implementation of rope + query, key = _apply_rotary_position_embeddings( + rotation_matrix, query, key) + + if True: # llama implementation of rope + (cos, sin, pos) = rotary_emb + query = query.transpose(1, 2) + key = key.transpose(1, 2) + query, key = apply_rotary_pos_emb(query, key, cos, sin, pos) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + + breakpoint() + + query = query.reshape(*(query.shape[:-2]), self.d_model) + key = key.reshape(*(key.shape[:-2]), + self.kv_n_heads * self.head_dim) key_padding_mask = attention_mask diff --git a/llmfoundry/models/layers/rotary_embedding.py b/llmfoundry/models/layers/rotary_embedding.py index c41af9bcc3..7f8303dc57 100644 --- a/llmfoundry/models/layers/rotary_embedding.py +++ b/llmfoundry/models/layers/rotary_embedding.py @@ -1,31 +1,49 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + import torch from torch import nn + # Code taken from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py class RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() + def __init__(self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + dtype=None): + super().__init__() + if dtype is None: + dtype = torch.get_default_dtype() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) + inv_freq = 1.0 / (self.base**( + torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer('inv_freq', inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) + self._set_cos_sin_cache(seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=dtype) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange(self.max_seq_len_cached, + device=device, + dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer('cos_cached', + emb.cos()[None, None, :, :].to(dtype), + persistent=False) + self.register_buffer('sin_cached', + emb.sin()[None, None, :, :].to(dtype), + persistent=False) def forward(self, device, dtype, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] @@ -40,8 +58,8 @@ def forward(self, device, dtype, seq_len=None): def rotate_half(x): """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) @@ -53,4 +71,4 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed \ No newline at end of file + return q_embed, k_embed diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 947a5d8573..9b7474567f 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -147,7 +147,7 @@ def __init__(self, config: MPTConfig): self.rope = config.attn_config['rope'] assert (config.d_model % config.n_heads == 0) - self.rope_head_dim = config.d_model//config.n_heads + self.rope_head_dim = config.d_model // config.n_heads self.rope_max_seq_len = config.max_seq_len self.rope_bf = config.attn_config['rope_bf'] @@ -253,39 +253,50 @@ def _attn_bias( # Code taken from https://github.com/huggingface/transformers/blob/v4.33.3/src/transformers/models/roformer/modeling_roformer.py @torch.no_grad() - def _rotation_matrix(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor: - """ - Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in - the 2nd half of the vector. [head_dim // 2:] + def _rotation_matrix(self, device: torch.device, + dtype: torch.dtype) -> torch.Tensor: + """Identical to the XLM create_sinusoidal_embeddings except features are + not interleaved. + + The cos features are in the 2nd half of the vector. [head_dim // 2:] """ if not self._rotation_matrix_initialized: self.rotation_matrix = None if self.rope: - self.rotation_matrix = torch.empty(self.rotation_matrix_shape, device=device, dtype=dtype) - - theta = 1/(torch.pow(self.rope_bf, 2 * (torch.arange(self.rope_head_dim)//2) / self.rope_head_dim)) - position_enc = torch.outer(torch.arange(self.rope_max_seq_len), theta) - - sentinel = self.rope_head_dim // 2 if self.rope_head_dim % 2 == 0 else (self.rope_head_dim // 2) + 1 - self.rotation_matrix[:, 0:sentinel] = torch.FloatTensor(torch.sin(position_enc[:, 0::2])).to(self.rotation_matrix) - self.rotation_matrix[:, sentinel:] = torch.FloatTensor(torch.cos(position_enc[:, 1::2])).to(self.rotation_matrix) + self.rotation_matrix = torch.empty(self.rotation_matrix_shape, + device=device, + dtype=dtype) + + theta = 1 / (torch.pow( + self.rope_bf, 2 * (torch.arange(self.rope_head_dim) // 2) / + self.rope_head_dim)) + position_enc = torch.outer(torch.arange(self.rope_max_seq_len), + theta) + + sentinel = self.rope_head_dim // 2 if self.rope_head_dim % 2 == 0 else ( + self.rope_head_dim // 2) + 1 + self.rotation_matrix[:, 0:sentinel] = torch.FloatTensor( + torch.sin(position_enc[:, 0::2])).to(self.rotation_matrix) + self.rotation_matrix[:, sentinel:] = torch.FloatTensor( + torch.cos(position_enc[:, 1::2])).to(self.rotation_matrix) self._rotation_matrix_initialized = True return self.rotation_matrix - + @torch.no_grad() def _rotary_emb(self, device, dtype, seq_len) -> torch.Tensor: if not self._rotary_embedding_initialized: - self.rotary_embedding= None + self.rotary_embedding = None if self.rope: - self.rotary_embedding=RotaryEmbedding( + self.rotary_embedding = RotaryEmbedding( self.rope_head_dim, max_position_embeddings=self.rope_max_seq_len, base=self.rope_bf, - ) - + device=device, + dtype=dtype) + self._rotary_embedding_initialized = True if self.rotary_embedding is None: - return None + return None return self.rotary_embedding(device, dtype, seq_len) def _apply_prefix_mask(self, attn_bias: torch.Tensor, @@ -438,10 +449,10 @@ def forward( # adjust the position indices to account for padding tokens pos = torch.clamp( pos - torch.cumsum((~attention_mask).to(torch.int32), - dim=1)[:, past_position:], + dim=1)[:, past_position:], min=0, ) - + if self.learned_pos_emb: if S + past_position > self.config.max_seq_len: raise ValueError( @@ -449,7 +460,7 @@ def forward( + f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.' ) - + pos_emb = self.wpe(pos) x = tok_emb + pos_emb else: @@ -473,8 +484,9 @@ def forward( sequence_id=sequence_id, ) - rotation_matrix = self._rotation_matrix(device=x.device, dtype=torch.float32) - + rotation_matrix = self._rotation_matrix(device=x.device, + dtype=torch.float32) + (sin, cos) = self._rotary_emb(x.device, x.dtype, S) # initialize the past key values cache if it should be used From de765c40595ec3612b4a4e7b4452ace307cb43f3 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 6 Oct 2023 18:44:28 +0000 Subject: [PATCH 015/106] .. --- llmfoundry/models/layers/attention.py | 2 -- llmfoundry/models/layers/rotary_embedding.py | 2 +- scripts/train/yamls/pretrain/mpt-125m-realistic-rope.yaml | 2 +- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 9f7216935f..3e4d85620d 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -583,8 +583,6 @@ def forward( query = query.transpose(1, 2) key = key.transpose(1, 2) - breakpoint() - query = query.reshape(*(query.shape[:-2]), self.d_model) key = key.reshape(*(key.shape[:-2]), self.kv_n_heads * self.head_dim) diff --git a/llmfoundry/models/layers/rotary_embedding.py b/llmfoundry/models/layers/rotary_embedding.py index 7f8303dc57..eebc296b6d 100644 --- a/llmfoundry/models/layers/rotary_embedding.py +++ b/llmfoundry/models/layers/rotary_embedding.py @@ -71,4 +71,4 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed + return q_embed.to(q), k_embed.to(k) diff --git a/scripts/train/yamls/pretrain/mpt-125m-realistic-rope.yaml b/scripts/train/yamls/pretrain/mpt-125m-realistic-rope.yaml index 5e69108c21..bf6da8ad7f 100644 --- a/scripts/train/yamls/pretrain/mpt-125m-realistic-rope.yaml +++ b/scripts/train/yamls/pretrain/mpt-125m-realistic-rope.yaml @@ -20,7 +20,7 @@ model: attn_impl: triton clip_qkv: 6 attn_uses_sequence_id: true - rope: true # Keep alibi: true, and pos_bias_type can be one of alibi, rope, or... + rope: true rope_bf: 10000 # Use 10000 for original rope, 500000 for rope_abf (https://arxiv.org/pdf/2309.16039.pdf) # Tokenizer From 529ada8609294230b0cb3080c73572aa6c7c376f Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 6 Oct 2023 20:45:27 +0000 Subject: [PATCH 016/106] .. --- llmfoundry/models/mpt/modeling_mpt.py | 55 ++++++++++++++------------- 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 9b7474567f..a02acee508 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -424,34 +424,35 @@ def forward( ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' tok_emb = self.wte(input_ids) - past_position = 0 - if past_key_values is not None: - if len(past_key_values) != self.config.n_layers: - raise ValueError( - f'past_key_values must provide a past_key_value for each attention ' - + - f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).' + if self.learned_pos_emb or self.rope: + past_position = 0 + if past_key_values is not None: + if len(past_key_values) != self.config.n_layers: + raise ValueError( + f'past_key_values must provide a past_key_value for each attention ' + + + f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).' + ) + # For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim). + # For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq). + # Here we shift position embedding using the `seq` dim of the past key + past_position = past_key_values[0][0].size(1) + if self.attn_impl == 'torch': + past_position = past_key_values[0][0].size(3) + + pos = torch.arange( + past_position, + S + past_position, + dtype=torch.long, + device=input_ids.device, + ).unsqueeze(0) + if attention_mask is not None: + # adjust the position indices to account for padding tokens + pos = torch.clamp( + pos - torch.cumsum((~attention_mask).to(torch.int32), + dim=1)[:, past_position:], + min=0, ) - # For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim). - # For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq). - # Here we shift position embedding using the `seq` dim of the past key - past_position = past_key_values[0][0].size(1) - if self.attn_impl == 'torch': - past_position = past_key_values[0][0].size(3) - - pos = torch.arange( - past_position, - S + past_position, - dtype=torch.long, - device=input_ids.device, - ).unsqueeze(0) - if attention_mask is not None: - # adjust the position indices to account for padding tokens - pos = torch.clamp( - pos - torch.cumsum((~attention_mask).to(torch.int32), - dim=1)[:, past_position:], - min=0, - ) if self.learned_pos_emb: if S + past_position > self.config.max_seq_len: From 7d39ffc42813c0578dbfbf3a04e86bc550c22c2c Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 6 Oct 2023 20:45:54 +0000 Subject: [PATCH 017/106] .. --- llmfoundry/models/mpt/modeling_mpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index a02acee508..bf7fb77783 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -450,7 +450,7 @@ def forward( # adjust the position indices to account for padding tokens pos = torch.clamp( pos - torch.cumsum((~attention_mask).to(torch.int32), - dim=1)[:, past_position:], + dim=1)[:, past_position:], min=0, ) From bb927691649ee2105faac526300b5818a74e2b37 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 6 Oct 2023 21:49:59 +0000 Subject: [PATCH 018/106] .. --- llmfoundry/models/mpt/modeling_mpt.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index bf7fb77783..b20b9fd415 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -283,7 +283,7 @@ def _rotation_matrix(self, device: torch.device, return self.rotation_matrix @torch.no_grad() - def _rotary_emb(self, device, dtype, seq_len) -> torch.Tensor: + def _rotary_emb(self, device, dtype, seq_len, pos) -> torch.Tensor: if not self._rotary_embedding_initialized: self.rotary_embedding = None if self.rope: @@ -297,7 +297,7 @@ def _rotary_emb(self, device, dtype, seq_len) -> torch.Tensor: self._rotary_embedding_initialized = True if self.rotary_embedding is None: return None - return self.rotary_embedding(device, dtype, seq_len) + return (*(self.rotary_embedding(device, dtype, seq_len)), pos) def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor) -> torch.Tensor: @@ -488,7 +488,7 @@ def forward( rotation_matrix = self._rotation_matrix(device=x.device, dtype=torch.float32) - (sin, cos) = self._rotary_emb(x.device, x.dtype, S) + rotary_emb = self._rotary_emb(x.device, x.dtype, S, pos) # initialize the past key values cache if it should be used if use_cache and past_key_values is None: @@ -508,7 +508,7 @@ def forward( past_key_value=past_key_value, attn_bias=attn_bias, rotation_matrix=rotation_matrix, - rotary_emb=(sin, cos, pos), + rotary_emb=rotary_emb, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions), From 841becb1408f5af7b03f722e93f4474e051ff6c1 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 6 Oct 2023 23:04:14 +0000 Subject: [PATCH 019/106] .. --- llmfoundry/models/layers/attention.py | 31 +++++++++++-------- llmfoundry/models/mpt/modeling_mpt.py | 12 ++++--- .../pretrain/mpt-125m-realistic-rope.yaml | 4 +-- 3 files changed, 27 insertions(+), 20 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 3e4d85620d..7c0302c385 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -567,21 +567,26 @@ def forward( dim=2, ) - if rotation_matrix is not None: + # # Roformer implementation of rope + # if rotation_matrix is not None: + # query = query.view(*(query.shape[:-1]), -1, self.head_dim) + # key = key.view(*(key.shape[:-1]), -1, self.head_dim) + # query, key = _apply_rotary_position_embeddings( + # rotation_matrix, query, key) + # query = query.reshape(*(query.shape[:-2]), self.d_model) + # key = key.reshape(*(key.shape[:-2]), + # self.kv_n_heads * self.head_dim) + + # Llama implementation of rope + if rotary_emb is not None: query = query.view(*(query.shape[:-1]), -1, self.head_dim) key = key.view(*(key.shape[:-1]), -1, self.head_dim) - - if False: # roformer implementation of rope - query, key = _apply_rotary_position_embeddings( - rotation_matrix, query, key) - - if True: # llama implementation of rope - (cos, sin, pos) = rotary_emb - query = query.transpose(1, 2) - key = key.transpose(1, 2) - query, key = apply_rotary_pos_emb(query, key, cos, sin, pos) - query = query.transpose(1, 2) - key = key.transpose(1, 2) + (cos, sin, pos) = rotary_emb + query = query.transpose(1, 2) + key = key.transpose(1, 2) + query, key = apply_rotary_pos_emb(query, key, cos, sin, pos) + query = query.transpose(1, 2) + key = key.transpose(1, 2) query = query.reshape(*(query.shape[:-2]), self.d_model) key = key.reshape(*(key.shape[:-2]), diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index b20b9fd415..c534681358 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -295,7 +295,7 @@ def _rotary_emb(self, device, dtype, seq_len, pos) -> torch.Tensor: dtype=dtype) self._rotary_embedding_initialized = True - if self.rotary_embedding is None: + if self.rotary_embedding is None or pos is None: return None return (*(self.rotary_embedding(device, dtype, seq_len)), pos) @@ -485,10 +485,12 @@ def forward( sequence_id=sequence_id, ) - rotation_matrix = self._rotation_matrix(device=x.device, - dtype=torch.float32) - - rotary_emb = self._rotary_emb(x.device, x.dtype, S, pos) + rotation_matrix=None + rotary_emb=None + if self.rope: + # rotation_matrix = self._rotation_matrix(device=x.device, dtype=torch.float32) + rotary_emb = self._rotary_emb(x.device, x.dtype, S, pos) + breakpoint() # initialize the past key values cache if it should be used if use_cache and past_key_values is None: diff --git a/scripts/train/yamls/pretrain/mpt-125m-realistic-rope.yaml b/scripts/train/yamls/pretrain/mpt-125m-realistic-rope.yaml index bf6da8ad7f..50403efdb0 100644 --- a/scripts/train/yamls/pretrain/mpt-125m-realistic-rope.yaml +++ b/scripts/train/yamls/pretrain/mpt-125m-realistic-rope.yaml @@ -16,11 +16,11 @@ model: vocab_size: 50368 # update for hero run with custom tokenizer no_bias: true attn_config: - alibi: false + alibi: true attn_impl: triton clip_qkv: 6 attn_uses_sequence_id: true - rope: true + rope: false rope_bf: 10000 # Use 10000 for original rope, 500000 for rope_abf (https://arxiv.org/pdf/2309.16039.pdf) # Tokenizer From e5d0e65bcbf2d87d556bb61737a764b31b9bdc0c Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 6 Oct 2023 23:36:43 +0000 Subject: [PATCH 020/106] .. --- llmfoundry/models/layers/attention.py | 18 +++++++++--------- llmfoundry/models/mpt/modeling_mpt.py | 12 ++++++------ .../pretrain/mpt-125m-realistic-rope.yaml | 6 +++--- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 7c0302c385..46225a3881 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -567,15 +567,15 @@ def forward( dim=2, ) - # # Roformer implementation of rope - # if rotation_matrix is not None: - # query = query.view(*(query.shape[:-1]), -1, self.head_dim) - # key = key.view(*(key.shape[:-1]), -1, self.head_dim) - # query, key = _apply_rotary_position_embeddings( - # rotation_matrix, query, key) - # query = query.reshape(*(query.shape[:-2]), self.d_model) - # key = key.reshape(*(key.shape[:-2]), - # self.kv_n_heads * self.head_dim) + # Roformer implementation of rope + if rotation_matrix is not None: + query = query.view(*(query.shape[:-1]), -1, self.head_dim) + key = key.view(*(key.shape[:-1]), -1, self.head_dim) + query, key = _apply_rotary_position_embeddings( + rotation_matrix, query, key) + query = query.reshape(*(query.shape[:-2]), self.d_model) + key = key.reshape(*(key.shape[:-2]), + self.kv_n_heads * self.head_dim) # Llama implementation of rope if rotary_emb is not None: diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index c534681358..68de5b920a 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -424,6 +424,7 @@ def forward( ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' tok_emb = self.wte(input_ids) + pos = None if self.learned_pos_emb or self.rope: past_position = 0 if past_key_values is not None: @@ -485,12 +486,11 @@ def forward( sequence_id=sequence_id, ) - rotation_matrix=None - rotary_emb=None - if self.rope: - # rotation_matrix = self._rotation_matrix(device=x.device, dtype=torch.float32) - rotary_emb = self._rotary_emb(x.device, x.dtype, S, pos) - breakpoint() + + # rotation_matrix is used for the roformer implementation of rope, rotary_emb is used for llama implemenation of rope + rotation_matrix = None + # rotation_matrix = self._rotation_matrix(device=x.device, dtype=torch.float32) + rotary_emb = self._rotary_emb(x.device, x.dtype, S, pos) # initialize the past key values cache if it should be used if use_cache and past_key_values is None: diff --git a/scripts/train/yamls/pretrain/mpt-125m-realistic-rope.yaml b/scripts/train/yamls/pretrain/mpt-125m-realistic-rope.yaml index 50403efdb0..e4a94727ed 100644 --- a/scripts/train/yamls/pretrain/mpt-125m-realistic-rope.yaml +++ b/scripts/train/yamls/pretrain/mpt-125m-realistic-rope.yaml @@ -16,11 +16,11 @@ model: vocab_size: 50368 # update for hero run with custom tokenizer no_bias: true attn_config: - alibi: true + alibi: false attn_impl: triton clip_qkv: 6 - attn_uses_sequence_id: true - rope: false + attn_uses_sequence_id: false + rope: true rope_bf: 10000 # Use 10000 for original rope, 500000 for rope_abf (https://arxiv.org/pdf/2309.16039.pdf) # Tokenizer From dabd2315f75926765360065c7a67501680c1b361 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 7 Oct 2023 15:01:55 +0000 Subject: [PATCH 021/106] .. --- llmfoundry/models/mpt/configuration_mpt.py | 8 ++++---- scripts/inference/run_mpt_with_ft.py | 2 +- scripts/train/yamls/pretrain/mpt-125m-realistic-rope.yaml | 1 + scripts/train/yamls/pretrain/mpt-125m-realistic.yaml | 1 + 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index e535b9b44d..bb7f998de7 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -154,10 +154,10 @@ def __init__( del kwargs['name'] if 'loss_fn' in kwargs: del kwargs['loss_fn'] - if self.attn_config.get('alibi', False): + if self.attn_config.get('alibi', False) or self.attn_config.get('rope', False): self.learned_pos_emb = False warnings.warn( - f'alibi is turned on, setting `learned_pos_emb` to `False.`') + f'alibi or rope is turned on, setting `learned_pos_emb` to `False.`') super().__init__(**kwargs) self._validate_config() @@ -221,9 +221,9 @@ def _validate_config(self) -> None: ) if self.init_config.get('name', None) is None: raise ValueError(f"{self.init_config=} 'name' needs to be set.") - if not self.learned_pos_emb and not self.attn_config['alibi']: + if not (self.learned_pos_emb or self.attn_config['alibi'] or self.attn_config['rope']): warnings.warn( - f'Positional information not being provided to the model using either learned_pos_emb or alibi.' + f'Positional information not being provided to the model using either learned_pos_emb or alibi or rope.' ) if self.fc_type == 'te' or self.ffn_config['ffn_type'] == 'te_ln_mlp': try: diff --git a/scripts/inference/run_mpt_with_ft.py b/scripts/inference/run_mpt_with_ft.py index 10ccf6b78b..8781e690fa 100644 --- a/scripts/inference/run_mpt_with_ft.py +++ b/scripts/inference/run_mpt_with_ft.py @@ -280,7 +280,7 @@ def main(): shared_contexts_ratio = args.shared_contexts_ratio layernorm_eps = args.layernorm_eps use_attention_linear_bias = args.alibi - has_positional_encoding = not args.alibi + has_positional_encoding = not args.alibi # TODO: Should be: has_positional_encoding = not (args.alibi or args.rope) print('\n=================== Arguments ===================') for k, v in vars(args).items(): diff --git a/scripts/train/yamls/pretrain/mpt-125m-realistic-rope.yaml b/scripts/train/yamls/pretrain/mpt-125m-realistic-rope.yaml index e4a94727ed..2cbd41e2ea 100644 --- a/scripts/train/yamls/pretrain/mpt-125m-realistic-rope.yaml +++ b/scripts/train/yamls/pretrain/mpt-125m-realistic-rope.yaml @@ -15,6 +15,7 @@ model: max_seq_len: ${max_seq_len} vocab_size: 50368 # update for hero run with custom tokenizer no_bias: true + learned_pos_emb: false attn_config: alibi: false attn_impl: triton diff --git a/scripts/train/yamls/pretrain/mpt-125m-realistic.yaml b/scripts/train/yamls/pretrain/mpt-125m-realistic.yaml index d414c837c3..9513a75b34 100644 --- a/scripts/train/yamls/pretrain/mpt-125m-realistic.yaml +++ b/scripts/train/yamls/pretrain/mpt-125m-realistic.yaml @@ -15,6 +15,7 @@ model: max_seq_len: ${max_seq_len} vocab_size: 50368 # update for hero run with custom tokenizer no_bias: true + learned_pos_emb: false attn_config: alibi: true attn_impl: triton From 7f1109a70c2206a59189f6021aa44537e48d21bb Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 7 Oct 2023 15:55:38 +0000 Subject: [PATCH 022/106] .. --- scripts/train/yamls/pretrain/mpt-125m-realistic.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/train/yamls/pretrain/mpt-125m-realistic.yaml b/scripts/train/yamls/pretrain/mpt-125m-realistic.yaml index 9513a75b34..65e2859c82 100644 --- a/scripts/train/yamls/pretrain/mpt-125m-realistic.yaml +++ b/scripts/train/yamls/pretrain/mpt-125m-realistic.yaml @@ -20,7 +20,7 @@ model: alibi: true attn_impl: triton clip_qkv: 6 - attn_uses_sequence_id: true + attn_uses_sequence_id: false rope: false # Keep alibi: true, and pos_bias_type can be one of alibi, rope, or... rope_bf: 10000 # Use 10000 for original rope, 500000 for rope_abf (https://arxiv.org/pdf/2309.16039.pdf) From e98841d3e88abb2cdcf0321480122882257e48f2 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sun, 8 Oct 2023 16:01:29 +0000 Subject: [PATCH 023/106] removed the roformer impementation of rope --- llmfoundry/models/layers/attention.py | 36 --------------------------- llmfoundry/models/layers/blocks.py | 2 -- llmfoundry/models/mpt/modeling_mpt.py | 36 --------------------------- 3 files changed, 74 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 46225a3881..9e3ecb147e 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -18,31 +18,6 @@ from llmfoundry.models.layers.rotary_embedding import apply_rotary_pos_emb -# Code taken from https://github.com/huggingface/transformers/blob/v4.33.3/src/transformers/models/roformer/modeling_roformer.py -def _apply_rotary_position_embeddings(rotation_matrix: torch.Tensor, - query: torch.Tensor, key: torch.Tensor): - rotation_matrix = rotation_matrix.unsqueeze(-2).unsqueeze(0) - # sin [batch_size, sequence_length, num_heads, embed_size_per_head//2] - # cos [batch_size, sequence_length, num_heads, embed_size_per_head//2] - sin, cos = rotation_matrix.chunk(2, dim=-1) - # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] - sin_pos = torch.stack([sin, sin], dim=-1).reshape_as(rotation_matrix) - # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] - cos_pos = torch.stack([cos, cos], dim=-1).reshape_as(rotation_matrix) - - # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2] - rotate_half_query_layer = torch.stack([-query[..., 1::2], query[..., ::2]], - dim=-1).reshape_as(query) - query = (query * cos_pos + rotate_half_query_layer * sin_pos).to(query) - - # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2] - rotate_half_key_layer = torch.stack([-key[..., 1::2], key[..., ::2]], - dim=-1).reshape_as(key) - key = (key * cos_pos + rotate_half_key_layer * sin_pos).to(key) - - return query, key - - def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool) -> bool: # disable causal when it is not needed @@ -547,7 +522,6 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - rotation_matrix: Optional[torch.Tensor] = None, rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, is_causal: bool = True, needs_weights: bool = False, @@ -567,16 +541,6 @@ def forward( dim=2, ) - # Roformer implementation of rope - if rotation_matrix is not None: - query = query.view(*(query.shape[:-1]), -1, self.head_dim) - key = key.view(*(key.shape[:-1]), -1, self.head_dim) - query, key = _apply_rotary_position_embeddings( - rotation_matrix, query, key) - query = query.reshape(*(query.shape[:-2]), self.d_model) - key = key.reshape(*(key.shape[:-2]), - self.kv_n_heads * self.head_dim) - # Llama implementation of rope if rotary_emb is not None: query = query.view(*(query.shape[:-1]), -1, self.head_dim) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 6dc11713ab..49c2b510f6 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -96,7 +96,6 @@ def forward( x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, - rotation_matrix: Optional[torch.Tensor] = None, rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, pos: Optional[torch.Tensor] = None, attention_mask: Optional[torch.ByteTensor] = None, @@ -109,7 +108,6 @@ def forward( a, past_key_value=past_key_value, attn_bias=attn_bias, - rotation_matrix=rotation_matrix, rotary_emb=rotary_emb, attention_mask=attention_mask, is_causal=is_causal, diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 68de5b920a..8da7277e16 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -251,37 +251,6 @@ def _attn_bias( return attn_bias, None - # Code taken from https://github.com/huggingface/transformers/blob/v4.33.3/src/transformers/models/roformer/modeling_roformer.py - @torch.no_grad() - def _rotation_matrix(self, device: torch.device, - dtype: torch.dtype) -> torch.Tensor: - """Identical to the XLM create_sinusoidal_embeddings except features are - not interleaved. - - The cos features are in the 2nd half of the vector. [head_dim // 2:] - """ - if not self._rotation_matrix_initialized: - self.rotation_matrix = None - if self.rope: - self.rotation_matrix = torch.empty(self.rotation_matrix_shape, - device=device, - dtype=dtype) - - theta = 1 / (torch.pow( - self.rope_bf, 2 * (torch.arange(self.rope_head_dim) // 2) / - self.rope_head_dim)) - position_enc = torch.outer(torch.arange(self.rope_max_seq_len), - theta) - - sentinel = self.rope_head_dim // 2 if self.rope_head_dim % 2 == 0 else ( - self.rope_head_dim // 2) + 1 - self.rotation_matrix[:, 0:sentinel] = torch.FloatTensor( - torch.sin(position_enc[:, 0::2])).to(self.rotation_matrix) - self.rotation_matrix[:, sentinel:] = torch.FloatTensor( - torch.cos(position_enc[:, 1::2])).to(self.rotation_matrix) - self._rotation_matrix_initialized = True - return self.rotation_matrix - @torch.no_grad() def _rotary_emb(self, device, dtype, seq_len, pos) -> torch.Tensor: if not self._rotary_embedding_initialized: @@ -486,10 +455,6 @@ def forward( sequence_id=sequence_id, ) - - # rotation_matrix is used for the roformer implementation of rope, rotary_emb is used for llama implemenation of rope - rotation_matrix = None - # rotation_matrix = self._rotation_matrix(device=x.device, dtype=torch.float32) rotary_emb = self._rotary_emb(x.device, x.dtype, S, pos) # initialize the past key values cache if it should be used @@ -509,7 +474,6 @@ def forward( x, past_key_value=past_key_value, attn_bias=attn_bias, - rotation_matrix=rotation_matrix, rotary_emb=rotary_emb, attention_mask=attention_mask, is_causal=self.is_causal, From dea3b0333f09b7d502745f4af49d4e35249308ef Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sun, 8 Oct 2023 17:40:57 +0000 Subject: [PATCH 024/106] .. --- llmfoundry/models/layers/__init__.py | 2 ++ llmfoundry/models/layers/attention.py | 1 - llmfoundry/models/layers/blocks.py | 1 - llmfoundry/models/layers/rotary_embedding.py | 33 ++++++++------------ llmfoundry/models/mpt/modeling_mpt.py | 5 +-- 5 files changed, 18 insertions(+), 24 deletions(-) diff --git a/llmfoundry/models/layers/__init__.py b/llmfoundry/models/layers/__init__.py index fea55042d2..8f189dbda5 100644 --- a/llmfoundry/models/layers/__init__.py +++ b/llmfoundry/models/layers/__init__.py @@ -10,6 +10,8 @@ from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY, LPLayerNorm +from llmfoundry.models.layers.rotary_embedding import (RotaryEmbedding, + apply_rotary_pos_emb) __all__ = [ 'scaled_multihead_dot_product_attention', diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 9e3ecb147e..c37b0f9e6f 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -541,7 +541,6 @@ def forward( dim=2, ) - # Llama implementation of rope if rotary_emb is not None: query = query.view(*(query.shape[:-1]), -1, self.head_dim) key = key.view(*(key.shape[:-1]), -1, self.head_dim) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 49c2b510f6..0c22226680 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -97,7 +97,6 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - pos: Optional[torch.Tensor] = None, attention_mask: Optional[torch.ByteTensor] = None, is_causal: bool = True, output_attentions: bool = False, diff --git a/llmfoundry/models/layers/rotary_embedding.py b/llmfoundry/models/layers/rotary_embedding.py index eebc296b6d..77f1a819af 100644 --- a/llmfoundry/models/layers/rotary_embedding.py +++ b/llmfoundry/models/layers/rotary_embedding.py @@ -1,22 +1,17 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +# Code taken from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + import torch from torch import nn -# Code taken from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py class RotaryEmbedding(nn.Module): - def __init__(self, - dim, - max_position_embeddings=2048, - base=10000, - device=None, - dtype=None): + def __init__(self, dim: int, max_position_embeddings: int, base: int, + device: torch.device, dtype: torch.dtype): super().__init__() - if dtype is None: - dtype = torch.get_default_dtype() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base @@ -24,16 +19,13 @@ def __init__(self, torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer('inv_freq', inv_freq, persistent=False) + self.max_seq_len_cached = max_position_embeddings # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache(seq_len=max_position_embeddings, - device=self.inv_freq.device, - dtype=dtype) + self._set_cos_sin_cache(seq_len=max_position_embeddings, dtype=dtype) - def _set_cos_sin_cache(self, seq_len, device, dtype): + def _set_cos_sin_cache(self, seq_len: int, dtype: torch.dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, - device=device, - dtype=self.inv_freq.dtype) + t = torch.arange(self.max_seq_len_cached).to(self.inv_freq) freqs = torch.einsum('i,j->ij', t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation @@ -45,10 +37,10 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): emb.sin()[None, None, :, :].to(dtype), persistent=False) - def forward(self, device, dtype, seq_len=None): + def forward(self, dtype: torch.dtype, seq_len: int): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=device, dtype=dtype) + self._set_cos_sin_cache(seq_len=seq_len, dtype=dtype) return ( self.cos_cached[:, :, :seq_len, ...].to(dtype=dtype), @@ -56,14 +48,15 @@ def forward(self, device, dtype, seq_len=None): ) -def rotate_half(x): +def rotate_half(x: torch.Tensor): """Rotates half the hidden dims of the input.""" x1 = x[..., :x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): +def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, + sin: torch.Tensor, position_ids: torch.Tensor): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 8da7277e16..049341accb 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -252,7 +252,8 @@ def _attn_bias( return attn_bias, None @torch.no_grad() - def _rotary_emb(self, device, dtype, seq_len, pos) -> torch.Tensor: + def _rotary_emb(self, device: torch.device, dtype: torch.dtype, + seq_len: int, pos: Union[torch.Tensor, None]): if not self._rotary_embedding_initialized: self.rotary_embedding = None if self.rope: @@ -266,7 +267,7 @@ def _rotary_emb(self, device, dtype, seq_len, pos) -> torch.Tensor: self._rotary_embedding_initialized = True if self.rotary_embedding is None or pos is None: return None - return (*(self.rotary_embedding(device, dtype, seq_len)), pos) + return (*(self.rotary_embedding(dtype, seq_len)), pos) def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor) -> torch.Tensor: From 2927a8cb0b57e2aa2ce3fecbded5acc9d6693e7f Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sun, 8 Oct 2023 18:50:34 +0000 Subject: [PATCH 025/106] fixed all the lint errors --- llmfoundry/models/layers/attention.py | 3 +- llmfoundry/models/layers/blocks.py | 3 +- llmfoundry/models/layers/rotary_embedding.py | 40 ++++++++------------ llmfoundry/models/mpt/modeling_mpt.py | 2 +- 4 files changed, 20 insertions(+), 28 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index c37b0f9e6f..ac59f0c07e 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -522,7 +522,8 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor, + torch.Tensor]] = None, is_causal: bool = True, needs_weights: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 0c22226680..71ca316336 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -96,7 +96,8 @@ def forward( x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, - rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor, + torch.Tensor]] = None, attention_mask: Optional[torch.ByteTensor] = None, is_causal: bool = True, output_attentions: bool = False, diff --git a/llmfoundry/models/layers/rotary_embedding.py b/llmfoundry/models/layers/rotary_embedding.py index 77f1a819af..797c5d7b2e 100644 --- a/llmfoundry/models/layers/rotary_embedding.py +++ b/llmfoundry/models/layers/rotary_embedding.py @@ -1,7 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -# Code taken from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py +# Code modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py import torch from torch import nn @@ -12,39 +12,29 @@ class RotaryEmbedding(nn.Module): def __init__(self, dim: int, max_position_embeddings: int, base: int, device: torch.device, dtype: torch.dtype): super().__init__() - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base**( - torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer('inv_freq', inv_freq, persistent=False) - self.max_seq_len_cached = max_position_embeddings - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache(seq_len=max_position_embeddings, dtype=dtype) + self.max_position_embeddings = max_position_embeddings - def _set_cos_sin_cache(self, seq_len: int, dtype: torch.dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached).to(self.inv_freq) + inv_freq = 1.0 / (base + **(torch.arange(0, dim, 2).float().to(device) / dim)) + t = torch.arange(self.max_position_embeddings).to(inv_freq) - freqs = torch.einsum('i,j->ij', t, self.inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer('cos_cached', - emb.cos()[None, None, :, :].to(dtype), - persistent=False) - self.register_buffer('sin_cached', - emb.sin()[None, None, :, :].to(dtype), - persistent=False) + self.cos_cached = emb.cos()[None, None, :, :].to(dtype) + self.sin_cached = emb.sin()[None, None, :, :].to(dtype) - def forward(self, dtype: torch.dtype, seq_len: int): + def forward(self, dtype: torch.dtype, device: torch.device, seq_len: int): # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, dtype=dtype) + if seq_len > self.max_position_embeddings: + raise ValueError( + 'The sequence length is greater than the maximum sequence length.' + ) return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=dtype), + self.cos_cached[:, :, :seq_len, ...].to(dtype=dtype, device=device), + self.sin_cached[:, :, :seq_len, ...].to(dtype=dtype, device=device), ) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 049341accb..07eab0c52e 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -267,7 +267,7 @@ def _rotary_emb(self, device: torch.device, dtype: torch.dtype, self._rotary_embedding_initialized = True if self.rotary_embedding is None or pos is None: return None - return (*(self.rotary_embedding(dtype, seq_len)), pos) + return (*(self.rotary_embedding(dtype, device, seq_len)), pos) def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor) -> torch.Tensor: From d605fbf2128c573f700406500677d8074a5bb0a5 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sun, 8 Oct 2023 18:56:13 +0000 Subject: [PATCH 026/106] .. --- llmfoundry/models/mpt/modeling_mpt.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 07eab0c52e..137d942ab9 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -146,15 +146,9 @@ def __init__(self, config: MPTConfig): ) self.rope = config.attn_config['rope'] - assert (config.d_model % config.n_heads == 0) self.rope_head_dim = config.d_model // config.n_heads self.rope_max_seq_len = config.max_seq_len self.rope_bf = config.attn_config['rope_bf'] - - self._rotation_matrix_initialized = False - self.rotation_matrix = None - self.rotation_matrix_shape = (config.max_seq_len, self.rope_head_dim) - self._rotary_embedding_initialized = False self.rotary_embedding = None From 7b250f77dbde07c50ab900492b39b71112ce7201 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Mon, 9 Oct 2023 14:46:00 +0000 Subject: [PATCH 027/106] .. --- scripts/inference/run_mpt_with_ft.py | 2 +- .../pretrain/mpt-125m-realistic-rope.yaml | 178 ------------------ .../yamls/pretrain/mpt-125m-realistic.yaml | 178 ------------------ 3 files changed, 1 insertion(+), 357 deletions(-) delete mode 100644 scripts/train/yamls/pretrain/mpt-125m-realistic-rope.yaml delete mode 100644 scripts/train/yamls/pretrain/mpt-125m-realistic.yaml diff --git a/scripts/inference/run_mpt_with_ft.py b/scripts/inference/run_mpt_with_ft.py index 8781e690fa..3b5bce6b3a 100644 --- a/scripts/inference/run_mpt_with_ft.py +++ b/scripts/inference/run_mpt_with_ft.py @@ -280,7 +280,7 @@ def main(): shared_contexts_ratio = args.shared_contexts_ratio layernorm_eps = args.layernorm_eps use_attention_linear_bias = args.alibi - has_positional_encoding = not args.alibi # TODO: Should be: has_positional_encoding = not (args.alibi or args.rope) + has_positional_encoding = not args.alibi # TODO: Should probably be: has_positional_encoding = not (args.alibi or args.rope) print('\n=================== Arguments ===================') for k, v in vars(args).items(): diff --git a/scripts/train/yamls/pretrain/mpt-125m-realistic-rope.yaml b/scripts/train/yamls/pretrain/mpt-125m-realistic-rope.yaml deleted file mode 100644 index 2cbd41e2ea..0000000000 --- a/scripts/train/yamls/pretrain/mpt-125m-realistic-rope.yaml +++ /dev/null @@ -1,178 +0,0 @@ -global_seed: 17 -max_seq_len: 2048 # max tokens per sequence during training - -data_local: /root/my-copy-c4 -data_remote: oci://mosaicml-internal-dataset-c4/preconcat-gpt_neox/ # If blank, files must be present in data_local - -# Model -model: - name: mpt_causal_lm - init_device: meta - d_model: 768 - n_heads: 12 - n_layers: 12 - expansion_ratio: 4 - max_seq_len: ${max_seq_len} - vocab_size: 50368 # update for hero run with custom tokenizer - no_bias: true - learned_pos_emb: false - attn_config: - alibi: false - attn_impl: triton - clip_qkv: 6 - attn_uses_sequence_id: false - rope: true - rope_bf: 10000 # Use 10000 for original rope, 500000 for rope_abf (https://arxiv.org/pdf/2309.16039.pdf) - -# Tokenizer -tokenizer: - name: EleutherAI/gpt-neox-20b # default tokenizer used for MPT - kwargs: - model_max_length: ${max_seq_len} - -# Optimization -global_train_batch_size: 512 # ~1M tokens, update for hero run, must be divisible by gpu_num -max_duration: 100000ba # update for hero run, e.g. 100000ba ~= 100B tokens - -optimizer: - name: decoupled_lionw - lr: 0.0006 - betas: - - 0.9 - - 0.95 - weight_decay: 0.0006 - -scheduler: - name: cosine_with_warmup - t_warmup: 5000ba - alpha_f: 0.1 - -# System -seed: ${global_seed} -precision: amp_bf16 -device_train_microbatch_size: 16 # NOTE: Please ensure that ${gpus} is less than or equal to ${global_train_batch_size}/${device_train_microbatch_size}. - # 1. If you want to increase the number of gpus beyond this number, please decrease device_train_microbatch_size correspondingly to - # satisfy this equation. - # 2. If you want to increase the number of gpus beyond this number and keep device_train_microbatch_size the same (or increase it), - # you would need to increase global_train_batch_size. However, in that case other hyperparameters like lr would probably need to - # be changed and we cannot guarantee convergence. -device_eval_batch_size: 16 - -# FSDP -fsdp_config: - activation_checkpointing: false - activation_checkpointing_reentrant: false - activation_cpu_offload: false - limit_all_gathers: true - mixed_precision: PURE - sharding_strategy: SHARD_GRAD_OP - state_dict_type: full - verbose: false - -# Logging -eval_first: false -eval_interval: 900000ba -log_to_console: true -console_log_interval: 1ba -progress_bar: false -python_log_level: DEBUG -loggers: - wandb: {} - -# Checkpointing -autoresume: true -save_filename: ep{epoch}-ba{batch}/rank{rank}.pt -save_folder: oci://mosaicml-internal-checkpoints/shashank/rope-vs-alibi/mpt-125m # update for hero run -save_interval: 100000ba # update for hero run, e.g. 2000ba -save_num_checkpoints_to_keep: 1 - -# Algos and Callbacks -algorithms: - gradient_clipping: - clipping_threshold: 1 - clipping_type: norm - -callbacks: - # generate_callback: - # batch_log_interval: 10 # update for hero run, e.g. 2000 - # do_sample: true - # max_new_tokens: 100 - # prompts: - # - The quick brown fox jumps over - # - |- - # Vegan Banana Bread - # Instructions: - # 1. - # - The other day I was explaining what generative AI is to my five year old. - # temperature: 1 - # top_k: 50 - # top_p: 0.95 - # use_cache: true - lr_monitor: {} - memory_monitor: {} - # mono_ckpt_saver: - # batch_interval: ${save_interval} - # filename: ep{epoch}-ba{batch}/mono.pt - # save_folder: ${save_folder} - runtime_estimator: {} - scheduled_gc: - batch_interval: 2000 - speed_monitor: - window_size: 10 - -# Dataloaders -eos_token_id: 0 # update for hero run with custom tokenizer -num_canonical_nodes: 128 # update for hero run, must be codivisible by # physical nodes - -# Eval loader -eval_loader: - name: text - drop_last: false - num_workers: 8 - - # dataset yaml from https://github.com/mosaicml/llm-foundry/blob/main/scripts/train/yamls/pretrain/mpt-125m.yaml - - dataset: - eos_token_id: ${eos_token_id} - local: ${data_local} - remote: ${data_remote} - split: val - shuffle: false - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} - - - -# In-context-learning tasks - -## If you want to use one of our suites of tasks -# icl_tasks: eval/yamls/tasks_light.yaml - -## Or if you want to manually specify individual tasks -icl_tasks: -- - label: piqa - dataset_uri: eval/local_data/commonsense_reasoning/piqa.jsonl - num_fewshot: [0, 1, 5] - icl_task_type: multiple_choice - continuation_delimiter: 'Answer: ' -- - label: lambada - dataset_uri: eval/local_data/language_understanding/lambada_openai.jsonl - num_fewshot: [0] - icl_task_type: language_modeling - -# Train loader -train_loader: - name: text - drop_last: true - num_workers: 8 - - # dataset yaml from https://github.com/mosaicml/llm-foundry/blob/main/scripts/train/yamls/pretrain/mpt-125m.yaml - - dataset: - eos_token_id: ${eos_token_id} - local: ${data_local} - remote: ${data_remote} - split: train - shuffle: true - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} \ No newline at end of file diff --git a/scripts/train/yamls/pretrain/mpt-125m-realistic.yaml b/scripts/train/yamls/pretrain/mpt-125m-realistic.yaml deleted file mode 100644 index 65e2859c82..0000000000 --- a/scripts/train/yamls/pretrain/mpt-125m-realistic.yaml +++ /dev/null @@ -1,178 +0,0 @@ -global_seed: 17 -max_seq_len: 2048 # max tokens per sequence during training - -data_local: /root/my-copy-c4 -data_remote: oci://mosaicml-internal-dataset-c4/preconcat-gpt_neox/ # If blank, files must be present in data_local - -# Model -model: - name: mpt_causal_lm - init_device: meta - d_model: 768 - n_heads: 12 - n_layers: 12 - expansion_ratio: 4 - max_seq_len: ${max_seq_len} - vocab_size: 50368 # update for hero run with custom tokenizer - no_bias: true - learned_pos_emb: false - attn_config: - alibi: true - attn_impl: triton - clip_qkv: 6 - attn_uses_sequence_id: false - rope: false # Keep alibi: true, and pos_bias_type can be one of alibi, rope, or... - rope_bf: 10000 # Use 10000 for original rope, 500000 for rope_abf (https://arxiv.org/pdf/2309.16039.pdf) - -# Tokenizer -tokenizer: - name: EleutherAI/gpt-neox-20b # default tokenizer used for MPT - kwargs: - model_max_length: ${max_seq_len} - -# Optimization -global_train_batch_size: 512 # ~1M tokens, update for hero run, must be divisible by gpu_num -max_duration: 100000ba # update for hero run, e.g. 100000ba ~= 100B tokens - -optimizer: - name: decoupled_lionw - lr: 0.0006 - betas: - - 0.9 - - 0.95 - weight_decay: 0.0006 - -scheduler: - name: cosine_with_warmup - t_warmup: 5000ba - alpha_f: 0.1 - -# System -seed: ${global_seed} -precision: amp_bf16 -device_train_microbatch_size: 16 # NOTE: Please ensure that ${gpus} is less than or equal to ${global_train_batch_size}/${device_train_microbatch_size}. - # 1. If you want to increase the number of gpus beyond this number, please decrease device_train_microbatch_size correspondingly to - # satisfy this equation. - # 2. If you want to increase the number of gpus beyond this number and keep device_train_microbatch_size the same (or increase it), - # you would need to increase global_train_batch_size. However, in that case other hyperparameters like lr would probably need to - # be changed and we cannot guarantee convergence. -device_eval_batch_size: 16 - -# FSDP -fsdp_config: - activation_checkpointing: false - activation_checkpointing_reentrant: false - activation_cpu_offload: false - limit_all_gathers: true - mixed_precision: PURE - sharding_strategy: SHARD_GRAD_OP - state_dict_type: full - verbose: false - -# Logging -eval_first: false -eval_interval: 900000ba -log_to_console: true -console_log_interval: 1ba -progress_bar: false -python_log_level: DEBUG -loggers: - wandb: {} - -# Checkpointing -autoresume: true -save_filename: ep{epoch}-ba{batch}/rank{rank}.pt -save_folder: oci://mosaicml-internal-checkpoints/shashank/rope-vs-alibi/mpt-125m # update for hero run -save_interval: 100000ba # update for hero run, e.g. 2000ba -save_num_checkpoints_to_keep: 1 - -# Algos and Callbacks -algorithms: - gradient_clipping: - clipping_threshold: 1 - clipping_type: norm - -callbacks: - # generate_callback: - # batch_log_interval: 10 # update for hero run, e.g. 2000 - # do_sample: true - # max_new_tokens: 100 - # prompts: - # - The quick brown fox jumps over - # - |- - # Vegan Banana Bread - # Instructions: - # 1. - # - The other day I was explaining what generative AI is to my five year old. - # temperature: 1 - # top_k: 50 - # top_p: 0.95 - # use_cache: true - lr_monitor: {} - memory_monitor: {} - # mono_ckpt_saver: - # batch_interval: ${save_interval} - # filename: ep{epoch}-ba{batch}/mono.pt - # save_folder: ${save_folder} - runtime_estimator: {} - scheduled_gc: - batch_interval: 2000 - speed_monitor: - window_size: 10 - -# Dataloaders -eos_token_id: 0 # update for hero run with custom tokenizer -num_canonical_nodes: 128 # update for hero run, must be codivisible by # physical nodes - -# Eval loader -eval_loader: - name: text - drop_last: false - num_workers: 8 - - # dataset yaml from https://github.com/mosaicml/llm-foundry/blob/main/scripts/train/yamls/pretrain/mpt-125m.yaml - - dataset: - eos_token_id: ${eos_token_id} - local: ${data_local} - remote: ${data_remote} - split: val - shuffle: false - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} - - - -# In-context-learning tasks - -## If you want to use one of our suites of tasks -# icl_tasks: eval/yamls/tasks_light.yaml - -## Or if you want to manually specify individual tasks -icl_tasks: -- - label: piqa - dataset_uri: eval/local_data/commonsense_reasoning/piqa.jsonl - num_fewshot: [0, 1, 5] - icl_task_type: multiple_choice - continuation_delimiter: 'Answer: ' -- - label: lambada - dataset_uri: eval/local_data/language_understanding/lambada_openai.jsonl - num_fewshot: [0] - icl_task_type: language_modeling - -# Train loader -train_loader: - name: text - drop_last: true - num_workers: 8 - - # dataset yaml from https://github.com/mosaicml/llm-foundry/blob/main/scripts/train/yamls/pretrain/mpt-125m.yaml - - dataset: - eos_token_id: ${eos_token_id} - local: ${data_local} - remote: ${data_remote} - split: train - shuffle: true - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} \ No newline at end of file From 196b8e1f476746fbbc4ce1ca501805be677e59ff Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Mon, 9 Oct 2023 16:15:20 +0000 Subject: [PATCH 028/106] ../llmfoundry/models/mpt/modeling_mpt.py --- llmfoundry/models/mpt/modeling_mpt.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 137d942ab9..c9f2f76c36 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -249,7 +249,6 @@ def _attn_bias( def _rotary_emb(self, device: torch.device, dtype: torch.dtype, seq_len: int, pos: Union[torch.Tensor, None]): if not self._rotary_embedding_initialized: - self.rotary_embedding = None if self.rope: self.rotary_embedding = RotaryEmbedding( self.rope_head_dim, @@ -405,6 +404,12 @@ def forward( if self.attn_impl == 'torch': past_position = past_key_values[0][0].size(3) + if S + past_position > self.config.max_seq_len: + raise ValueError( + f'Cannot forward input with past sequence length {past_position} and current sequence length ' + + + f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.' + ) pos = torch.arange( past_position, S + past_position, @@ -419,19 +424,9 @@ def forward( min=0, ) + x = tok_emb if self.learned_pos_emb: - if S + past_position > self.config.max_seq_len: - raise ValueError( - f'Cannot forward input with past sequence length {past_position} and current sequence length ' - + - f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.' - ) - - pos_emb = self.wpe(pos) - x = tok_emb + pos_emb - else: - # ALiBi and NoPE use this path (RoPE will also use this path if / when enabled) - x = tok_emb + x += self.wpe(pos) if self.embedding_fraction == 1: x = self.emb_drop(x) From 0c3942ee57eac74e1db29fb3a6b95b71c58a8d13 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 10 Oct 2023 23:19:06 +0000 Subject: [PATCH 029/106] .. --- llmfoundry/models/layers/__init__.py | 7 +- llmfoundry/models/layers/attention.py | 12 ++- llmfoundry/models/layers/blocks.py | 12 ++- llmfoundry/models/layers/rotary_embedding.py | 100 ++++++++++++++----- llmfoundry/models/mpt/configuration_mpt.py | 23 ++++- llmfoundry/models/mpt/modeling_mpt.py | 57 +++++++---- 6 files changed, 150 insertions(+), 61 deletions(-) diff --git a/llmfoundry/models/layers/__init__.py b/llmfoundry/models/layers/__init__.py index 8f189dbda5..2cb539fadb 100644 --- a/llmfoundry/models/layers/__init__.py +++ b/llmfoundry/models/layers/__init__.py @@ -10,8 +10,9 @@ from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY, LPLayerNorm -from llmfoundry.models.layers.rotary_embedding import (RotaryEmbedding, - apply_rotary_pos_emb) +from llmfoundry.models.layers.rotary_embedding import ( + DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding, + RotaryEmbedding, apply_rotary_pos_emb) __all__ = [ 'scaled_multihead_dot_product_attention', @@ -32,5 +33,7 @@ 'FFN_CLASS_REGISTRY', 'build_ffn', 'RotaryEmbedding', + 'LinearScalingRotaryEmbedding', + 'DynamicNTKScalingRotaryEmbedding', 'apply_rotary_pos_emb', ] diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index ac59f0c07e..f2260ab8d9 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -5,7 +5,7 @@ import math import warnings -from typing import Any, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn as nn @@ -522,8 +522,7 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor, - torch.Tensor]] = None, + rotary_emb_w_offset_info: Optional[Dict] = None, is_causal: bool = True, needs_weights: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ @@ -542,12 +541,15 @@ def forward( dim=2, ) - if rotary_emb is not None: + if rotary_emb_w_offset_info is not None: query = query.view(*(query.shape[:-1]), -1, self.head_dim) key = key.view(*(key.shape[:-1]), -1, self.head_dim) - (cos, sin, pos) = rotary_emb query = query.transpose(1, 2) key = key.transpose(1, 2) + rotary_emb = rotary_emb_w_offset_info['rotary_emb'] + seq_len = rotary_emb_w_offset_info['seq_len'] + pos = rotary_emb_w_offset_info['pos'] + (cos, sin) = rotary_emb(x, seq_len) query, key = apply_rotary_pos_emb(query, key, cos, sin, pos) query = query.transpose(1, 2) key = key.transpose(1, 2) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 71ca316336..3f86510e13 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -42,7 +42,9 @@ def __init__( 'alibi': False, 'alibi_bias_max': 8, 'rope': False, - 'rope_bf': 10000, + 'rope_theta': 10000, + 'rope_scaling_type': 'no_scaling', + 'rope_scaling_factor': 1.0, } if ffn_config is None: @@ -60,7 +62,8 @@ def __init__( # necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs args_to_exclude_in_attn_class = { 'attn_type', 'prefix_lm', 'alibi', 'attn_uses_sequence_id', - 'alibi_bias_max', 'rope', 'rope_bf' + 'alibi_bias_max', 'rope', 'rope_theta', 'rope_scaling_type', + 'rope_scaling_factor' } attn_config_subset_for_attn_class = { k: v @@ -96,8 +99,7 @@ def forward( x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, - rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor, - torch.Tensor]] = None, + rotary_emb_w_offset_info: Optional[Dict] = None, attention_mask: Optional[torch.ByteTensor] = None, is_causal: bool = True, output_attentions: bool = False, @@ -108,7 +110,7 @@ def forward( a, past_key_value=past_key_value, attn_bias=attn_bias, - rotary_emb=rotary_emb, + rotary_emb_w_offset_info=rotary_emb_w_offset_info, attention_mask=attention_mask, is_causal=is_causal, needs_weights=output_attentions, diff --git a/llmfoundry/models/layers/rotary_embedding.py b/llmfoundry/models/layers/rotary_embedding.py index 797c5d7b2e..7b2ee7ab05 100644 --- a/llmfoundry/models/layers/rotary_embedding.py +++ b/llmfoundry/models/layers/rotary_embedding.py @@ -9,35 +9,90 @@ class RotaryEmbedding(nn.Module): - def __init__(self, dim: int, max_position_embeddings: int, base: int, - device: torch.device, dtype: torch.dtype): + def __init__(self, dim: int, base: float): super().__init__() + self.dim = dim + self.base = base + self.max_seq_len_cached = -1 - self.max_position_embeddings = max_position_embeddings - - inv_freq = 1.0 / (base - **(torch.arange(0, dim, 2).float().to(device) / dim)) - t = torch.arange(self.max_position_embeddings).to(inv_freq) + self.caches_initialized = False + self.cos_cached = torch.Tensor() + self.sin_cached = torch.Tensor() + def _set_cos_sin_cache(self, x: torch.Tensor, seq_len: int): + self.max_seq_len_cached = seq_len + inv_freq = self._get_inv_freq(x, seq_len) + t = self._get_t(x) freqs = torch.einsum('i,j->ij', t, inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.cos_cached = emb.cos()[None, None, :, :].to(dtype) - self.sin_cached = emb.sin()[None, None, :, :].to(dtype) + self.cos_cached = emb.cos() + self.sin_cached = emb.sin() + self.caches_initialized = True - def forward(self, dtype: torch.dtype, device: torch.device, seq_len: int): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_position_embeddings: - raise ValueError( - 'The sequence length is greater than the maximum sequence length.' - ) + def _get_t(self, x: torch.Tensor): + t = torch.arange(self.max_seq_len_cached).to(x) + return t + + def _get_inv_freq(self, x: torch.Tensor, seq_len: int): + del seq_len + inv_freq = ( + 1.0 / (self.base**(torch.arange(0, self.dim, 2) / self.dim))).to(x) + return inv_freq + + @torch.no_grad() + def forward(self, x: torch.Tensor, seq_len: int): + # x is only used to get the correct dtype and device + if (not self.caches_initialized) or seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, x=x) return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=dtype, device=device), - self.sin_cached[:, :, :seq_len, ...].to(dtype=dtype, device=device), + self.cos_cached[:seq_len], + self.sin_cached[:seq_len], ) +class LinearScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with linear scaling. + + Credits to the Reddit user /u/kaiokendev + """ + + def __init__(self, dim: int, base: float, scaling_factor: float): + self.scaling_factor = scaling_factor + super().__init__(dim, base) + + def _get_t(self, x: torch.Tensor): + t = (torch.arange(self.max_seq_len_cached) / self.scaling_factor).to(x) + return t + + +class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with Dynamic NTK scaling. + + Credits to the Reddit users /u/bloc97 and /u/emozilla + """ + + def __init__(self, dim: int, base: float, scaling_factor: float, + max_position_embeddings: float): + self.scaling_factor = scaling_factor + self.max_position_embeddings = max_position_embeddings + super().__init__(dim, base) + + def _get_inv_freq(self, x: torch.Tensor, seq_len: int): + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - + (self.scaling_factor - 1))**(self.dim / (self.dim - 2)) + inv_freq = (1.0 / + (base**(torch.arange(0, self.dim, 2) / self.dim))).to(x) + else: + inv_freq = ( + 1.0 / + (self.base**(torch.arange(0, self.dim, 2) / self.dim))).to(x) + return inv_freq + + def rotate_half(x: torch.Tensor): """Rotates half the hidden dims of the input.""" x1 = x[..., :x.shape[-1] // 2] @@ -45,13 +100,12 @@ def rotate_half(x: torch.Tensor): return torch.cat((-x2, x1), dim=-1) +# Copied from transformers.models.gpt_neox.modeling_gpt_neox.apply_rotary_pos_emb def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + cos = cos[position_ids].unsqueeze( + 1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim] + sin = sin[position_ids].unsqueeze(1) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed.to(q), k_embed.to(k) + return q_embed, k_embed diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index bb7f998de7..faa2c2fac9 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -20,7 +20,9 @@ 'alibi': False, 'alibi_bias_max': 8, 'rope': False, - 'rope_bf': 10000, + 'rope_theta': 10000, + 'rope_scaling_type': 'no_scaling', + 'rope_scaling_factor': 1.0, } ffn_config_defaults: Dict = { @@ -97,7 +99,9 @@ def __init__( alibi (bool): Whether to use the alibi bias instead of position embeddings. alibi_bias_max (int): The maximum value of the alibi bias. rope (bool): Whether to use rotary positional embeddings. - rope_bf (int): The base frequency for rope. + rope_theta (int): The base frequency for rope. + rope_scaling_type (str): Can be one of 'no_scaling', 'linear', or 'dynamic'. 'no_scaling' uses the default implementation for rotary embeddings, 'linear' uses linear scaling as proposed by the Reddit user /u/kaiokendev, and 'dynamic' uses Dynamic NTK scaling as proposed by the Reddit users /u/bloc97 and /u/emozilla. + rope_scaling_factor (float): Scaling factor to use if using 'linear' or 'dynamic' as rope_scaling_type. kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads. ffn_config (Dict): A dictionary used to configure the model's ffn module: ffn_type (str): type of ffn to use. Options: mptmlp, te_ln_mlp @@ -154,10 +158,12 @@ def __init__( del kwargs['name'] if 'loss_fn' in kwargs: del kwargs['loss_fn'] - if self.attn_config.get('alibi', False) or self.attn_config.get('rope', False): + if self.attn_config.get('alibi', False) or self.attn_config.get( + 'rope', False): self.learned_pos_emb = False warnings.warn( - f'alibi or rope is turned on, setting `learned_pos_emb` to `False.`') + f'alibi or rope is turned on, setting `learned_pos_emb` to `False.`' + ) super().__init__(**kwargs) self._validate_config() @@ -221,7 +227,8 @@ def _validate_config(self) -> None: ) if self.init_config.get('name', None) is None: raise ValueError(f"{self.init_config=} 'name' needs to be set.") - if not (self.learned_pos_emb or self.attn_config['alibi'] or self.attn_config['rope']): + if not (self.learned_pos_emb or self.attn_config['alibi'] or + self.attn_config['rope']): warnings.warn( f'Positional information not being provided to the model using either learned_pos_emb or alibi or rope.' ) @@ -237,6 +244,12 @@ def _validate_config(self) -> None: + 'pip install flash-attn==1.0.6 --no-build-isolation \n' + 'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156' ) + if self.attn_config['rope_scaling_type'] not in [ + 'no_scaling', 'linear', 'dynamic' + ]: + raise ValueError( + 'rope_scaling_type should be one of "no_scaling", "linear" or "dynamic".' + ) if self.ffn_config['ffn_type'] == 'mptmlp': self.ffn_config['fc_type'] = self.fc_type elif self.ffn_config['ffn_type'] == 'te_ln_mlp': diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index c9f2f76c36..741b21f26c 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -38,7 +38,9 @@ from llmfoundry.models.layers.ffn import MPTMLP as MPTMLP from llmfoundry.models.layers.ffn import build_ffn as build_ffn from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY -from llmfoundry.models.layers.rotary_embedding import RotaryEmbedding +from llmfoundry.models.layers.rotary_embedding import ( + DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding, + RotaryEmbedding) from llmfoundry.models.mpt.configuration_mpt import MPTConfig # NOTE: All utils are imported directly even if unused so that @@ -147,8 +149,10 @@ def __init__(self, config: MPTConfig): self.rope = config.attn_config['rope'] self.rope_head_dim = config.d_model // config.n_heads + self.rope_theta = config.attn_config['rope_theta'] + self.rope_scaling_type = config.attn_config['rope_scaling_type'] + self.rope_scaling_factor = config.attn_config['rope_scaling_factor'] self.rope_max_seq_len = config.max_seq_len - self.rope_bf = config.attn_config['rope_bf'] self._rotary_embedding_initialized = False self.rotary_embedding = None @@ -246,21 +250,25 @@ def _attn_bias( return attn_bias, None @torch.no_grad() - def _rotary_emb(self, device: torch.device, dtype: torch.dtype, - seq_len: int, pos: Union[torch.Tensor, None]): + def _rotary_emb(self): if not self._rotary_embedding_initialized: - if self.rope: - self.rotary_embedding = RotaryEmbedding( + if self.rope_scaling_type == 'no_scaling': + self.rotary_embedding = RotaryEmbedding(self.rope_head_dim, + base=self.rope_theta) + elif self.rope_scaling_type == 'linear': + self.rotary_embedding = LinearScalingRotaryEmbedding( + self.rope_head_dim, + base=self.rope_theta, + scaling_factor=self.rope_scaling_factor) + elif self.rope_scaling_type == 'dynamic': + self.rotary_embedding = DynamicNTKScalingRotaryEmbedding( self.rope_head_dim, - max_position_embeddings=self.rope_max_seq_len, - base=self.rope_bf, - device=device, - dtype=dtype) + base=self.rope_theta, + scaling_factor=self.rope_scaling_factor, + max_position_embeddings=self.rope_max_seq_len) self._rotary_embedding_initialized = True - if self.rotary_embedding is None or pos is None: - return None - return (*(self.rotary_embedding(dtype, device, seq_len)), pos) + return self.rotary_embedding def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor) -> torch.Tensor: @@ -386,8 +394,9 @@ def forward( S <= self.config.max_seq_len ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' + rotary_emb_w_offset_info = None + pos_emb = 0.0 tok_emb = self.wte(input_ids) - pos = None if self.learned_pos_emb or self.rope: past_position = 0 if past_key_values is not None: @@ -404,7 +413,8 @@ def forward( if self.attn_impl == 'torch': past_position = past_key_values[0][0].size(3) - if S + past_position > self.config.max_seq_len: + if self.learned_pos_emb and (S + past_position > + self.config.max_seq_len): raise ValueError( f'Cannot forward input with past sequence length {past_position} and current sequence length ' + @@ -423,10 +433,17 @@ def forward( dim=1)[:, past_position:], min=0, ) + if self.rope: + rotary_emb = self._rotary_emb() + rotary_emb_w_offset_info = { + 'rotary_emb': rotary_emb, + 'pos': pos, + 'seq_len': S + past_position + } + if self.learned_pos_emb: + pos_emb = self.wpe(pos) - x = tok_emb - if self.learned_pos_emb: - x += self.wpe(pos) + x = tok_emb + pos_emb if self.embedding_fraction == 1: x = self.emb_drop(x) @@ -445,8 +462,6 @@ def forward( sequence_id=sequence_id, ) - rotary_emb = self._rotary_emb(x.device, x.dtype, S, pos) - # initialize the past key values cache if it should be used if use_cache and past_key_values is None: past_key_values = [() for _ in range(self.config.n_layers) @@ -464,7 +479,7 @@ def forward( x, past_key_value=past_key_value, attn_bias=attn_bias, - rotary_emb=rotary_emb, + rotary_emb_w_offset_info=rotary_emb_w_offset_info, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions), From 829b2a47a6f7995baaf89b3f157e426e004daa1e Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 12 Oct 2023 05:30:47 +0000 Subject: [PATCH 030/106] .. --- tests/test_flash_triton_torch.py | 75 ++++++++++++++++++++++++++++++-- 1 file changed, 71 insertions(+), 4 deletions(-) diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index 145d4a5885..ee4acaea49 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -19,7 +19,31 @@ def allclose_helper(t0: torch.Tensor, @pytest.mark.parametrize('attn_impl_1', ['flash', 'triton', 'torch']) @pytest.mark.parametrize('clip_qkv', [True, False]) @pytest.mark.parametrize('qk_ln', [True, False]) -@pytest.mark.parametrize('alibi', [True, False]) +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'no_scaling', + 'rope_scaling_factor': 1.0 +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'linear', + 'rope_scaling_factor': 1.0 +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'dynamic', + 'rope_scaling_factor': 1.0 +}]) @pytest.mark.parametrize( 'attn_type', ['multihead_attention', 'multiquery_attention', 'grouped_query_attention']) @@ -27,7 +51,7 @@ def test_attn_impl(attn_impl_0: str, attn_impl_1: str, clip_qkv: bool, qk_ln: bool, - alibi: bool, + pos_emb_config: dict, attn_type: str, device: str = 'cuda'): """Compare all attn impl with each other. @@ -35,7 +59,11 @@ def test_attn_impl(attn_impl_0: str, Includes testing with and without attn_clip_qkv, attn_qk_ln, and alibi. """ from llmfoundry.models.layers import attention - + from llmfoundry.models.layers.rotary_embedding import ( + DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding, + RotaryEmbedding) + alibi = pos_emb_config['alibi'] + rope = pos_emb_config['rope'] if alibi and (attn_impl_0 == 'flash' or attn_impl_1 == 'flash'): pytest.xfail('flash attn does not support alibi') @@ -51,7 +79,8 @@ def test_attn_impl(attn_impl_0: str, }) n, s, f = 2, 16, cfg.d_model - + assert cfg.d_model % cfg.n_heads == 0 + rope_head_dim = cfg.d_model // cfg.n_heads if attn_type == 'grouped_query_attention': cfg.kv_n_heads = 2 @@ -87,6 +116,27 @@ def gen_bias(attn_impl: str): return attn_bias + def gen_rotary_emb(): + if pos_emb_config['rope_scaling_type'] == 'no_scaling': + rotary_embedding = RotaryEmbedding( + rope_head_dim, base=pos_emb_config['rope_theta']) + elif pos_emb_config['rope_scaling_type'] == 'linear': + rotary_embedding = LinearScalingRotaryEmbedding( + rope_head_dim, + base=pos_emb_config['rope_theta'], + scaling_factor=pos_emb_config['rope_scaling_factor']) + elif pos_emb_config['rope_scaling_type'] == 'dynamic': + rotary_embedding = DynamicNTKScalingRotaryEmbedding( + rope_head_dim, + base=pos_emb_config['rope_theta'], + scaling_factor=pos_emb_config['rope_scaling_factor'], + max_position_embeddings=s) + else: + raise ValueError( + 'rope_scaling_type should be one no_scaling, linear, or dynamic' + ) + return rotary_embedding + x0 = torch.randn(n, s, f).to(device) x1 = x0.clone().detach() x0.requires_grad = True @@ -94,16 +144,33 @@ def gen_bias(attn_impl: str): with torch.autocast(x0.device.type): attn_bias = gen_bias(attn0.attn_impl) + pos = torch.arange(s).unsqueeze(0).to(device=device) + # adjust the position indices to account for padding tokens + pos = torch.clamp( + pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1), + min=0, + ) + + rotary_emb_w_offset_info = None + if rope: + rotary_emb = gen_rotary_emb() + rotary_emb_w_offset_info = { + 'rotary_emb': rotary_emb, + 'pos': pos, + 'seq_len': s + } y0, _, _ = attn0(x0, past_key_value=None, attn_bias=attn_bias, attention_mask=attention_mask, + rotary_emb_w_offset_info=rotary_emb_w_offset_info, is_causal=True) attn_bias = gen_bias(attn1.attn_impl) y1, _, _ = attn1(x1, past_key_value=None, attn_bias=attn_bias, attention_mask=attention_mask, + rotary_emb_w_offset_info=rotary_emb_w_offset_info, is_causal=True) y0 *= attention_mask.unsqueeze(-1) y1 *= attention_mask.unsqueeze(-1) From 1629d1a0f63b51655df127e18fc9e051c52e58b5 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 12 Oct 2023 18:09:10 +0000 Subject: [PATCH 031/106] .. --- tests/test_flash_triton_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index ee4acaea49..f5903eb21a 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -56,7 +56,7 @@ def test_attn_impl(attn_impl_0: str, device: str = 'cuda'): """Compare all attn impl with each other. - Includes testing with and without attn_clip_qkv, attn_qk_ln, and alibi. + Includes testing with and without attn_clip_qkv, attn_qk_ln, alibi, and rope. """ from llmfoundry.models.layers import attention from llmfoundry.models.layers.rotary_embedding import ( From eb658a39d0990482357f8250cc1a47fb1a96dc52 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 12 Oct 2023 20:15:56 +0000 Subject: [PATCH 032/106] added unit test to test rotary embeddings --- tests/test_rotary_embedding.py | 93 ++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 tests/test_rotary_embedding.py diff --git a/tests/test_rotary_embedding.py b/tests/test_rotary_embedding.py new file mode 100644 index 0000000000..f08f84444b --- /dev/null +++ b/tests/test_rotary_embedding.py @@ -0,0 +1,93 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from composer.utils import reproducibility + + +def allclose_helper(t0: torch.Tensor, + t1: torch.Tensor, + rtol: float = 1e-2, + atol: float = 1e-2): + return torch.allclose(t0, t1, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize('device', ['cpu', 'cuda']) +@pytest.mark.parametrize('dtype', [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize('rope_scaling_type', ['no_scaling', 'linear', 'dynamic']) +@pytest.mark.parametrize('tensor_type', ['query', 'key']) +def test_rotation_scaling_factor_1(device: str, dtype: torch.dtype, rope_scaling_type: str, tensor_type: str): + """Checks all the rotation embedding techniques with scaling factor 1.""" + from llmfoundry.models.layers.rotary_embedding import ( + DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding, + RotaryEmbedding, apply_rotary_pos_emb) + + reproducibility.seed_all(7) + + rope_head_dim = 8 + assert rope_head_dim % 2 == 0 + rope_theta = 5 + rope_scaling_factor = 1.0 + + seq_len = 7 + batch_size = 1 + num_heads = 1 + pos = torch.arange(seq_len, device=device, dtype=torch.long).repeat(batch_size, 1) # + + # x will test the first half cosine part of the rotation and second half of the sine part + x = torch.ones((batch_size, num_heads, seq_len, rope_head_dim), device=device, dtype=dtype) + x[..., rope_head_dim//2:] = 0.0 + + # y will test the first half sine part of the rotation and second half of the cosine part + y = torch.ones((batch_size, num_heads, seq_len, rope_head_dim), device=device, dtype=dtype) + y[..., :rope_head_dim//2] = 0.0 + + z = torch.zeros((batch_size, num_heads, seq_len, rope_head_dim), device=device, dtype=dtype) + + def gen_rotary_emb(): + if rope_scaling_type == 'no_scaling': + rotary_embedding = RotaryEmbedding( + rope_head_dim, base=rope_theta) + elif rope_scaling_type == 'linear': + rotary_embedding = LinearScalingRotaryEmbedding( + rope_head_dim, + base=rope_theta, + scaling_factor=rope_scaling_factor) + elif rope_scaling_type == 'dynamic': + rotary_embedding = DynamicNTKScalingRotaryEmbedding( + rope_head_dim, + base=rope_theta, + scaling_factor=rope_scaling_factor, + max_position_embeddings=seq_len) + else: + raise ValueError( + 'rope_scaling_type should be one no_scaling, linear, or dynamic' + ) + return rotary_embedding + + rotary_emb = gen_rotary_emb() + (cos, sin) = rotary_emb(x, seq_len) + assert allclose_helper(cos[:, :rope_head_dim//2], cos[:, rope_head_dim//2:]) + assert allclose_helper(sin[:, :rope_head_dim//2], sin[:, rope_head_dim//2:]) + + assert allclose_helper(cos*cos + sin*sin, torch.ones_like(cos)) + + if tensor_type == 'query': + x, _ = apply_rotary_pos_emb(x, z, cos, sin, pos) + y, _ = apply_rotary_pos_emb(y, z, cos, sin, pos) + elif tensor_type == 'key': + _, x = apply_rotary_pos_emb(z, x, cos, sin, pos) + _, y = apply_rotary_pos_emb(z, y, cos, sin, pos) + + assert allclose_helper(x[..., :rope_head_dim//2], y[..., rope_head_dim//2:]) + assert allclose_helper(x[..., rope_head_dim//2:], -y[..., :rope_head_dim//2]) + + inv_freq = ( + 1.0 / (rope_theta**(torch.arange(0, rope_head_dim, 2) / rope_head_dim))).to(x) + t = torch.arange(seq_len).to(x) + + expected_rotation_angles = torch.outer(t, inv_freq) + + assert allclose_helper(x[..., :rope_head_dim//2], expected_rotation_angles.cos()) + assert allclose_helper(x[..., rope_head_dim//2:], expected_rotation_angles.sin()) \ No newline at end of file From 5cb95f6ce6a9fc3b68726d675e5d733ed299ad4f Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 12 Oct 2023 20:17:47 +0000 Subject: [PATCH 033/106] .. --- tests/test_rotary_embedding.py | 53 ++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/tests/test_rotary_embedding.py b/tests/test_rotary_embedding.py index f08f84444b..7e9cf27d87 100644 --- a/tests/test_rotary_embedding.py +++ b/tests/test_rotary_embedding.py @@ -15,9 +15,11 @@ def allclose_helper(t0: torch.Tensor, @pytest.mark.parametrize('device', ['cpu', 'cuda']) @pytest.mark.parametrize('dtype', [torch.float32, torch.bfloat16]) -@pytest.mark.parametrize('rope_scaling_type', ['no_scaling', 'linear', 'dynamic']) +@pytest.mark.parametrize('rope_scaling_type', + ['no_scaling', 'linear', 'dynamic']) @pytest.mark.parametrize('tensor_type', ['query', 'key']) -def test_rotation_scaling_factor_1(device: str, dtype: torch.dtype, rope_scaling_type: str, tensor_type: str): +def test_rotation_scaling_factor_1(device: str, dtype: torch.dtype, + rope_scaling_type: str, tensor_type: str): """Checks all the rotation embedding techniques with scaling factor 1.""" from llmfoundry.models.layers.rotary_embedding import ( DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding, @@ -33,22 +35,28 @@ def test_rotation_scaling_factor_1(device: str, dtype: torch.dtype, rope_scaling seq_len = 7 batch_size = 1 num_heads = 1 - pos = torch.arange(seq_len, device=device, dtype=torch.long).repeat(batch_size, 1) # + pos = torch.arange(seq_len, device=device, + dtype=torch.long).repeat(batch_size, 1) # # x will test the first half cosine part of the rotation and second half of the sine part - x = torch.ones((batch_size, num_heads, seq_len, rope_head_dim), device=device, dtype=dtype) - x[..., rope_head_dim//2:] = 0.0 - + x = torch.ones((batch_size, num_heads, seq_len, rope_head_dim), + device=device, + dtype=dtype) + x[..., rope_head_dim // 2:] = 0.0 + # y will test the first half sine part of the rotation and second half of the cosine part - y = torch.ones((batch_size, num_heads, seq_len, rope_head_dim), device=device, dtype=dtype) - y[..., :rope_head_dim//2] = 0.0 + y = torch.ones((batch_size, num_heads, seq_len, rope_head_dim), + device=device, + dtype=dtype) + y[..., :rope_head_dim // 2] = 0.0 - z = torch.zeros((batch_size, num_heads, seq_len, rope_head_dim), device=device, dtype=dtype) + z = torch.zeros((batch_size, num_heads, seq_len, rope_head_dim), + device=device, + dtype=dtype) def gen_rotary_emb(): if rope_scaling_type == 'no_scaling': - rotary_embedding = RotaryEmbedding( - rope_head_dim, base=rope_theta) + rotary_embedding = RotaryEmbedding(rope_head_dim, base=rope_theta) elif rope_scaling_type == 'linear': rotary_embedding = LinearScalingRotaryEmbedding( rope_head_dim, @@ -68,10 +76,12 @@ def gen_rotary_emb(): rotary_emb = gen_rotary_emb() (cos, sin) = rotary_emb(x, seq_len) - assert allclose_helper(cos[:, :rope_head_dim//2], cos[:, rope_head_dim//2:]) - assert allclose_helper(sin[:, :rope_head_dim//2], sin[:, rope_head_dim//2:]) + assert allclose_helper(cos[:, :rope_head_dim // 2], + cos[:, rope_head_dim // 2:]) + assert allclose_helper(sin[:, :rope_head_dim // 2], + sin[:, rope_head_dim // 2:]) - assert allclose_helper(cos*cos + sin*sin, torch.ones_like(cos)) + assert allclose_helper(cos * cos + sin * sin, torch.ones_like(cos)) if tensor_type == 'query': x, _ = apply_rotary_pos_emb(x, z, cos, sin, pos) @@ -80,14 +90,19 @@ def gen_rotary_emb(): _, x = apply_rotary_pos_emb(z, x, cos, sin, pos) _, y = apply_rotary_pos_emb(z, y, cos, sin, pos) - assert allclose_helper(x[..., :rope_head_dim//2], y[..., rope_head_dim//2:]) - assert allclose_helper(x[..., rope_head_dim//2:], -y[..., :rope_head_dim//2]) + assert allclose_helper(x[..., :rope_head_dim // 2], y[..., + rope_head_dim // 2:]) + assert allclose_helper(x[..., rope_head_dim // 2:], + -y[..., :rope_head_dim // 2]) inv_freq = ( - 1.0 / (rope_theta**(torch.arange(0, rope_head_dim, 2) / rope_head_dim))).to(x) + 1.0 / + (rope_theta**(torch.arange(0, rope_head_dim, 2) / rope_head_dim))).to(x) t = torch.arange(seq_len).to(x) expected_rotation_angles = torch.outer(t, inv_freq) - assert allclose_helper(x[..., :rope_head_dim//2], expected_rotation_angles.cos()) - assert allclose_helper(x[..., rope_head_dim//2:], expected_rotation_angles.sin()) \ No newline at end of file + assert allclose_helper(x[..., :rope_head_dim // 2], + expected_rotation_angles.cos()) + assert allclose_helper(x[..., rope_head_dim // 2:], + expected_rotation_angles.sin()) From 30aa448a3dfd6829f784c09ac5bd96e6ed024f72 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 12 Oct 2023 22:12:24 +0000 Subject: [PATCH 034/106] .. --- llmfoundry/models/layers/rotary_embedding.py | 4 +- tests/test_model.py | 256 ++++++++++++++++--- tests/test_rotary_embedding.py | 3 +- 3 files changed, 227 insertions(+), 36 deletions(-) diff --git a/llmfoundry/models/layers/rotary_embedding.py b/llmfoundry/models/layers/rotary_embedding.py index 7b2ee7ab05..cb488ceca8 100644 --- a/llmfoundry/models/layers/rotary_embedding.py +++ b/llmfoundry/models/layers/rotary_embedding.py @@ -47,8 +47,8 @@ def forward(self, x: torch.Tensor, seq_len: int): self._set_cos_sin_cache(seq_len=seq_len, x=x) return ( - self.cos_cached[:seq_len], - self.sin_cached[:seq_len], + self.cos_cached[:seq_len].to(x), + self.sin_cached[:seq_len].to(x), ) diff --git a/tests/test_model.py b/tests/test_model.py index 6ea530731a..134d2efea0 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -527,13 +527,39 @@ def test_mpt_creation(norm_type: str, no_bias: bool): ('flash', 'gpu'), ('triton', 'gpu'), ('torch', 'gpu')]) -@pytest.mark.parametrize('alibi', [True, False]) -def test_forward_with_padding(attention_impl: str, device: str, alibi: bool): +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'no_scaling', + 'rope_scaling_factor': 1.0 +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'linear', + 'rope_scaling_factor': 1.0 +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'dynamic', + 'rope_scaling_factor': 1.0 +}]) +def test_forward_with_padding(attention_impl: str, device: str, + pos_emb_config: dict): # Test that different placement of padding does not affect the output. if not torch.cuda.is_available() and device == 'gpu': pytest.skip( f'This test requires CUDA to be available in order to run with {attention_impl} attention.' ) + alibi = pos_emb_config['alibi'] if alibi and attention_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') @@ -551,7 +577,7 @@ def test_forward_with_padding(attention_impl: str, device: str, alibi: bool): resid_pdrop=0.2, attn_config={ 'attn_impl': attention_impl, - 'alibi': alibi, + **pos_emb_config, }, init_config={ 'name': 'baseline_', @@ -705,15 +731,39 @@ def test_advanced_mask_building(attention_impl: str): ('flash', 'gpu'), ('triton', 'gpu'), ('torch', 'gpu')]) -@pytest.mark.parametrize('alibi', [True, False]) -def test_generate(attention_impl: str, device: str, alibi: bool): +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'no_scaling', + 'rope_scaling_factor': 1.0 +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'linear', + 'rope_scaling_factor': 1.0 +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'dynamic', + 'rope_scaling_factor': 1.0 +}]) +def test_generate(attention_impl: str, device: str, pos_emb_config: dict): # Test that generate works, and produces the same output with or without # padding in the input. if not torch.cuda.is_available() and device == 'gpu': pytest.skip( f'This test requires CUDA to be available in order to run with {attention_impl} attention.' ) - if alibi and attention_impl == 'flash': + if pos_emb_config['alibi'] and attention_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') reproducibility.seed_all(1234) @@ -730,7 +780,7 @@ def test_generate(attention_impl: str, device: str, alibi: bool): resid_pdrop=0.2, attn_config={ 'attn_impl': attention_impl, - 'alibi': alibi, + **pos_emb_config, }, ) mpt = MPTForCausalLM(hf_config) @@ -900,8 +950,32 @@ def test_save_from_pretrained(tmp_path: pathlib.Path): check_hf_model_equivalence(mpt, mpt2) -@pytest.mark.parametrize('alibi', [True, False]) -def test_forward_with_cache_and_padding(alibi: bool): +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'no_scaling', + 'rope_scaling_factor': 1.0 +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'linear', + 'rope_scaling_factor': 1.0 +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'dynamic', + 'rope_scaling_factor': 1.0 +}]) +def test_forward_with_cache_and_padding(pos_emb_config: dict): # Tests that the result is the same with or without padding when using kv caching hf_config = MPTConfig( init_device='cpu', @@ -914,7 +988,7 @@ def test_forward_with_cache_and_padding(alibi: bool): resid_pdrop=0.2, attn_config={ 'attn_impl': 'torch', - 'alibi': alibi, + **pos_emb_config, }, use_cache=True, init_config={ @@ -973,15 +1047,39 @@ def test_forward_with_cache_and_padding(alibi: bool): ('triton', 'gpu'), ('torch', 'gpu'), ]) -@pytest.mark.parametrize('alibi', [True, False]) -def test_forward_with_cache(attn_impl: str, device: str, alibi: bool): +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'no_scaling', + 'rope_scaling_factor': 1.0 +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'linear', + 'rope_scaling_factor': 1.0 +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'dynamic', + 'rope_scaling_factor': 1.0 +}]) +def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: bool): # Test that model forward with and without the key-value cache produces the # same output. if not torch.cuda.is_available() and device == 'gpu': pytest.skip( f'This test requires CUDA to be available in order to run with {attn_impl} attention.' ) - if alibi and attn_impl == 'flash': + if pos_emb_config['alibi'] and attn_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') composer_device = get_device(device) @@ -997,10 +1095,8 @@ def test_forward_with_cache(attn_impl: str, device: str, alibi: bool): resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, - 'alibi': alibi, + **pos_emb_config, }, - attn_impl=attn_impl, - alibi=alibi, use_cache=True, init_config={ 'name': 'baseline_', @@ -1084,8 +1180,32 @@ def test_forward_with_cache(attn_impl: str, device: str, alibi: bool): ) -@pytest.mark.parametrize('alibi', [True, False]) -def test_generate_with_past_kv(alibi: bool): +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'no_scaling', + 'rope_scaling_factor': 1.0 +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'linear', + 'rope_scaling_factor': 1.0 +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'dynamic', + 'rope_scaling_factor': 1.0 +}]) +def test_generate_with_past_kv(pos_emb_config: dict): hf_config = MPTConfig( init_device='cpu', d_model=128, @@ -1097,7 +1217,7 @@ def test_generate_with_past_kv(alibi: bool): resid_pdrop=0.2, attn_config={ 'attn_impl': 'torch', - 'alibi': alibi, + **pos_emb_config, }, use_cache=True, init_config={ @@ -1144,9 +1264,33 @@ def test_generate_with_past_kv(alibi: bool): 'do_sample': True, 'top_p': 0.95 }]) -@pytest.mark.parametrize('alibi', [True, False]) +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'no_scaling', + 'rope_scaling_factor': 1.0 +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'linear', + 'rope_scaling_factor': 1.0 +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'dynamic', + 'rope_scaling_factor': 1.0 +}]) def test_generation_kwargs_dont_crash(generation_kwargs: Dict[str, Any], - alibi: bool): + pos_emb_config: dict): hf_config = MPTConfig( init_device='cpu', d_model=128, @@ -1158,7 +1302,7 @@ def test_generation_kwargs_dont_crash(generation_kwargs: Dict[str, Any], resid_pdrop=0.2, attn_config={ 'attn_impl': 'torch', - 'alibi': alibi, + **pos_emb_config, }, use_cache=True, ) @@ -1176,14 +1320,38 @@ def test_generation_kwargs_dont_crash(generation_kwargs: Dict[str, Any], @pytest.mark.gpu @pytest.mark.parametrize('attention_impl', ['torch', 'flash', 'triton']) -@pytest.mark.parametrize('alibi', [True, False]) -def test_model_to(attention_impl: str, alibi: bool): +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'no_scaling', + 'rope_scaling_factor': 1.0 +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'linear', + 'rope_scaling_factor': 1.0 +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'dynamic', + 'rope_scaling_factor': 1.0 +}]) +def test_model_to(attention_impl: str, pos_emb_config: dict): # test that moving the model to diff devices and dtypes in diff ways does not break the model if not torch.cuda.is_available(): pytest.skip( f'This test requires CUDA to be available in order to run with {attention_impl} attention.' ) - if alibi and attention_impl == 'flash': + if pos_emb_config['alibi'] and attention_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') hf_config = MPTConfig( @@ -1197,7 +1365,7 @@ def test_model_to(attention_impl: str, alibi: bool): resid_pdrop=0.2, attn_config={ 'attn_impl': attention_impl, - 'alibi': alibi, + **pos_emb_config, }, use_cache=True, init_config={ @@ -1277,18 +1445,42 @@ def test_alibi_vs_hf(): ('triton', 'gpu'), ('torch', 'gpu'), ]) -@pytest.mark.parametrize('alibi', [True, False]) +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'no_scaling', + 'rope_scaling_factor': 1.0 +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'linear', + 'rope_scaling_factor': 1.0 +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_scaling_type': 'dynamic', + 'rope_scaling_factor': 1.0 +}]) @pytest.mark.parametrize('output_attentions', [True, False]) @pytest.mark.parametrize('output_hidden_states', [True, False]) def test_forward_with_output_attentions_and_output_hidden_states( - attn_impl: str, device: str, alibi: bool, output_attentions: bool, - output_hidden_states: bool): + attn_impl: str, device: str, pos_emb_config: dict, + output_attentions: bool, output_hidden_states: bool): # Test that model forward with output_attentions_and_output_hidden_states if not torch.cuda.is_available() and device == 'gpu': pytest.skip( f'This test requires CUDA to be available in order to run with {attn_impl} attention.' ) - if alibi and attn_impl == 'flash': + if pos_emb_config['alibi'] and attn_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') if output_attentions and attn_impl in ['flash', 'triton']: pytest.skip(f'output_attentions only implemented with torch attention.') @@ -1308,10 +1500,8 @@ def test_forward_with_output_attentions_and_output_hidden_states( resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, - 'alibi': alibi, + **pos_emb_config, }, - attn_impl=attn_impl, - alibi=alibi, use_cache=True, init_config={ 'name': 'baseline_', diff --git a/tests/test_rotary_embedding.py b/tests/test_rotary_embedding.py index 7e9cf27d87..9d257a9392 100644 --- a/tests/test_rotary_embedding.py +++ b/tests/test_rotary_embedding.py @@ -20,7 +20,8 @@ def allclose_helper(t0: torch.Tensor, @pytest.mark.parametrize('tensor_type', ['query', 'key']) def test_rotation_scaling_factor_1(device: str, dtype: torch.dtype, rope_scaling_type: str, tensor_type: str): - """Checks all the rotation embedding techniques with scaling factor 1.""" + """Checks all the rotation embedding techniques (with scaling factor 1) + produce the expected rotation.""" from llmfoundry.models.layers.rotary_embedding import ( DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding, RotaryEmbedding, apply_rotary_pos_emb) From 52119f5720affadb3758614654bc1db94597dfb2 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 13 Oct 2023 00:49:22 +0000 Subject: [PATCH 035/106] .. --- llmfoundry/models/layers/blocks.py | 17 ++- llmfoundry/models/mpt/configuration_mpt.py | 15 ++- llmfoundry/models/mpt/modeling_mpt.py | 4 +- tests/test_flash_triton_torch.py | 33 +++-- tests/test_model.py | 144 ++++++++++++++------- 5 files changed, 139 insertions(+), 74 deletions(-) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 3f86510e13..9349c91ce9 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -43,8 +43,10 @@ def __init__( 'alibi_bias_max': 8, 'rope': False, 'rope_theta': 10000, - 'rope_scaling_type': 'no_scaling', - 'rope_scaling_factor': 1.0, + 'rope_scaling': { + 'type': 'no_scaling', + 'factor': 1.0 + } } if ffn_config is None: @@ -61,9 +63,14 @@ def __init__( # necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs args_to_exclude_in_attn_class = { - 'attn_type', 'prefix_lm', 'alibi', 'attn_uses_sequence_id', - 'alibi_bias_max', 'rope', 'rope_theta', 'rope_scaling_type', - 'rope_scaling_factor' + 'attn_type', + 'prefix_lm', + 'alibi', + 'attn_uses_sequence_id', + 'alibi_bias_max', + 'rope', + 'rope_theta', + 'rope_scaling', } attn_config_subset_for_attn_class = { k: v diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index faa2c2fac9..652bbdab47 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -21,8 +21,10 @@ 'alibi_bias_max': 8, 'rope': False, 'rope_theta': 10000, - 'rope_scaling_type': 'no_scaling', - 'rope_scaling_factor': 1.0, + 'rope_scaling': { + 'type': 'no_scaling', + 'factor': 1.0 + } } ffn_config_defaults: Dict = { @@ -100,8 +102,9 @@ def __init__( alibi_bias_max (int): The maximum value of the alibi bias. rope (bool): Whether to use rotary positional embeddings. rope_theta (int): The base frequency for rope. - rope_scaling_type (str): Can be one of 'no_scaling', 'linear', or 'dynamic'. 'no_scaling' uses the default implementation for rotary embeddings, 'linear' uses linear scaling as proposed by the Reddit user /u/kaiokendev, and 'dynamic' uses Dynamic NTK scaling as proposed by the Reddit users /u/bloc97 and /u/emozilla. - rope_scaling_factor (float): Scaling factor to use if using 'linear' or 'dynamic' as rope_scaling_type. + rope_scaling (Dict): A dictionary used to configure rope's scaling behavior (when scaling beyond the training length) + type (str): Can be one of 'no_scaling', 'linear', or 'dynamic'. 'no_scaling' uses the default implementation for rotary embeddings, 'linear' uses linear scaling as proposed by the Reddit user /u/kaiokendev, and 'dynamic' uses Dynamic NTK scaling as proposed by the Reddit users /u/bloc97 and /u/emozilla. + factor (float): Scaling factor to use if using 'linear' or 'dynamic' as rope_scaling.type. kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads. ffn_config (Dict): A dictionary used to configure the model's ffn module: ffn_type (str): type of ffn to use. Options: mptmlp, te_ln_mlp @@ -244,11 +247,11 @@ def _validate_config(self) -> None: + 'pip install flash-attn==1.0.6 --no-build-isolation \n' + 'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156' ) - if self.attn_config['rope_scaling_type'] not in [ + if self.attn_config['rope_scaling']['type'] not in [ 'no_scaling', 'linear', 'dynamic' ]: raise ValueError( - 'rope_scaling_type should be one of "no_scaling", "linear" or "dynamic".' + 'rope_scaling.type should be one of "no_scaling", "linear" or "dynamic".' ) if self.ffn_config['ffn_type'] == 'mptmlp': self.ffn_config['fc_type'] = self.fc_type diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 741b21f26c..34c918a283 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -150,8 +150,8 @@ def __init__(self, config: MPTConfig): self.rope = config.attn_config['rope'] self.rope_head_dim = config.d_model // config.n_heads self.rope_theta = config.attn_config['rope_theta'] - self.rope_scaling_type = config.attn_config['rope_scaling_type'] - self.rope_scaling_factor = config.attn_config['rope_scaling_factor'] + self.rope_scaling_type = config.attn_config['rope_scaling']['type'] + self.rope_scaling_factor = config.attn_config['rope_scaling']['factor'] self.rope_max_seq_len = config.max_seq_len self._rotary_embedding_initialized = False self.rotary_embedding = None diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index f5903eb21a..cc6438b008 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -29,20 +29,26 @@ def allclose_helper(t0: torch.Tensor, 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'no_scaling', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'no_scaling', + 'factor': 1.0 + } }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'linear', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'linear', + 'factor': 1.0 + } }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'dynamic', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'dynamic', + 'factor': 1.0 + } }]) @pytest.mark.parametrize( 'attn_type', @@ -56,7 +62,8 @@ def test_attn_impl(attn_impl_0: str, device: str = 'cuda'): """Compare all attn impl with each other. - Includes testing with and without attn_clip_qkv, attn_qk_ln, alibi, and rope. + Includes testing with and without attn_clip_qkv, attn_qk_ln, alibi, and + rope. """ from llmfoundry.models.layers import attention from llmfoundry.models.layers.rotary_embedding import ( @@ -117,23 +124,23 @@ def gen_bias(attn_impl: str): return attn_bias def gen_rotary_emb(): - if pos_emb_config['rope_scaling_type'] == 'no_scaling': + if pos_emb_config['rope_scaling']['type'] == 'no_scaling': rotary_embedding = RotaryEmbedding( rope_head_dim, base=pos_emb_config['rope_theta']) - elif pos_emb_config['rope_scaling_type'] == 'linear': + elif pos_emb_config['rope_scaling']['type'] == 'linear': rotary_embedding = LinearScalingRotaryEmbedding( rope_head_dim, base=pos_emb_config['rope_theta'], - scaling_factor=pos_emb_config['rope_scaling_factor']) - elif pos_emb_config['rope_scaling_type'] == 'dynamic': + scaling_factor=pos_emb_config['rope_scaling']['factor']) + elif pos_emb_config['rope_scaling']['type'] == 'dynamic': rotary_embedding = DynamicNTKScalingRotaryEmbedding( rope_head_dim, base=pos_emb_config['rope_theta'], - scaling_factor=pos_emb_config['rope_scaling_factor'], + scaling_factor=pos_emb_config['rope_scaling']['factor'], max_position_embeddings=s) else: raise ValueError( - 'rope_scaling_type should be one no_scaling, linear, or dynamic' + 'rope_scaling.type should be one no_scaling, linear, or dynamic' ) return rotary_embedding diff --git a/tests/test_model.py b/tests/test_model.py index 134d2efea0..529e28d857 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -537,20 +537,26 @@ def test_mpt_creation(norm_type: str, no_bias: bool): 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'no_scaling', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'no_scaling', + 'factor': 1.0 + } }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'linear', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'linear', + 'factor': 1.0 + } }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'dynamic', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'dynamic', + 'factor': 1.0 + } }]) def test_forward_with_padding(attention_impl: str, device: str, pos_emb_config: dict): @@ -741,20 +747,26 @@ def test_advanced_mask_building(attention_impl: str): 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'no_scaling', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'no_scaling', + 'factor': 1.0 + } }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'linear', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'linear', + 'factor': 1.0 + } }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'dynamic', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'dynamic', + 'factor': 1.0 + } }]) def test_generate(attention_impl: str, device: str, pos_emb_config: dict): # Test that generate works, and produces the same output with or without @@ -960,20 +972,26 @@ def test_save_from_pretrained(tmp_path: pathlib.Path): 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'no_scaling', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'no_scaling', + 'factor': 1.0 + } }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'linear', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'linear', + 'factor': 1.0 + } }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'dynamic', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'dynamic', + 'factor': 1.0 + } }]) def test_forward_with_cache_and_padding(pos_emb_config: dict): # Tests that the result is the same with or without padding when using kv caching @@ -1057,20 +1075,26 @@ def test_forward_with_cache_and_padding(pos_emb_config: dict): 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'no_scaling', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'no_scaling', + 'factor': 1.0 + } }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'linear', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'linear', + 'factor': 1.0 + } }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'dynamic', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'dynamic', + 'factor': 1.0 + } }]) def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: bool): # Test that model forward with and without the key-value cache produces the @@ -1190,20 +1214,26 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: bool): 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'no_scaling', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'no_scaling', + 'factor': 1.0 + } }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'linear', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'linear', + 'factor': 1.0 + } }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'dynamic', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'dynamic', + 'factor': 1.0 + } }]) def test_generate_with_past_kv(pos_emb_config: dict): hf_config = MPTConfig( @@ -1274,20 +1304,26 @@ def test_generate_with_past_kv(pos_emb_config: dict): 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'no_scaling', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'no_scaling', + 'factor': 1.0 + } }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'linear', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'linear', + 'factor': 1.0 + } }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'dynamic', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'dynamic', + 'factor': 1.0 + } }]) def test_generation_kwargs_dont_crash(generation_kwargs: Dict[str, Any], pos_emb_config: dict): @@ -1330,20 +1366,26 @@ def test_generation_kwargs_dont_crash(generation_kwargs: Dict[str, Any], 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'no_scaling', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'no_scaling', + 'factor': 1.0 + } }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'linear', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'linear', + 'factor': 1.0 + } }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'dynamic', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'dynamic', + 'factor': 1.0 + } }]) def test_model_to(attention_impl: str, pos_emb_config: dict): # test that moving the model to diff devices and dtypes in diff ways does not break the model @@ -1455,20 +1497,26 @@ def test_alibi_vs_hf(): 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'no_scaling', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'no_scaling', + 'factor': 1.0 + } }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'linear', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'linear', + 'factor': 1.0 + } }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_scaling_type': 'dynamic', - 'rope_scaling_factor': 1.0 + 'rope_scaling': { + 'type': 'dynamic', + 'factor': 1.0 + } }]) @pytest.mark.parametrize('output_attentions', [True, False]) @pytest.mark.parametrize('output_hidden_states', [True, False]) From 048b886f1b2eac6c0a79c813028228d9197ba9d8 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 13 Oct 2023 01:35:52 +0000 Subject: [PATCH 036/106] .. --- llmfoundry/models/layers/rotary_embedding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/models/layers/rotary_embedding.py b/llmfoundry/models/layers/rotary_embedding.py index cb488ceca8..7b2ee7ab05 100644 --- a/llmfoundry/models/layers/rotary_embedding.py +++ b/llmfoundry/models/layers/rotary_embedding.py @@ -47,8 +47,8 @@ def forward(self, x: torch.Tensor, seq_len: int): self._set_cos_sin_cache(seq_len=seq_len, x=x) return ( - self.cos_cached[:seq_len].to(x), - self.sin_cached[:seq_len].to(x), + self.cos_cached[:seq_len], + self.sin_cached[:seq_len], ) From c2ee0de9799f5327af2e7b2a8abb72b26df8bfac Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 13 Oct 2023 02:08:23 +0000 Subject: [PATCH 037/106] .. --- llmfoundry/models/layers/attention.py | 2 +- llmfoundry/models/mpt/modeling_mpt.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index f2260ab8d9..02da7297b1 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -12,10 +12,10 @@ from einops import rearrange from packaging import version from torch import nn +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY -from llmfoundry.models.layers.rotary_embedding import apply_rotary_pos_emb def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 34c918a283..cf44356ec8 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -28,6 +28,10 @@ from transformers import PreTrainedModel, PreTrainedTokenizerBase from transformers.modeling_outputs import (BaseModelOutputWithPast, CausalLMOutputWithPast) +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding as RotaryEmbedding +from transformers.models.llama.modeling_llama import LlamaLinearScalingRotaryEmbedding as LinearScalingRotaryEmbedding +from transformers.models.llama.modeling_llama import LlamaDynamicNTKScalingRotaryEmbedding as DynamicNTKScalingRotaryEmbedding + from llmfoundry.models.layers.attention import attn_bias_shape, build_attn_bias from llmfoundry.models.layers.blocks import MPTBlock @@ -38,9 +42,6 @@ from llmfoundry.models.layers.ffn import MPTMLP as MPTMLP from llmfoundry.models.layers.ffn import build_ffn as build_ffn from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY -from llmfoundry.models.layers.rotary_embedding import ( - DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding, - RotaryEmbedding) from llmfoundry.models.mpt.configuration_mpt import MPTConfig # NOTE: All utils are imported directly even if unused so that @@ -154,7 +155,7 @@ def __init__(self, config: MPTConfig): self.rope_scaling_factor = config.attn_config['rope_scaling']['factor'] self.rope_max_seq_len = config.max_seq_len self._rotary_embedding_initialized = False - self.rotary_embedding = None + self.rotary_embedding = self._rotary_emb() if config.no_bias: for module in self.modules(): From 9e8a1d64455c078d20dc395f26ba871f183788ed Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 13 Oct 2023 02:41:25 +0000 Subject: [PATCH 038/106] .. --- llmfoundry/models/layers/__init__.py | 7 -- llmfoundry/models/layers/rotary_embedding.py | 111 ------------------- llmfoundry/models/mpt/modeling_mpt.py | 49 ++++---- tests/test_flash_triton_torch.py | 7 +- tests/test_rotary_embedding.py | 109 ------------------ 5 files changed, 22 insertions(+), 261 deletions(-) delete mode 100644 llmfoundry/models/layers/rotary_embedding.py delete mode 100644 tests/test_rotary_embedding.py diff --git a/llmfoundry/models/layers/__init__.py b/llmfoundry/models/layers/__init__.py index 2cb539fadb..68aa0fe7fe 100644 --- a/llmfoundry/models/layers/__init__.py +++ b/llmfoundry/models/layers/__init__.py @@ -10,9 +10,6 @@ from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY, LPLayerNorm -from llmfoundry.models.layers.rotary_embedding import ( - DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding, - RotaryEmbedding, apply_rotary_pos_emb) __all__ = [ 'scaled_multihead_dot_product_attention', @@ -32,8 +29,4 @@ 'SharedEmbedding', 'FFN_CLASS_REGISTRY', 'build_ffn', - 'RotaryEmbedding', - 'LinearScalingRotaryEmbedding', - 'DynamicNTKScalingRotaryEmbedding', - 'apply_rotary_pos_emb', ] diff --git a/llmfoundry/models/layers/rotary_embedding.py b/llmfoundry/models/layers/rotary_embedding.py deleted file mode 100644 index 7b2ee7ab05..0000000000 --- a/llmfoundry/models/layers/rotary_embedding.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2022 MosaicML LLM Foundry authors -# SPDX-License-Identifier: Apache-2.0 - -# Code modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py - -import torch -from torch import nn - - -class RotaryEmbedding(nn.Module): - - def __init__(self, dim: int, base: float): - super().__init__() - self.dim = dim - self.base = base - self.max_seq_len_cached = -1 - - self.caches_initialized = False - self.cos_cached = torch.Tensor() - self.sin_cached = torch.Tensor() - - def _set_cos_sin_cache(self, x: torch.Tensor, seq_len: int): - self.max_seq_len_cached = seq_len - inv_freq = self._get_inv_freq(x, seq_len) - t = self._get_t(x) - freqs = torch.einsum('i,j->ij', t, inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.cos_cached = emb.cos() - self.sin_cached = emb.sin() - self.caches_initialized = True - - def _get_t(self, x: torch.Tensor): - t = torch.arange(self.max_seq_len_cached).to(x) - return t - - def _get_inv_freq(self, x: torch.Tensor, seq_len: int): - del seq_len - inv_freq = ( - 1.0 / (self.base**(torch.arange(0, self.dim, 2) / self.dim))).to(x) - return inv_freq - - @torch.no_grad() - def forward(self, x: torch.Tensor, seq_len: int): - # x is only used to get the correct dtype and device - if (not self.caches_initialized) or seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, x=x) - - return ( - self.cos_cached[:seq_len], - self.sin_cached[:seq_len], - ) - - -class LinearScalingRotaryEmbedding(RotaryEmbedding): - """RotaryEmbedding extended with linear scaling. - - Credits to the Reddit user /u/kaiokendev - """ - - def __init__(self, dim: int, base: float, scaling_factor: float): - self.scaling_factor = scaling_factor - super().__init__(dim, base) - - def _get_t(self, x: torch.Tensor): - t = (torch.arange(self.max_seq_len_cached) / self.scaling_factor).to(x) - return t - - -class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): - """RotaryEmbedding extended with Dynamic NTK scaling. - - Credits to the Reddit users /u/bloc97 and /u/emozilla - """ - - def __init__(self, dim: int, base: float, scaling_factor: float, - max_position_embeddings: float): - self.scaling_factor = scaling_factor - self.max_position_embeddings = max_position_embeddings - super().__init__(dim, base) - - def _get_inv_freq(self, x: torch.Tensor, seq_len: int): - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - - (self.scaling_factor - 1))**(self.dim / (self.dim - 2)) - inv_freq = (1.0 / - (base**(torch.arange(0, self.dim, 2) / self.dim))).to(x) - else: - inv_freq = ( - 1.0 / - (self.base**(torch.arange(0, self.dim, 2) / self.dim))).to(x) - return inv_freq - - -def rotate_half(x: torch.Tensor): - """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.gpt_neox.modeling_gpt_neox.apply_rotary_pos_emb -def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, - sin: torch.Tensor, position_ids: torch.Tensor): - cos = cos[position_ids].unsqueeze( - 1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim] - sin = sin[position_ids].unsqueeze(1) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index cf44356ec8..d97b82010f 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -73,6 +73,22 @@ log = logging.getLogger(__name__) +def _rotary_embedding(config: Dict): + rope_head_dim = config.d_model // config.n_heads + if config.attn_config['rope_scaling']['type'] == 'no_scaling': + return RotaryEmbedding(rope_head_dim, + base=config.attn_config['rope_theta']) + elif config.attn_config['rope_scaling']['type'] == 'linear': + return LinearScalingRotaryEmbedding( + rope_head_dim, + base=config.attn_config['rope_theta'], + scaling_factor=config.attn_config['rope_scaling']['factor']) + elif config.attn_config['rope_scaling']['type'] == 'dynamic': + return DynamicNTKScalingRotaryEmbedding( + rope_head_dim, + base=config.attn_config['rope_theta'], + scaling_factor=config.attn_config['rope_scaling']['factor'], + max_position_embeddings=config.max_seq_len) class MPTPreTrainedModel(PreTrainedModel): config_class = MPTConfig @@ -149,13 +165,8 @@ def __init__(self, config: MPTConfig): ) self.rope = config.attn_config['rope'] - self.rope_head_dim = config.d_model // config.n_heads - self.rope_theta = config.attn_config['rope_theta'] - self.rope_scaling_type = config.attn_config['rope_scaling']['type'] - self.rope_scaling_factor = config.attn_config['rope_scaling']['factor'] - self.rope_max_seq_len = config.max_seq_len - self._rotary_embedding_initialized = False - self.rotary_embedding = self._rotary_emb() + if self.rope: + self.rotary_embedding = _rotary_embedding(config) if config.no_bias: for module in self.modules(): @@ -250,27 +261,6 @@ def _attn_bias( return attn_bias, None - @torch.no_grad() - def _rotary_emb(self): - if not self._rotary_embedding_initialized: - if self.rope_scaling_type == 'no_scaling': - self.rotary_embedding = RotaryEmbedding(self.rope_head_dim, - base=self.rope_theta) - elif self.rope_scaling_type == 'linear': - self.rotary_embedding = LinearScalingRotaryEmbedding( - self.rope_head_dim, - base=self.rope_theta, - scaling_factor=self.rope_scaling_factor) - elif self.rope_scaling_type == 'dynamic': - self.rotary_embedding = DynamicNTKScalingRotaryEmbedding( - self.rope_head_dim, - base=self.rope_theta, - scaling_factor=self.rope_scaling_factor, - max_position_embeddings=self.rope_max_seq_len) - - self._rotary_embedding_initialized = True - return self.rotary_embedding - def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor) -> torch.Tensor: s_k, s_q = attn_bias.shape[-2:] @@ -435,9 +425,8 @@ def forward( min=0, ) if self.rope: - rotary_emb = self._rotary_emb() rotary_emb_w_offset_info = { - 'rotary_emb': rotary_emb, + 'rotary_emb': self.rotary_embedding, 'pos': pos, 'seq_len': S + past_position } diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index cc6438b008..56c7af1934 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -5,7 +5,9 @@ import torch from composer.utils import reproducibility from omegaconf import OmegaConf as om - +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding as RotaryEmbedding +from transformers.models.llama.modeling_llama import LlamaLinearScalingRotaryEmbedding as LinearScalingRotaryEmbedding +from transformers.models.llama.modeling_llama import LlamaDynamicNTKScalingRotaryEmbedding as DynamicNTKScalingRotaryEmbedding def allclose_helper(t0: torch.Tensor, t1: torch.Tensor, @@ -66,9 +68,6 @@ def test_attn_impl(attn_impl_0: str, rope. """ from llmfoundry.models.layers import attention - from llmfoundry.models.layers.rotary_embedding import ( - DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding, - RotaryEmbedding) alibi = pos_emb_config['alibi'] rope = pos_emb_config['rope'] if alibi and (attn_impl_0 == 'flash' or attn_impl_1 == 'flash'): diff --git a/tests/test_rotary_embedding.py b/tests/test_rotary_embedding.py deleted file mode 100644 index 9d257a9392..0000000000 --- a/tests/test_rotary_embedding.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright 2022 MosaicML LLM Foundry authors -# SPDX-License-Identifier: Apache-2.0 - -import pytest -import torch -from composer.utils import reproducibility - - -def allclose_helper(t0: torch.Tensor, - t1: torch.Tensor, - rtol: float = 1e-2, - atol: float = 1e-2): - return torch.allclose(t0, t1, rtol=rtol, atol=atol) - - -@pytest.mark.parametrize('device', ['cpu', 'cuda']) -@pytest.mark.parametrize('dtype', [torch.float32, torch.bfloat16]) -@pytest.mark.parametrize('rope_scaling_type', - ['no_scaling', 'linear', 'dynamic']) -@pytest.mark.parametrize('tensor_type', ['query', 'key']) -def test_rotation_scaling_factor_1(device: str, dtype: torch.dtype, - rope_scaling_type: str, tensor_type: str): - """Checks all the rotation embedding techniques (with scaling factor 1) - produce the expected rotation.""" - from llmfoundry.models.layers.rotary_embedding import ( - DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding, - RotaryEmbedding, apply_rotary_pos_emb) - - reproducibility.seed_all(7) - - rope_head_dim = 8 - assert rope_head_dim % 2 == 0 - rope_theta = 5 - rope_scaling_factor = 1.0 - - seq_len = 7 - batch_size = 1 - num_heads = 1 - pos = torch.arange(seq_len, device=device, - dtype=torch.long).repeat(batch_size, 1) # - - # x will test the first half cosine part of the rotation and second half of the sine part - x = torch.ones((batch_size, num_heads, seq_len, rope_head_dim), - device=device, - dtype=dtype) - x[..., rope_head_dim // 2:] = 0.0 - - # y will test the first half sine part of the rotation and second half of the cosine part - y = torch.ones((batch_size, num_heads, seq_len, rope_head_dim), - device=device, - dtype=dtype) - y[..., :rope_head_dim // 2] = 0.0 - - z = torch.zeros((batch_size, num_heads, seq_len, rope_head_dim), - device=device, - dtype=dtype) - - def gen_rotary_emb(): - if rope_scaling_type == 'no_scaling': - rotary_embedding = RotaryEmbedding(rope_head_dim, base=rope_theta) - elif rope_scaling_type == 'linear': - rotary_embedding = LinearScalingRotaryEmbedding( - rope_head_dim, - base=rope_theta, - scaling_factor=rope_scaling_factor) - elif rope_scaling_type == 'dynamic': - rotary_embedding = DynamicNTKScalingRotaryEmbedding( - rope_head_dim, - base=rope_theta, - scaling_factor=rope_scaling_factor, - max_position_embeddings=seq_len) - else: - raise ValueError( - 'rope_scaling_type should be one no_scaling, linear, or dynamic' - ) - return rotary_embedding - - rotary_emb = gen_rotary_emb() - (cos, sin) = rotary_emb(x, seq_len) - assert allclose_helper(cos[:, :rope_head_dim // 2], - cos[:, rope_head_dim // 2:]) - assert allclose_helper(sin[:, :rope_head_dim // 2], - sin[:, rope_head_dim // 2:]) - - assert allclose_helper(cos * cos + sin * sin, torch.ones_like(cos)) - - if tensor_type == 'query': - x, _ = apply_rotary_pos_emb(x, z, cos, sin, pos) - y, _ = apply_rotary_pos_emb(y, z, cos, sin, pos) - elif tensor_type == 'key': - _, x = apply_rotary_pos_emb(z, x, cos, sin, pos) - _, y = apply_rotary_pos_emb(z, y, cos, sin, pos) - - assert allclose_helper(x[..., :rope_head_dim // 2], y[..., - rope_head_dim // 2:]) - assert allclose_helper(x[..., rope_head_dim // 2:], - -y[..., :rope_head_dim // 2]) - - inv_freq = ( - 1.0 / - (rope_theta**(torch.arange(0, rope_head_dim, 2) / rope_head_dim))).to(x) - t = torch.arange(seq_len).to(x) - - expected_rotation_angles = torch.outer(t, inv_freq) - - assert allclose_helper(x[..., :rope_head_dim // 2], - expected_rotation_angles.cos()) - assert allclose_helper(x[..., rope_head_dim // 2:], - expected_rotation_angles.sin()) From 9c2a2a6e7553c6e1de1441918a9c169c4412968b Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 13 Oct 2023 02:58:24 +0000 Subject: [PATCH 039/106] .. --- llmfoundry/models/mpt/modeling_mpt.py | 30 +++++++++++++++------------ tests/test_model.py | 2 +- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index d97b82010f..97cfd1918a 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -28,10 +28,12 @@ from transformers import PreTrainedModel, PreTrainedTokenizerBase from transformers.modeling_outputs import (BaseModelOutputWithPast, CausalLMOutputWithPast) -from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding as RotaryEmbedding -from transformers.models.llama.modeling_llama import LlamaLinearScalingRotaryEmbedding as LinearScalingRotaryEmbedding -from transformers.models.llama.modeling_llama import LlamaDynamicNTKScalingRotaryEmbedding as DynamicNTKScalingRotaryEmbedding - +from transformers.models.llama.modeling_llama import \ + LlamaDynamicNTKScalingRotaryEmbedding as DynamicNTKScalingRotaryEmbedding +from transformers.models.llama.modeling_llama import \ + LlamaLinearScalingRotaryEmbedding as LinearScalingRotaryEmbedding +from transformers.models.llama.modeling_llama import \ + LlamaRotaryEmbedding as RotaryEmbedding from llmfoundry.models.layers.attention import attn_bias_shape, build_attn_bias from llmfoundry.models.layers.blocks import MPTBlock @@ -73,22 +75,24 @@ log = logging.getLogger(__name__) -def _rotary_embedding(config: Dict): + +def _rotary_embedding(config: MPTConfig): rope_head_dim = config.d_model // config.n_heads if config.attn_config['rope_scaling']['type'] == 'no_scaling': return RotaryEmbedding(rope_head_dim, - base=config.attn_config['rope_theta']) + base=config.attn_config['rope_theta']) elif config.attn_config['rope_scaling']['type'] == 'linear': return LinearScalingRotaryEmbedding( - rope_head_dim, - base=config.attn_config['rope_theta'], - scaling_factor=config.attn_config['rope_scaling']['factor']) + rope_head_dim, + base=config.attn_config['rope_theta'], + scaling_factor=config.attn_config['rope_scaling']['factor']) elif config.attn_config['rope_scaling']['type'] == 'dynamic': return DynamicNTKScalingRotaryEmbedding( - rope_head_dim, - base=config.attn_config['rope_theta'], - scaling_factor=config.attn_config['rope_scaling']['factor'], - max_position_embeddings=config.max_seq_len) + rope_head_dim, + base=config.attn_config['rope_theta'], + scaling_factor=config.attn_config['rope_scaling']['factor'], + max_position_embeddings=config.max_seq_len) + class MPTPreTrainedModel(PreTrainedModel): config_class = MPTConfig diff --git a/tests/test_model.py b/tests/test_model.py index 529e28d857..e8cd0b61c4 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1096,7 +1096,7 @@ def test_forward_with_cache_and_padding(pos_emb_config: dict): 'factor': 1.0 } }]) -def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: bool): +def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): # Test that model forward with and without the key-value cache produces the # same output. if not torch.cuda.is_available() and device == 'gpu': From 6fa3037a3ef3734669c6d1d7d7c7bae73f3dd333 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 13 Oct 2023 16:00:13 +0000 Subject: [PATCH 040/106] .. --- llmfoundry/models/layers/attention.py | 2 +- tests/test_flash_triton_torch.py | 21 ++++++++++++--------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 02da7297b1..f02226d589 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -549,7 +549,7 @@ def forward( rotary_emb = rotary_emb_w_offset_info['rotary_emb'] seq_len = rotary_emb_w_offset_info['seq_len'] pos = rotary_emb_w_offset_info['pos'] - (cos, sin) = rotary_emb(x, seq_len) + (cos, sin) = rotary_emb(query, seq_len) query, key = apply_rotary_pos_emb(query, key, cos, sin, pos) query = query.transpose(1, 2) key = key.transpose(1, 2) diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index 56c7af1934..1567a807c4 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -5,9 +5,13 @@ import torch from composer.utils import reproducibility from omegaconf import OmegaConf as om -from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding as RotaryEmbedding -from transformers.models.llama.modeling_llama import LlamaLinearScalingRotaryEmbedding as LinearScalingRotaryEmbedding -from transformers.models.llama.modeling_llama import LlamaDynamicNTKScalingRotaryEmbedding as DynamicNTKScalingRotaryEmbedding +from transformers.models.llama.modeling_llama import \ + LlamaDynamicNTKScalingRotaryEmbedding as DynamicNTKScalingRotaryEmbedding +from transformers.models.llama.modeling_llama import \ + LlamaLinearScalingRotaryEmbedding as LinearScalingRotaryEmbedding +from transformers.models.llama.modeling_llama import \ + LlamaRotaryEmbedding as RotaryEmbedding + def allclose_helper(t0: torch.Tensor, t1: torch.Tensor, @@ -124,15 +128,15 @@ def gen_bias(attn_impl: str): def gen_rotary_emb(): if pos_emb_config['rope_scaling']['type'] == 'no_scaling': - rotary_embedding = RotaryEmbedding( - rope_head_dim, base=pos_emb_config['rope_theta']) + return RotaryEmbedding(rope_head_dim, + base=pos_emb_config['rope_theta']) elif pos_emb_config['rope_scaling']['type'] == 'linear': - rotary_embedding = LinearScalingRotaryEmbedding( + return LinearScalingRotaryEmbedding( rope_head_dim, base=pos_emb_config['rope_theta'], scaling_factor=pos_emb_config['rope_scaling']['factor']) elif pos_emb_config['rope_scaling']['type'] == 'dynamic': - rotary_embedding = DynamicNTKScalingRotaryEmbedding( + return DynamicNTKScalingRotaryEmbedding( rope_head_dim, base=pos_emb_config['rope_theta'], scaling_factor=pos_emb_config['rope_scaling']['factor'], @@ -141,7 +145,6 @@ def gen_rotary_emb(): raise ValueError( 'rope_scaling.type should be one no_scaling, linear, or dynamic' ) - return rotary_embedding x0 = torch.randn(n, s, f).to(device) x1 = x0.clone().detach() @@ -159,7 +162,7 @@ def gen_rotary_emb(): rotary_emb_w_offset_info = None if rope: - rotary_emb = gen_rotary_emb() + rotary_emb = gen_rotary_emb().to(device) rotary_emb_w_offset_info = { 'rotary_emb': rotary_emb, 'pos': pos, From 53235207bc05766450185a29f205e8145f582568 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 13 Oct 2023 21:50:04 +0000 Subject: [PATCH 041/106] .. --- llmfoundry/models/layers/attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index f02226d589..d441531d29 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -546,14 +546,15 @@ def forward( key = key.view(*(key.shape[:-1]), -1, self.head_dim) query = query.transpose(1, 2) key = key.transpose(1, 2) + rotary_emb = rotary_emb_w_offset_info['rotary_emb'] seq_len = rotary_emb_w_offset_info['seq_len'] pos = rotary_emb_w_offset_info['pos'] (cos, sin) = rotary_emb(query, seq_len) query, key = apply_rotary_pos_emb(query, key, cos, sin, pos) + query = query.transpose(1, 2) key = key.transpose(1, 2) - query = query.reshape(*(query.shape[:-2]), self.d_model) key = key.reshape(*(key.shape[:-2]), self.kv_n_heads * self.head_dim) From b0960f78b5e660686fbba7fb5d71bbb002db844c Mon Sep 17 00:00:00 2001 From: ShashankMosaicML <144760128+ShashankMosaicML@users.noreply.github.com> Date: Fri, 13 Oct 2023 15:45:05 -0700 Subject: [PATCH 042/106] Update llmfoundry/models/mpt/modeling_mpt.py Accepting the suggestion Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> --- llmfoundry/models/mpt/modeling_mpt.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 97cfd1918a..62c7da7cd5 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -434,8 +434,10 @@ def forward( 'pos': pos, 'seq_len': S + past_position } + x = tok_emb if self.learned_pos_emb: pos_emb = self.wpe(pos) + x = x + pos_emb x = tok_emb + pos_emb From f6632e116bcf324ad74aafb30f95cf8ec124b615 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 13 Oct 2023 22:58:07 +0000 Subject: [PATCH 043/106] incorporated some suggestions from the pr --- llmfoundry/models/layers/attention.py | 24 ++++++++++++------------ llmfoundry/models/mpt/modeling_mpt.py | 11 ++++------- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index d441531d29..b3bb6e193d 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -5,7 +5,7 @@ import math import warnings -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional import torch import torch.nn as nn @@ -55,7 +55,7 @@ def scaled_multihead_dot_product_attention( value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int] = None, - past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, softmax_scale: Optional[float] = None, attn_bias: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None, @@ -64,7 +64,7 @@ def scaled_multihead_dot_product_attention( training: bool = False, needs_weights: bool = False, multiquery: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: if multiquery: @@ -168,7 +168,7 @@ def scaled_multihead_dot_product_attention( def check_valid_inputs(*tensors: torch.Tensor, - valid_dtypes: Optional[List[torch.dtype]] = None): + valid_dtypes: Optional[list[torch.dtype]] = None): if valid_dtypes is None: valid_dtypes = [torch.float16, torch.bfloat16] for tensor in tensors: @@ -184,7 +184,7 @@ def flash_attn_fn( value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int] = None, - past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, softmax_scale: Optional[float] = None, attn_bias: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None, @@ -193,7 +193,7 @@ def flash_attn_fn( training: bool = False, needs_weights: bool = False, multiquery: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: try: from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip @@ -304,7 +304,7 @@ def triton_flash_attn_fn( value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int] = None, - past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, softmax_scale: Optional[float] = None, attn_bias: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None, @@ -313,7 +313,7 @@ def triton_flash_attn_fn( training: bool = False, needs_weights: bool = False, multiquery: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: try: from llmfoundry.models.layers.flash_attn_triton import flash_attn_func @@ -519,13 +519,13 @@ def __init__( def forward( self, x: torch.Tensor, - past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - rotary_emb_w_offset_info: Optional[Dict] = None, + rotary_emb_w_offset_info: Optional[dict] = None, is_causal: bool = True, needs_weights: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: qkv = self.Wqkv(x) @@ -663,7 +663,7 @@ def __init__( def attn_bias_shape( attn_impl: str, n_heads: int, seq_len: int, alibi: bool, prefix_lm: bool, causal: bool, - use_sequence_id: bool) -> Optional[Tuple[int, int, int, int]]: + use_sequence_id: bool) -> Optional[tuple[int, int, int, int]]: if attn_impl == 'flash': return None elif attn_impl in ['torch', 'triton']: diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 62c7da7cd5..18050da079 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -390,7 +390,6 @@ def forward( ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' rotary_emb_w_offset_info = None - pos_emb = 0.0 tok_emb = self.wte(input_ids) if self.learned_pos_emb or self.rope: past_position = 0 @@ -434,12 +433,10 @@ def forward( 'pos': pos, 'seq_len': S + past_position } - x = tok_emb - if self.learned_pos_emb: - pos_emb = self.wpe(pos) - x = x + pos_emb - - x = tok_emb + pos_emb + x = tok_emb + if self.learned_pos_emb: + pos_emb = self.wpe(pos) + x = x + pos_emb if self.embedding_fraction == 1: x = self.emb_drop(x) From df749ae12b56da9158d6c35f1c9ff0e773d42e24 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 14 Oct 2023 00:48:21 +0000 Subject: [PATCH 044/106] .. --- llmfoundry/models/layers/attention.py | 16 ++++++++-------- llmfoundry/models/mpt/modeling_mpt.py | 8 +++----- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index b3bb6e193d..a3a6b2d32a 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -541,6 +541,14 @@ def forward( dim=2, ) + key_padding_mask = attention_mask + + if self.qk_ln: + # Applying layernorm to qk + dtype = query.dtype + query = self.q_ln(query).to(dtype) + key = self.k_ln(key).to(dtype) + if rotary_emb_w_offset_info is not None: query = query.view(*(query.shape[:-1]), -1, self.head_dim) key = key.view(*(key.shape[:-1]), -1, self.head_dim) @@ -559,14 +567,6 @@ def forward( key = key.reshape(*(key.shape[:-2]), self.kv_n_heads * self.head_dim) - key_padding_mask = attention_mask - - if self.qk_ln: - # Applying layernorm to qk - dtype = query.dtype - query = self.q_ln(query).to(dtype) - key = self.k_ln(key).to(dtype) - context, attn_weights, past_key_value = self.attn_fn( query, key, diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 18050da079..fa42de78ec 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -390,7 +390,7 @@ def forward( ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' rotary_emb_w_offset_info = None - tok_emb = self.wte(input_ids) + x = self.wte(input_ids) if self.learned_pos_emb or self.rope: past_position = 0 if past_key_values is not None: @@ -433,10 +433,8 @@ def forward( 'pos': pos, 'seq_len': S + past_position } - x = tok_emb - if self.learned_pos_emb: - pos_emb = self.wpe(pos) - x = x + pos_emb + if self.learned_pos_emb: + x = x + self.wpe(pos) if self.embedding_fraction == 1: x = self.emb_drop(x) From 8b886bafecb625b3f0016802cb8c04426f4f0210 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 17 Oct 2023 17:55:21 +0000 Subject: [PATCH 045/106] .. --- llmfoundry/models/layers/__init__.py | 7 + llmfoundry/models/layers/attention.py | 11 +- llmfoundry/models/layers/rotary_embedding.py | 140 +++++++++++++++++++ llmfoundry/models/mpt/modeling_mpt.py | 20 +-- tests/test_rotary_embedding.py | 114 +++++++++++++++ 5 files changed, 275 insertions(+), 17 deletions(-) create mode 100644 llmfoundry/models/layers/rotary_embedding.py create mode 100644 tests/test_rotary_embedding.py diff --git a/llmfoundry/models/layers/__init__.py b/llmfoundry/models/layers/__init__.py index 68aa0fe7fe..2cb539fadb 100644 --- a/llmfoundry/models/layers/__init__.py +++ b/llmfoundry/models/layers/__init__.py @@ -10,6 +10,9 @@ from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY, LPLayerNorm +from llmfoundry.models.layers.rotary_embedding import ( + DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding, + RotaryEmbedding, apply_rotary_pos_emb) __all__ = [ 'scaled_multihead_dot_product_attention', @@ -29,4 +32,8 @@ 'SharedEmbedding', 'FFN_CLASS_REGISTRY', 'build_ffn', + 'RotaryEmbedding', + 'LinearScalingRotaryEmbedding', + 'DynamicNTKScalingRotaryEmbedding', + 'apply_rotary_pos_emb', ] diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index a3a6b2d32a..b2e4b7a4fc 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -12,10 +12,10 @@ from einops import rearrange from packaging import version from torch import nn -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY +from llmfoundry.models.layers.rotary_embedding import apply_rotary_pos_emb def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, @@ -552,8 +552,6 @@ def forward( if rotary_emb_w_offset_info is not None: query = query.view(*(query.shape[:-1]), -1, self.head_dim) key = key.view(*(key.shape[:-1]), -1, self.head_dim) - query = query.transpose(1, 2) - key = key.transpose(1, 2) rotary_emb = rotary_emb_w_offset_info['rotary_emb'] seq_len = rotary_emb_w_offset_info['seq_len'] @@ -561,11 +559,8 @@ def forward( (cos, sin) = rotary_emb(query, seq_len) query, key = apply_rotary_pos_emb(query, key, cos, sin, pos) - query = query.transpose(1, 2) - key = key.transpose(1, 2) - query = query.reshape(*(query.shape[:-2]), self.d_model) - key = key.reshape(*(key.shape[:-2]), - self.kv_n_heads * self.head_dim) + query = query.view(*(query.shape[:-2]), self.d_model) + key = key.view(*(key.shape[:-2]), self.kv_n_heads * self.head_dim) context, attn_weights, past_key_value = self.attn_fn( query, diff --git a/llmfoundry/models/layers/rotary_embedding.py b/llmfoundry/models/layers/rotary_embedding.py new file mode 100644 index 0000000000..b3993d069f --- /dev/null +++ b/llmfoundry/models/layers/rotary_embedding.py @@ -0,0 +1,140 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +# Code modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py +class RotaryEmbedding(torch.nn.Module): + + def __init__(self, dim: int, max_position_embeddings: int, base: int, + device: torch.device): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base**( + torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer('inv_freq', inv_freq, persistent=False) + + self.max_seq_len_cached = self.max_position_embeddings + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache(seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype()) + + def _set_cos_sin_cache(self, seq_len: int, device: torch.device, + dtype: torch.dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, + device=device, + dtype=self.inv_freq.dtype) + + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer('cos_cached', + emb.cos().to(dtype), + persistent=False) + self.register_buffer('sin_cached', + emb.sin().to(dtype), + persistent=False) + + def forward(self, x: torch.tensor, seq_len: int): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, + device=x.device, + dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +class LinearScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with linear scaling. + + Credits to the Reddit user /u/kaiokendev + """ + + def __init__(self, dim: int, max_position_embeddings: int, base: int, + device: torch.device, scaling_factor: float): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len: int, device: torch.device, + dtype: torch.dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, + device=device, + dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer('cos_cached', + emb.cos().to(dtype), + persistent=False) + self.register_buffer('sin_cached', + emb.sin().to(dtype), + persistent=False) + + +class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with Dynamic NTK scaling. + + Credits to the Reddit users /u/bloc97 and /u/emozilla + """ + + def __init__(self, dim: int, max_position_embeddings: int, base: int, + device: torch.device, scaling_factor: float): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len: int, device: torch.device, + dtype: torch.dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - + (self.scaling_factor - 1))**(self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base**( + torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer('inv_freq', inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, + device=device, + dtype=self.inv_freq.dtype) + + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer('cos_cached', + emb.cos().to(dtype), + persistent=False) + self.register_buffer('sin_cached', + emb.sin().to(dtype), + persistent=False) + + +def rotate_half(x: torch.Tensor): + """Rotates half the hidden dims of the input.""" + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.gpt_neox.modeling_gpt_neox.apply_rotary_pos_emb +def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, + sin: torch.Tensor, position_ids: torch.Tensor): + cos = cos[position_ids].unsqueeze( + -2) # [seq_len, dim] -> [batch_size, seq_len, 1, head_dim] + sin = sin[position_ids].unsqueeze(-2) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index fa42de78ec..693a2848b4 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -28,12 +28,6 @@ from transformers import PreTrainedModel, PreTrainedTokenizerBase from transformers.modeling_outputs import (BaseModelOutputWithPast, CausalLMOutputWithPast) -from transformers.models.llama.modeling_llama import \ - LlamaDynamicNTKScalingRotaryEmbedding as DynamicNTKScalingRotaryEmbedding -from transformers.models.llama.modeling_llama import \ - LlamaLinearScalingRotaryEmbedding as LinearScalingRotaryEmbedding -from transformers.models.llama.modeling_llama import \ - LlamaRotaryEmbedding as RotaryEmbedding from llmfoundry.models.layers.attention import attn_bias_shape, build_attn_bias from llmfoundry.models.layers.blocks import MPTBlock @@ -44,6 +38,9 @@ from llmfoundry.models.layers.ffn import MPTMLP as MPTMLP from llmfoundry.models.layers.ffn import build_ffn as build_ffn from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY +from llmfoundry.models.layers.rotary_embedding import ( + DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding, + RotaryEmbedding) from llmfoundry.models.mpt.configuration_mpt import MPTConfig # NOTE: All utils are imported directly even if unused so that @@ -80,18 +77,23 @@ def _rotary_embedding(config: MPTConfig): rope_head_dim = config.d_model // config.n_heads if config.attn_config['rope_scaling']['type'] == 'no_scaling': return RotaryEmbedding(rope_head_dim, - base=config.attn_config['rope_theta']) + max_position_embeddings=config.max_seq_len, + base=config.attn_config['rope_theta'], + device=config.init_device) elif config.attn_config['rope_scaling']['type'] == 'linear': return LinearScalingRotaryEmbedding( rope_head_dim, + max_position_embeddings=config.max_seq_len, base=config.attn_config['rope_theta'], + device=config.init_device, scaling_factor=config.attn_config['rope_scaling']['factor']) elif config.attn_config['rope_scaling']['type'] == 'dynamic': return DynamicNTKScalingRotaryEmbedding( rope_head_dim, + max_position_embeddings=config.max_seq_len, base=config.attn_config['rope_theta'], - scaling_factor=config.attn_config['rope_scaling']['factor'], - max_position_embeddings=config.max_seq_len) + device=config.init_device, + scaling_factor=config.attn_config['rope_scaling']['factor']) class MPTPreTrainedModel(PreTrainedModel): diff --git a/tests/test_rotary_embedding.py b/tests/test_rotary_embedding.py new file mode 100644 index 0000000000..fc7e338921 --- /dev/null +++ b/tests/test_rotary_embedding.py @@ -0,0 +1,114 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from composer.utils import reproducibility + + +def allclose_helper(t0: torch.Tensor, + t1: torch.Tensor, + rtol: float = 1e-2, + atol: float = 1e-2): + return torch.allclose(t0, t1, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize('device', ['cpu', 'cuda']) +@pytest.mark.parametrize('dtype', [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize('rope_scaling_type', + ['no_scaling', 'linear', 'dynamic']) +@pytest.mark.parametrize('tensor_type', ['query', 'key']) +def test_rotation_scaling_factor_1(device: str, dtype: torch.dtype, + rope_scaling_type: str, tensor_type: str): + """Checks all the rotation embedding techniques (with scaling factor 1) + produce the expected rotation.""" + from llmfoundry.models.layers.rotary_embedding import ( + DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding, + RotaryEmbedding, apply_rotary_pos_emb) + + reproducibility.seed_all(7) + + rope_head_dim = 8 + assert rope_head_dim % 2 == 0 + rope_theta = 5 + rope_scaling_factor = 1.0 + + seq_len = 7 + batch_size = 1 + num_heads = 1 + pos = torch.arange(seq_len, device=device, + dtype=torch.long).repeat(batch_size, 1) # + + # x will test the first half cosine part of the rotation and second half of the sine part + x = torch.ones((batch_size, seq_len, num_heads, rope_head_dim), + device=device, + dtype=dtype) + x[..., rope_head_dim // 2:] = 0.0 + + # y will test the first half sine part of the rotation and second half of the cosine part + y = torch.ones((batch_size, seq_len, num_heads, rope_head_dim), + device=device, + dtype=dtype) + y[..., :rope_head_dim // 2] = 0.0 + + z = torch.zeros((batch_size, seq_len, num_heads, rope_head_dim), + device=device, + dtype=dtype) + + def gen_rotary_emb(): + if rope_scaling_type == 'no_scaling': + rotary_embedding = RotaryEmbedding(rope_head_dim, + max_position_embeddings=seq_len, + base=rope_theta, + device=device) + elif rope_scaling_type == 'linear': + rotary_embedding = LinearScalingRotaryEmbedding( + rope_head_dim, + max_position_embeddings=seq_len, + base=rope_theta, + device=device, + scaling_factor=rope_scaling_factor) + elif rope_scaling_type == 'dynamic': + rotary_embedding = DynamicNTKScalingRotaryEmbedding( + rope_head_dim, + max_position_embeddings=seq_len, + base=rope_theta, + device=device, + scaling_factor=rope_scaling_factor) + else: + raise ValueError( + 'rope_scaling_type should be one no_scaling, linear, or dynamic' + ) + return rotary_embedding + + rotary_emb = gen_rotary_emb() + (cos, sin) = rotary_emb(x, seq_len) + assert allclose_helper(cos[:, :rope_head_dim // 2], + cos[:, rope_head_dim // 2:]) + assert allclose_helper(sin[:, :rope_head_dim // 2], + sin[:, rope_head_dim // 2:]) + + assert allclose_helper(cos * cos + sin * sin, torch.ones_like(cos)) + + if tensor_type == 'query': + x, _ = apply_rotary_pos_emb(x, z, cos, sin, pos) + y, _ = apply_rotary_pos_emb(y, z, cos, sin, pos) + elif tensor_type == 'key': + _, x = apply_rotary_pos_emb(z, x, cos, sin, pos) + _, y = apply_rotary_pos_emb(z, y, cos, sin, pos) + + assert allclose_helper(x[..., :rope_head_dim // 2], y[..., + rope_head_dim // 2:]) + assert allclose_helper(x[..., rope_head_dim // 2:], + -y[..., :rope_head_dim // 2]) + + inv_freq = ( + 1.0 / + (rope_theta**(torch.arange(0, rope_head_dim, 2) / rope_head_dim))).to(x) + t = torch.arange(seq_len).to(x) + + expected_rotation_angles = torch.outer(t, inv_freq) + assert allclose_helper(x[..., :rope_head_dim // 2].squeeze(), + expected_rotation_angles.cos()) + assert allclose_helper(x[..., rope_head_dim // 2:].squeeze(), + expected_rotation_angles.sin()) From dc58fc744f97104bc3297fa1bbd3682de854897c Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 17 Oct 2023 18:18:21 +0000 Subject: [PATCH 046/106] .. --- llmfoundry/models/layers/attention.py | 7 ++++++- llmfoundry/models/layers/rotary_embedding.py | 13 ++++++++----- tests/test_rotary_embedding.py | 8 ++++---- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index c47607278f..a3f3aaf64a 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -591,7 +591,12 @@ def forward( seq_len = rotary_emb_w_offset_info['seq_len'] pos = rotary_emb_w_offset_info['pos'] (cos, sin) = rotary_emb(query, seq_len) - query, key = apply_rotary_pos_emb(query, key, cos, sin, pos) + query, key = apply_rotary_pos_emb(query, + key, + cos, + sin, + pos, + dim_heads_index=2) query = query.view(*(query.shape[:-2]), self.d_model) key = key.view(*(key.shape[:-2]), self.kv_n_heads * self.head_dim) diff --git a/llmfoundry/models/layers/rotary_embedding.py b/llmfoundry/models/layers/rotary_embedding.py index b3993d069f..e6aac35565 100644 --- a/llmfoundry/models/layers/rotary_embedding.py +++ b/llmfoundry/models/layers/rotary_embedding.py @@ -130,11 +130,14 @@ def rotate_half(x: torch.Tensor): # Copied from transformers.models.gpt_neox.modeling_gpt_neox.apply_rotary_pos_emb -def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, - sin: torch.Tensor, position_ids: torch.Tensor): - cos = cos[position_ids].unsqueeze( - -2) # [seq_len, dim] -> [batch_size, seq_len, 1, head_dim] - sin = sin[position_ids].unsqueeze(-2) +def apply_rotary_pos_emb(q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: torch.Tensor, + dim_heads_index: int = 1): + cos = cos[position_ids].unsqueeze(dim_heads_index) + sin = sin[position_ids].unsqueeze(dim_heads_index) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed diff --git a/tests/test_rotary_embedding.py b/tests/test_rotary_embedding.py index fc7e338921..fa2ea8f2dc 100644 --- a/tests/test_rotary_embedding.py +++ b/tests/test_rotary_embedding.py @@ -91,11 +91,11 @@ def gen_rotary_emb(): assert allclose_helper(cos * cos + sin * sin, torch.ones_like(cos)) if tensor_type == 'query': - x, _ = apply_rotary_pos_emb(x, z, cos, sin, pos) - y, _ = apply_rotary_pos_emb(y, z, cos, sin, pos) + x, _ = apply_rotary_pos_emb(x, z, cos, sin, pos, dim_heads_index=2) + y, _ = apply_rotary_pos_emb(y, z, cos, sin, pos, dim_heads_index=2) elif tensor_type == 'key': - _, x = apply_rotary_pos_emb(z, x, cos, sin, pos) - _, y = apply_rotary_pos_emb(z, y, cos, sin, pos) + _, x = apply_rotary_pos_emb(z, x, cos, sin, pos, dim_heads_index=2) + _, y = apply_rotary_pos_emb(z, y, cos, sin, pos, dim_heads_index=2) assert allclose_helper(x[..., :rope_head_dim // 2], y[..., rope_head_dim // 2:]) From ed5a47745f329af33af517f68d24b5a628f0096f Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 17 Oct 2023 23:24:31 +0000 Subject: [PATCH 047/106] .. --- llmfoundry/models/layers/rotary_embedding.py | 23 +++++++++++++++----- llmfoundry/models/mpt/modeling_mpt.py | 5 +---- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/llmfoundry/models/layers/rotary_embedding.py b/llmfoundry/models/layers/rotary_embedding.py index e6aac35565..931ef2f959 100644 --- a/llmfoundry/models/layers/rotary_embedding.py +++ b/llmfoundry/models/layers/rotary_embedding.py @@ -7,8 +7,11 @@ # Code modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py class RotaryEmbedding(torch.nn.Module): - def __init__(self, dim: int, max_position_embeddings: int, base: int, - device: torch.device): + def __init__(self, + dim: int, + max_position_embeddings: int, + base: int, + device=None): super().__init__() self.dim = dim @@ -60,8 +63,12 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding): Credits to the Reddit user /u/kaiokendev """ - def __init__(self, dim: int, max_position_embeddings: int, base: int, - device: torch.device, scaling_factor: float): + def __init__(self, + dim: int, + max_position_embeddings: int, + base: int, + scaling_factor: float, + device=None): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) @@ -90,8 +97,12 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): Credits to the Reddit users /u/bloc97 and /u/emozilla """ - def __init__(self, dim: int, max_position_embeddings: int, base: int, - device: torch.device, scaling_factor: float): + def __init__(self, + dim: int, + max_position_embeddings: int, + base: int, + scaling_factor: float, + device=None): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index ae1c5aa5c0..796601ef36 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -78,21 +78,18 @@ def _rotary_embedding(config: MPTConfig): if config.attn_config['rope_scaling']['type'] == 'no_scaling': return RotaryEmbedding(rope_head_dim, max_position_embeddings=config.max_seq_len, - base=config.attn_config['rope_theta'], - device=config.init_device) + base=config.attn_config['rope_theta']) elif config.attn_config['rope_scaling']['type'] == 'linear': return LinearScalingRotaryEmbedding( rope_head_dim, max_position_embeddings=config.max_seq_len, base=config.attn_config['rope_theta'], - device=config.init_device, scaling_factor=config.attn_config['rope_scaling']['factor']) elif config.attn_config['rope_scaling']['type'] == 'dynamic': return DynamicNTKScalingRotaryEmbedding( rope_head_dim, max_position_embeddings=config.max_seq_len, base=config.attn_config['rope_theta'], - device=config.init_device, scaling_factor=config.attn_config['rope_scaling']['factor']) From 2a53de397f9e2a9ec6e9bc5439e2b9ee720d79a6 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 19 Oct 2023 21:00:32 +0000 Subject: [PATCH 048/106] .. --- llmfoundry/models/layers/rotary_embedding.py | 52 +++++++++----------- llmfoundry/models/mpt/modeling_mpt.py | 25 ++++++---- tests/test_flash_triton_torch.py | 12 +++-- 3 files changed, 47 insertions(+), 42 deletions(-) diff --git a/llmfoundry/models/layers/rotary_embedding.py b/llmfoundry/models/layers/rotary_embedding.py index 931ef2f959..8ab71d7746 100644 --- a/llmfoundry/models/layers/rotary_embedding.py +++ b/llmfoundry/models/layers/rotary_embedding.py @@ -1,17 +1,16 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +from typing import Union + import torch # Code modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py class RotaryEmbedding(torch.nn.Module): - def __init__(self, - dim: int, - max_position_embeddings: int, - base: int, - device=None): + def __init__(self, dim: int, max_position_embeddings: int, base: int, + device: Union[str, torch.device]): super().__init__() self.dim = dim @@ -23,16 +22,17 @@ def __init__(self, self.max_seq_len_cached = self.max_position_embeddings # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache(seq_len=max_position_embeddings, - device=self.inv_freq.device, - dtype=torch.get_default_dtype()) + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, # type: ignore + dtype=torch.get_default_dtype()) # type: ignore - def _set_cos_sin_cache(self, seq_len: int, device: torch.device, + def _set_cos_sin_cache(self, seq_len: int, device: Union[str, torch.device], dtype: torch.dtype): self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=device, - dtype=self.inv_freq.dtype) + dtype=self.inv_freq.dtype) # type: ignore freqs = torch.einsum('i,j->ij', t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation @@ -44,16 +44,16 @@ def _set_cos_sin_cache(self, seq_len: int, device: torch.device, emb.sin().to(dtype), persistent=False) - def forward(self, x: torch.tensor, seq_len: int): + def forward(self, x: torch.Tensor, seq_len: int): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, - dtype=x.dtype) + dtype=x.dtype) # type: ignore return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), + self.cos_cached[:seq_len].to(dtype=x.dtype), # type: ignore + self.sin_cached[:seq_len].to(dtype=x.dtype), # type: ignore ) @@ -63,21 +63,17 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding): Credits to the Reddit user /u/kaiokendev """ - def __init__(self, - dim: int, - max_position_embeddings: int, - base: int, - scaling_factor: float, - device=None): + def __init__(self, dim: int, max_position_embeddings: int, base: int, + scaling_factor: float, device: Union[str, torch.device]): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) - def _set_cos_sin_cache(self, seq_len: int, device: torch.device, + def _set_cos_sin_cache(self, seq_len: int, device: Union[str, torch.device], dtype: torch.dtype): self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=device, - dtype=self.inv_freq.dtype) + dtype=self.inv_freq.dtype) # type: ignore t = t / self.scaling_factor freqs = torch.einsum('i,j->ij', t, self.inv_freq) @@ -97,16 +93,12 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): Credits to the Reddit users /u/bloc97 and /u/emozilla """ - def __init__(self, - dim: int, - max_position_embeddings: int, - base: int, - scaling_factor: float, - device=None): + def __init__(self, dim: int, max_position_embeddings: int, base: int, + scaling_factor: float, device: Union[str, torch.device]): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) - def _set_cos_sin_cache(self, seq_len: int, device: torch.device, + def _set_cos_sin_cache(self, seq_len: int, device: Union[str, torch.device], dtype: torch.dtype): self.max_seq_len_cached = seq_len @@ -120,7 +112,7 @@ def _set_cos_sin_cache(self, seq_len: int, device: torch.device, t = torch.arange(self.max_seq_len_cached, device=device, - dtype=self.inv_freq.dtype) + dtype=self.inv_freq.dtype) # type: ignore freqs = torch.einsum('i,j->ij', t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 796601ef36..babdf85a8e 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -76,21 +76,28 @@ def _rotary_embedding(config: MPTConfig): rope_head_dim = config.d_model // config.n_heads if config.attn_config['rope_scaling']['type'] == 'no_scaling': - return RotaryEmbedding(rope_head_dim, - max_position_embeddings=config.max_seq_len, - base=config.attn_config['rope_theta']) + return RotaryEmbedding( + rope_head_dim, + max_position_embeddings=config.max_seq_len, + base=config.attn_config['rope_theta'], + device='cpu' + ) # FSDP does not materialize modules with no parameters, hence if we create meta buffers in rotary embeddings, they will not be materialized elif config.attn_config['rope_scaling']['type'] == 'linear': return LinearScalingRotaryEmbedding( rope_head_dim, max_position_embeddings=config.max_seq_len, base=config.attn_config['rope_theta'], - scaling_factor=config.attn_config['rope_scaling']['factor']) + scaling_factor=config.attn_config['rope_scaling']['factor'], + device='cpu' + ) # FSDP does not materialize modules with no parameters, hence if we create meta buffers in rotary embeddings, they will not be materialized elif config.attn_config['rope_scaling']['type'] == 'dynamic': return DynamicNTKScalingRotaryEmbedding( rope_head_dim, max_position_embeddings=config.max_seq_len, base=config.attn_config['rope_theta'], - scaling_factor=config.attn_config['rope_scaling']['factor']) + scaling_factor=config.attn_config['rope_scaling']['factor'], + device='cpu' + ) # FSDP does not materialize modules with no parameters, hence if we create meta buffers in rotary embeddings, they will not be materialized class MPTPreTrainedModel(PreTrainedModel): @@ -146,6 +153,10 @@ def __init__(self, config: MPTConfig): ]) self.norm_f = norm_class(config.d_model, device=config.init_device) + self.rope = config.attn_config['rope'] + if self.rope: + self.rotary_embedding = _rotary_embedding(config) + if config.init_device != 'meta': log.info( f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.' @@ -167,10 +178,6 @@ def __init__(self, config: MPTConfig): use_sequence_id=self.attn_uses_sequence_id, ) - self.rope = config.attn_config['rope'] - if self.rope: - self.rotary_embedding = _rotary_embedding(config) - if config.no_bias: for module in self.modules(): if hasattr(module, 'bias') and isinstance( diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index 1567a807c4..d9559696b5 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -129,18 +129,24 @@ def gen_bias(attn_impl: str): def gen_rotary_emb(): if pos_emb_config['rope_scaling']['type'] == 'no_scaling': return RotaryEmbedding(rope_head_dim, - base=pos_emb_config['rope_theta']) + max_position_embeddings=s, + base=pos_emb_config['rope_theta'], + device='cpu') elif pos_emb_config['rope_scaling']['type'] == 'linear': return LinearScalingRotaryEmbedding( rope_head_dim, + max_position_embeddings=s, base=pos_emb_config['rope_theta'], - scaling_factor=pos_emb_config['rope_scaling']['factor']) + scaling_factor=pos_emb_config['rope_scaling']['factor'], + device='cpu') elif pos_emb_config['rope_scaling']['type'] == 'dynamic': return DynamicNTKScalingRotaryEmbedding( rope_head_dim, + max_position_embeddings=s, base=pos_emb_config['rope_theta'], scaling_factor=pos_emb_config['rope_scaling']['factor'], - max_position_embeddings=s) + max_position_embeddings=s, + device='cpu') else: raise ValueError( 'rope_scaling.type should be one no_scaling, linear, or dynamic' From 34e147c3f0e57182575b19c0c2c6ef557d1acb33 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 19 Oct 2023 22:02:24 +0000 Subject: [PATCH 049/106] .. --- tests/test_flash_triton_torch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index d9559696b5..5faec18787 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -145,7 +145,6 @@ def gen_rotary_emb(): max_position_embeddings=s, base=pos_emb_config['rope_theta'], scaling_factor=pos_emb_config['rope_scaling']['factor'], - max_position_embeddings=s, device='cpu') else: raise ValueError( From 0a9d3af61380c66521a7e7a44b1f3cd1b03ae406 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 19 Oct 2023 22:31:43 +0000 Subject: [PATCH 050/106] .. --- tests/test_flash_triton_torch.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index 5faec18787..3cb0382491 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -5,12 +5,10 @@ import torch from composer.utils import reproducibility from omegaconf import OmegaConf as om -from transformers.models.llama.modeling_llama import \ - LlamaDynamicNTKScalingRotaryEmbedding as DynamicNTKScalingRotaryEmbedding -from transformers.models.llama.modeling_llama import \ - LlamaLinearScalingRotaryEmbedding as LinearScalingRotaryEmbedding -from transformers.models.llama.modeling_llama import \ - LlamaRotaryEmbedding as RotaryEmbedding + +from llmfoundry.models.layers.rotary_embedding import ( + DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding, + RotaryEmbedding) def allclose_helper(t0: torch.Tensor, @@ -131,21 +129,21 @@ def gen_rotary_emb(): return RotaryEmbedding(rope_head_dim, max_position_embeddings=s, base=pos_emb_config['rope_theta'], - device='cpu') + device=device) elif pos_emb_config['rope_scaling']['type'] == 'linear': return LinearScalingRotaryEmbedding( rope_head_dim, max_position_embeddings=s, base=pos_emb_config['rope_theta'], scaling_factor=pos_emb_config['rope_scaling']['factor'], - device='cpu') + device=device) elif pos_emb_config['rope_scaling']['type'] == 'dynamic': return DynamicNTKScalingRotaryEmbedding( rope_head_dim, max_position_embeddings=s, base=pos_emb_config['rope_theta'], scaling_factor=pos_emb_config['rope_scaling']['factor'], - device='cpu') + device=device) else: raise ValueError( 'rope_scaling.type should be one no_scaling, linear, or dynamic' From 5981ade5863029738c578cdf47acb6cc9213630e Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 20 Oct 2023 01:51:14 +0000 Subject: [PATCH 051/106] added mark for gpu in the rotary embedding test --- tests/test_rotary_embedding.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_rotary_embedding.py b/tests/test_rotary_embedding.py index fa2ea8f2dc..043e38b98f 100644 --- a/tests/test_rotary_embedding.py +++ b/tests/test_rotary_embedding.py @@ -13,6 +13,7 @@ def allclose_helper(t0: torch.Tensor, return torch.allclose(t0, t1, rtol=rtol, atol=atol) +@pytest.mark.gpu @pytest.mark.parametrize('device', ['cpu', 'cuda']) @pytest.mark.parametrize('dtype', [torch.float32, torch.bfloat16]) @pytest.mark.parametrize('rope_scaling_type', From 9afa082ebea50b31a317ec34101ff4732579f5a6 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 20 Oct 2023 02:57:01 +0000 Subject: [PATCH 052/106] .. --- tests/test_rotary_embedding.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_rotary_embedding.py b/tests/test_rotary_embedding.py index 043e38b98f..42bdef7ae2 100644 --- a/tests/test_rotary_embedding.py +++ b/tests/test_rotary_embedding.py @@ -22,7 +22,10 @@ def allclose_helper(t0: torch.Tensor, def test_rotation_scaling_factor_1(device: str, dtype: torch.dtype, rope_scaling_type: str, tensor_type: str): """Checks all the rotation embedding techniques (with scaling factor 1) - produce the expected rotation.""" + + Checks that the rotations produced by the techniques is correct, for + both query and key tensors. + """ from llmfoundry.models.layers.rotary_embedding import ( DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding, RotaryEmbedding, apply_rotary_pos_emb) From 9835acd761b60638a4e68a8453ccde7873642bde Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 20 Oct 2023 02:58:54 +0000 Subject: [PATCH 053/106] .. --- tests/test_rotary_embedding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_rotary_embedding.py b/tests/test_rotary_embedding.py index 42bdef7ae2..5f23b7d5ac 100644 --- a/tests/test_rotary_embedding.py +++ b/tests/test_rotary_embedding.py @@ -23,8 +23,8 @@ def test_rotation_scaling_factor_1(device: str, dtype: torch.dtype, rope_scaling_type: str, tensor_type: str): """Checks all the rotation embedding techniques (with scaling factor 1) - Checks that the rotations produced by the techniques is correct, for - both query and key tensors. + Checks that the rotations produced by all the techniques are correct, + for both the query and key tensors. """ from llmfoundry.models.layers.rotary_embedding import ( DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding, From 1801677b8e4d99863da9e02819cb47a2fed992f8 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 21 Oct 2023 00:35:03 +0000 Subject: [PATCH 054/106] .. --- llmfoundry/models/layers/attention.py | 33 +++++--- llmfoundry/models/layers/blocks.py | 6 +- llmfoundry/models/mpt/configuration_mpt.py | 8 ++ llmfoundry/models/mpt/modeling_mpt.py | 99 +++++++++++++--------- 4 files changed, 92 insertions(+), 54 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index a3f3aaf64a..2bdcc2d670 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -556,7 +556,7 @@ def forward( past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - rotary_emb_w_offset_info: Optional[dict] = None, + rotary_emb_w_meta_info: Optional[dict] = None, is_causal: bool = True, needs_weights: bool = False, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ @@ -583,20 +583,29 @@ def forward( query = self.q_ln(query).to(dtype) key = self.k_ln(key).to(dtype) - if rotary_emb_w_offset_info is not None: + if rotary_emb_w_meta_info is not None: query = query.view(*(query.shape[:-1]), -1, self.head_dim) key = key.view(*(key.shape[:-1]), -1, self.head_dim) - rotary_emb = rotary_emb_w_offset_info['rotary_emb'] - seq_len = rotary_emb_w_offset_info['seq_len'] - pos = rotary_emb_w_offset_info['pos'] - (cos, sin) = rotary_emb(query, seq_len) - query, key = apply_rotary_pos_emb(query, - key, - cos, - sin, - pos, - dim_heads_index=2) + rotary_emb = rotary_emb_w_meta_info['rotary_emb'] + seq_len = rotary_emb_w_meta_info['seq_len'] + pos = rotary_emb_w_meta_info['pos'] + if rotary_emb_w_meta_info['imp'] == 'hf_llama': + (cos, sin) = rotary_emb(query, seq_len) + query, key = apply_rotary_pos_emb(query, + key, + cos, + sin, + pos, + dim_heads_index=2) + elif rotary_emb_w_meta_info['imp'] == 'flash': + value = value.view(*(value.shape[:-1]), -1, self.head_dim) + + kv = torch.stack([key, value], dim=2) + query, kv = rotary_emb(query, kv, seqlen_offset=pos, max_seqlen=seq_len) + [key, value] = torch.unbind(kv, dim=2) + + value = value.view(*(value.shape[:-2]), self.kv_n_heads * self.head_dim) query = query.view(*(query.shape[:-2]), self.d_model) key = key.view(*(key.shape[:-2]), self.kv_n_heads * self.head_dim) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 9349c91ce9..0dc9f88423 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -42,6 +42,7 @@ def __init__( 'alibi': False, 'alibi_bias_max': 8, 'rope': False, + 'rope_imp': 'hf_llama', 'rope_theta': 10000, 'rope_scaling': { 'type': 'no_scaling', @@ -69,6 +70,7 @@ def __init__( 'attn_uses_sequence_id', 'alibi_bias_max', 'rope', + 'rope_imp', 'rope_theta', 'rope_scaling', } @@ -106,7 +108,7 @@ def forward( x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, - rotary_emb_w_offset_info: Optional[Dict] = None, + rotary_emb_w_meta_info: Optional[Dict] = None, attention_mask: Optional[torch.ByteTensor] = None, is_causal: bool = True, output_attentions: bool = False, @@ -117,7 +119,7 @@ def forward( a, past_key_value=past_key_value, attn_bias=attn_bias, - rotary_emb_w_offset_info=rotary_emb_w_offset_info, + rotary_emb_w_meta_info=rotary_emb_w_meta_info, attention_mask=attention_mask, is_causal=is_causal, needs_weights=output_attentions, diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 652bbdab47..309d2058a9 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -20,6 +20,7 @@ 'alibi': False, 'alibi_bias_max': 8, 'rope': False, + 'rope_imp': 'hf_llama', 'rope_theta': 10000, 'rope_scaling': { 'type': 'no_scaling', @@ -101,6 +102,7 @@ def __init__( alibi (bool): Whether to use the alibi bias instead of position embeddings. alibi_bias_max (int): The maximum value of the alibi bias. rope (bool): Whether to use rotary positional embeddings. + rope_imp (str): The implementation of rope to use. One of 'hf_llama' or 'flash'. rope_theta (int): The base frequency for rope. rope_scaling (Dict): A dictionary used to configure rope's scaling behavior (when scaling beyond the training length) type (str): Can be one of 'no_scaling', 'linear', or 'dynamic'. 'no_scaling' uses the default implementation for rotary embeddings, 'linear' uses linear scaling as proposed by the Reddit user /u/kaiokendev, and 'dynamic' uses Dynamic NTK scaling as proposed by the Reddit users /u/bloc97 and /u/emozilla. @@ -247,6 +249,12 @@ def _validate_config(self) -> None: + 'pip install flash-attn==1.0.6 --no-build-isolation \n' + 'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156' ) + if self.attn_config['rope_imp'] not in [ + 'hf_llama', 'flash' + ]: + raise ValueError( + 'rope_imp should be either "hf_llama", or "flash".' + ) if self.attn_config['rope_scaling']['type'] not in [ 'no_scaling', 'linear', 'dynamic' ]: diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index babdf85a8e..db1318a1ed 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -23,6 +23,7 @@ from composer.metrics.nlp import LanguageCrossEntropy, LanguagePerplexity from composer.models import HuggingFaceModel from composer.utils import dist +from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding from omegaconf import DictConfig from omegaconf import OmegaConf as om from transformers import PreTrainedModel, PreTrainedTokenizerBase @@ -75,29 +76,39 @@ def _rotary_embedding(config: MPTConfig): rope_head_dim = config.d_model // config.n_heads - if config.attn_config['rope_scaling']['type'] == 'no_scaling': - return RotaryEmbedding( - rope_head_dim, - max_position_embeddings=config.max_seq_len, - base=config.attn_config['rope_theta'], - device='cpu' - ) # FSDP does not materialize modules with no parameters, hence if we create meta buffers in rotary embeddings, they will not be materialized - elif config.attn_config['rope_scaling']['type'] == 'linear': - return LinearScalingRotaryEmbedding( - rope_head_dim, - max_position_embeddings=config.max_seq_len, - base=config.attn_config['rope_theta'], - scaling_factor=config.attn_config['rope_scaling']['factor'], - device='cpu' - ) # FSDP does not materialize modules with no parameters, hence if we create meta buffers in rotary embeddings, they will not be materialized - elif config.attn_config['rope_scaling']['type'] == 'dynamic': - return DynamicNTKScalingRotaryEmbedding( - rope_head_dim, - max_position_embeddings=config.max_seq_len, - base=config.attn_config['rope_theta'], - scaling_factor=config.attn_config['rope_scaling']['factor'], - device='cpu' - ) # FSDP does not materialize modules with no parameters, hence if we create meta buffers in rotary embeddings, they will not be materialized + if config.attn_config['rope_imp'] == 'flash': + return FlashRotaryEmbedding( + dim=rope_head_dim, + base=config.attn_config['rope_theta'], + interleaved=False, + scale_base=None, # "If scale_base is not None, this implements XPos. A recommended value for scale_base is 512." (Source: https://github.com/Dao-AILab/flash-attention/blob/02ac572f3ffc4f402e4183aaa6824b45859d3ed3/flash_attn/layers/rotary.py#L312C1-L314C110) + pos_idx_in_fp32=False, # if True, the position indices [0.0, ..., seqlen - 1] are in fp32, otherwise they might be in lower precision. bf16 rounds position 1995 to 2000, for example, which leads to them having the same positional embedding + device='cpu', + ) + elif config.attn_config['rope_imp'] == 'hf_llama': + if config.attn_config['rope_scaling']['type'] == 'no_scaling': + return RotaryEmbedding( + rope_head_dim, + max_position_embeddings=config.max_seq_len, + base=config.attn_config['rope_theta'], + device='cpu' + ) # FSDP does not materialize modules with no parameters, hence if we create meta buffers in rotary embeddings, they will not be materialized + elif config.attn_config['rope_scaling']['type'] == 'linear': + return LinearScalingRotaryEmbedding( + rope_head_dim, + max_position_embeddings=config.max_seq_len, + base=config.attn_config['rope_theta'], + scaling_factor=config.attn_config['rope_scaling']['factor'], + device='cpu' + ) # FSDP does not materialize modules with no parameters, hence if we create meta buffers in rotary embeddings, they will not be materialized + elif config.attn_config['rope_scaling']['type'] == 'dynamic': + return DynamicNTKScalingRotaryEmbedding( + rope_head_dim, + max_position_embeddings=config.max_seq_len, + base=config.attn_config['rope_theta'], + scaling_factor=config.attn_config['rope_scaling']['factor'], + device='cpu' + ) # FSDP does not materialize modules with no parameters, hence if we create meta buffers in rotary embeddings, they will not be materialized class MPTPreTrainedModel(PreTrainedModel): @@ -154,7 +165,9 @@ def __init__(self, config: MPTConfig): self.norm_f = norm_class(config.d_model, device=config.init_device) self.rope = config.attn_config['rope'] + self.rope_imp = None if self.rope: + self.rope_imp = config.attn_config['rope_imp'] self.rotary_embedding = _rotary_embedding(config) if config.init_device != 'meta': @@ -395,7 +408,7 @@ def forward( S <= self.config.max_seq_len ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' - rotary_emb_w_offset_info = None + rotary_emb_w_meta_info = None x = self.wte(input_ids) if self.learned_pos_emb or self.rope: past_position = 0 @@ -420,24 +433,30 @@ def forward( + f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.' ) - pos = torch.arange( - past_position, - S + past_position, - dtype=torch.long, - device=input_ids.device, - ).unsqueeze(0) - if attention_mask is not None: - # adjust the position indices to account for padding tokens - pos = torch.clamp( - pos - torch.cumsum((~attention_mask).to(torch.int32), - dim=1)[:, past_position:], - min=0, - ) + pos = 0 + if self.learned_pos_emb or (self.rope and self.rope_imp == 'hf_llama'): + pos = torch.arange( + past_position, + S + past_position, + dtype=torch.long, + device=input_ids.device, + ).unsqueeze(0) + if attention_mask is not None: + # adjust the position indices to account for padding tokens + pos = torch.clamp( + pos - torch.cumsum((~attention_mask).to(torch.int32), + dim=1)[:, past_position:], + min=0, + ) + elif (self.rope and self.rope_imp == 'flash'): + pos = past_position + if self.rope: - rotary_emb_w_offset_info = { + rotary_emb_w_meta_info = { 'rotary_emb': self.rotary_embedding, 'pos': pos, - 'seq_len': S + past_position + 'seq_len': S + past_position, + 'imp': self.rope_imp, } if self.learned_pos_emb: x = x + self.wpe(pos) @@ -477,7 +496,7 @@ def forward( x, past_key_value=past_key_value, attn_bias=attn_bias, - rotary_emb_w_offset_info=rotary_emb_w_offset_info, + rotary_emb_w_meta_info=rotary_emb_w_meta_info, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions), From 0a38037fe4a747bd97909ba14d63400df46074b3 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Mon, 23 Oct 2023 06:35:42 +0000 Subject: [PATCH 055/106] removed thecode for hf implementation of rope --- llmfoundry/models/layers/__init__.py | 7 - llmfoundry/models/layers/attention.py | 27 +--- llmfoundry/models/layers/blocks.py | 13 +- llmfoundry/models/layers/rotary_embedding.py | 146 ------------------- llmfoundry/models/mpt/configuration_mpt.py | 30 ++-- llmfoundry/models/mpt/modeling_mpt.py | 58 ++------ tests/test_rotary_embedding.py | 118 --------------- 7 files changed, 32 insertions(+), 367 deletions(-) delete mode 100644 llmfoundry/models/layers/rotary_embedding.py delete mode 100644 tests/test_rotary_embedding.py diff --git a/llmfoundry/models/layers/__init__.py b/llmfoundry/models/layers/__init__.py index 2cb539fadb..68aa0fe7fe 100644 --- a/llmfoundry/models/layers/__init__.py +++ b/llmfoundry/models/layers/__init__.py @@ -10,9 +10,6 @@ from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY, LPLayerNorm -from llmfoundry.models.layers.rotary_embedding import ( - DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding, - RotaryEmbedding, apply_rotary_pos_emb) __all__ = [ 'scaled_multihead_dot_product_attention', @@ -32,8 +29,4 @@ 'SharedEmbedding', 'FFN_CLASS_REGISTRY', 'build_ffn', - 'RotaryEmbedding', - 'LinearScalingRotaryEmbedding', - 'DynamicNTKScalingRotaryEmbedding', - 'apply_rotary_pos_emb', ] diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 2bdcc2d670..38c7d731d4 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -15,7 +15,6 @@ from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY -from llmfoundry.models.layers.rotary_embedding import apply_rotary_pos_emb def is_flash_v2_installed(): @@ -586,27 +585,13 @@ def forward( if rotary_emb_w_meta_info is not None: query = query.view(*(query.shape[:-1]), -1, self.head_dim) key = key.view(*(key.shape[:-1]), -1, self.head_dim) + value = value.view(*(value.shape[:-1]), -1, self.head_dim) - rotary_emb = rotary_emb_w_meta_info['rotary_emb'] - seq_len = rotary_emb_w_meta_info['seq_len'] - pos = rotary_emb_w_meta_info['pos'] - if rotary_emb_w_meta_info['imp'] == 'hf_llama': - (cos, sin) = rotary_emb(query, seq_len) - query, key = apply_rotary_pos_emb(query, - key, - cos, - sin, - pos, - dim_heads_index=2) - elif rotary_emb_w_meta_info['imp'] == 'flash': - value = value.view(*(value.shape[:-1]), -1, self.head_dim) - - kv = torch.stack([key, value], dim=2) - query, kv = rotary_emb(query, kv, seqlen_offset=pos, max_seqlen=seq_len) - [key, value] = torch.unbind(kv, dim=2) - - value = value.view(*(value.shape[:-2]), self.kv_n_heads * self.head_dim) - + kv = torch.stack([key, value], dim=2) + query, kv = rotary_emb_w_meta_info['rotary_emb'](query, kv, seqlen_offset=rotary_emb_w_meta_info['pos'], max_seqlen=rotary_emb_w_meta_info['seq_len']) + [key, value] = torch.unbind(kv, dim=2) + + value = value.view(*(value.shape[:-2]), self.kv_n_heads * self.head_dim) query = query.view(*(query.shape[:-2]), self.d_model) key = key.view(*(key.shape[:-2]), self.kv_n_heads * self.head_dim) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 0dc9f88423..39db48e87e 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -42,12 +42,10 @@ def __init__( 'alibi': False, 'alibi_bias_max': 8, 'rope': False, - 'rope_imp': 'hf_llama', + 'rope_type': 'original', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'no_scaling', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': False, + 'xpos_scale_base': 512, } if ffn_config is None: @@ -70,9 +68,10 @@ def __init__( 'attn_uses_sequence_id', 'alibi_bias_max', 'rope', - 'rope_imp', + 'rope_type', 'rope_theta', - 'rope_scaling', + 'rope_pos_idx_in_fp32', + 'xpos_scale_base', } attn_config_subset_for_attn_class = { k: v diff --git a/llmfoundry/models/layers/rotary_embedding.py b/llmfoundry/models/layers/rotary_embedding.py deleted file mode 100644 index 8ab71d7746..0000000000 --- a/llmfoundry/models/layers/rotary_embedding.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright 2022 MosaicML LLM Foundry authors -# SPDX-License-Identifier: Apache-2.0 - -from typing import Union - -import torch - - -# Code modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py -class RotaryEmbedding(torch.nn.Module): - - def __init__(self, dim: int, max_position_embeddings: int, base: int, - device: Union[str, torch.device]): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base**( - torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer('inv_freq', inv_freq, persistent=False) - - self.max_seq_len_cached = self.max_position_embeddings - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, - device=self.inv_freq.device, # type: ignore - dtype=torch.get_default_dtype()) # type: ignore - - def _set_cos_sin_cache(self, seq_len: int, device: Union[str, torch.device], - dtype: torch.dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, - device=device, - dtype=self.inv_freq.dtype) # type: ignore - - freqs = torch.einsum('i,j->ij', t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer('cos_cached', - emb.cos().to(dtype), - persistent=False) - self.register_buffer('sin_cached', - emb.sin().to(dtype), - persistent=False) - - def forward(self, x: torch.Tensor, seq_len: int): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, - device=x.device, - dtype=x.dtype) # type: ignore - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), # type: ignore - self.sin_cached[:seq_len].to(dtype=x.dtype), # type: ignore - ) - - -class LinearScalingRotaryEmbedding(RotaryEmbedding): - """RotaryEmbedding extended with linear scaling. - - Credits to the Reddit user /u/kaiokendev - """ - - def __init__(self, dim: int, max_position_embeddings: int, base: int, - scaling_factor: float, device: Union[str, torch.device]): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len: int, device: Union[str, torch.device], - dtype: torch.dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, - device=device, - dtype=self.inv_freq.dtype) # type: ignore - t = t / self.scaling_factor - - freqs = torch.einsum('i,j->ij', t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer('cos_cached', - emb.cos().to(dtype), - persistent=False) - self.register_buffer('sin_cached', - emb.sin().to(dtype), - persistent=False) - - -class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): - """RotaryEmbedding extended with Dynamic NTK scaling. - - Credits to the Reddit users /u/bloc97 and /u/emozilla - """ - - def __init__(self, dim: int, max_position_embeddings: int, base: int, - scaling_factor: float, device: Union[str, torch.device]): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len: int, device: Union[str, torch.device], - dtype: torch.dtype): - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - - (self.scaling_factor - 1))**(self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base**( - torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer('inv_freq', inv_freq, persistent=False) - - t = torch.arange(self.max_seq_len_cached, - device=device, - dtype=self.inv_freq.dtype) # type: ignore - - freqs = torch.einsum('i,j->ij', t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer('cos_cached', - emb.cos().to(dtype), - persistent=False) - self.register_buffer('sin_cached', - emb.sin().to(dtype), - persistent=False) - - -def rotate_half(x: torch.Tensor): - """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.gpt_neox.modeling_gpt_neox.apply_rotary_pos_emb -def apply_rotary_pos_emb(q: torch.Tensor, - k: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - position_ids: torch.Tensor, - dim_heads_index: int = 1): - cos = cos[position_ids].unsqueeze(dim_heads_index) - sin = sin[position_ids].unsqueeze(dim_heads_index) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 309d2058a9..cf273bb5f5 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -20,12 +20,10 @@ 'alibi': False, 'alibi_bias_max': 8, 'rope': False, - 'rope_imp': 'hf_llama', + 'rope_type': 'original', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'no_scaling', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': False, + 'xpos_scale_base': 512, } ffn_config_defaults: Dict = { @@ -102,11 +100,10 @@ def __init__( alibi (bool): Whether to use the alibi bias instead of position embeddings. alibi_bias_max (int): The maximum value of the alibi bias. rope (bool): Whether to use rotary positional embeddings. - rope_imp (str): The implementation of rope to use. One of 'hf_llama' or 'flash'. + rope_type (str): The type of rope to use. Options: 'original', 'xpos' rope_theta (int): The base frequency for rope. - rope_scaling (Dict): A dictionary used to configure rope's scaling behavior (when scaling beyond the training length) - type (str): Can be one of 'no_scaling', 'linear', or 'dynamic'. 'no_scaling' uses the default implementation for rotary embeddings, 'linear' uses linear scaling as proposed by the Reddit user /u/kaiokendev, and 'dynamic' uses Dynamic NTK scaling as proposed by the Reddit users /u/bloc97 and /u/emozilla. - factor (float): Scaling factor to use if using 'linear' or 'dynamic' as rope_scaling.type. + rope_pos_idx_in_fp32 (bool): Whether to use fp32 as the dtype for rope positional indices. + xpos_scale_base (float): The scale base for XPos. kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads. ffn_config (Dict): A dictionary used to configure the model's ffn module: ffn_type (str): type of ffn to use. Options: mptmlp, te_ln_mlp @@ -221,6 +218,9 @@ def _validate_config(self) -> None: raise NotImplementedError( 'attn_uses_sequence_id only implemented with torch and triton attention.' ) + if self.attn_config['rope'] and (self.attn_config['rope_type'] not in ['original', 'xpos']): + raise NotImplementedError( + 'rope_type must be one of "original" or "xpos".') if self.embedding_fraction > 1 or self.embedding_fraction <= 0: raise ValueError( 'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!' @@ -249,18 +249,6 @@ def _validate_config(self) -> None: + 'pip install flash-attn==1.0.6 --no-build-isolation \n' + 'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156' ) - if self.attn_config['rope_imp'] not in [ - 'hf_llama', 'flash' - ]: - raise ValueError( - 'rope_imp should be either "hf_llama", or "flash".' - ) - if self.attn_config['rope_scaling']['type'] not in [ - 'no_scaling', 'linear', 'dynamic' - ]: - raise ValueError( - 'rope_scaling.type should be one of "no_scaling", "linear" or "dynamic".' - ) if self.ffn_config['ffn_type'] == 'mptmlp': self.ffn_config['fc_type'] = self.fc_type elif self.ffn_config['ffn_type'] == 'te_ln_mlp': diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index db1318a1ed..e6cc716ae9 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -39,9 +39,6 @@ from llmfoundry.models.layers.ffn import MPTMLP as MPTMLP from llmfoundry.models.layers.ffn import build_ffn as build_ffn from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY -from llmfoundry.models.layers.rotary_embedding import ( - DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding, - RotaryEmbedding) from llmfoundry.models.mpt.configuration_mpt import MPTConfig # NOTE: All utils are imported directly even if unused so that @@ -73,44 +70,6 @@ log = logging.getLogger(__name__) - -def _rotary_embedding(config: MPTConfig): - rope_head_dim = config.d_model // config.n_heads - if config.attn_config['rope_imp'] == 'flash': - return FlashRotaryEmbedding( - dim=rope_head_dim, - base=config.attn_config['rope_theta'], - interleaved=False, - scale_base=None, # "If scale_base is not None, this implements XPos. A recommended value for scale_base is 512." (Source: https://github.com/Dao-AILab/flash-attention/blob/02ac572f3ffc4f402e4183aaa6824b45859d3ed3/flash_attn/layers/rotary.py#L312C1-L314C110) - pos_idx_in_fp32=False, # if True, the position indices [0.0, ..., seqlen - 1] are in fp32, otherwise they might be in lower precision. bf16 rounds position 1995 to 2000, for example, which leads to them having the same positional embedding - device='cpu', - ) - elif config.attn_config['rope_imp'] == 'hf_llama': - if config.attn_config['rope_scaling']['type'] == 'no_scaling': - return RotaryEmbedding( - rope_head_dim, - max_position_embeddings=config.max_seq_len, - base=config.attn_config['rope_theta'], - device='cpu' - ) # FSDP does not materialize modules with no parameters, hence if we create meta buffers in rotary embeddings, they will not be materialized - elif config.attn_config['rope_scaling']['type'] == 'linear': - return LinearScalingRotaryEmbedding( - rope_head_dim, - max_position_embeddings=config.max_seq_len, - base=config.attn_config['rope_theta'], - scaling_factor=config.attn_config['rope_scaling']['factor'], - device='cpu' - ) # FSDP does not materialize modules with no parameters, hence if we create meta buffers in rotary embeddings, they will not be materialized - elif config.attn_config['rope_scaling']['type'] == 'dynamic': - return DynamicNTKScalingRotaryEmbedding( - rope_head_dim, - max_position_embeddings=config.max_seq_len, - base=config.attn_config['rope_theta'], - scaling_factor=config.attn_config['rope_scaling']['factor'], - device='cpu' - ) # FSDP does not materialize modules with no parameters, hence if we create meta buffers in rotary embeddings, they will not be materialized - - class MPTPreTrainedModel(PreTrainedModel): config_class = MPTConfig base_model_prefix = 'model' @@ -165,10 +124,16 @@ def __init__(self, config: MPTConfig): self.norm_f = norm_class(config.d_model, device=config.init_device) self.rope = config.attn_config['rope'] - self.rope_imp = None + breakpoint() if self.rope: - self.rope_imp = config.attn_config['rope_imp'] - self.rotary_embedding = _rotary_embedding(config) + self.rotary_embedding = FlashRotaryEmbedding( + dim=config.d_model // config.n_heads, + base=config.attn_config['rope_theta'], + interleaved=False, + scale_base=config.attn_config['xpos_scale_base'] if (config.attn_config['rope_type']=='xpos') else None, + pos_idx_in_fp32=config.attn_config['rope_pos_idx_in_fp32'], # if True, the position indices [0.0, ..., seqlen - 1] are in fp32, otherwise they might be in lower precision. bf16 rounds position 1995 to 2000, for example, which leads to them having the same positional embedding + device='cpu', + ) if config.init_device != 'meta': log.info( @@ -434,7 +399,7 @@ def forward( f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.' ) pos = 0 - if self.learned_pos_emb or (self.rope and self.rope_imp == 'hf_llama'): + if self.learned_pos_emb: pos = torch.arange( past_position, S + past_position, @@ -448,7 +413,7 @@ def forward( dim=1)[:, past_position:], min=0, ) - elif (self.rope and self.rope_imp == 'flash'): + elif self.rope: pos = past_position if self.rope: @@ -456,7 +421,6 @@ def forward( 'rotary_emb': self.rotary_embedding, 'pos': pos, 'seq_len': S + past_position, - 'imp': self.rope_imp, } if self.learned_pos_emb: x = x + self.wpe(pos) diff --git a/tests/test_rotary_embedding.py b/tests/test_rotary_embedding.py deleted file mode 100644 index 5f23b7d5ac..0000000000 --- a/tests/test_rotary_embedding.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright 2022 MosaicML LLM Foundry authors -# SPDX-License-Identifier: Apache-2.0 - -import pytest -import torch -from composer.utils import reproducibility - - -def allclose_helper(t0: torch.Tensor, - t1: torch.Tensor, - rtol: float = 1e-2, - atol: float = 1e-2): - return torch.allclose(t0, t1, rtol=rtol, atol=atol) - - -@pytest.mark.gpu -@pytest.mark.parametrize('device', ['cpu', 'cuda']) -@pytest.mark.parametrize('dtype', [torch.float32, torch.bfloat16]) -@pytest.mark.parametrize('rope_scaling_type', - ['no_scaling', 'linear', 'dynamic']) -@pytest.mark.parametrize('tensor_type', ['query', 'key']) -def test_rotation_scaling_factor_1(device: str, dtype: torch.dtype, - rope_scaling_type: str, tensor_type: str): - """Checks all the rotation embedding techniques (with scaling factor 1) - - Checks that the rotations produced by all the techniques are correct, - for both the query and key tensors. - """ - from llmfoundry.models.layers.rotary_embedding import ( - DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding, - RotaryEmbedding, apply_rotary_pos_emb) - - reproducibility.seed_all(7) - - rope_head_dim = 8 - assert rope_head_dim % 2 == 0 - rope_theta = 5 - rope_scaling_factor = 1.0 - - seq_len = 7 - batch_size = 1 - num_heads = 1 - pos = torch.arange(seq_len, device=device, - dtype=torch.long).repeat(batch_size, 1) # - - # x will test the first half cosine part of the rotation and second half of the sine part - x = torch.ones((batch_size, seq_len, num_heads, rope_head_dim), - device=device, - dtype=dtype) - x[..., rope_head_dim // 2:] = 0.0 - - # y will test the first half sine part of the rotation and second half of the cosine part - y = torch.ones((batch_size, seq_len, num_heads, rope_head_dim), - device=device, - dtype=dtype) - y[..., :rope_head_dim // 2] = 0.0 - - z = torch.zeros((batch_size, seq_len, num_heads, rope_head_dim), - device=device, - dtype=dtype) - - def gen_rotary_emb(): - if rope_scaling_type == 'no_scaling': - rotary_embedding = RotaryEmbedding(rope_head_dim, - max_position_embeddings=seq_len, - base=rope_theta, - device=device) - elif rope_scaling_type == 'linear': - rotary_embedding = LinearScalingRotaryEmbedding( - rope_head_dim, - max_position_embeddings=seq_len, - base=rope_theta, - device=device, - scaling_factor=rope_scaling_factor) - elif rope_scaling_type == 'dynamic': - rotary_embedding = DynamicNTKScalingRotaryEmbedding( - rope_head_dim, - max_position_embeddings=seq_len, - base=rope_theta, - device=device, - scaling_factor=rope_scaling_factor) - else: - raise ValueError( - 'rope_scaling_type should be one no_scaling, linear, or dynamic' - ) - return rotary_embedding - - rotary_emb = gen_rotary_emb() - (cos, sin) = rotary_emb(x, seq_len) - assert allclose_helper(cos[:, :rope_head_dim // 2], - cos[:, rope_head_dim // 2:]) - assert allclose_helper(sin[:, :rope_head_dim // 2], - sin[:, rope_head_dim // 2:]) - - assert allclose_helper(cos * cos + sin * sin, torch.ones_like(cos)) - - if tensor_type == 'query': - x, _ = apply_rotary_pos_emb(x, z, cos, sin, pos, dim_heads_index=2) - y, _ = apply_rotary_pos_emb(y, z, cos, sin, pos, dim_heads_index=2) - elif tensor_type == 'key': - _, x = apply_rotary_pos_emb(z, x, cos, sin, pos, dim_heads_index=2) - _, y = apply_rotary_pos_emb(z, y, cos, sin, pos, dim_heads_index=2) - - assert allclose_helper(x[..., :rope_head_dim // 2], y[..., - rope_head_dim // 2:]) - assert allclose_helper(x[..., rope_head_dim // 2:], - -y[..., :rope_head_dim // 2]) - - inv_freq = ( - 1.0 / - (rope_theta**(torch.arange(0, rope_head_dim, 2) / rope_head_dim))).to(x) - t = torch.arange(seq_len).to(x) - - expected_rotation_angles = torch.outer(t, inv_freq) - assert allclose_helper(x[..., :rope_head_dim // 2].squeeze(), - expected_rotation_angles.cos()) - assert allclose_helper(x[..., rope_head_dim // 2:].squeeze(), - expected_rotation_angles.sin()) From d86a1a53a268698581dd879cd27f2d72089a84c5 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Mon, 23 Oct 2023 06:41:57 +0000 Subject: [PATCH 056/106] .. --- llmfoundry/models/mpt/configuration_mpt.py | 7 ++++--- llmfoundry/models/mpt/modeling_mpt.py | 10 ++++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index cf273bb5f5..3d03e1476a 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -102,7 +102,7 @@ def __init__( rope (bool): Whether to use rotary positional embeddings. rope_type (str): The type of rope to use. Options: 'original', 'xpos' rope_theta (int): The base frequency for rope. - rope_pos_idx_in_fp32 (bool): Whether to use fp32 as the dtype for rope positional indices. + rope_pos_idx_in_fp32 (bool): If True, the position indices [0.0, ..., seqlen - 1] are in fp32, otherwise they might be in lower precision. A consequence could be, for example, that bf16 rounds position 1995 to 2000, which leads to them having the same positional embedding. xpos_scale_base (float): The scale base for XPos. kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads. ffn_config (Dict): A dictionary used to configure the model's ffn module: @@ -218,9 +218,10 @@ def _validate_config(self) -> None: raise NotImplementedError( 'attn_uses_sequence_id only implemented with torch and triton attention.' ) - if self.attn_config['rope'] and (self.attn_config['rope_type'] not in ['original', 'xpos']): + if self.attn_config['rope'] and (self.attn_config['rope_type'] + not in ['original', 'xpos']): raise NotImplementedError( - 'rope_type must be one of "original" or "xpos".') + 'rope_type must be one of "original" or "xpos".') if self.embedding_fraction > 1 or self.embedding_fraction <= 0: raise ValueError( 'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!' diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index e6cc716ae9..1553c85949 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -70,6 +70,7 @@ log = logging.getLogger(__name__) + class MPTPreTrainedModel(PreTrainedModel): config_class = MPTConfig base_model_prefix = 'model' @@ -130,10 +131,11 @@ def __init__(self, config: MPTConfig): dim=config.d_model // config.n_heads, base=config.attn_config['rope_theta'], interleaved=False, - scale_base=config.attn_config['xpos_scale_base'] if (config.attn_config['rope_type']=='xpos') else None, - pos_idx_in_fp32=config.attn_config['rope_pos_idx_in_fp32'], # if True, the position indices [0.0, ..., seqlen - 1] are in fp32, otherwise they might be in lower precision. bf16 rounds position 1995 to 2000, for example, which leads to them having the same positional embedding + scale_base=config.attn_config['xpos_scale_base'] if + (config.attn_config['rope_type'] == 'xpos') else None, + pos_idx_in_fp32=config.attn_config['rope_pos_idx_in_fp32'], device='cpu', - ) + ) if config.init_device != 'meta': log.info( @@ -410,7 +412,7 @@ def forward( # adjust the position indices to account for padding tokens pos = torch.clamp( pos - torch.cumsum((~attention_mask).to(torch.int32), - dim=1)[:, past_position:], + dim=1)[:, past_position:], min=0, ) elif self.rope: From 7e336d27447f18bbd0d72e11e698c72a9bd8e24e Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Mon, 23 Oct 2023 06:55:32 +0000 Subject: [PATCH 057/106] .. --- llmfoundry/models/mpt/modeling_mpt.py | 1 - 1 file changed, 1 deletion(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 1553c85949..e9785d9b59 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -125,7 +125,6 @@ def __init__(self, config: MPTConfig): self.norm_f = norm_class(config.d_model, device=config.init_device) self.rope = config.attn_config['rope'] - breakpoint() if self.rope: self.rotary_embedding = FlashRotaryEmbedding( dim=config.d_model // config.n_heads, From 189735383881634736b1c017a96315caedd92d1e Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 24 Oct 2023 00:35:53 +0000 Subject: [PATCH 058/106] added tests --- llmfoundry/models/layers/attention.py | 6 +- llmfoundry/models/layers/blocks.py | 4 +- llmfoundry/models/mpt/modeling_mpt.py | 25 +- tests/test_flash_triton_torch.py | 79 ++--- tests/test_model.py | 425 ++++++++++++++++---------- 5 files changed, 312 insertions(+), 227 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 38c7d731d4..9775c49ec4 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -555,7 +555,7 @@ def forward( past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - rotary_emb_w_meta_info: Optional[dict] = None, + rotary_emb_w_offset_info: Optional[dict] = None, is_causal: bool = True, needs_weights: bool = False, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ @@ -582,13 +582,13 @@ def forward( query = self.q_ln(query).to(dtype) key = self.k_ln(key).to(dtype) - if rotary_emb_w_meta_info is not None: + if rotary_emb_w_offset_info is not None: query = query.view(*(query.shape[:-1]), -1, self.head_dim) key = key.view(*(key.shape[:-1]), -1, self.head_dim) value = value.view(*(value.shape[:-1]), -1, self.head_dim) kv = torch.stack([key, value], dim=2) - query, kv = rotary_emb_w_meta_info['rotary_emb'](query, kv, seqlen_offset=rotary_emb_w_meta_info['pos'], max_seqlen=rotary_emb_w_meta_info['seq_len']) + query, kv = rotary_emb_w_offset_info['rotary_embedding'](query, kv, seqlen_offset=rotary_emb_w_offset_info['seqlen_offset'], max_seqlen=rotary_emb_w_offset_info['max_seqlen']) [key, value] = torch.unbind(kv, dim=2) value = value.view(*(value.shape[:-2]), self.kv_n_heads * self.head_dim) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 39db48e87e..9e5b7ce844 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -107,7 +107,7 @@ def forward( x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, - rotary_emb_w_meta_info: Optional[Dict] = None, + rotary_emb_w_offset_info: Optional[Dict] = None, attention_mask: Optional[torch.ByteTensor] = None, is_causal: bool = True, output_attentions: bool = False, @@ -118,7 +118,7 @@ def forward( a, past_key_value=past_key_value, attn_bias=attn_bias, - rotary_emb_w_meta_info=rotary_emb_w_meta_info, + rotary_emb_w_offset_info=rotary_emb_w_offset_info, attention_mask=attention_mask, is_causal=is_causal, needs_weights=output_attentions, diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index e9785d9b59..85864ab338 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -23,7 +23,7 @@ from composer.metrics.nlp import LanguageCrossEntropy, LanguagePerplexity from composer.models import HuggingFaceModel from composer.utils import dist -from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding +from flash_attn.layers.rotary import RotaryEmbedding from omegaconf import DictConfig from omegaconf import OmegaConf as om from transformers import PreTrainedModel, PreTrainedTokenizerBase @@ -126,7 +126,7 @@ def __init__(self, config: MPTConfig): self.rope = config.attn_config['rope'] if self.rope: - self.rotary_embedding = FlashRotaryEmbedding( + self.rotary_embedding = RotaryEmbedding( dim=config.d_model // config.n_heads, base=config.attn_config['rope_theta'], interleaved=False, @@ -374,7 +374,7 @@ def forward( S <= self.config.max_seq_len ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' - rotary_emb_w_meta_info = None + rotary_emb_w_offset_info = None x = self.wte(input_ids) if self.learned_pos_emb or self.rope: past_position = 0 @@ -399,7 +399,7 @@ def forward( + f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.' ) - pos = 0 + if self.learned_pos_emb: pos = torch.arange( past_position, @@ -414,17 +414,14 @@ def forward( dim=1)[:, past_position:], min=0, ) - elif self.rope: - pos = past_position + x = x + self.wpe(pos) if self.rope: - rotary_emb_w_meta_info = { - 'rotary_emb': self.rotary_embedding, - 'pos': pos, - 'seq_len': S + past_position, - } - if self.learned_pos_emb: - x = x + self.wpe(pos) + rotary_emb_w_offset_info = { + 'rotary_embedding': self.rotary_embedding, + 'seqlen_offset': past_position, + 'max_seqlen': S + past_position, + } if self.embedding_fraction == 1: x = self.emb_drop(x) @@ -461,7 +458,7 @@ def forward( x, past_key_value=past_key_value, attn_bias=attn_bias, - rotary_emb_w_meta_info=rotary_emb_w_meta_info, + rotary_emb_w_offset_info=rotary_emb_w_offset_info, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions), diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index 3cb0382491..e31cc2be00 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -6,9 +6,7 @@ from composer.utils import reproducibility from omegaconf import OmegaConf as om -from llmfoundry.models.layers.rotary_embedding import ( - DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding, - RotaryEmbedding) +from flash_attn.layers.rotary import RotaryEmbedding def allclose_helper(t0: torch.Tensor, @@ -32,27 +30,31 @@ def allclose_helper(t0: torch.Tensor, }, { 'alibi': False, 'rope': True, + 'rope_type': 'original', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'no_scaling', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': False, + 'xpos_scale_base': 512, }, { 'alibi': False, 'rope': True, + 'rope_type': 'xpos', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'linear', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': False, + 'xpos_scale_base': 512, }, { 'alibi': False, 'rope': True, + 'rope_type': 'original', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'dynamic', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': True, + 'xpos_scale_base': 512, +}, { + 'alibi': False, + 'rope': True, + 'rope_type': 'xpos', + 'rope_theta': 10000, + 'rope_pos_idx_in_fp32': True, + 'xpos_scale_base': 512, }]) @pytest.mark.parametrize( 'attn_type', @@ -88,7 +90,6 @@ def test_attn_impl(attn_impl_0: str, n, s, f = 2, 16, cfg.d_model assert cfg.d_model % cfg.n_heads == 0 - rope_head_dim = cfg.d_model // cfg.n_heads if attn_type == 'grouped_query_attention': cfg.kv_n_heads = 2 @@ -124,31 +125,6 @@ def gen_bias(attn_impl: str): return attn_bias - def gen_rotary_emb(): - if pos_emb_config['rope_scaling']['type'] == 'no_scaling': - return RotaryEmbedding(rope_head_dim, - max_position_embeddings=s, - base=pos_emb_config['rope_theta'], - device=device) - elif pos_emb_config['rope_scaling']['type'] == 'linear': - return LinearScalingRotaryEmbedding( - rope_head_dim, - max_position_embeddings=s, - base=pos_emb_config['rope_theta'], - scaling_factor=pos_emb_config['rope_scaling']['factor'], - device=device) - elif pos_emb_config['rope_scaling']['type'] == 'dynamic': - return DynamicNTKScalingRotaryEmbedding( - rope_head_dim, - max_position_embeddings=s, - base=pos_emb_config['rope_theta'], - scaling_factor=pos_emb_config['rope_scaling']['factor'], - device=device) - else: - raise ValueError( - 'rope_scaling.type should be one no_scaling, linear, or dynamic' - ) - x0 = torch.randn(n, s, f).to(device) x1 = x0.clone().detach() x0.requires_grad = True @@ -156,20 +132,21 @@ def gen_rotary_emb(): with torch.autocast(x0.device.type): attn_bias = gen_bias(attn0.attn_impl) - pos = torch.arange(s).unsqueeze(0).to(device=device) - # adjust the position indices to account for padding tokens - pos = torch.clamp( - pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1), - min=0, - ) rotary_emb_w_offset_info = None if rope: - rotary_emb = gen_rotary_emb().to(device) + rotary_embedding = RotaryEmbedding( + dim=cfg.d_model // cfg.n_heads, + base=pos_emb_config['rope_theta'], + interleaved=False, + scale_base=pos_emb_config['xpos_scale_base'] if (pos_emb_config['rope_type'] == 'xpos') else None, + pos_idx_in_fp32=pos_emb_config['rope_pos_idx_in_fp32'], + device='cpu' + ).to(device) rotary_emb_w_offset_info = { - 'rotary_emb': rotary_emb, - 'pos': pos, - 'seq_len': s + 'rotary_embedding': rotary_embedding, + 'seqlen_offset': 0, + 'max_seqlen': s } y0, _, _ = attn0(x0, past_key_value=None, diff --git a/tests/test_model.py b/tests/test_model.py index e8cd0b61c4..380c8b777d 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -536,27 +536,31 @@ def test_mpt_creation(norm_type: str, no_bias: bool): }, { 'alibi': False, 'rope': True, + 'rope_type': 'original', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'no_scaling', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': True, + 'xpos_scale_base': 512, }, { 'alibi': False, 'rope': True, + 'rope_type': 'xpos', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'linear', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': True, + 'xpos_scale_base': 512, }, { 'alibi': False, 'rope': True, + 'rope_type': 'original', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'dynamic', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': False, + 'xpos_scale_base': 512, +}, { + 'alibi': False, + 'rope': True, + 'rope_type': 'xpos', + 'rope_theta': 10000, + 'rope_pos_idx_in_fp32': False, + 'xpos_scale_base': 512, }]) def test_forward_with_padding(attention_impl: str, device: str, pos_emb_config: dict): @@ -568,6 +572,10 @@ def test_forward_with_padding(attention_impl: str, device: str, alibi = pos_emb_config['alibi'] if alibi and attention_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') + + rope = pos_emb_config['rope'] + if rope and device == 'cpu': + pytest.skip(f'rope only implemented for gpus.') reproducibility.seed_all(1234) composer_device = get_device(device) @@ -655,10 +663,16 @@ def test_forward_with_padding(attention_impl: str, device: str, attention_mask=batched_attention_mask).logits # check that right padding and left padding produce the same output - assert torch.allclose(right_padding_output[0, :3], - left_padding_output[0, 3:], - atol=1e-6 if attention_impl == 'torch' else 1e-8) - if not alibi: + if rope:########################################## + assert torch.allclose(right_padding_output[0, :3], + left_padding_output[0, 3:], + rtol=1e-2, + atol=1e-2) + else: + assert torch.allclose(right_padding_output[0, :3], + left_padding_output[0, 3:], + atol=1e-6 if attention_impl == 'torch' else 1e-8) + if not (alibi or rope): # check that right padding and middle padding produce the same output # Note: alibi not implemented for middle padding. assert torch.allclose( @@ -666,10 +680,17 @@ def test_forward_with_padding(attention_impl: str, device: str, middle_padding_output[0, [0, 1, 5]], atol=1e-6 if attention_impl == 'torch' else 1e-8) # check that right padding and right padding in a batch produce the same output - assert torch.allclose(right_padding_output[0, :3], + + if rope:########################################## + assert torch.allclose(right_padding_output[0, :3], + left_padding_output[0, 3:], + rtol=1e-2, + atol=1e-2) + else: + assert torch.allclose(right_padding_output[0, :3], batched_output[0, :3], atol=1e-6 if attention_impl == 'torch' else 1e-8) - if not alibi: + if not (alibi or rope): # check that middle padding and middle padding in a batch produce the same output # Note: alibi not implemented for middle padding. assert torch.allclose( @@ -736,7 +757,8 @@ def test_advanced_mask_building(attention_impl: str): @pytest.mark.parametrize('attention_impl,device', [('torch', 'cpu'), ('flash', 'gpu'), ('triton', 'gpu'), - ('torch', 'gpu')]) + ('torch', 'gpu') + ]) @pytest.mark.parametrize('pos_emb_config', [{ 'alibi': False, 'rope': False @@ -746,27 +768,31 @@ def test_advanced_mask_building(attention_impl: str): }, { 'alibi': False, 'rope': True, + 'rope_type': 'original', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'no_scaling', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': False, + 'xpos_scale_base': 512, }, { 'alibi': False, 'rope': True, + 'rope_type': 'xpos', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'linear', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': False, + 'xpos_scale_base': 512, }, { 'alibi': False, 'rope': True, + 'rope_type': 'original', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'dynamic', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': True, + 'xpos_scale_base': 512, +}, { + 'alibi': False, + 'rope': True, + 'rope_type': 'xpos', + 'rope_theta': 10000, + 'rope_pos_idx_in_fp32': True, + 'xpos_scale_base': 512, }]) def test_generate(attention_impl: str, device: str, pos_emb_config: dict): # Test that generate works, and produces the same output with or without @@ -777,6 +803,9 @@ def test_generate(attention_impl: str, device: str, pos_emb_config: dict): ) if pos_emb_config['alibi'] and attention_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') + + if pos_emb_config['rope'] and device == 'cpu': + pytest.skip(f'rope only implemented for gpus.') reproducibility.seed_all(1234) composer_device = get_device(device) @@ -961,7 +990,12 @@ def test_save_from_pretrained(tmp_path: pathlib.Path): check_hf_model_equivalence(mpt, mpt2) - +@pytest.mark.parametrize('attn_impl,device', [ + ('torch', 'cpu'), + ('flash', 'gpu'), + ('triton', 'gpu'), + ('torch', 'gpu'), +]) @pytest.mark.parametrize('pos_emb_config', [{ 'alibi': False, 'rope': False @@ -971,30 +1005,45 @@ def test_save_from_pretrained(tmp_path: pathlib.Path): }, { 'alibi': False, 'rope': True, + 'rope_type': 'original', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'no_scaling', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': False, + 'xpos_scale_base': 512, }, { 'alibi': False, 'rope': True, + 'rope_type': 'xpos', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'linear', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': False, + 'xpos_scale_base': 512, }, { 'alibi': False, 'rope': True, + 'rope_type': 'original', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'dynamic', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': True, + 'xpos_scale_base': 512, +}, { + 'alibi': False, + 'rope': True, + 'rope_type': 'xpos', + 'rope_theta': 10000, + 'rope_pos_idx_in_fp32': True, + 'xpos_scale_base': 512, }]) -def test_forward_with_cache_and_padding(pos_emb_config: dict): +def test_forward_with_cache_and_padding(attn_impl: str, device: str, pos_emb_config: dict): # Tests that the result is the same with or without padding when using kv caching + if not torch.cuda.is_available() and device == 'gpu': + pytest.skip( + f'This test requires CUDA to be available in order to run with {attn_impl} attention.' + ) + if pos_emb_config['alibi'] and attn_impl == 'flash': + pytest.skip(f'alibi only implemented with torch and triton attention.') + if pos_emb_config['rope'] and device == 'cpu': + pytest.skip(f'rope only implemented for gpus.') + + composer_device = get_device(device) + hf_config = MPTConfig( init_device='cpu', d_model=128, @@ -1005,7 +1054,7 @@ def test_forward_with_cache_and_padding(pos_emb_config: dict): emb_pdrop=0.1, resid_pdrop=0.2, attn_config={ - 'attn_impl': 'torch', + 'attn_impl': attn_impl, **pos_emb_config, }, use_cache=True, @@ -1016,47 +1065,63 @@ def test_forward_with_cache_and_padding(pos_emb_config: dict): ) mpt = MPTForCausalLM(hf_config) + mpt = composer_device.module_to_device(mpt) mpt.eval() - - first_input_ids_no_padding = torch.tensor([[11274, 16390, 11]]) - first_attention_mask_no_padding = torch.tensor([[1, 1, 1]]).bool() - - # start with passing the first three tokens through (no padding) - first_output_no_padding = mpt( - first_input_ids_no_padding, - attention_mask=first_attention_mask_no_padding) - - second_input_ids_no_padding = torch.tensor([[11274, 16390, 11, 11274]]) - second_attention_mask_no_padding = torch.tensor([[1, 1, 1, 1]]).bool() - - # pass through the fourth token by itself, using the key-value cache (no padding) - second_output_no_padding = mpt( - second_input_ids_no_padding[:, -1].unsqueeze(-1), - attention_mask=second_attention_mask_no_padding, - past_key_values=first_output_no_padding.past_key_values) - - first_input_ids_padding = torch.tensor([[50256, 11274, 16390, 11]]) - first_attention_mask_padding = torch.tensor([[0, 1, 1, 1]]).bool() - - # start with passing the first three tokens through (with left padding) - first_output_padding = mpt(first_input_ids_padding, - attention_mask=first_attention_mask_padding) - - second_input_ids_padding = torch.tensor([[50256, 11274, 16390, 11, 11274]]) - second_attention_mask_padding = torch.tensor([[0, 1, 1, 1, 1]]).bool() - - # pass through the fourth token by itself, using the key-value cache (with left padding) - second_output_padding = mpt( - second_input_ids_padding[:, -1].unsqueeze(-1), - attention_mask=second_attention_mask_padding, - past_key_values=first_output_padding.past_key_values) - - # check that the outputs are the same with or without padding - torch.testing.assert_close(second_output_no_padding.logits, - second_output_padding.logits[:, - -1, :].unsqueeze(1), - atol=1e-6, - rtol=1e-6) + with get_precision_context('amp_bf16' if composer_device.name == 'gpu' else 'fp32'): + first_input_ids_no_padding = torch.tensor([[11274, 16390, 11]]) + first_input_ids_no_padding = composer_device.tensor_to_device(first_input_ids_no_padding) + first_attention_mask_no_padding = torch.tensor([[1, 1, 1]]).bool() + first_attention_mask_no_padding = composer_device.tensor_to_device(first_attention_mask_no_padding) + + # start with passing the first three tokens through (no padding) + first_output_no_padding = mpt( + first_input_ids_no_padding, + attention_mask=first_attention_mask_no_padding) + + second_input_ids_no_padding = torch.tensor([[11274, 16390, 11, 11274]]) + second_input_ids_no_padding = composer_device.tensor_to_device(second_input_ids_no_padding) + second_attention_mask_no_padding = torch.tensor([[1, 1, 1, 1]]).bool() + second_attention_mask_no_padding = composer_device.tensor_to_device(second_attention_mask_no_padding) + + # pass through the fourth token by itself, using the key-value cache (no padding) + second_output_no_padding = mpt( + second_input_ids_no_padding[:, -1].unsqueeze(-1), + attention_mask=second_attention_mask_no_padding, + past_key_values=first_output_no_padding.past_key_values) + + first_input_ids_padding = torch.tensor([[50256, 11274, 16390, 11]]) + first_input_ids_padding = composer_device.tensor_to_device(first_input_ids_padding) + first_attention_mask_padding = torch.tensor([[0, 1, 1, 1]]).bool() + first_attention_mask_padding = composer_device.tensor_to_device(first_attention_mask_padding) + + # start with passing the first three tokens through (with left padding) + first_output_padding = mpt(first_input_ids_padding, + attention_mask=first_attention_mask_padding) + + second_input_ids_padding = torch.tensor([[50256, 11274, 16390, 11, 11274]]) + second_input_ids_padding = composer_device.tensor_to_device(second_input_ids_padding) + second_attention_mask_padding = torch.tensor([[0, 1, 1, 1, 1]]).bool() + second_attention_mask_padding = composer_device.tensor_to_device(second_attention_mask_padding) + + # pass through the fourth token by itself, using the key-value cache (with left padding) + second_output_padding = mpt( + second_input_ids_padding[:, -1].unsqueeze(-1), + attention_mask=second_attention_mask_padding, + past_key_values=first_output_padding.past_key_values) + + # check that the outputs are the same with or without padding + if pos_emb_config['rope']: ########################################## + torch.testing.assert_close(second_output_no_padding.logits, + second_output_padding.logits[:, + -1, :].unsqueeze(1), + atol=1e-2, + rtol=1e-6) + else: + torch.testing.assert_close(second_output_no_padding.logits, + second_output_padding.logits[:, + -1, :].unsqueeze(1), + atol=1e-6, + rtol=1e-6) @pytest.mark.parametrize('attn_impl,device', [ @@ -1074,27 +1139,31 @@ def test_forward_with_cache_and_padding(pos_emb_config: dict): }, { 'alibi': False, 'rope': True, + 'rope_type': 'original', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'no_scaling', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': False, + 'xpos_scale_base': 512, }, { 'alibi': False, 'rope': True, + 'rope_type': 'xpos', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'linear', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': False, + 'xpos_scale_base': 512, }, { 'alibi': False, 'rope': True, + 'rope_type': 'original', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'dynamic', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': True, + 'xpos_scale_base': 512, +}, { + 'alibi': False, + 'rope': True, + 'rope_type': 'xpos', + 'rope_theta': 10000, + 'rope_pos_idx_in_fp32': True, + 'xpos_scale_base': 512, }]) def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): # Test that model forward with and without the key-value cache produces the @@ -1105,6 +1174,9 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): ) if pos_emb_config['alibi'] and attn_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') + + if pos_emb_config['rope'] and device == 'cpu': + pytest.skip(f'rope only implemented for gpus.') composer_device = get_device(device) @@ -1213,27 +1285,31 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): }, { 'alibi': False, 'rope': True, + 'rope_type': 'original', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'no_scaling', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': False, + 'xpos_scale_base': 512, }, { 'alibi': False, 'rope': True, + 'rope_type': 'xpos', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'linear', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': False, + 'xpos_scale_base': 512, }, { 'alibi': False, 'rope': True, + 'rope_type': 'original', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'dynamic', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': False, + 'xpos_scale_base': 512, +}, { + 'alibi': False, + 'rope': True, + 'rope_type': 'xpos', + 'rope_theta': 10000, + 'rope_pos_idx_in_fp32': True, + 'xpos_scale_base': 512, }]) def test_generate_with_past_kv(pos_emb_config: dict): hf_config = MPTConfig( @@ -1282,7 +1358,12 @@ def test_generate_with_past_kv(pos_emb_config: dict): assert kwargs['past_key_values'][0][0].shape == (1, 3, hf_config.d_model) - +@pytest.mark.parametrize('attn_impl,device', [ + ('torch', 'cpu'), + ('flash', 'gpu'), + ('triton', 'gpu'), + ('torch', 'gpu'), +]) @pytest.mark.parametrize('generation_kwargs', [{ 'max_new_tokens': 2, 'num_beams': 4 @@ -1303,30 +1384,43 @@ def test_generate_with_past_kv(pos_emb_config: dict): }, { 'alibi': False, 'rope': True, + 'rope_type': 'original', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'no_scaling', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': False, + 'xpos_scale_base': 512, }, { 'alibi': False, 'rope': True, + 'rope_type': 'xpos', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'linear', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': False, + 'xpos_scale_base': 512, }, { 'alibi': False, 'rope': True, + 'rope_type': 'original', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'dynamic', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': False, + 'xpos_scale_base': 512, +}, { + 'alibi': False, + 'rope': True, + 'rope_type': 'xpos', + 'rope_theta': 10000, + 'rope_pos_idx_in_fp32': True, + 'xpos_scale_base': 512, }]) -def test_generation_kwargs_dont_crash(generation_kwargs: Dict[str, Any], +def test_generation_kwargs_dont_crash(attn_impl: str, device: str, generation_kwargs: Dict[str, Any], pos_emb_config: dict): + if pos_emb_config['alibi'] and attn_impl == 'flash': + pytest.skip(f'alibi only implemented with torch and triton attention.') + + if pos_emb_config['rope'] and device == 'cpu': + pytest.skip(f'rope only implemented for gpus.') + + composer_device = get_device(device) + if device=='gpu': + torch.use_deterministic_algorithms(False) hf_config = MPTConfig( init_device='cpu', d_model=128, @@ -1337,21 +1431,28 @@ def test_generation_kwargs_dont_crash(generation_kwargs: Dict[str, Any], emb_pdrop=0.1, resid_pdrop=0.2, attn_config={ - 'attn_impl': 'torch', + 'attn_impl': attn_impl, **pos_emb_config, }, use_cache=True, ) mpt = MPTForCausalLM(hf_config) + mpt = composer_device.module_to_device(mpt) mpt.eval() - # no padding in the input - no_padding_input_ids = torch.tensor([[11274, 16390, 11]]) - no_padding_attention_mask = torch.tensor([[1, 1, 1]]) + with get_precision_context('amp_bf16' if composer_device.name == + 'gpu' else 'fp32'): + # no padding in the input + no_padding_input_ids = torch.tensor([[11274, 16390, 11]]) + no_padding_input_ids = composer_device.tensor_to_device(no_padding_input_ids) + no_padding_attention_mask = torch.tensor([[1, 1, 1]]) + no_padding_attention_mask = composer_device.tensor_to_device(no_padding_attention_mask) - _ = mpt.generate(input_ids=no_padding_input_ids, - attention_mask=no_padding_attention_mask, - **generation_kwargs) + _ = mpt.generate(input_ids=no_padding_input_ids, + attention_mask=no_padding_attention_mask, + **generation_kwargs) + if device=='gpu': + torch.use_deterministic_algorithms(True) @pytest.mark.gpu @@ -1365,27 +1466,31 @@ def test_generation_kwargs_dont_crash(generation_kwargs: Dict[str, Any], }, { 'alibi': False, 'rope': True, + 'rope_type': 'original', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'no_scaling', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': False, + 'xpos_scale_base': 512, }, { 'alibi': False, 'rope': True, + 'rope_type': 'xpos', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'linear', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': False, + 'xpos_scale_base': 512, }, { 'alibi': False, 'rope': True, + 'rope_type': 'original', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'dynamic', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': False, + 'xpos_scale_base': 512, +}, { + 'alibi': False, + 'rope': True, + 'rope_type': 'xpos', + 'rope_theta': 10000, + 'rope_pos_idx_in_fp32': True, + 'xpos_scale_base': 512, }]) def test_model_to(attention_impl: str, pos_emb_config: dict): # test that moving the model to diff devices and dtypes in diff ways does not break the model @@ -1433,7 +1538,7 @@ def test_model_to(attention_impl: str, pos_emb_config: dict): mpt = mpt.to('cpu') # verify the model still works - if attention_impl == 'torch': + if attention_impl == 'torch' and (not pos_emb_config['rope']): with torch.autocast('cpu', dtype=torch.bfloat16, enabled=True): _ = mpt(input_ids.to('cpu'), attention_mask=attention_mask.to('cpu')) @@ -1450,7 +1555,7 @@ def test_model_to(attention_impl: str, pos_emb_config: dict): mpt = mpt.float() # verify the model still works - if attention_impl == 'torch': + if attention_impl == 'torch' and (not pos_emb_config['rope']): _ = mpt(input_ids.to('cpu'), attention_mask=attention_mask.to('cpu')) mpt = mpt.half() @@ -1496,27 +1601,31 @@ def test_alibi_vs_hf(): }, { 'alibi': False, 'rope': True, + 'rope_type': 'original', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'no_scaling', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': False, + 'xpos_scale_base': 512, }, { 'alibi': False, 'rope': True, + 'rope_type': 'xpos', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'linear', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': False, + 'xpos_scale_base': 512, }, { 'alibi': False, 'rope': True, + 'rope_type': 'original', 'rope_theta': 10000, - 'rope_scaling': { - 'type': 'dynamic', - 'factor': 1.0 - } + 'rope_pos_idx_in_fp32': False, + 'xpos_scale_base': 512, +}, { + 'alibi': False, + 'rope': True, + 'rope_type': 'xpos', + 'rope_theta': 10000, + 'rope_pos_idx_in_fp32': True, + 'xpos_scale_base': 512, }]) @pytest.mark.parametrize('output_attentions', [True, False]) @pytest.mark.parametrize('output_hidden_states', [True, False]) @@ -1532,6 +1641,8 @@ def test_forward_with_output_attentions_and_output_hidden_states( pytest.skip(f'alibi only implemented with torch and triton attention.') if output_attentions and attn_impl in ['flash', 'triton']: pytest.skip(f'output_attentions only implemented with torch attention.') + if pos_emb_config['rope'] and device == 'cpu': + pytest.skip(f'rope only implemented for gpus.') composer_device = get_device(device) From 213cd1484cc112abc1179cb3e8a4c42bbe9aa384 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 24 Oct 2023 19:53:39 +0000 Subject: [PATCH 059/106] .. --- tests/test_model.py | 115 ++++++++++++++++++++++++++------------------ 1 file changed, 67 insertions(+), 48 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index 380c8b777d..03c08b1c56 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -572,7 +572,7 @@ def test_forward_with_padding(attention_impl: str, device: str, alibi = pos_emb_config['alibi'] if alibi and attention_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') - + rope = pos_emb_config['rope'] if rope and device == 'cpu': pytest.skip(f'rope only implemented for gpus.') @@ -663,15 +663,16 @@ def test_forward_with_padding(attention_impl: str, device: str, attention_mask=batched_attention_mask).logits # check that right padding and left padding produce the same output - if rope:########################################## + if rope: # RoPE uses bf16 precision, which causes some differences between the outputs of padded and unpadded inputs. assert torch.allclose(right_padding_output[0, :3], left_padding_output[0, 3:], rtol=1e-2, atol=1e-2) else: - assert torch.allclose(right_padding_output[0, :3], - left_padding_output[0, 3:], - atol=1e-6 if attention_impl == 'torch' else 1e-8) + assert torch.allclose( + right_padding_output[0, :3], + left_padding_output[0, 3:], + atol=1e-6 if attention_impl == 'torch' else 1e-8) if not (alibi or rope): # check that right padding and middle padding produce the same output # Note: alibi not implemented for middle padding. @@ -680,16 +681,17 @@ def test_forward_with_padding(attention_impl: str, device: str, middle_padding_output[0, [0, 1, 5]], atol=1e-6 if attention_impl == 'torch' else 1e-8) # check that right padding and right padding in a batch produce the same output - - if rope:########################################## + + if rope: # RoPE uses bf16 precision, which causes some differences between the outputs of padded and unpadded inputs. assert torch.allclose(right_padding_output[0, :3], left_padding_output[0, 3:], rtol=1e-2, atol=1e-2) else: - assert torch.allclose(right_padding_output[0, :3], - batched_output[0, :3], - atol=1e-6 if attention_impl == 'torch' else 1e-8) + assert torch.allclose( + right_padding_output[0, :3], + batched_output[0, :3], + atol=1e-6 if attention_impl == 'torch' else 1e-8) if not (alibi or rope): # check that middle padding and middle padding in a batch produce the same output # Note: alibi not implemented for middle padding. @@ -757,8 +759,7 @@ def test_advanced_mask_building(attention_impl: str): @pytest.mark.parametrize('attention_impl,device', [('torch', 'cpu'), ('flash', 'gpu'), ('triton', 'gpu'), - ('torch', 'gpu') - ]) + ('torch', 'gpu')]) @pytest.mark.parametrize('pos_emb_config', [{ 'alibi': False, 'rope': False @@ -803,7 +804,7 @@ def test_generate(attention_impl: str, device: str, pos_emb_config: dict): ) if pos_emb_config['alibi'] and attention_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') - + if pos_emb_config['rope'] and device == 'cpu': pytest.skip(f'rope only implemented for gpus.') @@ -990,6 +991,7 @@ def test_save_from_pretrained(tmp_path: pathlib.Path): check_hf_model_equivalence(mpt, mpt2) + @pytest.mark.parametrize('attn_impl,device', [ ('torch', 'cpu'), ('flash', 'gpu'), @@ -1031,7 +1033,8 @@ def test_save_from_pretrained(tmp_path: pathlib.Path): 'rope_pos_idx_in_fp32': True, 'xpos_scale_base': 512, }]) -def test_forward_with_cache_and_padding(attn_impl: str, device: str, pos_emb_config: dict): +def test_forward_with_cache_and_padding(attn_impl: str, device: str, + pos_emb_config: dict): # Tests that the result is the same with or without padding when using kv caching if not torch.cuda.is_available() and device == 'gpu': pytest.skip( @@ -1067,11 +1070,14 @@ def test_forward_with_cache_and_padding(attn_impl: str, device: str, pos_emb_con mpt = MPTForCausalLM(hf_config) mpt = composer_device.module_to_device(mpt) mpt.eval() - with get_precision_context('amp_bf16' if composer_device.name == 'gpu' else 'fp32'): + with get_precision_context('amp_bf16' if composer_device.name == + 'gpu' else 'fp32'): first_input_ids_no_padding = torch.tensor([[11274, 16390, 11]]) - first_input_ids_no_padding = composer_device.tensor_to_device(first_input_ids_no_padding) + first_input_ids_no_padding = composer_device.tensor_to_device( + first_input_ids_no_padding) first_attention_mask_no_padding = torch.tensor([[1, 1, 1]]).bool() - first_attention_mask_no_padding = composer_device.tensor_to_device(first_attention_mask_no_padding) + first_attention_mask_no_padding = composer_device.tensor_to_device( + first_attention_mask_no_padding) # start with passing the first three tokens through (no padding) first_output_no_padding = mpt( @@ -1079,9 +1085,11 @@ def test_forward_with_cache_and_padding(attn_impl: str, device: str, pos_emb_con attention_mask=first_attention_mask_no_padding) second_input_ids_no_padding = torch.tensor([[11274, 16390, 11, 11274]]) - second_input_ids_no_padding = composer_device.tensor_to_device(second_input_ids_no_padding) + second_input_ids_no_padding = composer_device.tensor_to_device( + second_input_ids_no_padding) second_attention_mask_no_padding = torch.tensor([[1, 1, 1, 1]]).bool() - second_attention_mask_no_padding = composer_device.tensor_to_device(second_attention_mask_no_padding) + second_attention_mask_no_padding = composer_device.tensor_to_device( + second_attention_mask_no_padding) # pass through the fourth token by itself, using the key-value cache (no padding) second_output_no_padding = mpt( @@ -1090,18 +1098,23 @@ def test_forward_with_cache_and_padding(attn_impl: str, device: str, pos_emb_con past_key_values=first_output_no_padding.past_key_values) first_input_ids_padding = torch.tensor([[50256, 11274, 16390, 11]]) - first_input_ids_padding = composer_device.tensor_to_device(first_input_ids_padding) + first_input_ids_padding = composer_device.tensor_to_device( + first_input_ids_padding) first_attention_mask_padding = torch.tensor([[0, 1, 1, 1]]).bool() - first_attention_mask_padding = composer_device.tensor_to_device(first_attention_mask_padding) + first_attention_mask_padding = composer_device.tensor_to_device( + first_attention_mask_padding) # start with passing the first three tokens through (with left padding) first_output_padding = mpt(first_input_ids_padding, - attention_mask=first_attention_mask_padding) + attention_mask=first_attention_mask_padding) - second_input_ids_padding = torch.tensor([[50256, 11274, 16390, 11, 11274]]) - second_input_ids_padding = composer_device.tensor_to_device(second_input_ids_padding) + second_input_ids_padding = torch.tensor( + [[50256, 11274, 16390, 11, 11274]]) + second_input_ids_padding = composer_device.tensor_to_device( + second_input_ids_padding) second_attention_mask_padding = torch.tensor([[0, 1, 1, 1, 1]]).bool() - second_attention_mask_padding = composer_device.tensor_to_device(second_attention_mask_padding) + second_attention_mask_padding = composer_device.tensor_to_device( + second_attention_mask_padding) # pass through the fourth token by itself, using the key-value cache (with left padding) second_output_padding = mpt( @@ -1110,18 +1123,20 @@ def test_forward_with_cache_and_padding(attn_impl: str, device: str, pos_emb_con past_key_values=first_output_padding.past_key_values) # check that the outputs are the same with or without padding - if pos_emb_config['rope']: ########################################## - torch.testing.assert_close(second_output_no_padding.logits, - second_output_padding.logits[:, - -1, :].unsqueeze(1), - atol=1e-2, - rtol=1e-6) + if pos_emb_config[ + 'rope']: # RoPE uses bf16 precision, which causes some differences between the outputs of padded and unpadded inputs. + breakpoint() + torch.testing.assert_close( + second_output_no_padding.logits, + second_output_padding.logits[:, -1, :].unsqueeze(1), + atol=1e-2, + rtol=1e-6) else: - torch.testing.assert_close(second_output_no_padding.logits, - second_output_padding.logits[:, - -1, :].unsqueeze(1), - atol=1e-6, - rtol=1e-6) + torch.testing.assert_close( + second_output_no_padding.logits, + second_output_padding.logits[:, -1, :].unsqueeze(1), + atol=1e-6, + rtol=1e-6) @pytest.mark.parametrize('attn_impl,device', [ @@ -1174,7 +1189,7 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): ) if pos_emb_config['alibi'] and attn_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') - + if pos_emb_config['rope'] and device == 'cpu': pytest.skip(f'rope only implemented for gpus.') @@ -1358,6 +1373,7 @@ def test_generate_with_past_kv(pos_emb_config: dict): assert kwargs['past_key_values'][0][0].shape == (1, 3, hf_config.d_model) + @pytest.mark.parametrize('attn_impl,device', [ ('torch', 'cpu'), ('flash', 'gpu'), @@ -1410,16 +1426,17 @@ def test_generate_with_past_kv(pos_emb_config: dict): 'rope_pos_idx_in_fp32': True, 'xpos_scale_base': 512, }]) -def test_generation_kwargs_dont_crash(attn_impl: str, device: str, generation_kwargs: Dict[str, Any], +def test_generation_kwargs_dont_crash(attn_impl: str, device: str, + generation_kwargs: Dict[str, Any], pos_emb_config: dict): if pos_emb_config['alibi'] and attn_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') - + if pos_emb_config['rope'] and device == 'cpu': pytest.skip(f'rope only implemented for gpus.') - + reproducibility.seed_all(1234) composer_device = get_device(device) - if device=='gpu': + if device == 'gpu': # Switch deteminism off torch.use_deterministic_algorithms(False) hf_config = MPTConfig( init_device='cpu', @@ -1441,18 +1458,20 @@ def test_generation_kwargs_dont_crash(attn_impl: str, device: str, generation_kw mpt.eval() with get_precision_context('amp_bf16' if composer_device.name == - 'gpu' else 'fp32'): + 'gpu' else 'fp32'): # no padding in the input no_padding_input_ids = torch.tensor([[11274, 16390, 11]]) - no_padding_input_ids = composer_device.tensor_to_device(no_padding_input_ids) + no_padding_input_ids = composer_device.tensor_to_device( + no_padding_input_ids) no_padding_attention_mask = torch.tensor([[1, 1, 1]]) - no_padding_attention_mask = composer_device.tensor_to_device(no_padding_attention_mask) + no_padding_attention_mask = composer_device.tensor_to_device( + no_padding_attention_mask) _ = mpt.generate(input_ids=no_padding_input_ids, - attention_mask=no_padding_attention_mask, - **generation_kwargs) - if device=='gpu': - torch.use_deterministic_algorithms(True) + attention_mask=no_padding_attention_mask, + **generation_kwargs) + if device == 'gpu': # Switch deteminism back on + reproducibility.configure_deterministic_mode() # @pytest.mark.gpu From 68d03d3c2093ef880640d95f1571d16bbc9a1af1 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 24 Oct 2023 19:57:54 +0000 Subject: [PATCH 060/106] .. --- llmfoundry/models/layers/attention.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 9775c49ec4..c3feb375ce 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -588,10 +588,15 @@ def forward( value = value.view(*(value.shape[:-1]), -1, self.head_dim) kv = torch.stack([key, value], dim=2) - query, kv = rotary_emb_w_offset_info['rotary_embedding'](query, kv, seqlen_offset=rotary_emb_w_offset_info['seqlen_offset'], max_seqlen=rotary_emb_w_offset_info['max_seqlen']) + query, kv = rotary_emb_w_offset_info['rotary_embedding']( + query, + kv, + seqlen_offset=rotary_emb_w_offset_info['seqlen_offset'], + max_seqlen=rotary_emb_w_offset_info['max_seqlen']) [key, value] = torch.unbind(kv, dim=2) - - value = value.view(*(value.shape[:-2]), self.kv_n_heads * self.head_dim) + + value = value.view(*(value.shape[:-2]), + self.kv_n_heads * self.head_dim) query = query.view(*(query.shape[:-2]), self.d_model) key = key.view(*(key.shape[:-2]), self.kv_n_heads * self.head_dim) From c0da75cc9a6a55d26d70f5eb0d8dd8a67c45aa20 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 24 Oct 2023 21:14:03 +0000 Subject: [PATCH 061/106] ... --- tests/test_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index 03c08b1c56..3a309207dc 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1125,7 +1125,6 @@ def test_forward_with_cache_and_padding(attn_impl: str, device: str, # check that the outputs are the same with or without padding if pos_emb_config[ 'rope']: # RoPE uses bf16 precision, which causes some differences between the outputs of padded and unpadded inputs. - breakpoint() torch.testing.assert_close( second_output_no_padding.logits, second_output_padding.logits[:, -1, :].unsqueeze(1), From 4600415c120fde2054ecaae04df2ca09bc3aa25f Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Wed, 25 Oct 2023 22:05:02 +0000 Subject: [PATCH 062/106] .. --- llmfoundry/models/layers/attention.py | 46 +++++++---- llmfoundry/models/layers/blocks.py | 30 ++++---- llmfoundry/models/mpt/configuration_mpt.py | 46 ++++++++--- llmfoundry/models/mpt/modeling_mpt.py | 88 +++++++++++++++++----- 4 files changed, 153 insertions(+), 57 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index c3feb375ce..201661d9af 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -12,6 +12,7 @@ from einops import rearrange from packaging import version from torch import nn +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY @@ -555,7 +556,7 @@ def forward( past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - rotary_emb_w_offset_info: Optional[dict] = None, + rotary_emb_w_meta_info: Optional[dict] = None, is_causal: bool = True, needs_weights: bool = False, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ @@ -582,23 +583,42 @@ def forward( query = self.q_ln(query).to(dtype) key = self.k_ln(key).to(dtype) - if rotary_emb_w_offset_info is not None: + if rotary_emb_w_meta_info is not None: + rotary_emb = rotary_emb_w_meta_info['rotary_emb'] + seq_len = rotary_emb_w_meta_info['seq_len'] + offset_info = rotary_emb_w_meta_info['offset_info'] + query = query.view(*(query.shape[:-1]), -1, self.head_dim) key = key.view(*(key.shape[:-1]), -1, self.head_dim) - value = value.view(*(value.shape[:-1]), -1, self.head_dim) - kv = torch.stack([key, value], dim=2) - query, kv = rotary_emb_w_offset_info['rotary_embedding']( - query, - kv, - seqlen_offset=rotary_emb_w_offset_info['seqlen_offset'], - max_seqlen=rotary_emb_w_offset_info['max_seqlen']) - [key, value] = torch.unbind(kv, dim=2) + if rotary_emb_w_meta_info['imp'] == 'dail': + value = value.view(*(value.shape[:-1]), -1, self.head_dim) + + kv = torch.stack([key, value], dim=2) + query, kv = rotary_emb(query, + kv, + seqlen_offset=offset_info, + max_seqlen=seq_len) + [key, value] = torch.unbind(kv, dim=2) - value = value.view(*(value.shape[:-2]), + value = value.view(*(value.shape[:-2]), + self.kv_n_heads * self.head_dim) + query = query.view(*(query.shape[:-2]), self.d_model) + key = key.view(*(key.shape[:-2]), self.kv_n_heads * self.head_dim) - query = query.view(*(query.shape[:-2]), self.d_model) - key = key.view(*(key.shape[:-2]), self.kv_n_heads * self.head_dim) + elif rotary_emb_w_meta_info['imp'] == 'hf_llama': + (cos, sin) = rotary_emb(value, seq_len) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + query, key = apply_rotary_pos_emb(query, key, cos, sin, + offset_info) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + breakpoint( + ) # Check if reshape is needed below or we can just use tensor.view + query = query.reshape(*(query.shape[:-2]), self.d_model) + key = key.reshape(*(key.shape[:-2]), + self.kv_n_heads * self.head_dim) context, attn_weights, past_key_value = self.attn_fn( query, diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 9e5b7ce844..976ffc6768 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -42,10 +42,17 @@ def __init__( 'alibi': False, 'alibi_bias_max': 8, 'rope': False, - 'rope_type': 'original', 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': False, - 'xpos_scale_base': 512, + 'rope_imp': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, } if ffn_config is None: @@ -62,16 +69,9 @@ def __init__( # necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs args_to_exclude_in_attn_class = { - 'attn_type', - 'prefix_lm', - 'alibi', - 'attn_uses_sequence_id', - 'alibi_bias_max', - 'rope', - 'rope_type', - 'rope_theta', - 'rope_pos_idx_in_fp32', - 'xpos_scale_base', + 'attn_type', 'prefix_lm', 'alibi', 'attn_uses_sequence_id', + 'alibi_bias_max', 'rope', 'rope_theta', 'rope_imp', + 'rope_dail_config', 'rope_hf_config' } attn_config_subset_for_attn_class = { k: v @@ -107,7 +107,7 @@ def forward( x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, - rotary_emb_w_offset_info: Optional[Dict] = None, + rotary_emb_w_meta_info: Optional[Dict] = None, attention_mask: Optional[torch.ByteTensor] = None, is_causal: bool = True, output_attentions: bool = False, @@ -118,7 +118,7 @@ def forward( a, past_key_value=past_key_value, attn_bias=attn_bias, - rotary_emb_w_offset_info=rotary_emb_w_offset_info, + rotary_emb_w_meta_info=rotary_emb_w_meta_info, attention_mask=attention_mask, is_causal=is_causal, needs_weights=output_attentions, diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 3d03e1476a..2295d1e034 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -20,10 +20,17 @@ 'alibi': False, 'alibi_bias_max': 8, 'rope': False, - 'rope_type': 'original', 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': False, - 'xpos_scale_base': 512, + 'rope_imp': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, } ffn_config_defaults: Dict = { @@ -100,10 +107,15 @@ def __init__( alibi (bool): Whether to use the alibi bias instead of position embeddings. alibi_bias_max (int): The maximum value of the alibi bias. rope (bool): Whether to use rotary positional embeddings. - rope_type (str): The type of rope to use. Options: 'original', 'xpos' rope_theta (int): The base frequency for rope. - rope_pos_idx_in_fp32 (bool): If True, the position indices [0.0, ..., seqlen - 1] are in fp32, otherwise they might be in lower precision. A consequence could be, for example, that bf16 rounds position 1995 to 2000, which leads to them having the same positional embedding. - xpos_scale_base (float): The scale base for XPos. + rope_imp (str): The implementation of rope to use. One of 'hf' (to use the implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py) or 'dail' (to use the implementation from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py). + rope_dail_config (Dict): The configuration for the dail implementation of rope. + type (str): The type of rotary position embedding to use. Options: 'original' (for https://arxiv.org/pdf/2104.09864.pdf), 'xpos' (for https://arxiv.org/pdf/2212.10554.pdf). + pos_idx_in_fp32 (bool): If True, the position indices [0, ..., seqlen - 1] are in fp32, otherwise they might be in lower precision. A consequence could be, for example, that bf16 rounds position 1995 to 2000, which leads to them having the same positional embedding. + xpos_scale_base (float): The scale base for XPos (if using XPos). + rope_hf_config (Dict): A dictionary used to configure rope's scaling behavior (when scaling beyond the training length). + type (str): Can be one of 'no_scaling', 'linear', or 'dynamic'. 'no_scaling' uses the default implementation for rotary embeddings, 'linear' uses linear scaling as proposed by the Reddit user /u/kaiokendev, and 'dynamic' uses Dynamic NTK scaling as proposed by the Reddit users /u/bloc97 and /u/emozilla. + factor (float): Scaling factor to use if using 'linear' or 'dynamic' as rope_scaling.type. kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads. ffn_config (Dict): A dictionary used to configure the model's ffn module: ffn_type (str): type of ffn to use. Options: mptmlp, te_ln_mlp @@ -218,10 +230,26 @@ def _validate_config(self) -> None: raise NotImplementedError( 'attn_uses_sequence_id only implemented with torch and triton attention.' ) - if self.attn_config['rope'] and (self.attn_config['rope_type'] - not in ['original', 'xpos']): + if self.attn_config['rope'] and (self.attn_config['rope_imp'] + not in ['dail', 'hf']): + raise ValueError( + 'If rope is being used then rope_imp should be either "dail", or "hf".' + ) + if self.attn_config['rope'] and ( + self.attn_config['rope_imp'] + == 'hf') and self.attn_config['rope_hf_config']['type'] not in [ + 'no_scaling', 'linear', 'dynamic' + ]: + raise ValueError( + 'If using hf implementation of rope, the type should be one of "no_scaling", "linear" or "dynamic".' + ) + if self.attn_config['rope'] and ( + self.attn_config['rope_imp'] + == 'dail') and (self.attn_config['rope_dail_config']['type'] + not in ['original', 'xpos']): raise NotImplementedError( - 'rope_type must be one of "original" or "xpos".') + 'If using dail implementation of rope, the type should be one of "original" or "xpos".' + ) if self.embedding_fraction > 1 or self.embedding_fraction <= 0: raise ValueError( 'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!' diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 85864ab338..ea167449c2 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -23,12 +23,18 @@ from composer.metrics.nlp import LanguageCrossEntropy, LanguagePerplexity from composer.models import HuggingFaceModel from composer.utils import dist -from flash_attn.layers.rotary import RotaryEmbedding +from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding from omegaconf import DictConfig from omegaconf import OmegaConf as om from transformers import PreTrainedModel, PreTrainedTokenizerBase from transformers.modeling_outputs import (BaseModelOutputWithPast, CausalLMOutputWithPast) +from transformers.models.llama.modeling_llama import \ + DynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding +from transformers.models.llama.modeling_llama import \ + LinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding +from transformers.models.llama.modeling_llama import \ + RotaryEmbedding as HFRotaryEmbedding from llmfoundry.models.layers.attention import attn_bias_shape, build_attn_bias from llmfoundry.models.layers.blocks import MPTBlock @@ -71,6 +77,46 @@ log = logging.getLogger(__name__) +def _rotary_embedding(config: MPTConfig): + rope_head_dim = config.d_model // config.n_heads + if config.attn_config['rope_imp'] == 'dail': + return DAILRotaryEmbedding( + dim=rope_head_dim, + base=config.attn_config['rope_theta'], + interleaved=False, + scale_base=config.attn_config['rope_dail_config']['xpos_scale_base'] + if (config.attn_config['rope_dail_config']['type'] + == 'xpos') else None, + pos_idx_in_fp32=config.attn_config['rope_dail_config'] + ['pos_idx_in_fp32'], + device='cpu', + ) + elif config.attn_config['rope_imp'] == 'hf': + if config.attn_config['rope_hf_config']['type'] == 'no_scaling': + return HFRotaryEmbedding( + rope_head_dim, + max_position_embeddings=config.max_seq_len, + base=config.attn_config['rope_theta'], + device='cpu' + ) # FSDP does not materialize modules with no parameters, hence if we create meta buffers in rotary embeddings, they will not be materialized + elif config.attn_config['rope_hf_config']['type'] == 'linear': + return HFLinearScalingRotaryEmbedding( + rope_head_dim, + max_position_embeddings=config.max_seq_len, + base=config.attn_config['rope_theta'], + scaling_factor=config.attn_config['rope_hf_config']['factor'], + device='cpu' + ) # FSDP does not materialize modules with no parameters, hence if we create meta buffers in rotary embeddings, they will not be materialized + elif config.attn_config['rope_hf_config']['type'] == 'dynamic': + return HFDynamicNTKScalingRotaryEmbedding( + rope_head_dim, + max_position_embeddings=config.max_seq_len, + base=config.attn_config['rope_theta'], + scaling_factor=config.attn_config['rope_hf_config']['factor'], + device='cpu' + ) # FSDP does not materialize modules with no parameters, hence if we create meta buffers in rotary embeddings, they will not be materialized + + class MPTPreTrainedModel(PreTrainedModel): config_class = MPTConfig base_model_prefix = 'model' @@ -125,16 +171,10 @@ def __init__(self, config: MPTConfig): self.norm_f = norm_class(config.d_model, device=config.init_device) self.rope = config.attn_config['rope'] + self.rope_imp = None if self.rope: - self.rotary_embedding = RotaryEmbedding( - dim=config.d_model // config.n_heads, - base=config.attn_config['rope_theta'], - interleaved=False, - scale_base=config.attn_config['xpos_scale_base'] if - (config.attn_config['rope_type'] == 'xpos') else None, - pos_idx_in_fp32=config.attn_config['rope_pos_idx_in_fp32'], - device='cpu', - ) + self.rope_imp = config.attn_config['rope_imp'] + self.rotary_embedding = _rotary_embedding(config) if config.init_device != 'meta': log.info( @@ -374,7 +414,7 @@ def forward( S <= self.config.max_seq_len ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' - rotary_emb_w_offset_info = None + rotary_emb_w_meta_info = None x = self.wte(input_ids) if self.learned_pos_emb or self.rope: past_position = 0 @@ -400,7 +440,7 @@ def forward( f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.' ) - if self.learned_pos_emb: + if self.learned_pos_emb or (self.rope and self.rope_imp == 'hf'): pos = torch.arange( past_position, S + past_position, @@ -414,14 +454,22 @@ def forward( dim=1)[:, past_position:], min=0, ) - x = x + self.wpe(pos) - - if self.rope: - rotary_emb_w_offset_info = { + if self.learned_pos_emb: + x = x + self.wpe(pos) + elif self.rope and self.rope_imp == 'hf': + rotary_emb_w_meta_info = { + 'imp': self.rope_imp, + 'rotary_embedding': self.rotary_embedding, + 'offset_info': pos, + 'seq_len': S + past_position, + } + elif self.rope and self.rope_imp == 'dail': + rotary_emb_w_meta_info = { + 'imp': self.rope_imp, 'rotary_embedding': self.rotary_embedding, - 'seqlen_offset': past_position, - 'max_seqlen': S + past_position, - } + 'offset_info': past_position, + 'seq_len': S + past_position, + } if self.embedding_fraction == 1: x = self.emb_drop(x) @@ -458,7 +506,7 @@ def forward( x, past_key_value=past_key_value, attn_bias=attn_bias, - rotary_emb_w_offset_info=rotary_emb_w_offset_info, + rotary_emb_w_meta_info=rotary_emb_w_meta_info, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions), From 5ecda44c951f014b03ac148b00638d328843ccf4 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Wed, 25 Oct 2023 22:40:21 +0000 Subject: [PATCH 063/106] .. --- llmfoundry/models/mpt/modeling_mpt.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index ea167449c2..05b8a136f4 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -30,11 +30,11 @@ from transformers.modeling_outputs import (BaseModelOutputWithPast, CausalLMOutputWithPast) from transformers.models.llama.modeling_llama import \ - DynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding + LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding from transformers.models.llama.modeling_llama import \ - LinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding + LlamaLinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding from transformers.models.llama.modeling_llama import \ - RotaryEmbedding as HFRotaryEmbedding + LlamaRotaryEmbedding as HFRotaryEmbedding from llmfoundry.models.layers.attention import attn_bias_shape, build_attn_bias from llmfoundry.models.layers.blocks import MPTBlock @@ -89,7 +89,7 @@ def _rotary_embedding(config: MPTConfig): == 'xpos') else None, pos_idx_in_fp32=config.attn_config['rope_dail_config'] ['pos_idx_in_fp32'], - device='cpu', + device='cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu ) elif config.attn_config['rope_imp'] == 'hf': if config.attn_config['rope_hf_config']['type'] == 'no_scaling': @@ -97,24 +97,24 @@ def _rotary_embedding(config: MPTConfig): rope_head_dim, max_position_embeddings=config.max_seq_len, base=config.attn_config['rope_theta'], - device='cpu' - ) # FSDP does not materialize modules with no parameters, hence if we create meta buffers in rotary embeddings, they will not be materialized + device='cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu + ) elif config.attn_config['rope_hf_config']['type'] == 'linear': return HFLinearScalingRotaryEmbedding( rope_head_dim, max_position_embeddings=config.max_seq_len, base=config.attn_config['rope_theta'], scaling_factor=config.attn_config['rope_hf_config']['factor'], - device='cpu' - ) # FSDP does not materialize modules with no parameters, hence if we create meta buffers in rotary embeddings, they will not be materialized + device='cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu + ) elif config.attn_config['rope_hf_config']['type'] == 'dynamic': return HFDynamicNTKScalingRotaryEmbedding( rope_head_dim, max_position_embeddings=config.max_seq_len, base=config.attn_config['rope_theta'], scaling_factor=config.attn_config['rope_hf_config']['factor'], - device='cpu' - ) # FSDP does not materialize modules with no parameters, hence if we create meta buffers in rotary embeddings, they will not be materialized + device='cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu + ) class MPTPreTrainedModel(PreTrainedModel): From 64411806788d3ecd47801d0235975f2daa187426 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 26 Oct 2023 00:32:00 +0000 Subject: [PATCH 064/106] .. --- llmfoundry/models/layers/attention.py | 13 +- llmfoundry/models/mpt/modeling_mpt.py | 16 +- tests/test_flash_triton_torch.py | 136 +++++++--- tests/test_model.py | 367 +++++++++++++------------- 4 files changed, 303 insertions(+), 229 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 201661d9af..f10d2c1f64 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -603,10 +603,7 @@ def forward( value = value.view(*(value.shape[:-2]), self.kv_n_heads * self.head_dim) - query = query.view(*(query.shape[:-2]), self.d_model) - key = key.view(*(key.shape[:-2]), - self.kv_n_heads * self.head_dim) - elif rotary_emb_w_meta_info['imp'] == 'hf_llama': + elif rotary_emb_w_meta_info['imp'] == 'hf': (cos, sin) = rotary_emb(value, seq_len) query = query.transpose(1, 2) key = key.transpose(1, 2) @@ -614,11 +611,9 @@ def forward( offset_info) query = query.transpose(1, 2) key = key.transpose(1, 2) - breakpoint( - ) # Check if reshape is needed below or we can just use tensor.view - query = query.reshape(*(query.shape[:-2]), self.d_model) - key = key.reshape(*(key.shape[:-2]), - self.kv_n_heads * self.head_dim) + + query = query.view(*(query.shape[:-2]), self.d_model) + key = key.view(*(key.shape[:-2]), self.kv_n_heads * self.head_dim) context, attn_weights, past_key_value = self.attn_fn( query, diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 05b8a136f4..fb14161fe1 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -89,7 +89,8 @@ def _rotary_embedding(config: MPTConfig): == 'xpos') else None, pos_idx_in_fp32=config.attn_config['rope_dail_config'] ['pos_idx_in_fp32'], - device='cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu + device= + 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu ) elif config.attn_config['rope_imp'] == 'hf': if config.attn_config['rope_hf_config']['type'] == 'no_scaling': @@ -97,7 +98,8 @@ def _rotary_embedding(config: MPTConfig): rope_head_dim, max_position_embeddings=config.max_seq_len, base=config.attn_config['rope_theta'], - device='cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu + device= + 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu ) elif config.attn_config['rope_hf_config']['type'] == 'linear': return HFLinearScalingRotaryEmbedding( @@ -105,7 +107,8 @@ def _rotary_embedding(config: MPTConfig): max_position_embeddings=config.max_seq_len, base=config.attn_config['rope_theta'], scaling_factor=config.attn_config['rope_hf_config']['factor'], - device='cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu + device= + 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu ) elif config.attn_config['rope_hf_config']['type'] == 'dynamic': return HFDynamicNTKScalingRotaryEmbedding( @@ -113,7 +116,8 @@ def _rotary_embedding(config: MPTConfig): max_position_embeddings=config.max_seq_len, base=config.attn_config['rope_theta'], scaling_factor=config.attn_config['rope_hf_config']['factor'], - device='cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu + device= + 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu ) @@ -459,14 +463,14 @@ def forward( elif self.rope and self.rope_imp == 'hf': rotary_emb_w_meta_info = { 'imp': self.rope_imp, - 'rotary_embedding': self.rotary_embedding, + 'rotary_emb': self.rotary_embedding, 'offset_info': pos, 'seq_len': S + past_position, } elif self.rope and self.rope_imp == 'dail': rotary_emb_w_meta_info = { 'imp': self.rope_imp, - 'rotary_embedding': self.rotary_embedding, + 'rotary_emb': self.rotary_embedding, 'offset_info': past_position, 'seq_len': S + past_position, } diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index e31cc2be00..e87d70223f 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -4,9 +4,14 @@ import pytest import torch from composer.utils import reproducibility +from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding from omegaconf import OmegaConf as om - -from flash_attn.layers.rotary import RotaryEmbedding +from transformers.models.llama.modeling_llama import \ + LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding +from transformers.models.llama.modeling_llama import \ + LlamaLinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding +from transformers.models.llama.modeling_llama import \ + LlamaRotaryEmbedding as HFRotaryEmbedding def allclose_helper(t0: torch.Tensor, @@ -30,31 +35,31 @@ def allclose_helper(t0: torch.Tensor, }, { 'alibi': False, 'rope': True, - 'rope_type': 'original', - 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': False, - 'xpos_scale_base': 512, -}, { - 'alibi': False, - 'rope': True, - 'rope_type': 'xpos', - 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': False, - 'xpos_scale_base': 512, -}, { - 'alibi': False, - 'rope': True, - 'rope_type': 'original', 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': True, - 'xpos_scale_base': 512, + 'rope_imp': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, }, { 'alibi': False, 'rope': True, - 'rope_type': 'xpos', 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': True, - 'xpos_scale_base': 512, + 'rope_imp': 'hf', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, }]) @pytest.mark.parametrize( 'attn_type', @@ -125,6 +130,55 @@ def gen_bias(attn_impl: str): return attn_bias + def gen_rotary_embedding(rope_head_dim: int, pos_emb_config: dict, + max_seq_len: int): + if pos_emb_config['rope_imp'] == 'dail': + return DAILRotaryEmbedding( + dim=rope_head_dim, + base=pos_emb_config['rope_theta'], + interleaved=False, + scale_base=pos_emb_config['rope_dail_config']['xpos_scale_base'] + if (pos_emb_config['rope_dail_config']['type'] + == 'xpos') else None, + pos_idx_in_fp32=pos_emb_config['rope_dail_config'] + ['pos_idx_in_fp32'], + device= + 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu + ) + elif pos_emb_config['rope_imp'] == 'hf': + if pos_emb_config['rope_hf_config']['type'] == 'no_scaling': + return HFRotaryEmbedding( + rope_head_dim, + max_position_embeddings=max_seq_len, + base=pos_emb_config['rope_theta'], + device= + 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu + ) + elif pos_emb_config['rope_hf_config']['type'] == 'linear': + return HFLinearScalingRotaryEmbedding( + rope_head_dim, + max_position_embeddings=max_seq_len, + base=pos_emb_config['rope_theta'], + scaling_factor=pos_emb_config['rope_hf_config']['factor'], + device= + 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu + ) + elif pos_emb_config['rope_hf_config']['type'] == 'dynamic': + return HFDynamicNTKScalingRotaryEmbedding( + rope_head_dim, + max_position_embeddings=max_seq_len, + base=pos_emb_config['rope_theta'], + scaling_factor=pos_emb_config['rope_hf_config']['factor'], + device= + 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu + ) + else: + raise ValueError( + f'Invalid scaling type: {pos_emb_config["rope_hf_config"]["type"]}' + ) + else: + raise ValueError(f'Invalid rope_imp: {pos_emb_config["rope_imp"]}') + x0 = torch.randn(n, s, f).to(device) x1 = x0.clone().detach() x0.requires_grad = True @@ -133,33 +187,41 @@ def gen_bias(attn_impl: str): with torch.autocast(x0.device.type): attn_bias = gen_bias(attn0.attn_impl) - rotary_emb_w_offset_info = None + rotary_emb_w_meta_info = None if rope: - rotary_embedding = RotaryEmbedding( - dim=cfg.d_model // cfg.n_heads, - base=pos_emb_config['rope_theta'], - interleaved=False, - scale_base=pos_emb_config['xpos_scale_base'] if (pos_emb_config['rope_type'] == 'xpos') else None, - pos_idx_in_fp32=pos_emb_config['rope_pos_idx_in_fp32'], - device='cpu' - ).to(device) - rotary_emb_w_offset_info = { - 'rotary_embedding': rotary_embedding, - 'seqlen_offset': 0, - 'max_seqlen': s + rotary_embedding = gen_rotary_embedding( + rope_head_dim=cfg.d_model // cfg.n_heads, + pos_emb_config=pos_emb_config, + max_seq_len=s).to(device) + pos = torch.arange(s).unsqueeze(0).to(device=device) + # adjust the position indices to account for padding tokens + pos = torch.clamp( + pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1), + min=0, + ) + rotary_emb_w_meta_info = { + 'imp': + pos_emb_config['rope_imp'], + 'rotary_emb': + rotary_embedding, + 'offset_info': + pos if (pos_emb_config['rope_imp'] == 'hf') else 0, + 'seq_len': + s, } + y0, _, _ = attn0(x0, past_key_value=None, attn_bias=attn_bias, attention_mask=attention_mask, - rotary_emb_w_offset_info=rotary_emb_w_offset_info, + rotary_emb_w_meta_info=rotary_emb_w_meta_info, is_causal=True) attn_bias = gen_bias(attn1.attn_impl) y1, _, _ = attn1(x1, past_key_value=None, attn_bias=attn_bias, attention_mask=attention_mask, - rotary_emb_w_offset_info=rotary_emb_w_offset_info, + rotary_emb_w_meta_info=rotary_emb_w_meta_info, is_causal=True) y0 *= attention_mask.unsqueeze(-1) y1 *= attention_mask.unsqueeze(-1) diff --git a/tests/test_model.py b/tests/test_model.py index 3a309207dc..5d8b069fc1 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -536,31 +536,31 @@ def test_mpt_creation(norm_type: str, no_bias: bool): }, { 'alibi': False, 'rope': True, - 'rope_type': 'original', 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': True, - 'xpos_scale_base': 512, + 'rope_imp': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, }, { 'alibi': False, 'rope': True, - 'rope_type': 'xpos', 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': True, - 'xpos_scale_base': 512, -}, { - 'alibi': False, - 'rope': True, - 'rope_type': 'original', - 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': False, - 'xpos_scale_base': 512, -}, { - 'alibi': False, - 'rope': True, - 'rope_type': 'xpos', - 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': False, - 'xpos_scale_base': 512, + 'rope_imp': 'hf', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, }]) def test_forward_with_padding(attention_impl: str, device: str, pos_emb_config: dict): @@ -574,8 +574,9 @@ def test_forward_with_padding(attention_impl: str, device: str, pytest.skip(f'alibi only implemented with torch and triton attention.') rope = pos_emb_config['rope'] - if rope and device == 'cpu': - pytest.skip(f'rope only implemented for gpus.') + if rope and pos_emb_config['rope_imp'] == 'dail' and device == 'cpu': + pytest.skip( + f'dail implementation of rope is only implemented for gpus.') reproducibility.seed_all(1234) composer_device = get_device(device) @@ -663,7 +664,8 @@ def test_forward_with_padding(attention_impl: str, device: str, attention_mask=batched_attention_mask).logits # check that right padding and left padding produce the same output - if rope: # RoPE uses bf16 precision, which causes some differences between the outputs of padded and unpadded inputs. + if rope and pos_emb_config[ + 'rope_imp'] == 'dail': # dail implementation of rope uses bf16 precision and hence the rotations have small numerical errors. This causes some differences between the outputs of padded and unpadded inputs. assert torch.allclose(right_padding_output[0, :3], left_padding_output[0, 3:], rtol=1e-2, @@ -682,7 +684,8 @@ def test_forward_with_padding(attention_impl: str, device: str, atol=1e-6 if attention_impl == 'torch' else 1e-8) # check that right padding and right padding in a batch produce the same output - if rope: # RoPE uses bf16 precision, which causes some differences between the outputs of padded and unpadded inputs. + if rope and pos_emb_config[ + 'rope_imp'] == 'dail': # dail implementation of rope uses bf16 precision and hence the rotations have small numerical errors. This causes some differences between the outputs of padded and unpadded inputs. assert torch.allclose(right_padding_output[0, :3], left_padding_output[0, 3:], rtol=1e-2, @@ -769,31 +772,31 @@ def test_advanced_mask_building(attention_impl: str): }, { 'alibi': False, 'rope': True, - 'rope_type': 'original', 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': False, - 'xpos_scale_base': 512, + 'rope_imp': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, }, { 'alibi': False, 'rope': True, - 'rope_type': 'xpos', 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': False, - 'xpos_scale_base': 512, -}, { - 'alibi': False, - 'rope': True, - 'rope_type': 'original', - 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': True, - 'xpos_scale_base': 512, -}, { - 'alibi': False, - 'rope': True, - 'rope_type': 'xpos', - 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': True, - 'xpos_scale_base': 512, + 'rope_imp': 'hf', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, }]) def test_generate(attention_impl: str, device: str, pos_emb_config: dict): # Test that generate works, and produces the same output with or without @@ -805,8 +808,10 @@ def test_generate(attention_impl: str, device: str, pos_emb_config: dict): if pos_emb_config['alibi'] and attention_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') - if pos_emb_config['rope'] and device == 'cpu': - pytest.skip(f'rope only implemented for gpus.') + if pos_emb_config['rope'] and pos_emb_config[ + 'rope_imp'] == 'dail' and device == 'cpu': + pytest.skip( + f'dail implementation of rope is only implemented for gpus.') reproducibility.seed_all(1234) composer_device = get_device(device) @@ -1007,31 +1012,31 @@ def test_save_from_pretrained(tmp_path: pathlib.Path): }, { 'alibi': False, 'rope': True, - 'rope_type': 'original', 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': False, - 'xpos_scale_base': 512, + 'rope_imp': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, }, { 'alibi': False, 'rope': True, - 'rope_type': 'xpos', 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': False, - 'xpos_scale_base': 512, -}, { - 'alibi': False, - 'rope': True, - 'rope_type': 'original', - 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': True, - 'xpos_scale_base': 512, -}, { - 'alibi': False, - 'rope': True, - 'rope_type': 'xpos', - 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': True, - 'xpos_scale_base': 512, + 'rope_imp': 'hf', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, }]) def test_forward_with_cache_and_padding(attn_impl: str, device: str, pos_emb_config: dict): @@ -1042,8 +1047,10 @@ def test_forward_with_cache_and_padding(attn_impl: str, device: str, ) if pos_emb_config['alibi'] and attn_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') - if pos_emb_config['rope'] and device == 'cpu': - pytest.skip(f'rope only implemented for gpus.') + if pos_emb_config['rope'] and pos_emb_config[ + 'rope_imp'] == 'dail' and device == 'cpu': + pytest.skip( + f'dail implementation of rope is only implemented for gpus.') composer_device = get_device(device) @@ -1123,8 +1130,8 @@ def test_forward_with_cache_and_padding(attn_impl: str, device: str, past_key_values=first_output_padding.past_key_values) # check that the outputs are the same with or without padding - if pos_emb_config[ - 'rope']: # RoPE uses bf16 precision, which causes some differences between the outputs of padded and unpadded inputs. + if pos_emb_config['rope'] and pos_emb_config[ + 'rope_imp'] == 'dail': # dail implementation of rope uses bf16 precision and hence the rotations have small numerical errors. This causes some differences between the outputs of padded and unpadded inputs. torch.testing.assert_close( second_output_no_padding.logits, second_output_padding.logits[:, -1, :].unsqueeze(1), @@ -1153,31 +1160,31 @@ def test_forward_with_cache_and_padding(attn_impl: str, device: str, }, { 'alibi': False, 'rope': True, - 'rope_type': 'original', - 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': False, - 'xpos_scale_base': 512, -}, { - 'alibi': False, - 'rope': True, - 'rope_type': 'xpos', - 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': False, - 'xpos_scale_base': 512, -}, { - 'alibi': False, - 'rope': True, - 'rope_type': 'original', 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': True, - 'xpos_scale_base': 512, + 'rope_imp': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, }, { 'alibi': False, 'rope': True, - 'rope_type': 'xpos', 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': True, - 'xpos_scale_base': 512, + 'rope_imp': 'hf', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, }]) def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): # Test that model forward with and without the key-value cache produces the @@ -1189,8 +1196,10 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): if pos_emb_config['alibi'] and attn_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') - if pos_emb_config['rope'] and device == 'cpu': - pytest.skip(f'rope only implemented for gpus.') + if pos_emb_config['rope'] and pos_emb_config[ + 'rope_imp'] == 'dail' and device == 'cpu': + pytest.skip( + f'dail implementation of rope is only implemented for gpus.') composer_device = get_device(device) @@ -1299,31 +1308,31 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): }, { 'alibi': False, 'rope': True, - 'rope_type': 'original', - 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': False, - 'xpos_scale_base': 512, -}, { - 'alibi': False, - 'rope': True, - 'rope_type': 'xpos', - 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': False, - 'xpos_scale_base': 512, -}, { - 'alibi': False, - 'rope': True, - 'rope_type': 'original', 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': False, - 'xpos_scale_base': 512, + 'rope_imp': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, }, { 'alibi': False, 'rope': True, - 'rope_type': 'xpos', 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': True, - 'xpos_scale_base': 512, + 'rope_imp': 'hf', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, }]) def test_generate_with_past_kv(pos_emb_config: dict): hf_config = MPTConfig( @@ -1399,31 +1408,31 @@ def test_generate_with_past_kv(pos_emb_config: dict): }, { 'alibi': False, 'rope': True, - 'rope_type': 'original', 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': False, - 'xpos_scale_base': 512, + 'rope_imp': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, }, { 'alibi': False, 'rope': True, - 'rope_type': 'xpos', 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': False, - 'xpos_scale_base': 512, -}, { - 'alibi': False, - 'rope': True, - 'rope_type': 'original', - 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': False, - 'xpos_scale_base': 512, -}, { - 'alibi': False, - 'rope': True, - 'rope_type': 'xpos', - 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': True, - 'xpos_scale_base': 512, + 'rope_imp': 'hf', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, }]) def test_generation_kwargs_dont_crash(attn_impl: str, device: str, generation_kwargs: Dict[str, Any], @@ -1431,8 +1440,10 @@ def test_generation_kwargs_dont_crash(attn_impl: str, device: str, if pos_emb_config['alibi'] and attn_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') - if pos_emb_config['rope'] and device == 'cpu': - pytest.skip(f'rope only implemented for gpus.') + if pos_emb_config['rope'] and pos_emb_config[ + 'rope_imp'] == 'dail' and device == 'cpu': + pytest.skip( + f'dail implementation of rope is only implemented for gpus.') reproducibility.seed_all(1234) composer_device = get_device(device) if device == 'gpu': # Switch deteminism off @@ -1470,7 +1481,7 @@ def test_generation_kwargs_dont_crash(attn_impl: str, device: str, attention_mask=no_padding_attention_mask, **generation_kwargs) if device == 'gpu': # Switch deteminism back on - reproducibility.configure_deterministic_mode() # + reproducibility.configure_deterministic_mode() @pytest.mark.gpu @@ -1484,31 +1495,31 @@ def test_generation_kwargs_dont_crash(attn_impl: str, device: str, }, { 'alibi': False, 'rope': True, - 'rope_type': 'original', 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': False, - 'xpos_scale_base': 512, + 'rope_imp': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, }, { 'alibi': False, 'rope': True, - 'rope_type': 'xpos', 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': False, - 'xpos_scale_base': 512, -}, { - 'alibi': False, - 'rope': True, - 'rope_type': 'original', - 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': False, - 'xpos_scale_base': 512, -}, { - 'alibi': False, - 'rope': True, - 'rope_type': 'xpos', - 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': True, - 'xpos_scale_base': 512, + 'rope_imp': 'hf', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, }]) def test_model_to(attention_impl: str, pos_emb_config: dict): # test that moving the model to diff devices and dtypes in diff ways does not break the model @@ -1619,31 +1630,31 @@ def test_alibi_vs_hf(): }, { 'alibi': False, 'rope': True, - 'rope_type': 'original', 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': False, - 'xpos_scale_base': 512, + 'rope_imp': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, }, { 'alibi': False, 'rope': True, - 'rope_type': 'xpos', 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': False, - 'xpos_scale_base': 512, -}, { - 'alibi': False, - 'rope': True, - 'rope_type': 'original', - 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': False, - 'xpos_scale_base': 512, -}, { - 'alibi': False, - 'rope': True, - 'rope_type': 'xpos', - 'rope_theta': 10000, - 'rope_pos_idx_in_fp32': True, - 'xpos_scale_base': 512, + 'rope_imp': 'hf', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, }]) @pytest.mark.parametrize('output_attentions', [True, False]) @pytest.mark.parametrize('output_hidden_states', [True, False]) @@ -1659,8 +1670,10 @@ def test_forward_with_output_attentions_and_output_hidden_states( pytest.skip(f'alibi only implemented with torch and triton attention.') if output_attentions and attn_impl in ['flash', 'triton']: pytest.skip(f'output_attentions only implemented with torch attention.') - if pos_emb_config['rope'] and device == 'cpu': - pytest.skip(f'rope only implemented for gpus.') + if pos_emb_config['rope'] and pos_emb_config[ + 'rope_imp'] == 'dail' and device == 'cpu': + pytest.skip( + f'dail implementation of rope is only implemented for gpus.') composer_device = get_device(device) From d71a2a0ea98c76d53804ced948f4de7b58a6c1e9 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 26 Oct 2023 02:22:31 +0000 Subject: [PATCH 065/106] .. --- tests/test_rope_dail_vs_hf.py | 191 ++++++++++++++++++++++++++++++++++ 1 file changed, 191 insertions(+) create mode 100644 tests/test_rope_dail_vs_hf.py diff --git a/tests/test_rope_dail_vs_hf.py b/tests/test_rope_dail_vs_hf.py new file mode 100644 index 0000000000..cda49f759f --- /dev/null +++ b/tests/test_rope_dail_vs_hf.py @@ -0,0 +1,191 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from composer.core.precision import get_precision_context +from composer.utils import reproducibility +from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding +from omegaconf import OmegaConf as om +from transformers.models.llama.modeling_llama import \ + LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding +from transformers.models.llama.modeling_llama import \ + LlamaLinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding +from transformers.models.llama.modeling_llama import \ + LlamaRotaryEmbedding as HFRotaryEmbedding + + +@pytest.mark.gpu +@pytest.mark.parametrize('clip_qkv', [True, False]) +@pytest.mark.parametrize('qk_ln', [True, False]) +@pytest.mark.parametrize( + 'attn_type', + ['multihead_attention', 'multiquery_attention', 'grouped_query_attention']) +@pytest.mark.parametrize('seq_len', [1, 233, 2048]) +def test_rope_dail_vs_hf(clip_qkv: bool, + qk_ln: bool, + attn_type: str, + seq_len: int, + device: str = 'cuda'): + # compare rope rotations for the dail vs hf implementations + def gen_rotary_embedding(rope_head_dim: int, pos_emb_config: dict, + max_seq_len: int): + if pos_emb_config['rope_imp'] == 'dail': + return DAILRotaryEmbedding( + dim=rope_head_dim, + base=pos_emb_config['rope_theta'], + interleaved=False, + scale_base=pos_emb_config['rope_dail_config']['xpos_scale_base'] + if (pos_emb_config['rope_dail_config']['type'] + == 'xpos') else None, + pos_idx_in_fp32=pos_emb_config['rope_dail_config'] + ['pos_idx_in_fp32'], + device= + 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu + ) + elif pos_emb_config['rope_imp'] == 'hf': + if pos_emb_config['rope_hf_config']['type'] == 'no_scaling': + return HFRotaryEmbedding( + rope_head_dim, + max_position_embeddings=max_seq_len, + base=pos_emb_config['rope_theta'], + device= + 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu + ) + elif pos_emb_config['rope_hf_config']['type'] == 'linear': + return HFLinearScalingRotaryEmbedding( + rope_head_dim, + max_position_embeddings=max_seq_len, + base=pos_emb_config['rope_theta'], + scaling_factor=pos_emb_config['rope_hf_config']['factor'], + device= + 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu + ) + elif pos_emb_config['rope_hf_config']['type'] == 'dynamic': + return HFDynamicNTKScalingRotaryEmbedding( + rope_head_dim, + max_position_embeddings=max_seq_len, + base=pos_emb_config['rope_theta'], + scaling_factor=pos_emb_config['rope_hf_config']['factor'], + device= + 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu + ) + else: + raise ValueError( + f'Invalid scaling type: {pos_emb_config["rope_hf_config"]["type"]}' + ) + else: + raise ValueError(f'Invalid rope_imp: {pos_emb_config["rope_imp"]}') + + from llmfoundry.models.layers import attention + + reproducibility.seed_all(7) + + cfg = om.create({ + 'attn_impl': 'flash', + 'd_model': 128, + 'n_heads': 4, + 'attn_pdrop': 0, + 'clip_qkv': clip_qkv, + 'qk_ln': qk_ln, + }) + + batch_size = 2 + assert cfg.d_model % cfg.n_heads == 0 + if attn_type == 'grouped_query_attention': + cfg.kv_n_heads = 2 + + attn0 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device) + attn1 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device) + + attn1.load_state_dict(attn0.state_dict()) + x0 = torch.randn(batch_size, seq_len, cfg.d_model).to(device) + x1 = x0.clone().detach() + x0.requires_grad = True + x1.requires_grad = True + attention_mask = torch.ones(batch_size, seq_len).to(device).bool() + + with get_precision_context('amp_bf16'): + dail_rope_config = { + 'rope_theta': 10000, + 'rope_imp': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + } + } + hf_rope_config = { + 'rope_theta': 10000, + 'rope_imp': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + } + } + + dail_rope = gen_rotary_embedding(rope_head_dim=cfg.d_model // + cfg.n_heads, + pos_emb_config=dail_rope_config, + max_seq_len=seq_len).to('cuda') + dail_rope_w_meta_info = { + 'imp': 'dail', + 'rotary_emb': dail_rope, + 'offset_info': 0, + 'seq_len': seq_len, + } + + hf_rope = gen_rotary_embedding(rope_head_dim=cfg.d_model // cfg.n_heads, + pos_emb_config=hf_rope_config, + max_seq_len=seq_len).to('cuda') + pos = torch.arange(seq_len).unsqueeze(0).to(device='cuda') + # adjust the position indices to account for padding tokens + pos = torch.clamp( + pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1), + min=0, + ) + hf_rope_w_meta_info = { + 'imp': 'hf', + 'rotary_emb': hf_rope, + 'offset_info': pos, + 'seq_len': seq_len, + } + + y0, _, _ = attn0(x0, + past_key_value=None, + attn_bias=None, + attention_mask=attention_mask, + rotary_emb_w_meta_info=dail_rope_w_meta_info, + is_causal=True) + + y1, _, _ = attn1(x1, + past_key_value=None, + attn_bias=None, + attention_mask=attention_mask, + rotary_emb_w_meta_info=hf_rope_w_meta_info, + is_causal=True) + + y0 *= attention_mask.unsqueeze(-1) + y1 *= attention_mask.unsqueeze(-1) + + loss0 = y0.sum() + loss1 = y1.sum() + + loss0.backward() + loss1.backward() + + torch.testing.assert_close(y0, y1, rtol=1e-2, atol=1e-2) + + torch_name_param_map = {n: p for n, p in attn1.named_parameters()} + for n, p in attn0.named_parameters(): + tp = torch_name_param_map[n] + assert p.grad is not None + assert tp.grad is not None + torch.testing.assert_close(p, tp, rtol=1e-2, atol=1e-2) + # Relaxed to a l2-norm based check. + assert torch.norm(tp.grad - p.grad) <= 1e-2 + 1e-2 * torch.norm(p.grad) + + assert x0.grad is not None + assert x1.grad is not None + # Relaxed to a l2-norm based check. + assert torch.norm(x0.grad - x1.grad) <= 1e-2 + 1e-2 * torch.norm(x0.grad) From f33ed5f61025aaa133a64cb49cdf47b855160c16 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 26 Oct 2023 03:04:46 +0000 Subject: [PATCH 066/106] .. --- tests/test_model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index 5d8b069fc1..69aa05a362 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -675,9 +675,10 @@ def test_forward_with_padding(attention_impl: str, device: str, right_padding_output[0, :3], left_padding_output[0, 3:], atol=1e-6 if attention_impl == 'torch' else 1e-8) - if not (alibi or rope): + if not (alibi or (rope and pos_emb_config['rope_imp'] == 'dail')): # check that right padding and middle padding produce the same output # Note: alibi not implemented for middle padding. + # Note: dail implementation of rope does not support middle padding. assert torch.allclose( right_padding_output[0, :3], middle_padding_output[0, [0, 1, 5]], @@ -695,9 +696,10 @@ def test_forward_with_padding(attention_impl: str, device: str, right_padding_output[0, :3], batched_output[0, :3], atol=1e-6 if attention_impl == 'torch' else 1e-8) - if not (alibi or rope): + if not (alibi or (rope and pos_emb_config['rope_imp'] == 'dail')): # check that middle padding and middle padding in a batch produce the same output # Note: alibi not implemented for middle padding. + # Note: dail implementation of rope does not support middle padding. assert torch.allclose( middle_padding_output[0], batched_output[1, :], From 7efb6b18506fa23f54ddd4242a7719cbfbeb8c7e Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 26 Oct 2023 21:38:07 +0000 Subject: [PATCH 067/106] fixed the tests after the merge --- tests/test_model.py | 3 +-- tests/test_rope_dail_vs_hf.py | 3 --- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index ef76396da2..db16ebb333 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -16,7 +16,7 @@ from composer.core.precision import Precision, get_precision_context from composer.optim import DecoupledAdamW from composer.trainer.dist_strategy import prepare_fsdp_module -from composer.utils import dist, get_device +from composer.utils import dist, get_device, reproducibility from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, @@ -1433,7 +1433,6 @@ def test_generation_kwargs_dont_crash(attn_impl: str, device: str, 'rope_imp'] == 'dail' and device == 'cpu': pytest.skip( f'dail implementation of rope is only implemented for gpus.') - reproducibility.seed_all(1234) composer_device = get_device(device) if device == 'gpu': # Switch deteminism off torch.use_deterministic_algorithms(False) diff --git a/tests/test_rope_dail_vs_hf.py b/tests/test_rope_dail_vs_hf.py index cda49f759f..9bd9fe5db0 100644 --- a/tests/test_rope_dail_vs_hf.py +++ b/tests/test_rope_dail_vs_hf.py @@ -4,7 +4,6 @@ import pytest import torch from composer.core.precision import get_precision_context -from composer.utils import reproducibility from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding from omegaconf import OmegaConf as om from transformers.models.llama.modeling_llama import \ @@ -79,8 +78,6 @@ def gen_rotary_embedding(rope_head_dim: int, pos_emb_config: dict, from llmfoundry.models.layers import attention - reproducibility.seed_all(7) - cfg = om.create({ 'attn_impl': 'flash', 'd_model': 128, From 3a056a8ed318f5864e86027ec2f5b8056fbe128f Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 26 Oct 2023 21:53:46 +0000 Subject: [PATCH 068/106] minor change --- scripts/inference/run_mpt_with_ft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/inference/run_mpt_with_ft.py b/scripts/inference/run_mpt_with_ft.py index 3b5bce6b3a..10ccf6b78b 100644 --- a/scripts/inference/run_mpt_with_ft.py +++ b/scripts/inference/run_mpt_with_ft.py @@ -280,7 +280,7 @@ def main(): shared_contexts_ratio = args.shared_contexts_ratio layernorm_eps = args.layernorm_eps use_attention_linear_bias = args.alibi - has_positional_encoding = not args.alibi # TODO: Should probably be: has_positional_encoding = not (args.alibi or args.rope) + has_positional_encoding = not args.alibi print('\n=================== Arguments ===================') for k, v in vars(args).items(): From 327ddeda5885f7d9d88977e19599bc01a17ea709 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 27 Oct 2023 00:05:43 +0000 Subject: [PATCH 069/106] Fixed some tests failing due to a transformers library bug --- llmfoundry/models/mpt/configuration_mpt.py | 11 +++++++++-- tests/test_flash_triton_torch.py | 9 +++++++++ tests/test_model.py | 10 ++++++++++ tests/test_rope_dail_vs_hf.py | 10 ++++++++++ 4 files changed, 38 insertions(+), 2 deletions(-) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 2295d1e034..f19434b56a 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -8,6 +8,8 @@ from transformers import PretrainedConfig +from llmfoundry.models.layers.attention import is_flash_v2_installed + attn_config_defaults: Dict = { 'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, @@ -243,12 +245,17 @@ def _validate_config(self) -> None: raise ValueError( 'If using hf implementation of rope, the type should be one of "no_scaling", "linear" or "dynamic".' ) + if self.attn_config['rope'] and (self.attn_config['rope_imp'] == 'dail' + ) and (not is_flash_v2_installed()): + raise NotImplementedError( + 'If using the dail implementation of rope, flash-attention 2 should be installed. Please install flash_attn==2.3.2`.' + ) if self.attn_config['rope'] and ( self.attn_config['rope_imp'] == 'dail') and (self.attn_config['rope_dail_config']['type'] not in ['original', 'xpos']): - raise NotImplementedError( - 'If using dail implementation of rope, the type should be one of "original" or "xpos".' + raise ValueError( + 'If using the dail implementation of rope, the type should be one of "original" or "xpos".' ) if self.embedding_fraction > 1 or self.embedding_fraction <= 0: raise ValueError( diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index 2441d0824a..3616bdbc12 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -3,6 +3,15 @@ import pytest import torch +import transformers + +from llmfoundry.models.layers.attention import is_flash_v1_installed + +# Before importing any transformers models, we need to disable transformers flash attention if +# we are in an environment with flash attention version <2. Transformers hard errors on a not properly +# gated import otherwise. +if is_flash_v1_installed(): + transformers.utils.is_flash_attn_available = lambda: False from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding from omegaconf import OmegaConf as om diff --git a/tests/test_model.py b/tests/test_model.py index db16ebb333..f687da440e 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -9,6 +9,16 @@ from typing import Any, Dict, Union, cast from unittest import mock +import transformers + +from llmfoundry.models.layers.attention import is_flash_v1_installed + +# Before importing any transformers models, we need to disable transformers flash attention if +# we are in an environment with flash attention version <2. Transformers hard errors on a not properly +# gated import otherwise. +if is_flash_v1_installed(): + transformers.utils.is_flash_attn_available = lambda: False + import pytest import torch import torch.nn as nn diff --git a/tests/test_rope_dail_vs_hf.py b/tests/test_rope_dail_vs_hf.py index 9bd9fe5db0..3708d7aa7c 100644 --- a/tests/test_rope_dail_vs_hf.py +++ b/tests/test_rope_dail_vs_hf.py @@ -3,7 +3,17 @@ import pytest import torch +import transformers from composer.core.precision import get_precision_context + +from llmfoundry.models.layers.attention import is_flash_v1_installed + +# Before importing any transformers models, we need to disable transformers flash attention if +# we are in an environment with flash attention version <2. Transformers hard errors on a not properly +# gated import otherwise. +if is_flash_v1_installed(): + transformers.utils.is_flash_attn_available = lambda: False + from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding from omegaconf import OmegaConf as om from transformers.models.llama.modeling_llama import \ From 9c00106f182982eff2238f94994864b50e002248 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 27 Oct 2023 00:26:36 +0000 Subject: [PATCH 070/106] added check for flash_attention before importing their rotary embedding --- llmfoundry/models/mpt/modeling_mpt.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index fb14161fe1..da8552a88a 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -23,7 +23,12 @@ from composer.metrics.nlp import LanguageCrossEntropy, LanguagePerplexity from composer.models import HuggingFaceModel from composer.utils import dist -from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding + +from llmfoundry.models.layers.attention import is_flash_v2_installed + +if is_flash_v2_installed(): + from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding + from omegaconf import DictConfig from omegaconf import OmegaConf as om from transformers import PreTrainedModel, PreTrainedTokenizerBase From 999209cd814cedc6c503ed68ffc2cca61997fc61 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 27 Oct 2023 00:59:07 +0000 Subject: [PATCH 071/106] added check for flash_attention in tests before using dail rope --- tests/test_flash_triton_torch.py | 11 +++++++++-- tests/test_model.py | 27 ++++++++++++++------------- tests/test_rope_dail_vs_hf.py | 11 +++++++++-- 3 files changed, 32 insertions(+), 17 deletions(-) diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index 3616bdbc12..5563455d69 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -5,7 +5,8 @@ import torch import transformers -from llmfoundry.models.layers.attention import is_flash_v1_installed +from llmfoundry.models.layers.attention import (is_flash_v1_installed, + is_flash_v2_installed) # Before importing any transformers models, we need to disable transformers flash attention if # we are in an environment with flash attention version <2. Transformers hard errors on a not properly @@ -13,7 +14,8 @@ if is_flash_v1_installed(): transformers.utils.is_flash_attn_available = lambda: False -from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding +if is_flash_v2_installed(): + from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding from omegaconf import OmegaConf as om from transformers.models.llama.modeling_llama import \ LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding @@ -91,6 +93,11 @@ def test_attn_impl(attn_impl_0: str, if alibi and (attn_impl_0 == 'flash' or attn_impl_1 == 'flash'): pytest.xfail('flash attn does not support alibi') + if rope and (pos_emb_config['rope_imp']=='dail') and (not is_flash_v2_installed()): + pytest.skip( + 'dail implementation of rope requires flash attention 2.' + ) + cfg = om.create({ 'attn_impl': 'flash', 'd_model': 128, diff --git a/tests/test_model.py b/tests/test_model.py index f687da440e..51ee9c1560 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -11,7 +11,8 @@ import transformers -from llmfoundry.models.layers.attention import is_flash_v1_installed +from llmfoundry.models.layers.attention import (is_flash_v1_installed, + is_flash_v2_installed) # Before importing any transformers models, we need to disable transformers flash attention if # we are in an environment with flash attention version <2. Transformers hard errors on a not properly @@ -579,9 +580,9 @@ def test_forward_with_padding(attention_impl: str, device: str, pytest.skip(f'alibi only implemented with torch and triton attention.') rope = pos_emb_config['rope'] - if rope and pos_emb_config['rope_imp'] == 'dail' and device == 'cpu': + if rope and pos_emb_config['rope_imp'] == 'dail' and (device == 'cpu' or not is_flash_v2_installed()): pytest.skip( - f'dail implementation of rope is only implemented for gpus.') + f'dail implementation of rope requires gpu and flash attention 2.') composer_device = get_device(device) @@ -815,9 +816,9 @@ def test_generate(attention_impl: str, device: str, pos_emb_config: dict): pytest.skip(f'alibi only implemented with torch and triton attention.') if pos_emb_config['rope'] and pos_emb_config[ - 'rope_imp'] == 'dail' and device == 'cpu': + 'rope_imp'] == 'dail' and (device == 'cpu' or not is_flash_v2_installed()): pytest.skip( - f'dail implementation of rope is only implemented for gpus.') + f'dail implementation of rope requires gpu and flash attention 2.') composer_device = get_device(device) @@ -1051,9 +1052,9 @@ def test_forward_with_cache_and_padding(attn_impl: str, device: str, if pos_emb_config['alibi'] and attn_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') if pos_emb_config['rope'] and pos_emb_config[ - 'rope_imp'] == 'dail' and device == 'cpu': + 'rope_imp'] == 'dail' and (device == 'cpu' or not is_flash_v2_installed()): pytest.skip( - f'dail implementation of rope is only implemented for gpus.') + f'dail implementation of rope requires gpu and flash attention 2.') composer_device = get_device(device) @@ -1200,9 +1201,9 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): pytest.skip(f'alibi only implemented with torch and triton attention.') if pos_emb_config['rope'] and pos_emb_config[ - 'rope_imp'] == 'dail' and device == 'cpu': + 'rope_imp'] == 'dail' and (device == 'cpu' or not is_flash_v2_installed()): pytest.skip( - f'dail implementation of rope is only implemented for gpus.') + f'dail implementation of rope requires gpu and flash attention 2.') composer_device = get_device(device) @@ -1440,9 +1441,9 @@ def test_generation_kwargs_dont_crash(attn_impl: str, device: str, pytest.skip(f'alibi only implemented with torch and triton attention.') if pos_emb_config['rope'] and pos_emb_config[ - 'rope_imp'] == 'dail' and device == 'cpu': + 'rope_imp'] == 'dail' and (device == 'cpu' or not is_flash_v2_installed()): pytest.skip( - f'dail implementation of rope is only implemented for gpus.') + f'dail implementation of rope requires gpu and flash attention 2.') composer_device = get_device(device) if device == 'gpu': # Switch deteminism off torch.use_deterministic_algorithms(False) @@ -1668,9 +1669,9 @@ def test_forward_with_output_attentions_and_output_hidden_states( if output_attentions and attn_impl in ['flash', 'triton']: pytest.skip(f'output_attentions only implemented with torch attention.') if pos_emb_config['rope'] and pos_emb_config[ - 'rope_imp'] == 'dail' and device == 'cpu': + 'rope_imp'] == 'dail' and (device == 'cpu' or not is_flash_v2_installed()): pytest.skip( - f'dail implementation of rope is only implemented for gpus.') + f'dail implementation of rope requires gpu and flash attention 2.') composer_device = get_device(device) diff --git a/tests/test_rope_dail_vs_hf.py b/tests/test_rope_dail_vs_hf.py index 3708d7aa7c..bf999219df 100644 --- a/tests/test_rope_dail_vs_hf.py +++ b/tests/test_rope_dail_vs_hf.py @@ -6,7 +6,8 @@ import transformers from composer.core.precision import get_precision_context -from llmfoundry.models.layers.attention import is_flash_v1_installed +from llmfoundry.models.layers.attention import (is_flash_v1_installed, + is_flash_v2_installed) # Before importing any transformers models, we need to disable transformers flash attention if # we are in an environment with flash attention version <2. Transformers hard errors on a not properly @@ -14,7 +15,8 @@ if is_flash_v1_installed(): transformers.utils.is_flash_attn_available = lambda: False -from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding +if is_flash_v2_installed(): + from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding from omegaconf import OmegaConf as om from transformers.models.llama.modeling_llama import \ LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding @@ -36,6 +38,11 @@ def test_rope_dail_vs_hf(clip_qkv: bool, attn_type: str, seq_len: int, device: str = 'cuda'): + if not is_flash_v2_installed(): + pytest.skip( + 'dail implementation of rope requires flash attention 2.' + ) + # compare rope rotations for the dail vs hf implementations def gen_rotary_embedding(rope_head_dim: int, pos_emb_config: dict, max_seq_len: int): From a681b645d97aef4eb1f0579cba053258f72b020f Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 27 Oct 2023 01:45:15 +0000 Subject: [PATCH 072/106] fixed tests --- llmfoundry/models/layers/attention.py | 7 ++++++- tests/test_flash_triton_torch.py | 10 +--------- tests/test_model.py | 11 +---------- tests/test_rope_dail_vs_hf.py | 10 +--------- 4 files changed, 9 insertions(+), 29 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index f10d2c1f64..708e953be9 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -12,7 +12,6 @@ from einops import rearrange from packaging import version from torch import nn -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY @@ -33,6 +32,12 @@ def is_flash_v1_installed(): return False return version.parse(flash_attn.__version__) < version.parse('2.0.0') +if is_flash_v1_installed(): + import transformers + transformers.utils.is_flash_attn_available = lambda: False + +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb + def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool) -> bool: diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index 5563455d69..ccabf616d2 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -3,16 +3,8 @@ import pytest import torch -import transformers -from llmfoundry.models.layers.attention import (is_flash_v1_installed, - is_flash_v2_installed) - -# Before importing any transformers models, we need to disable transformers flash attention if -# we are in an environment with flash attention version <2. Transformers hard errors on a not properly -# gated import otherwise. -if is_flash_v1_installed(): - transformers.utils.is_flash_attn_available = lambda: False +from llmfoundry.models.layers.attention import is_flash_v2_installed if is_flash_v2_installed(): from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding diff --git a/tests/test_model.py b/tests/test_model.py index 51ee9c1560..edc8bd4c95 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -9,16 +9,7 @@ from typing import Any, Dict, Union, cast from unittest import mock -import transformers - -from llmfoundry.models.layers.attention import (is_flash_v1_installed, - is_flash_v2_installed) - -# Before importing any transformers models, we need to disable transformers flash attention if -# we are in an environment with flash attention version <2. Transformers hard errors on a not properly -# gated import otherwise. -if is_flash_v1_installed(): - transformers.utils.is_flash_attn_available = lambda: False +from llmfoundry.models.layers.attention import is_flash_v2_installed import pytest import torch diff --git a/tests/test_rope_dail_vs_hf.py b/tests/test_rope_dail_vs_hf.py index bf999219df..51689bfc75 100644 --- a/tests/test_rope_dail_vs_hf.py +++ b/tests/test_rope_dail_vs_hf.py @@ -3,17 +3,9 @@ import pytest import torch -import transformers from composer.core.precision import get_precision_context -from llmfoundry.models.layers.attention import (is_flash_v1_installed, - is_flash_v2_installed) - -# Before importing any transformers models, we need to disable transformers flash attention if -# we are in an environment with flash attention version <2. Transformers hard errors on a not properly -# gated import otherwise. -if is_flash_v1_installed(): - transformers.utils.is_flash_attn_available = lambda: False +from llmfoundry.models.layers.attention import is_flash_v2_installed if is_flash_v2_installed(): from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding From 766fa753aa7f28b1de777cbae1abf15e23ee3bde Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 27 Oct 2023 02:13:45 +0000 Subject: [PATCH 073/106] .. --- tests/test_model.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index edc8bd4c95..26bebc909b 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -9,8 +9,6 @@ from typing import Any, Dict, Union, cast from unittest import mock -from llmfoundry.models.layers.attention import is_flash_v2_installed - import pytest import torch import torch.nn as nn @@ -31,6 +29,7 @@ ComposerHFPrefixLM) from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss from llmfoundry.models.layers import NORM_CLASS_REGISTRY, build_alibi_bias +from llmfoundry.models.layers.attention import is_flash_v2_installed from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM from llmfoundry.utils import build_tokenizer @@ -571,7 +570,8 @@ def test_forward_with_padding(attention_impl: str, device: str, pytest.skip(f'alibi only implemented with torch and triton attention.') rope = pos_emb_config['rope'] - if rope and pos_emb_config['rope_imp'] == 'dail' and (device == 'cpu' or not is_flash_v2_installed()): + if rope and pos_emb_config['rope_imp'] == 'dail' and ( + device == 'cpu' or not is_flash_v2_installed()): pytest.skip( f'dail implementation of rope requires gpu and flash attention 2.') @@ -806,8 +806,8 @@ def test_generate(attention_impl: str, device: str, pos_emb_config: dict): if pos_emb_config['alibi'] and attention_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') - if pos_emb_config['rope'] and pos_emb_config[ - 'rope_imp'] == 'dail' and (device == 'cpu' or not is_flash_v2_installed()): + if pos_emb_config['rope'] and pos_emb_config['rope_imp'] == 'dail' and ( + device == 'cpu' or not is_flash_v2_installed()): pytest.skip( f'dail implementation of rope requires gpu and flash attention 2.') @@ -1042,8 +1042,8 @@ def test_forward_with_cache_and_padding(attn_impl: str, device: str, ) if pos_emb_config['alibi'] and attn_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') - if pos_emb_config['rope'] and pos_emb_config[ - 'rope_imp'] == 'dail' and (device == 'cpu' or not is_flash_v2_installed()): + if pos_emb_config['rope'] and pos_emb_config['rope_imp'] == 'dail' and ( + device == 'cpu' or not is_flash_v2_installed()): pytest.skip( f'dail implementation of rope requires gpu and flash attention 2.') @@ -1191,8 +1191,8 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): if pos_emb_config['alibi'] and attn_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') - if pos_emb_config['rope'] and pos_emb_config[ - 'rope_imp'] == 'dail' and (device == 'cpu' or not is_flash_v2_installed()): + if pos_emb_config['rope'] and pos_emb_config['rope_imp'] == 'dail' and ( + device == 'cpu' or not is_flash_v2_installed()): pytest.skip( f'dail implementation of rope requires gpu and flash attention 2.') @@ -1428,11 +1428,15 @@ def test_generate_with_past_kv(pos_emb_config: dict): def test_generation_kwargs_dont_crash(attn_impl: str, device: str, generation_kwargs: Dict[str, Any], pos_emb_config: dict): + if not torch.cuda.is_available() and device == 'gpu': + pytest.skip( + f'This test requires CUDA to be available in order to run with {attn_impl} attention.' + ) if pos_emb_config['alibi'] and attn_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') - if pos_emb_config['rope'] and pos_emb_config[ - 'rope_imp'] == 'dail' and (device == 'cpu' or not is_flash_v2_installed()): + if pos_emb_config['rope'] and pos_emb_config['rope_imp'] == 'dail' and ( + device == 'cpu' or not is_flash_v2_installed()): pytest.skip( f'dail implementation of rope requires gpu and flash attention 2.') composer_device = get_device(device) @@ -1659,8 +1663,8 @@ def test_forward_with_output_attentions_and_output_hidden_states( pytest.skip(f'alibi only implemented with torch and triton attention.') if output_attentions and attn_impl in ['flash', 'triton']: pytest.skip(f'output_attentions only implemented with torch attention.') - if pos_emb_config['rope'] and pos_emb_config[ - 'rope_imp'] == 'dail' and (device == 'cpu' or not is_flash_v2_installed()): + if pos_emb_config['rope'] and pos_emb_config['rope_imp'] == 'dail' and ( + device == 'cpu' or not is_flash_v2_installed()): pytest.skip( f'dail implementation of rope requires gpu and flash attention 2.') From 21a4f31f27d5df51029c3bf5da2e8616dfdc5010 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 27 Oct 2023 03:05:51 +0000 Subject: [PATCH 074/106] .. --- tests/test_model.py | 79 +++++++++++++++++++++++++++++---------------- 1 file changed, 52 insertions(+), 27 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index 26bebc909b..5aec05bf22 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -571,7 +571,7 @@ def test_forward_with_padding(attention_impl: str, device: str, rope = pos_emb_config['rope'] if rope and pos_emb_config['rope_imp'] == 'dail' and ( - device == 'cpu' or not is_flash_v2_installed()): + device != 'gpu' or not is_flash_v2_installed()): pytest.skip( f'dail implementation of rope requires gpu and flash attention 2.') @@ -807,7 +807,7 @@ def test_generate(attention_impl: str, device: str, pos_emb_config: dict): pytest.skip(f'alibi only implemented with torch and triton attention.') if pos_emb_config['rope'] and pos_emb_config['rope_imp'] == 'dail' and ( - device == 'cpu' or not is_flash_v2_installed()): + device != 'gpu' or not is_flash_v2_installed()): pytest.skip( f'dail implementation of rope requires gpu and flash attention 2.') @@ -1043,7 +1043,7 @@ def test_forward_with_cache_and_padding(attn_impl: str, device: str, if pos_emb_config['alibi'] and attn_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') if pos_emb_config['rope'] and pos_emb_config['rope_imp'] == 'dail' and ( - device == 'cpu' or not is_flash_v2_installed()): + device != 'gpu' or not is_flash_v2_installed()): pytest.skip( f'dail implementation of rope requires gpu and flash attention 2.') @@ -1192,7 +1192,7 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): pytest.skip(f'alibi only implemented with torch and triton attention.') if pos_emb_config['rope'] and pos_emb_config['rope_imp'] == 'dail' and ( - device == 'cpu' or not is_flash_v2_installed()): + device != 'gpu' or not is_flash_v2_installed()): pytest.skip( f'dail implementation of rope requires gpu and flash attention 2.') @@ -1290,6 +1290,12 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): ) +@pytest.mark.parametrize('attn_impl,device', [ + ('torch', 'cpu'), + ('flash', 'gpu'), + ('triton', 'gpu'), + ('torch', 'gpu'), +]) @pytest.mark.parametrize('pos_emb_config', [{ 'alibi': False, 'rope': False @@ -1325,7 +1331,20 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): 'factor': 1.0, }, }]) -def test_generate_with_past_kv(pos_emb_config: dict): +def test_generate_with_past_kv(attn_impl: str, device: str, pos_emb_config: dict): + if not torch.cuda.is_available() and device == 'gpu': + pytest.skip( + f'This test requires CUDA to be available in order to run with {attn_impl} attention.' + ) + if pos_emb_config['alibi'] and attn_impl == 'flash': + pytest.skip(f'alibi only implemented with torch and triton attention.') + if pos_emb_config['rope'] and pos_emb_config['rope_imp'] == 'dail' and ( + device != 'gpu' or not is_flash_v2_installed()): + pytest.skip( + f'dail implementation of rope requires gpu and flash attention 2.') + + composer_device = get_device(device) + hf_config = MPTConfig( init_device='cpu', d_model=128, @@ -1336,7 +1355,7 @@ def test_generate_with_past_kv(pos_emb_config: dict): emb_pdrop=0.1, resid_pdrop=0.2, attn_config={ - 'attn_impl': 'torch', + 'attn_impl': attn_impl, **pos_emb_config, }, use_cache=True, @@ -1346,31 +1365,37 @@ def test_generate_with_past_kv(pos_emb_config: dict): }, ) mpt = MPTForCausalLM(hf_config) + mpt = composer_device.module_to_device(mpt) mpt.eval() # no padding in the input no_padding_input_ids = torch.tensor([[11274, 16390, 11]]) + no_padding_input_ids = composer_device.tensor_to_device( + no_padding_input_ids) no_padding_attention_mask = torch.tensor([[1, 1, 1]]) + no_padding_attention_mask = composer_device.tensor_to_device(no_padding_attention_mask) - with mock.patch.object(MPTForCausalLM, 'forward', - autospec=True) as forward_mocked: - forward_mocked.return_value = CausalLMOutputWithPast( - logits=torch.randn((1, 3, hf_config.vocab_size)), - past_key_values=[(torch.randn(1, 3, hf_config.d_model), - torch.randn(1, 3, hf_config.d_model)) - for _ in range(hf_config.n_layers)]) - _ = mpt.generate(input_ids=no_padding_input_ids, - attention_mask=no_padding_attention_mask, - max_new_tokens=2) - - assert forward_mocked.call_count == 2 - _, _, kwargs = forward_mocked.mock_calls[0] - assert kwargs['past_key_values'] is None - _, _, kwargs = forward_mocked.mock_calls[1] - assert kwargs['past_key_values'] is not None - assert len(kwargs['past_key_values']) == hf_config.n_layers - assert kwargs['past_key_values'][0][0].shape == (1, 3, - hf_config.d_model) + with get_precision_context('amp_bf16' if composer_device.name == + 'gpu' else 'fp32'): + with mock.patch.object(MPTForCausalLM, 'forward', + autospec=True) as forward_mocked: + forward_mocked.return_value = CausalLMOutputWithPast( + logits=torch.randn((1, 3, hf_config.vocab_size)), + past_key_values=[(torch.randn(1, 3, hf_config.d_model), + torch.randn(1, 3, hf_config.d_model)) + for _ in range(hf_config.n_layers)]) + _ = mpt.generate(input_ids=no_padding_input_ids, + attention_mask=no_padding_attention_mask, + max_new_tokens=2) + + assert forward_mocked.call_count == 2 + _, _, kwargs = forward_mocked.mock_calls[0] + assert kwargs['past_key_values'] is None + _, _, kwargs = forward_mocked.mock_calls[1] + assert kwargs['past_key_values'] is not None + assert len(kwargs['past_key_values']) == hf_config.n_layers + assert kwargs['past_key_values'][0][0].shape == (1, 3, + hf_config.d_model) @pytest.mark.parametrize('attn_impl,device', [ @@ -1436,7 +1461,7 @@ def test_generation_kwargs_dont_crash(attn_impl: str, device: str, pytest.skip(f'alibi only implemented with torch and triton attention.') if pos_emb_config['rope'] and pos_emb_config['rope_imp'] == 'dail' and ( - device == 'cpu' or not is_flash_v2_installed()): + device != 'gpu' or not is_flash_v2_installed()): pytest.skip( f'dail implementation of rope requires gpu and flash attention 2.') composer_device = get_device(device) @@ -1664,7 +1689,7 @@ def test_forward_with_output_attentions_and_output_hidden_states( if output_attentions and attn_impl in ['flash', 'triton']: pytest.skip(f'output_attentions only implemented with torch attention.') if pos_emb_config['rope'] and pos_emb_config['rope_imp'] == 'dail' and ( - device == 'cpu' or not is_flash_v2_installed()): + device != 'gpu' or not is_flash_v2_installed()): pytest.skip( f'dail implementation of rope requires gpu and flash attention 2.') From dbac1e0fb7f043529fdc83915ee77b89da1b40b2 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 27 Oct 2023 15:04:42 +0000 Subject: [PATCH 075/106] temporary fix --- llmfoundry/models/mpt/configuration_mpt.py | 27 +++++++++++----------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index f19434b56a..1bbc832b64 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -8,8 +8,6 @@ from transformers import PretrainedConfig -from llmfoundry.models.layers.attention import is_flash_v2_installed - attn_config_defaults: Dict = { 'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, @@ -245,18 +243,19 @@ def _validate_config(self) -> None: raise ValueError( 'If using hf implementation of rope, the type should be one of "no_scaling", "linear" or "dynamic".' ) - if self.attn_config['rope'] and (self.attn_config['rope_imp'] == 'dail' - ) and (not is_flash_v2_installed()): - raise NotImplementedError( - 'If using the dail implementation of rope, flash-attention 2 should be installed. Please install flash_attn==2.3.2`.' - ) - if self.attn_config['rope'] and ( - self.attn_config['rope_imp'] - == 'dail') and (self.attn_config['rope_dail_config']['type'] - not in ['original', 'xpos']): - raise ValueError( - 'If using the dail implementation of rope, the type should be one of "original" or "xpos".' - ) + if self.attn_config['rope'] and (self.attn_config['rope_imp'] + == 'dail'): + from llmfoundry.models.layers.attention import is_flash_v2_installed + if not is_flash_v2_installed(): + raise NotImplementedError( + 'If using the dail implementation of rope, flash-attention 2 should be installed. Please install flash_attn==2.3.2`.' + ) + if self.attn_config['rope_dail_config']['type'] not in [ + 'original', 'xpos' + ]: + raise ValueError( + 'If using the dail implementation of rope, the type should be one of "original" or "xpos".' + ) if self.embedding_fraction > 1 or self.embedding_fraction <= 0: raise ValueError( 'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!' From 5d62dfe6e22134b1c04a225cbae4e79efa170c24 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 27 Oct 2023 16:00:09 +0000 Subject: [PATCH 076/106] .. --- tests/test_model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_model.py b/tests/test_model.py index 5aec05bf22..6e07368a79 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1548,6 +1548,10 @@ def test_model_to(attention_impl: str, pos_emb_config: dict): ) if pos_emb_config['alibi'] and attention_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') + + if pos_emb_config['rope'] and pos_emb_config['rope_imp'] == 'dail' and not is_flash_v2_installed(): + pytest.skip( + f'dail implementation of rope requires flash attention 2.') hf_config = MPTConfig( init_device='cpu', From ca57151a06b267d1b671e0a67d31b7fb99f5c2d8 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 27 Oct 2023 16:47:15 +0000 Subject: [PATCH 077/106] .. --- llmfoundry/models/mpt/configuration_mpt.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 1bbc832b64..7988b607c2 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -243,19 +243,13 @@ def _validate_config(self) -> None: raise ValueError( 'If using hf implementation of rope, the type should be one of "no_scaling", "linear" or "dynamic".' ) - if self.attn_config['rope'] and (self.attn_config['rope_imp'] - == 'dail'): - from llmfoundry.models.layers.attention import is_flash_v2_installed - if not is_flash_v2_installed(): - raise NotImplementedError( - 'If using the dail implementation of rope, flash-attention 2 should be installed. Please install flash_attn==2.3.2`.' - ) - if self.attn_config['rope_dail_config']['type'] not in [ - 'original', 'xpos' - ]: - raise ValueError( - 'If using the dail implementation of rope, the type should be one of "original" or "xpos".' - ) + if self.attn_config['rope'] and ( + self.attn_config['rope_imp'] + == 'dail') and (self.attn_config['rope_dail_config']['type'] + not in ['original', 'xpos']): + raise ValueError( + 'If using the dail implementation of rope, the type should be one of "original" or "xpos".' + ) if self.embedding_fraction > 1 or self.embedding_fraction <= 0: raise ValueError( 'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!' From 99a81a106431a98235999f48f18291a2dd6868fe Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 27 Oct 2023 19:57:02 +0000 Subject: [PATCH 078/106] fixed a test --- llmfoundry/models/mpt/modeling_mpt.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index da8552a88a..cc8a8213a6 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -27,7 +27,10 @@ from llmfoundry.models.layers.attention import is_flash_v2_installed if is_flash_v2_installed(): - from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding + try: # This try...except is needed because transformers requires it despite the 'if' statement above + from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding + except Exception as e: + raise e from omegaconf import DictConfig from omegaconf import OmegaConf as om From b674e83ce3cffd546a4bf5b71638ff921679cc5e Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 27 Oct 2023 21:16:26 +0000 Subject: [PATCH 079/106] .. --- llmfoundry/models/layers/attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 708e953be9..419765323c 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -32,6 +32,7 @@ def is_flash_v1_installed(): return False return version.parse(flash_attn.__version__) < version.parse('2.0.0') + if is_flash_v1_installed(): import transformers transformers.utils.is_flash_attn_available = lambda: False From 8be09ab001cc451d566a7a4ad6871d4e8e0a12ea Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 27 Oct 2023 23:10:24 +0000 Subject: [PATCH 080/106] minor change --- tests/test_model.py | 38 +++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index 6e07368a79..d39a0b740a 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1331,7 +1331,8 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): 'factor': 1.0, }, }]) -def test_generate_with_past_kv(attn_impl: str, device: str, pos_emb_config: dict): +def test_generate_with_past_kv(attn_impl: str, device: str, + pos_emb_config: dict): if not torch.cuda.is_available() and device == 'gpu': pytest.skip( f'This test requires CUDA to be available in order to run with {attn_impl} attention.' @@ -1341,10 +1342,10 @@ def test_generate_with_past_kv(attn_impl: str, device: str, pos_emb_config: dict if pos_emb_config['rope'] and pos_emb_config['rope_imp'] == 'dail' and ( device != 'gpu' or not is_flash_v2_installed()): pytest.skip( - f'dail implementation of rope requires gpu and flash attention 2.') + f'dail implementation of rope requires gpu and flash attention 2.') composer_device = get_device(device) - + hf_config = MPTConfig( init_device='cpu', d_model=128, @@ -1373,20 +1374,21 @@ def test_generate_with_past_kv(attn_impl: str, device: str, pos_emb_config: dict no_padding_input_ids = composer_device.tensor_to_device( no_padding_input_ids) no_padding_attention_mask = torch.tensor([[1, 1, 1]]) - no_padding_attention_mask = composer_device.tensor_to_device(no_padding_attention_mask) + no_padding_attention_mask = composer_device.tensor_to_device( + no_padding_attention_mask) with get_precision_context('amp_bf16' if composer_device.name == - 'gpu' else 'fp32'): + 'gpu' else 'fp32'): with mock.patch.object(MPTForCausalLM, 'forward', - autospec=True) as forward_mocked: + autospec=True) as forward_mocked: forward_mocked.return_value = CausalLMOutputWithPast( logits=torch.randn((1, 3, hf_config.vocab_size)), past_key_values=[(torch.randn(1, 3, hf_config.d_model), - torch.randn(1, 3, hf_config.d_model)) - for _ in range(hf_config.n_layers)]) + torch.randn(1, 3, hf_config.d_model)) + for _ in range(hf_config.n_layers)]) _ = mpt.generate(input_ids=no_padding_input_ids, - attention_mask=no_padding_attention_mask, - max_new_tokens=2) + attention_mask=no_padding_attention_mask, + max_new_tokens=2) assert forward_mocked.call_count == 2 _, _, kwargs = forward_mocked.mock_calls[0] @@ -1395,7 +1397,7 @@ def test_generate_with_past_kv(attn_impl: str, device: str, pos_emb_config: dict assert kwargs['past_key_values'] is not None assert len(kwargs['past_key_values']) == hf_config.n_layers assert kwargs['past_key_values'][0][0].shape == (1, 3, - hf_config.d_model) + hf_config.d_model) @pytest.mark.parametrize('attn_impl,device', [ @@ -1548,10 +1550,10 @@ def test_model_to(attention_impl: str, pos_emb_config: dict): ) if pos_emb_config['alibi'] and attention_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') - - if pos_emb_config['rope'] and pos_emb_config['rope_imp'] == 'dail' and not is_flash_v2_installed(): - pytest.skip( - f'dail implementation of rope requires flash attention 2.') + + if pos_emb_config['rope'] and pos_emb_config[ + 'rope_imp'] == 'dail' and not is_flash_v2_installed(): + pytest.skip(f'dail implementation of rope requires flash attention 2.') hf_config = MPTConfig( init_device='cpu', @@ -1589,7 +1591,8 @@ def test_model_to(attention_impl: str, pos_emb_config: dict): mpt = mpt.to('cpu') # verify the model still works - if attention_impl == 'torch' and (not pos_emb_config['rope']): + if attention_impl == 'torch' and not (pos_emb_config['rope'] and + pos_emb_config['rope_imp'] == 'dail'): with torch.autocast('cpu', dtype=torch.bfloat16, enabled=True): _ = mpt(input_ids.to('cpu'), attention_mask=attention_mask.to('cpu')) @@ -1606,7 +1609,8 @@ def test_model_to(attention_impl: str, pos_emb_config: dict): mpt = mpt.float() # verify the model still works - if attention_impl == 'torch' and (not pos_emb_config['rope']): + if attention_impl == 'torch' and not (pos_emb_config['rope'] and + pos_emb_config['rope_imp'] == 'dail'): _ = mpt(input_ids.to('cpu'), attention_mask=attention_mask.to('cpu')) mpt = mpt.half() From 067439e7da72b84a29831b08d800adb9714e2495 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 27 Oct 2023 23:52:58 +0000 Subject: [PATCH 081/106] minor changes --- llmfoundry/models/layers/attention.py | 2 + tests/test_flash_triton_torch.py | 68 ++--------------- tests/test_rope_dail_vs_hf.py | 104 +++++++++++++------------- 3 files changed, 58 insertions(+), 116 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 419765323c..3263c709a9 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -611,10 +611,12 @@ def forward( self.kv_n_heads * self.head_dim) elif rotary_emb_w_meta_info['imp'] == 'hf': (cos, sin) = rotary_emb(value, seq_len) + # The following two transposes should be removed once the transformers library allows for the specification of the dimension for heads in the call to apply_rotary_pos_emb query = query.transpose(1, 2) key = key.transpose(1, 2) query, key = apply_rotary_pos_emb(query, key, cos, sin, offset_info) + # The following two transposes should be removed once the transformers library allows for the specification of the dimension for heads in the call to apply_rotary_pos_emb query = query.transpose(1, 2) key = key.transpose(1, 2) diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index ccabf616d2..c5768b5d0c 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -3,18 +3,10 @@ import pytest import torch +from omegaconf import OmegaConf as om from llmfoundry.models.layers.attention import is_flash_v2_installed - -if is_flash_v2_installed(): - from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding -from omegaconf import OmegaConf as om -from transformers.models.llama.modeling_llama import \ - LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding -from transformers.models.llama.modeling_llama import \ - LlamaLinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding -from transformers.models.llama.modeling_llama import \ - LlamaRotaryEmbedding as HFRotaryEmbedding +from tests.test_rope_dail_vs_hf import gen_rotary_embedding def allclose_helper(t0: torch.Tensor, @@ -85,10 +77,9 @@ def test_attn_impl(attn_impl_0: str, if alibi and (attn_impl_0 == 'flash' or attn_impl_1 == 'flash'): pytest.xfail('flash attn does not support alibi') - if rope and (pos_emb_config['rope_imp']=='dail') and (not is_flash_v2_installed()): - pytest.skip( - 'dail implementation of rope requires flash attention 2.' - ) + if rope and (pos_emb_config['rope_imp'] + == 'dail') and (not is_flash_v2_installed()): + pytest.skip('dail implementation of rope requires flash attention 2.') cfg = om.create({ 'attn_impl': 'flash', @@ -136,55 +127,6 @@ def gen_bias(attn_impl: str): return attn_bias - def gen_rotary_embedding(rope_head_dim: int, pos_emb_config: dict, - max_seq_len: int): - if pos_emb_config['rope_imp'] == 'dail': - return DAILRotaryEmbedding( - dim=rope_head_dim, - base=pos_emb_config['rope_theta'], - interleaved=False, - scale_base=pos_emb_config['rope_dail_config']['xpos_scale_base'] - if (pos_emb_config['rope_dail_config']['type'] - == 'xpos') else None, - pos_idx_in_fp32=pos_emb_config['rope_dail_config'] - ['pos_idx_in_fp32'], - device= - 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - elif pos_emb_config['rope_imp'] == 'hf': - if pos_emb_config['rope_hf_config']['type'] == 'no_scaling': - return HFRotaryEmbedding( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=pos_emb_config['rope_theta'], - device= - 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - elif pos_emb_config['rope_hf_config']['type'] == 'linear': - return HFLinearScalingRotaryEmbedding( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=pos_emb_config['rope_theta'], - scaling_factor=pos_emb_config['rope_hf_config']['factor'], - device= - 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - elif pos_emb_config['rope_hf_config']['type'] == 'dynamic': - return HFDynamicNTKScalingRotaryEmbedding( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=pos_emb_config['rope_theta'], - scaling_factor=pos_emb_config['rope_hf_config']['factor'], - device= - 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - else: - raise ValueError( - f'Invalid scaling type: {pos_emb_config["rope_hf_config"]["type"]}' - ) - else: - raise ValueError(f'Invalid rope_imp: {pos_emb_config["rope_imp"]}') - x0 = torch.randn(n, s, f).to(device) x1 = x0.clone().detach() x0.requires_grad = True diff --git a/tests/test_rope_dail_vs_hf.py b/tests/test_rope_dail_vs_hf.py index 51689bfc75..76d5cd50b0 100644 --- a/tests/test_rope_dail_vs_hf.py +++ b/tests/test_rope_dail_vs_hf.py @@ -18,6 +18,55 @@ LlamaRotaryEmbedding as HFRotaryEmbedding +def gen_rotary_embedding(rope_head_dim: int, pos_emb_config: dict, + max_seq_len: int): + if pos_emb_config['rope_imp'] == 'dail': + return DAILRotaryEmbedding( + dim=rope_head_dim, + base=pos_emb_config['rope_theta'], + interleaved=False, + scale_base=pos_emb_config['rope_dail_config']['xpos_scale_base'] if + (pos_emb_config['rope_dail_config']['type'] == 'xpos') else None, + pos_idx_in_fp32=pos_emb_config['rope_dail_config'] + ['pos_idx_in_fp32'], + device= + 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu + ) + elif pos_emb_config['rope_imp'] == 'hf': + if pos_emb_config['rope_hf_config']['type'] == 'no_scaling': + return HFRotaryEmbedding( + rope_head_dim, + max_position_embeddings=max_seq_len, + base=pos_emb_config['rope_theta'], + device= + 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu + ) + elif pos_emb_config['rope_hf_config']['type'] == 'linear': + return HFLinearScalingRotaryEmbedding( + rope_head_dim, + max_position_embeddings=max_seq_len, + base=pos_emb_config['rope_theta'], + scaling_factor=pos_emb_config['rope_hf_config']['factor'], + device= + 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu + ) + elif pos_emb_config['rope_hf_config']['type'] == 'dynamic': + return HFDynamicNTKScalingRotaryEmbedding( + rope_head_dim, + max_position_embeddings=max_seq_len, + base=pos_emb_config['rope_theta'], + scaling_factor=pos_emb_config['rope_hf_config']['factor'], + device= + 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu + ) + else: + raise ValueError( + f'Invalid scaling type: {pos_emb_config["rope_hf_config"]["type"]}' + ) + else: + raise ValueError(f'Invalid rope_imp: {pos_emb_config["rope_imp"]}') + + @pytest.mark.gpu @pytest.mark.parametrize('clip_qkv', [True, False]) @pytest.mark.parametrize('qk_ln', [True, False]) @@ -30,60 +79,9 @@ def test_rope_dail_vs_hf(clip_qkv: bool, attn_type: str, seq_len: int, device: str = 'cuda'): - if not is_flash_v2_installed(): - pytest.skip( - 'dail implementation of rope requires flash attention 2.' - ) - # compare rope rotations for the dail vs hf implementations - def gen_rotary_embedding(rope_head_dim: int, pos_emb_config: dict, - max_seq_len: int): - if pos_emb_config['rope_imp'] == 'dail': - return DAILRotaryEmbedding( - dim=rope_head_dim, - base=pos_emb_config['rope_theta'], - interleaved=False, - scale_base=pos_emb_config['rope_dail_config']['xpos_scale_base'] - if (pos_emb_config['rope_dail_config']['type'] - == 'xpos') else None, - pos_idx_in_fp32=pos_emb_config['rope_dail_config'] - ['pos_idx_in_fp32'], - device= - 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - elif pos_emb_config['rope_imp'] == 'hf': - if pos_emb_config['rope_hf_config']['type'] == 'no_scaling': - return HFRotaryEmbedding( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=pos_emb_config['rope_theta'], - device= - 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - elif pos_emb_config['rope_hf_config']['type'] == 'linear': - return HFLinearScalingRotaryEmbedding( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=pos_emb_config['rope_theta'], - scaling_factor=pos_emb_config['rope_hf_config']['factor'], - device= - 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - elif pos_emb_config['rope_hf_config']['type'] == 'dynamic': - return HFDynamicNTKScalingRotaryEmbedding( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=pos_emb_config['rope_theta'], - scaling_factor=pos_emb_config['rope_hf_config']['factor'], - device= - 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - else: - raise ValueError( - f'Invalid scaling type: {pos_emb_config["rope_hf_config"]["type"]}' - ) - else: - raise ValueError(f'Invalid rope_imp: {pos_emb_config["rope_imp"]}') + if not is_flash_v2_installed(): + pytest.skip('dail implementation of rope requires flash attention 2.') from llmfoundry.models.layers import attention From b325097675d404f95c1475efd177f1348958aab4 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Mon, 30 Oct 2023 23:27:29 +0000 Subject: [PATCH 082/106] added documentation --- TUTORIAL.md | 45 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/TUTORIAL.md b/TUTORIAL.md index 36993bc409..2f861c4459 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -8,27 +8,41 @@ Forging LLMs can be quite complicated — you have to get your data prepared, se This tutorial will provide a brief intro to the repo’s structure and underlying tools (all courtesy of MosaicML, of course), will go over a few example workflows and point you to the related resources within the repo, and will finally cover a number of FAQs that we have encountered since release. +- [LLM Foundry Tutorial](#llm-foundry-tutorial) - [Intro](#intro) - [How this repo is structured](#how-this-repo-is-structured) - [Key components](#key-components) + - [Composer](#composer) + - [StreamingDataset](#streamingdataset) + - [MCLI](#mcli) - [How the YAMLs work](#how-the-yamls-work) - [Example Workflows](#example-workflows) - [Workflow 1: I want to play with a HF model like MPT-7B locally](#workflow-1-i-want-to-play-with-a-hf-model-like-mpt-7b-locally) - [Workflow 2: I want to deploy an inference endpoint with a HF model like MPT-7B](#workflow-2-i-want-to-deploy-an-inference-endpoint-with-a-hf-model-like-mpt-7b) - [Workflow 3: I want to finetune a HF model like MPT-7B](#workflow-3-i-want-to-finetune-a-hf-model-like-mpt-7b) + - [Supervised FineTuning and Instruction FineTuning](#supervised-finetuning-and-instruction-finetuning) + - [Domain Adaptation and Sequence Length Adaptation](#domain-adaptation-and-sequence-length-adaptation) + - [Data](#data) + - [Modeling](#modeling) - [Workflow 4: I want to train a new HF model from scratch](#workflow-4-i-want-to-train-a-new-hf-model-from-scratch) - [FAQs](#faqs) - - [Why is the script only using 1 out of N GPUs?](#why-is-the-script-only-using-1-out-of-n-gpus) - - [I’m running into an Out-Of-Memory (OOM) error. What do I do?](#im-running-into-an-out-of-memory-oom-error-what-do-i-do) - - [What hardware can I train on?](#what-hardware-can-i-train-on) - - [What hardware can I run eval on?](#what-hardware-can-i-run-eval-on) - - [What is FSDP?](#what-is-fsdp) - - [What are the different attention options `torch` / `flash` / `triton` for MPT and which one should I use?](#what-are-the-different-attention-options-torch--flash--triton-for-mpt-and-which-one-should-i-use) - - [Can I finetune using PEFT / LORA?](#can-i-finetune-using-peft--lora) - - [Can I quantize these models and/or run on CPU?](#can-i-quantize-these-models-andor-run-on-cpu) - - [How do I deploy with ONNX/FasterTransformer?](#how-do-i-deploy-with-onnxfastertransformer) - - [How expensive is it to build LLMs?](#how-expensive-is-it-to-build-llms) - - [Common installation issues](#common-installation-issues) + - [Why is the script only using 1 out of N GPUs?](#why-is-the-script-only-using-1-out-of-n-gpus) + - [I’m running into an Out-Of-Memory (OOM) error. What do I do?](#im-running-into-an-out-of-memory-oom-error-what-do-i-do) + - [What hardware can I train on?](#what-hardware-can-i-train-on) + - [What hardware can I run eval on?](#what-hardware-can-i-run-eval-on) + - [What hardware can I run inference on?](#what-hardware-can-i-run-inference-on) + - [What is FSDP?](#what-is-fsdp) + - [What are the different attention options `torch` / `flash` / `triton` for MPT and which one should I use?](#what-are-the-different-attention-options-torch--flash--triton--for-mpt-and-which-one-should-i-use) + - [Limitations](#limitations) + - [What is `triton-pre-mlir`?](#what-is-triton-pre-mlir) + - [Known issue with sm86+ GPUs](#known-issue-with-sm86-gpus) + - [What kinds of positional embeddings does LLM Foundry support?](#what-kinds-of-positional-embeddings-does-llm-foundry-support) + - [Can I finetune using PEFT / LoRA?](#can-i-finetune-using-peft--lora) + - [Can I quantize these models and/or run on CPU?](#can-i-quantize-these-models-andor-run-on-cpu) + - [How do I deploy with ONNX/FasterTransformer?](#how-do-i-deploy-with-onnxfastertransformer) + - [TransformerEngine and amp\_fp8 support](#transformerengine-and-amp_fp8-support) + - [How expensive is it to build LLMs?](#how-expensive-is-it-to-build-llms) + - [Common installation issues](#common-installation-issues) Let’s get started! @@ -328,6 +342,15 @@ The majority of our training setups use `triton`. --> Updating to LLVM14 (or LLVM15) cannot be done because there are breaking changes. What is the result of this? Although sm89+ is not **formally** supported until LLVM15, our testing on H100 GPUs shows that `attn_impl=triton` still works well and still runs fast. The only issue is that when the network is starting to run, LLVM might throw a warning like: `'sm_90' is not a recognized processor for this target (ignoring processor)`. This warning does not seem to affect performance. +### What kinds of positional embeddings does LLM Foundry support? +Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706.03762.pdf), [Attention with Linear Biases (ALiBi)](https://arxiv.org/pdf/2108.12409.pdf), and [Rotary Positional Embeddings (RoPE)](https://arxiv.org/pdf/2104.09864.pdf). There is also an option to switch off all of these embeddings to get (No Positional Embedding)[https://arxiv.org/pdf/2203.16634.pdf]. + +| Name | YAML Config | Training MFU on MPT-7B trained on 8 A-100 80GB GPUs | Notes | +|------------------------------------|-------------------------------------------------------------------|-----------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| Learned Positional Embeddings | model: learned_pos_emb: True | 65.7 | | +| ALiBi | model: attn_config: alibi: True | 64.5 | Requires Triton or Torch attention. | +| RoPE (Dao-AILab Implementation) | model: attn_config: rope: True rope_imp: dail | 64.5 | Requires GPU and [flash-attention v2.0.1 or higher](https://github.com/Dao-AILab/flash-attention). Note that attention implementation can still be torch, triton, or flash. | +| RoPE (Huggin Face Implementation) | model: attn_config: rope: True rope_imp: hf | 62.3 | | ### Can I finetune using PEFT / LoRA? - The LLM Foundry codebase does not directly have examples of PEFT or LORA workflows. However, our MPT model is a subclass of HuggingFace `PretrainedModel`, and https://github.com/mosaicml/llm-foundry/pull/346 added required features to enable HuggingFace’s [PEFT](https://huggingface.co/docs/peft/index) / [LORA](https://huggingface.co/docs/peft/conceptual_guides/lora) workflows for MPT. MPT models with LoRA modules can be trained either using LLM Foundry or Hugging Face's [accelerate](https://huggingface.co/docs/accelerate/index). Within LLM Foundry, run (`scripts/train/train.py`), adding `lora` arguments to the config `.yaml`, like so: From 1b35c0b600f81d7e067ef214251258d9ad63de5e Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Mon, 30 Oct 2023 23:39:06 +0000 Subject: [PATCH 083/106] added documentation --- TUTORIAL.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/TUTORIAL.md b/TUTORIAL.md index 2f861c4459..0a04bc88f1 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -347,10 +347,10 @@ Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706. | Name | YAML Config | Training MFU on MPT-7B trained on 8 A-100 80GB GPUs | Notes | |------------------------------------|-------------------------------------------------------------------|-----------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| Learned Positional Embeddings | model: learned_pos_emb: True | 65.7 | | -| ALiBi | model: attn_config: alibi: True | 64.5 | Requires Triton or Torch attention. | -| RoPE (Dao-AILab Implementation) | model: attn_config: rope: True rope_imp: dail | 64.5 | Requires GPU and [flash-attention v2.0.1 or higher](https://github.com/Dao-AILab/flash-attention). Note that attention implementation can still be torch, triton, or flash. | -| RoPE (Huggin Face Implementation) | model: attn_config: rope: True rope_imp: hf | 62.3 | | +| Learned Positional Embeddings | model:
learned_pos_emb: True | 65.7 | | +| ALiBi | model:
attn_config:
alibi: True | 64.5 | Requires Triton or Torch attention. | +| RoPE (Dao-AILab Implementation) | model:
attn_config:
rope: True
rope_imp: dail | 64.5 | Requires GPU and [flash-attention v2.0.1 or higher](https://github.com/Dao-AILab/flash-attention). Note that attention implementation can still be torch, triton, or flash. | +| RoPE (Huggin Face Implementation) | model:
attn_config:
rope: True
rope_imp: hf | 62.3 | | ### Can I finetune using PEFT / LoRA? - The LLM Foundry codebase does not directly have examples of PEFT or LORA workflows. However, our MPT model is a subclass of HuggingFace `PretrainedModel`, and https://github.com/mosaicml/llm-foundry/pull/346 added required features to enable HuggingFace’s [PEFT](https://huggingface.co/docs/peft/index) / [LORA](https://huggingface.co/docs/peft/conceptual_guides/lora) workflows for MPT. MPT models with LoRA modules can be trained either using LLM Foundry or Hugging Face's [accelerate](https://huggingface.co/docs/accelerate/index). Within LLM Foundry, run (`scripts/train/train.py`), adding `lora` arguments to the config `.yaml`, like so: From 3988b57d23863639377cb5b5c01ab4ed6d320518 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 31 Oct 2023 01:01:09 +0000 Subject: [PATCH 084/106] temp commit --- TUTORIAL.md | 8 ++++---- llmfoundry/models/mpt/configuration_mpt.py | 2 ++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/TUTORIAL.md b/TUTORIAL.md index 0a04bc88f1..19fb75c762 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -347,10 +347,10 @@ Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706. | Name | YAML Config | Training MFU on MPT-7B trained on 8 A-100 80GB GPUs | Notes | |------------------------------------|-------------------------------------------------------------------|-----------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| Learned Positional Embeddings | model:
learned_pos_emb: True | 65.7 | | -| ALiBi | model:
attn_config:
alibi: True | 64.5 | Requires Triton or Torch attention. | -| RoPE (Dao-AILab Implementation) | model:
attn_config:
rope: True
rope_imp: dail | 64.5 | Requires GPU and [flash-attention v2.0.1 or higher](https://github.com/Dao-AILab/flash-attention). Note that attention implementation can still be torch, triton, or flash. | -| RoPE (Huggin Face Implementation) | model:
attn_config:
rope: True
rope_imp: hf | 62.3 | | +| Learned Positional Embeddings | ```model:
learned_pos_emb: True ```| 65.7 | | +| ALiBi | ```model:
attn_config:
alibi: True ```| 64.5 | Requires Triton or Torch attention. | +| RoPE (Dao-AILab Implementation) | ```model:
attn_config:
rope: True
rope_imp: dail ```| 64.5 | Requires GPU and [flash-attention v2.0.1 or higher](https://github.com/Dao-AILab/flash-attention). Note that attention implementation can still be torch, triton, or flash. | +| RoPE (Huggin Face Implementation) |``` model:
attn_config:
rope: True
rope_imp: hf ```| 62.3 | | ### Can I finetune using PEFT / LoRA? - The LLM Foundry codebase does not directly have examples of PEFT or LORA workflows. However, our MPT model is a subclass of HuggingFace `PretrainedModel`, and https://github.com/mosaicml/llm-foundry/pull/346 added required features to enable HuggingFace’s [PEFT](https://huggingface.co/docs/peft/index) / [LORA](https://huggingface.co/docs/peft/conceptual_guides/lora) workflows for MPT. MPT models with LoRA modules can be trained either using LLM Foundry or Hugging Face's [accelerate](https://huggingface.co/docs/accelerate/index). Within LLM Foundry, run (`scripts/train/train.py`), adding `lora` arguments to the config `.yaml`, like so: diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 7988b607c2..71e11c80bc 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -188,6 +188,8 @@ def _set_config_defaults(self, config: Dict[str, Any], for k, v in config_defaults.items(): if k not in config: config[k] = v + elif isinstance(v, dict): + config[k] = self._set_config_defaults(config[k], v) return config def _validate_config(self) -> None: From 82ce2d94b51934ce97a8d7d31c2a30d9a8e5937b Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 31 Oct 2023 05:03:54 +0000 Subject: [PATCH 085/106] made _set_config_defaults recursive --- llmfoundry/models/mpt/configuration_mpt.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 71e11c80bc..97749a365a 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -189,7 +189,8 @@ def _set_config_defaults(self, config: Dict[str, Any], if k not in config: config[k] = v elif isinstance(v, dict): - config[k] = self._set_config_defaults(config[k], v) + config[k] = self._set_config_defaults( + config[k] if (config[k] is not None) else {}, v) return config def _validate_config(self) -> None: From d2930f94bcdc2b25e3a0d9ecd15f26c9f2c68565 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 31 Oct 2023 05:37:53 +0000 Subject: [PATCH 086/106] minor changes --- tests/test_flash_triton_torch.py | 9 ---- tests/test_model.py | 72 -------------------------------- 2 files changed, 81 deletions(-) diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index c5768b5d0c..6900da633f 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -37,20 +37,11 @@ def allclose_helper(t0: torch.Tensor, 'pos_idx_in_fp32': True, 'xpos_scale_base': 512, }, - 'rope_hf_config': { - 'type': 'no_scaling', - 'factor': 1.0, - }, }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, 'rope_imp': 'hf', - 'rope_dail_config': { - 'type': 'original', - 'pos_idx_in_fp32': True, - 'xpos_scale_base': 512, - }, 'rope_hf_config': { 'type': 'no_scaling', 'factor': 1.0, diff --git a/tests/test_model.py b/tests/test_model.py index d39a0b740a..ab14fb1b53 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -539,20 +539,11 @@ def test_mpt_creation(norm_type: str, no_bias: bool): 'pos_idx_in_fp32': True, 'xpos_scale_base': 512, }, - 'rope_hf_config': { - 'type': 'no_scaling', - 'factor': 1.0, - }, }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, 'rope_imp': 'hf', - 'rope_dail_config': { - 'type': 'original', - 'pos_idx_in_fp32': True, - 'xpos_scale_base': 512, - }, 'rope_hf_config': { 'type': 'no_scaling', 'factor': 1.0, @@ -777,20 +768,11 @@ def test_advanced_mask_building(attention_impl: str): 'pos_idx_in_fp32': True, 'xpos_scale_base': 512, }, - 'rope_hf_config': { - 'type': 'no_scaling', - 'factor': 1.0, - }, }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, 'rope_imp': 'hf', - 'rope_dail_config': { - 'type': 'original', - 'pos_idx_in_fp32': True, - 'xpos_scale_base': 512, - }, 'rope_hf_config': { 'type': 'no_scaling', 'factor': 1.0, @@ -1014,20 +996,11 @@ def test_save_from_pretrained(tmp_path: pathlib.Path): 'pos_idx_in_fp32': True, 'xpos_scale_base': 512, }, - 'rope_hf_config': { - 'type': 'no_scaling', - 'factor': 1.0, - }, }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, 'rope_imp': 'hf', - 'rope_dail_config': { - 'type': 'original', - 'pos_idx_in_fp32': True, - 'xpos_scale_base': 512, - }, 'rope_hf_config': { 'type': 'no_scaling', 'factor': 1.0, @@ -1162,20 +1135,11 @@ def test_forward_with_cache_and_padding(attn_impl: str, device: str, 'pos_idx_in_fp32': True, 'xpos_scale_base': 512, }, - 'rope_hf_config': { - 'type': 'no_scaling', - 'factor': 1.0, - }, }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, 'rope_imp': 'hf', - 'rope_dail_config': { - 'type': 'original', - 'pos_idx_in_fp32': True, - 'xpos_scale_base': 512, - }, 'rope_hf_config': { 'type': 'no_scaling', 'factor': 1.0, @@ -1312,20 +1276,11 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): 'pos_idx_in_fp32': True, 'xpos_scale_base': 512, }, - 'rope_hf_config': { - 'type': 'no_scaling', - 'factor': 1.0, - }, }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, 'rope_imp': 'hf', - 'rope_dail_config': { - 'type': 'original', - 'pos_idx_in_fp32': True, - 'xpos_scale_base': 512, - }, 'rope_hf_config': { 'type': 'no_scaling', 'factor': 1.0, @@ -1433,20 +1388,11 @@ def test_generate_with_past_kv(attn_impl: str, device: str, 'pos_idx_in_fp32': True, 'xpos_scale_base': 512, }, - 'rope_hf_config': { - 'type': 'no_scaling', - 'factor': 1.0, - }, }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, 'rope_imp': 'hf', - 'rope_dail_config': { - 'type': 'original', - 'pos_idx_in_fp32': True, - 'xpos_scale_base': 512, - }, 'rope_hf_config': { 'type': 'no_scaling', 'factor': 1.0, @@ -1523,20 +1469,11 @@ def test_generation_kwargs_dont_crash(attn_impl: str, device: str, 'pos_idx_in_fp32': True, 'xpos_scale_base': 512, }, - 'rope_hf_config': { - 'type': 'no_scaling', - 'factor': 1.0, - }, }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, 'rope_imp': 'hf', - 'rope_dail_config': { - 'type': 'original', - 'pos_idx_in_fp32': True, - 'xpos_scale_base': 512, - }, 'rope_hf_config': { 'type': 'no_scaling', 'factor': 1.0, @@ -1663,20 +1600,11 @@ def test_alibi_vs_hf(): 'pos_idx_in_fp32': True, 'xpos_scale_base': 512, }, - 'rope_hf_config': { - 'type': 'no_scaling', - 'factor': 1.0, - }, }, { 'alibi': False, 'rope': True, 'rope_theta': 10000, 'rope_imp': 'hf', - 'rope_dail_config': { - 'type': 'original', - 'pos_idx_in_fp32': True, - 'xpos_scale_base': 512, - }, 'rope_hf_config': { 'type': 'no_scaling', 'factor': 1.0, From b022b128c3393859502cebbea5952f1a81ee44b9 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 31 Oct 2023 06:03:26 +0000 Subject: [PATCH 087/106] reformatted tutorial table --- TUTORIAL.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/TUTORIAL.md b/TUTORIAL.md index 19fb75c762..9f85074990 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -347,10 +347,10 @@ Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706. | Name | YAML Config | Training MFU on MPT-7B trained on 8 A-100 80GB GPUs | Notes | |------------------------------------|-------------------------------------------------------------------|-----------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| Learned Positional Embeddings | ```model:
learned_pos_emb: True ```| 65.7 | | -| ALiBi | ```model:
attn_config:
alibi: True ```| 64.5 | Requires Triton or Torch attention. | -| RoPE (Dao-AILab Implementation) | ```model:
attn_config:
rope: True
rope_imp: dail ```| 64.5 | Requires GPU and [flash-attention v2.0.1 or higher](https://github.com/Dao-AILab/flash-attention). Note that attention implementation can still be torch, triton, or flash. | -| RoPE (Huggin Face Implementation) |``` model:
attn_config:
rope: True
rope_imp: hf ```| 62.3 | | +| Learned Positional Embeddings | model:
learned_pos_emb: True
| 65.7 | | +| ALiBi | model:
attn_config:
alibi: True
| 64.5 | Requires Triton or Torch attention. | +| RoPE (Dao-AILab Implementation) | model:
attn_config:
rope: True
rope_imp: dail
| 64.5 | Requires GPU and [flash-attention v2.0.1 or higher](https://github.com/Dao-AILab/flash-attention). Note that attention implementation can still be torch, triton, or flash. | +| RoPE (Huggin Face Implementation) | model:
attn_config:
rope: True
rope_imp: hf
| 62.3 | | ### Can I finetune using PEFT / LoRA? - The LLM Foundry codebase does not directly have examples of PEFT or LORA workflows. However, our MPT model is a subclass of HuggingFace `PretrainedModel`, and https://github.com/mosaicml/llm-foundry/pull/346 added required features to enable HuggingFace’s [PEFT](https://huggingface.co/docs/peft/index) / [LORA](https://huggingface.co/docs/peft/conceptual_guides/lora) workflows for MPT. MPT models with LoRA modules can be trained either using LLM Foundry or Hugging Face's [accelerate](https://huggingface.co/docs/accelerate/index). Within LLM Foundry, run (`scripts/train/train.py`), adding `lora` arguments to the config `.yaml`, like so: From dbc0f8492303c456294fa62dfaeac12848fc2451 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 31 Oct 2023 06:13:57 +0000 Subject: [PATCH 088/106] reformatted tutorial table --- TUTORIAL.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/TUTORIAL.md b/TUTORIAL.md index 9f85074990..bfdee5bb76 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -347,10 +347,10 @@ Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706. | Name | YAML Config | Training MFU on MPT-7B trained on 8 A-100 80GB GPUs | Notes | |------------------------------------|-------------------------------------------------------------------|-----------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| Learned Positional Embeddings | model:
learned_pos_emb: True
| 65.7 | | -| ALiBi | model:
attn_config:
alibi: True
| 64.5 | Requires Triton or Torch attention. | -| RoPE (Dao-AILab Implementation) | model:
attn_config:
rope: True
rope_imp: dail
| 64.5 | Requires GPU and [flash-attention v2.0.1 or higher](https://github.com/Dao-AILab/flash-attention). Note that attention implementation can still be torch, triton, or flash. | -| RoPE (Huggin Face Implementation) | model:
attn_config:
rope: True
rope_imp: hf
| 62.3 | | +| Learned Positional Embeddings |
model:
learned_pos_emb: True
| 65.7 | | +| ALiBi |
model:
attn_config:
alibi: True
| 64.5 | Requires Triton or Torch attention. | +| RoPE (Dao-AILab Implementation) |
model:
attn_config:
rope: True
rope_imp: dail
| 64.5 | Requires GPU and [flash-attention v2.0.1 or higher](https://github.com/Dao-AILab/flash-attention). Note that attention implementation can still be torch, triton, or flash. | +| RoPE (Huggin Face Implementation) |
model:
attn_config:
rope: True
rope_imp: hf
| 62.3 | | ### Can I finetune using PEFT / LoRA? - The LLM Foundry codebase does not directly have examples of PEFT or LORA workflows. However, our MPT model is a subclass of HuggingFace `PretrainedModel`, and https://github.com/mosaicml/llm-foundry/pull/346 added required features to enable HuggingFace’s [PEFT](https://huggingface.co/docs/peft/index) / [LORA](https://huggingface.co/docs/peft/conceptual_guides/lora) workflows for MPT. MPT models with LoRA modules can be trained either using LLM Foundry or Hugging Face's [accelerate](https://huggingface.co/docs/accelerate/index). Within LLM Foundry, run (`scripts/train/train.py`), adding `lora` arguments to the config `.yaml`, like so: From 54d83043605f6a1b3322f745b1ae0d09875a2450 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 31 Oct 2023 06:17:12 +0000 Subject: [PATCH 089/106] reformatted tutorial table --- TUTORIAL.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/TUTORIAL.md b/TUTORIAL.md index bfdee5bb76..b5d620af36 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -347,10 +347,10 @@ Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706. | Name | YAML Config | Training MFU on MPT-7B trained on 8 A-100 80GB GPUs | Notes | |------------------------------------|-------------------------------------------------------------------|-----------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| Learned Positional Embeddings |
model:
learned_pos_emb: True
| 65.7 | | -| ALiBi |
model:
attn_config:
alibi: True
| 64.5 | Requires Triton or Torch attention. | -| RoPE (Dao-AILab Implementation) |
model:
attn_config:
rope: True
rope_imp: dail
| 64.5 | Requires GPU and [flash-attention v2.0.1 or higher](https://github.com/Dao-AILab/flash-attention). Note that attention implementation can still be torch, triton, or flash. | -| RoPE (Huggin Face Implementation) |
model:
attn_config:
rope: True
rope_imp: hf
| 62.3 | | +| Learned Positional Embeddings |
model:
learned_pos_emb: True
| 65.7 | | +| ALiBi |
model:
attn_config:
alibi: True
| 64.5 | Requires Triton or Torch attention. | +| RoPE (Dao-AILab Implementation) |
model:
attn_config:
rope: True
rope_imp: dail
| 64.5 | Requires GPU and [flash-attention v2.0.1 or higher](https://github.com/Dao-AILab/flash-attention). Note that attention implementation can still be torch, triton, or flash. | +| RoPE (Huggin Face Implementation) |
model:
attn_config:
rope: True
rope_imp: hf
| 62.3 | | ### Can I finetune using PEFT / LoRA? - The LLM Foundry codebase does not directly have examples of PEFT or LORA workflows. However, our MPT model is a subclass of HuggingFace `PretrainedModel`, and https://github.com/mosaicml/llm-foundry/pull/346 added required features to enable HuggingFace’s [PEFT](https://huggingface.co/docs/peft/index) / [LORA](https://huggingface.co/docs/peft/conceptual_guides/lora) workflows for MPT. MPT models with LoRA modules can be trained either using LLM Foundry or Hugging Face's [accelerate](https://huggingface.co/docs/accelerate/index). Within LLM Foundry, run (`scripts/train/train.py`), adding `lora` arguments to the config `.yaml`, like so: From c05d34d9f2e1c10df32ebf3bab7c26874b212752 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 31 Oct 2023 08:03:52 +0000 Subject: [PATCH 090/106] added documentation on how to install flash attention 2 --- TUTORIAL.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/TUTORIAL.md b/TUTORIAL.md index b5d620af36..53c11cc1b4 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -342,14 +342,17 @@ The majority of our training setups use `triton`. --> Updating to LLVM14 (or LLVM15) cannot be done because there are breaking changes. What is the result of this? Although sm89+ is not **formally** supported until LLVM15, our testing on H100 GPUs shows that `attn_impl=triton` still works well and still runs fast. The only issue is that when the network is starting to run, LLVM might throw a warning like: `'sm_90' is not a recognized processor for this target (ignoring processor)`. This warning does not seem to affect performance. +#### Support for FlashAttention-2 +- (FlashAttention-2)[https://arxiv.org/pdf/2307.08691.pdf] improves upon FlashAttention to get even faster and more efficient attention computation. LLM-Foundry now supports FlashAttention-2 by simply using the (new MosaicML Docker image)[https://github.com/mosaicml/llm-foundry#mosaicml-docker-images], then (following the instructions here)[https://github.com/mosaicml/llm-foundry#with-docker-recommended], and then running pip install -e ".[gpu-flash2]". This will also install the (flash-attn library)[https://github.com/Dao-AILab/flash-attention] v2.3.2. + ### What kinds of positional embeddings does LLM Foundry support? Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706.03762.pdf), [Attention with Linear Biases (ALiBi)](https://arxiv.org/pdf/2108.12409.pdf), and [Rotary Positional Embeddings (RoPE)](https://arxiv.org/pdf/2104.09864.pdf). There is also an option to switch off all of these embeddings to get (No Positional Embedding)[https://arxiv.org/pdf/2203.16634.pdf]. | Name | YAML Config | Training MFU on MPT-7B trained on 8 A-100 80GB GPUs | Notes | -|------------------------------------|-------------------------------------------------------------------|-----------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +|:-----------------------------------|:------------------------------------------------------------------|:---------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | Learned Positional Embeddings |
model:
learned_pos_emb: True
| 65.7 | | | ALiBi |
model:
attn_config:
alibi: True
| 64.5 | Requires Triton or Torch attention. | -| RoPE (Dao-AILab Implementation) |
model:
attn_config:
rope: True
rope_imp: dail
| 64.5 | Requires GPU and [flash-attention v2.0.1 or higher](https://github.com/Dao-AILab/flash-attention). Note that attention implementation can still be torch, triton, or flash. | +| RoPE (Dao-AILab Implementation) |
model:
attn_config:
rope: True
rope_imp: dail
| 64.5 | Requires a CUDA GPU and [the flash-attn library](https://github.com/Dao-AILab/flash-attention) (v2.0.1 or higher) to be installed. Please see the instructions in the section above on how to install the flash-attn library v2.3.2. Note that attention implementation can still be torch, triton, or flash, just that this needs the the flash-attn library (v2.0.1 or higher) since we use their RotaryEmbedding class. | | RoPE (Huggin Face Implementation) |
model:
attn_config:
rope: True
rope_imp: hf
| 62.3 | | ### Can I finetune using PEFT / LoRA? From c8f099d546340a42d7eb156a44356eaab53fb066 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 31 Oct 2023 08:08:19 +0000 Subject: [PATCH 091/106] minor changes --- TUTORIAL.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/TUTORIAL.md b/TUTORIAL.md index 53c11cc1b4..b7e018cafd 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -36,6 +36,7 @@ This tutorial will provide a brief intro to the repo’s structure and underlyin - [Limitations](#limitations) - [What is `triton-pre-mlir`?](#what-is-triton-pre-mlir) - [Known issue with sm86+ GPUs](#known-issue-with-sm86-gpus) + - [Support for FlashAttention-2](#support-for-flashattention-2) - [What kinds of positional embeddings does LLM Foundry support?](#what-kinds-of-positional-embeddings-does-llm-foundry-support) - [Can I finetune using PEFT / LoRA?](#can-i-finetune-using-peft--lora) - [Can I quantize these models and/or run on CPU?](#can-i-quantize-these-models-andor-run-on-cpu) @@ -343,7 +344,7 @@ The majority of our training setups use `triton`. --> What is the result of this? Although sm89+ is not **formally** supported until LLVM15, our testing on H100 GPUs shows that `attn_impl=triton` still works well and still runs fast. The only issue is that when the network is starting to run, LLVM might throw a warning like: `'sm_90' is not a recognized processor for this target (ignoring processor)`. This warning does not seem to affect performance. #### Support for FlashAttention-2 -- (FlashAttention-2)[https://arxiv.org/pdf/2307.08691.pdf] improves upon FlashAttention to get even faster and more efficient attention computation. LLM-Foundry now supports FlashAttention-2 by simply using the (new MosaicML Docker image)[https://github.com/mosaicml/llm-foundry#mosaicml-docker-images], then (following the instructions here)[https://github.com/mosaicml/llm-foundry#with-docker-recommended], and then running pip install -e ".[gpu-flash2]". This will also install the (flash-attn library)[https://github.com/Dao-AILab/flash-attention] v2.3.2. +- [FlashAttention-2](https://arxiv.org/pdf/2307.08691.pdf) improves upon FlashAttention to get even faster and more efficient attention computation. LLM-Foundry now supports FlashAttention-2 by simply using the [new MosaicML Docker image](https://github.com/mosaicml/llm-foundry#mosaicml-docker-images), then [following the instructions here](https://github.com/mosaicml/llm-foundry#with-docker-recommended), and then running pip install -e ".[gpu-flash2]". Then setting attn_impl: flash uses FlashAttention2. This will also install the [flash-attn library](https://github.com/Dao-AILab/flash-attention) v2.3.2. ### What kinds of positional embeddings does LLM Foundry support? Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706.03762.pdf), [Attention with Linear Biases (ALiBi)](https://arxiv.org/pdf/2108.12409.pdf), and [Rotary Positional Embeddings (RoPE)](https://arxiv.org/pdf/2104.09864.pdf). There is also an option to switch off all of these embeddings to get (No Positional Embedding)[https://arxiv.org/pdf/2203.16634.pdf]. From a48afa209b6fbb9b18b4a786980218849eb22a2f Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 31 Oct 2023 08:10:51 +0000 Subject: [PATCH 092/106] minor changes --- TUTORIAL.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TUTORIAL.md b/TUTORIAL.md index b7e018cafd..03b1173a7a 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -344,7 +344,7 @@ The majority of our training setups use `triton`. --> What is the result of this? Although sm89+ is not **formally** supported until LLVM15, our testing on H100 GPUs shows that `attn_impl=triton` still works well and still runs fast. The only issue is that when the network is starting to run, LLVM might throw a warning like: `'sm_90' is not a recognized processor for this target (ignoring processor)`. This warning does not seem to affect performance. #### Support for FlashAttention-2 -- [FlashAttention-2](https://arxiv.org/pdf/2307.08691.pdf) improves upon FlashAttention to get even faster and more efficient attention computation. LLM-Foundry now supports FlashAttention-2 by simply using the [new MosaicML Docker image](https://github.com/mosaicml/llm-foundry#mosaicml-docker-images), then [following the instructions here](https://github.com/mosaicml/llm-foundry#with-docker-recommended), and then running pip install -e ".[gpu-flash2]". Then setting attn_impl: flash uses FlashAttention2. This will also install the [flash-attn library](https://github.com/Dao-AILab/flash-attention) v2.3.2. +- [FlashAttention-2](https://arxiv.org/pdf/2307.08691.pdf) improves upon FlashAttention to get even faster attention computation. LLM-Foundry now supports FlashAttention-2 by simply using the [new MosaicML Docker image](https://github.com/mosaicml/llm-foundry#mosaicml-docker-images), then [following the instructions here](https://github.com/mosaicml/llm-foundry#with-docker-recommended), and then running pip install -e ".[gpu-flash2]". Then setting attn_impl: flash uses FlashAttention2. This will also install the [flash-attn library](https://github.com/Dao-AILab/flash-attention) v2.3.2. ### What kinds of positional embeddings does LLM Foundry support? Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706.03762.pdf), [Attention with Linear Biases (ALiBi)](https://arxiv.org/pdf/2108.12409.pdf), and [Rotary Positional Embeddings (RoPE)](https://arxiv.org/pdf/2104.09864.pdf). There is also an option to switch off all of these embeddings to get (No Positional Embedding)[https://arxiv.org/pdf/2203.16634.pdf]. From fa46318ad343c52a66a73dde6c29858f0636909f Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 31 Oct 2023 08:12:13 +0000 Subject: [PATCH 093/106] minor changes --- TUTORIAL.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TUTORIAL.md b/TUTORIAL.md index 03b1173a7a..8cbf484e09 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -353,7 +353,7 @@ Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706. |:-----------------------------------|:------------------------------------------------------------------|:---------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | Learned Positional Embeddings |
model:
learned_pos_emb: True
| 65.7 | | | ALiBi |
model:
attn_config:
alibi: True
| 64.5 | Requires Triton or Torch attention. | -| RoPE (Dao-AILab Implementation) |
model:
attn_config:
rope: True
rope_imp: dail
| 64.5 | Requires a CUDA GPU and [the flash-attn library](https://github.com/Dao-AILab/flash-attention) (v2.0.1 or higher) to be installed. Please see the instructions in the section above on how to install the flash-attn library v2.3.2. Note that attention implementation can still be torch, triton, or flash, just that this needs the the flash-attn library (v2.0.1 or higher) since we use their RotaryEmbedding class. | +| RoPE (Dao-AILab Implementation) |
model:
attn_config:
rope: True
rope_imp: dail
| 64.5 | Requires a CUDA GPU and [the flash-attn library](https://github.com/Dao-AILab/flash-attention) (v2.0.1 or higher) to be installed. Please see the instructions in the section above on how to install the flash-attn library v2.3.2. Note that attention implementation can still be torch, triton, or flash, just that this needs the the flash-attn library (v2.0.1 or higher) since we import their RotaryEmbedding class. | | RoPE (Huggin Face Implementation) |
model:
attn_config:
rope: True
rope_imp: hf
| 62.3 | | ### Can I finetune using PEFT / LoRA? From 5b1016456ff69e04391f14fa66dfbcc45f6644ad Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 31 Oct 2023 15:24:35 +0000 Subject: [PATCH 094/106] minor changes --- TUTORIAL.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/TUTORIAL.md b/TUTORIAL.md index 8cbf484e09..c0eeb078ca 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -344,7 +344,7 @@ The majority of our training setups use `triton`. --> What is the result of this? Although sm89+ is not **formally** supported until LLVM15, our testing on H100 GPUs shows that `attn_impl=triton` still works well and still runs fast. The only issue is that when the network is starting to run, LLVM might throw a warning like: `'sm_90' is not a recognized processor for this target (ignoring processor)`. This warning does not seem to affect performance. #### Support for FlashAttention-2 -- [FlashAttention-2](https://arxiv.org/pdf/2307.08691.pdf) improves upon FlashAttention to get even faster attention computation. LLM-Foundry now supports FlashAttention-2 by simply using the [new MosaicML Docker image](https://github.com/mosaicml/llm-foundry#mosaicml-docker-images), then [following the instructions here](https://github.com/mosaicml/llm-foundry#with-docker-recommended), and then running pip install -e ".[gpu-flash2]". Then setting attn_impl: flash uses FlashAttention2. This will also install the [flash-attn library](https://github.com/Dao-AILab/flash-attention) v2.3.2. +- [FlashAttention-2](https://arxiv.org/pdf/2307.08691.pdf) improves upon FlashAttention to get even faster attention computation. LLM-Foundry now supports FlashAttention-2, please follow the instructions [here](https://github.com/mosaicml/llm-foundry/tree/main/scripts/train#flashattention). ### What kinds of positional embeddings does LLM Foundry support? Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706.03762.pdf), [Attention with Linear Biases (ALiBi)](https://arxiv.org/pdf/2108.12409.pdf), and [Rotary Positional Embeddings (RoPE)](https://arxiv.org/pdf/2104.09864.pdf). There is also an option to switch off all of these embeddings to get (No Positional Embedding)[https://arxiv.org/pdf/2203.16634.pdf]. @@ -353,7 +353,7 @@ Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706. |:-----------------------------------|:------------------------------------------------------------------|:---------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | Learned Positional Embeddings |
model:
learned_pos_emb: True
| 65.7 | | | ALiBi |
model:
attn_config:
alibi: True
| 64.5 | Requires Triton or Torch attention. | -| RoPE (Dao-AILab Implementation) |
model:
attn_config:
rope: True
rope_imp: dail
| 64.5 | Requires a CUDA GPU and [the flash-attn library](https://github.com/Dao-AILab/flash-attention) (v2.0.1 or higher) to be installed. Please see the instructions in the section above on how to install the flash-attn library v2.3.2. Note that attention implementation can still be torch, triton, or flash, just that this needs the the flash-attn library (v2.0.1 or higher) since we import their RotaryEmbedding class. | +| RoPE (Dao-AILab Implementation) |
model:
attn_config:
rope: True
rope_imp: dail
| 64.5 | Requires a CUDA GPU and [the flash-attn library](https://github.com/Dao-AILab/flash-attention) (v2.0.1 or higher) to be installed. Please see the instructions in the [paragraph above](#support-for-flashattention-2) on how to install flash-attn library v2. Note that attention implementation can still be torch, triton, or flash, just that this needs the the flash-attn library (v2.0.1 or higher) since we import their RotaryEmbedding class. | | RoPE (Huggin Face Implementation) |
model:
attn_config:
rope: True
rope_imp: hf
| 62.3 | | ### Can I finetune using PEFT / LoRA? From 64d4a57e9d9abc2d2874c33d4a679d8d4c04782a Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 31 Oct 2023 15:25:34 +0000 Subject: [PATCH 095/106] minor changes --- TUTORIAL.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TUTORIAL.md b/TUTORIAL.md index c0eeb078ca..e2cdddd947 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -353,7 +353,7 @@ Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706. |:-----------------------------------|:------------------------------------------------------------------|:---------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | Learned Positional Embeddings |
model:
learned_pos_emb: True
| 65.7 | | | ALiBi |
model:
attn_config:
alibi: True
| 64.5 | Requires Triton or Torch attention. | -| RoPE (Dao-AILab Implementation) |
model:
attn_config:
rope: True
rope_imp: dail
| 64.5 | Requires a CUDA GPU and [the flash-attn library](https://github.com/Dao-AILab/flash-attention) (v2.0.1 or higher) to be installed. Please see the instructions in the [paragraph above](#support-for-flashattention-2) on how to install flash-attn library v2. Note that attention implementation can still be torch, triton, or flash, just that this needs the the flash-attn library (v2.0.1 or higher) since we import their RotaryEmbedding class. | +| RoPE (Dao-AILab Implementation) |
model:
attn_config:
rope: True
rope_imp: dail
| 64.5 | Requires a CUDA GPU and [the flash-attn library](https://github.com/Dao-AILab/flash-attention) (v2.0.1 or higher) to be installed. Please see the instructions in the [paragraph above](#support-for-flashattention-2) on how to install flash-attn library v2. Note that attention implementation can still be torch, triton, or flash, just that this needs the the flash-attn library (v2.0.1 or higher) installed since we import their RotaryEmbedding class. | | RoPE (Huggin Face Implementation) |
model:
attn_config:
rope: True
rope_imp: hf
| 62.3 | | ### Can I finetune using PEFT / LoRA? From 04452e66da75800e1bfa7eed4d1d516fbd6b4092 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 31 Oct 2023 15:29:08 +0000 Subject: [PATCH 096/106] minor changes --- TUTORIAL.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TUTORIAL.md b/TUTORIAL.md index e2cdddd947..aa85d6b66b 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -354,7 +354,7 @@ Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706. | Learned Positional Embeddings |
model:
learned_pos_emb: True
| 65.7 | | | ALiBi |
model:
attn_config:
alibi: True
| 64.5 | Requires Triton or Torch attention. | | RoPE (Dao-AILab Implementation) |
model:
attn_config:
rope: True
rope_imp: dail
| 64.5 | Requires a CUDA GPU and [the flash-attn library](https://github.com/Dao-AILab/flash-attention) (v2.0.1 or higher) to be installed. Please see the instructions in the [paragraph above](#support-for-flashattention-2) on how to install flash-attn library v2. Note that attention implementation can still be torch, triton, or flash, just that this needs the the flash-attn library (v2.0.1 or higher) installed since we import their RotaryEmbedding class. | -| RoPE (Huggin Face Implementation) |
model:
attn_config:
rope: True
rope_imp: hf
| 62.3 | | +| RoPE (Hugging Face Implementation) |
model:
attn_config:
rope: True
rope_imp: hf
| 62.3 | | ### Can I finetune using PEFT / LoRA? - The LLM Foundry codebase does not directly have examples of PEFT or LORA workflows. However, our MPT model is a subclass of HuggingFace `PretrainedModel`, and https://github.com/mosaicml/llm-foundry/pull/346 added required features to enable HuggingFace’s [PEFT](https://huggingface.co/docs/peft/index) / [LORA](https://huggingface.co/docs/peft/conceptual_guides/lora) workflows for MPT. MPT models with LoRA modules can be trained either using LLM Foundry or Hugging Face's [accelerate](https://huggingface.co/docs/accelerate/index). Within LLM Foundry, run (`scripts/train/train.py`), adding `lora` arguments to the config `.yaml`, like so: From 1eff648fb460943e31230eb7e2af813c0dc426a9 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 31 Oct 2023 17:25:57 +0000 Subject: [PATCH 097/106] .. --- llmfoundry/models/mpt/modeling_mpt.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index cc8a8213a6..bdd752fc41 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -27,8 +27,9 @@ from llmfoundry.models.layers.attention import is_flash_v2_installed if is_flash_v2_installed(): - try: # This try...except is needed because transformers requires it despite the 'if' statement above - from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding + try: # This try...except is needed because transformers requires it despite the 'if' statement above + from flash_attn.layers.rotary import \ + RotaryEmbedding as DAILRotaryEmbedding except Exception as e: raise e From 1e59de5ebcdd0f0b6cc358e367aa4f35ea8a016e Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 2 Nov 2023 18:31:27 +0000 Subject: [PATCH 098/106] resolved some comments from the PR --- TUTORIAL.md | 4 +- llmfoundry/models/layers/attention.py | 20 +++-- llmfoundry/models/layers/blocks.py | 52 ++++++------ llmfoundry/models/mpt/configuration_mpt.py | 36 ++------ llmfoundry/models/mpt/modeling_mpt.py | 63 +++++++------- tests/test_flash_triton_torch.py | 10 +-- tests/test_model.py | 99 ++++++++++------------ tests/test_rope_dail_vs_hf.py | 10 +-- 8 files changed, 138 insertions(+), 156 deletions(-) diff --git a/TUTORIAL.md b/TUTORIAL.md index 3e1dd6ed89..117a61f617 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -353,8 +353,8 @@ Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706. |:-----------------------------------|:------------------------------------------------------------------|:---------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | Learned Positional Embeddings |
model:
learned_pos_emb: True
| 65.7 | | | ALiBi |
model:
attn_config:
alibi: True
| 64.5 | Requires Triton or Torch attention. | -| RoPE (Dao-AILab Implementation) |
model:
attn_config:
rope: True
rope_imp: dail
| 64.5 | Requires a CUDA GPU and [the flash-attn library](https://github.com/Dao-AILab/flash-attention) (v2.0.1 or higher) to be installed. Please see the instructions in the [paragraph above](#support-for-flashattention-2) on how to install flash-attn library v2. Note that attention implementation can still be torch, triton, or flash, just that this needs the the flash-attn library (v2.0.1 or higher) installed since we import their RotaryEmbedding class. | -| RoPE (Hugging Face Implementation) |
model:
attn_config:
rope: True
rope_imp: hf
| 62.3 | | +| RoPE (Dao-AILab Implementation) |
model:
attn_config:
rope: True
rope_impl: dail
| 64.5 | Requires a CUDA GPU and [the flash-attn library](https://github.com/Dao-AILab/flash-attention) (v2.0.1 or higher) to be installed. Please see the instructions in the [paragraph above](#support-for-flashattention-2) on how to install flash-attn library v2. Note that attention implementation can still be torch, triton, or flash, just that this needs the the flash-attn library (v2.0.1 or higher) installed since we import their RotaryEmbedding class. | +| RoPE (Hugging Face Implementation) |
model:
attn_config:
rope: True
rope_impl: hf
| 62.3 | | ### Can I finetune using PEFT / LoRA? - The LLM Foundry codebase does not directly have examples of PEFT or LORA workflows. However, our MPT model is a subclass of HuggingFace `PretrainedModel`, and https://github.com/mosaicml/llm-foundry/pull/346 added required features to enable HuggingFace’s [PEFT](https://huggingface.co/docs/peft/index) / [LORA](https://huggingface.co/docs/peft/conceptual_guides/lora) workflows for MPT. MPT models with LoRA modules can be trained either using LLM Foundry or Hugging Face's [accelerate](https://huggingface.co/docs/accelerate/index). Within LLM Foundry, run (`scripts/train/train.py`), adding `lora` arguments to the config `.yaml`, like so: diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 3263c709a9..5bec1dbb70 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -33,6 +33,9 @@ def is_flash_v1_installed(): return version.parse(flash_attn.__version__) < version.parse('2.0.0') +# Before importing any transformers models, we need to disable transformers flash attention if +# we are in an environment with flash attention version <2. Transformers hard errors on a not properly +# gated import otherwise. if is_flash_v1_installed(): import transformers transformers.utils.is_flash_attn_available = lambda: False @@ -593,12 +596,14 @@ def forward( rotary_emb = rotary_emb_w_meta_info['rotary_emb'] seq_len = rotary_emb_w_meta_info['seq_len'] offset_info = rotary_emb_w_meta_info['offset_info'] - - query = query.view(*(query.shape[:-1]), -1, self.head_dim) - key = key.view(*(key.shape[:-1]), -1, self.head_dim) + assert query.shape[:2] == key.shape[:2] + assert query.shape[:2] == key.shape[:2] + bsz, seqlen = query.shape[:2] + query = query.view(bsz, seqlen, -1, self.head_dim) + key = key.view(bsz, seqlen, -1, self.head_dim) if rotary_emb_w_meta_info['imp'] == 'dail': - value = value.view(*(value.shape[:-1]), -1, self.head_dim) + value = value.view(bsz, seqlen, -1, self.head_dim) kv = torch.stack([key, value], dim=2) query, kv = rotary_emb(query, @@ -607,8 +612,7 @@ def forward( max_seqlen=seq_len) [key, value] = torch.unbind(kv, dim=2) - value = value.view(*(value.shape[:-2]), - self.kv_n_heads * self.head_dim) + value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim) elif rotary_emb_w_meta_info['imp'] == 'hf': (cos, sin) = rotary_emb(value, seq_len) # The following two transposes should be removed once the transformers library allows for the specification of the dimension for heads in the call to apply_rotary_pos_emb @@ -620,8 +624,8 @@ def forward( query = query.transpose(1, 2) key = key.transpose(1, 2) - query = query.view(*(query.shape[:-2]), self.d_model) - key = key.view(*(key.shape[:-2]), self.kv_n_heads * self.head_dim) + query = query.view(bsz, seqlen, self.d_model) + key = key.view(bsz, seqlen, self.kv_n_heads * self.head_dim) context, attn_weights, past_key_value = self.attn_fn( query, diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 976ffc6768..6605807c6b 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -12,6 +12,31 @@ from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY +attn_config_defaults: Dict = { + 'attn_type': 'multihead_attention', + 'attn_pdrop': 0.0, + 'attn_impl': 'triton', + 'qk_ln': False, + 'clip_qkv': None, + 'softmax_scale': None, + 'prefix_lm': False, + 'attn_uses_sequence_id': False, + 'alibi': False, + 'alibi_bias_max': 8, + 'rope': False, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, +} + class MPTBlock(nn.Module): @@ -30,30 +55,7 @@ def __init__( **kwargs: Any, ): if attn_config is None: - attn_config = { - 'attn_type': 'multihead_attention', - 'attn_pdrop': 0.0, - 'attn_impl': 'triton', - 'qk_ln': False, - 'clip_qkv': None, - 'softmax_scale': None, - 'prefix_lm': False, - 'attn_uses_sequence_id': False, - 'alibi': False, - 'alibi_bias_max': 8, - 'rope': False, - 'rope_theta': 10000, - 'rope_imp': 'dail', - 'rope_dail_config': { - 'type': 'original', - 'pos_idx_in_fp32': True, - 'xpos_scale_base': 512, - }, - 'rope_hf_config': { - 'type': 'no_scaling', - 'factor': 1.0, - }, - } + attn_config = attn_config_defaults if ffn_config is None: ffn_config = { @@ -70,7 +72,7 @@ def __init__( # necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs args_to_exclude_in_attn_class = { 'attn_type', 'prefix_lm', 'alibi', 'attn_uses_sequence_id', - 'alibi_bias_max', 'rope', 'rope_theta', 'rope_imp', + 'alibi_bias_max', 'rope', 'rope_theta', 'rope_impl', 'rope_dail_config', 'rope_hf_config' } attn_config_subset_for_attn_class = { diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 97749a365a..9c6f4287f7 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -8,30 +8,7 @@ from transformers import PretrainedConfig -attn_config_defaults: Dict = { - 'attn_type': 'multihead_attention', - 'attn_pdrop': 0.0, - 'attn_impl': 'triton', - 'qk_ln': False, - 'clip_qkv': None, - 'softmax_scale': None, - 'prefix_lm': False, - 'attn_uses_sequence_id': False, - 'alibi': False, - 'alibi_bias_max': 8, - 'rope': False, - 'rope_theta': 10000, - 'rope_imp': 'dail', - 'rope_dail_config': { - 'type': 'original', - 'pos_idx_in_fp32': True, - 'xpos_scale_base': 512, - }, - 'rope_hf_config': { - 'type': 'no_scaling', - 'factor': 1.0, - }, -} +from llmfoundry.models.layers.blocks import attn_config_defaults ffn_config_defaults: Dict = { 'ffn_type': 'mptmlp', @@ -108,7 +85,7 @@ def __init__( alibi_bias_max (int): The maximum value of the alibi bias. rope (bool): Whether to use rotary positional embeddings. rope_theta (int): The base frequency for rope. - rope_imp (str): The implementation of rope to use. One of 'hf' (to use the implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py) or 'dail' (to use the implementation from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py). + rope_impl (str): The implementation of rope to use. One of 'hf' (to use the implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py) or 'dail' (to use the implementation from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py). rope_dail_config (Dict): The configuration for the dail implementation of rope. type (str): The type of rotary position embedding to use. Options: 'original' (for https://arxiv.org/pdf/2104.09864.pdf), 'xpos' (for https://arxiv.org/pdf/2212.10554.pdf). pos_idx_in_fp32 (bool): If True, the position indices [0, ..., seqlen - 1] are in fp32, otherwise they might be in lower precision. A consequence could be, for example, that bf16 rounds position 1995 to 2000, which leads to them having the same positional embedding. @@ -189,6 +166,7 @@ def _set_config_defaults(self, config: Dict[str, Any], if k not in config: config[k] = v elif isinstance(v, dict): + # recursively set default values for any sub-dicts config[k] = self._set_config_defaults( config[k] if (config[k] is not None) else {}, v) return config @@ -233,13 +211,13 @@ def _validate_config(self) -> None: raise NotImplementedError( 'attn_uses_sequence_id only implemented with torch and triton attention.' ) - if self.attn_config['rope'] and (self.attn_config['rope_imp'] + if self.attn_config['rope'] and (self.attn_config['rope_impl'] not in ['dail', 'hf']): raise ValueError( - 'If rope is being used then rope_imp should be either "dail", or "hf".' + 'If rope is being used then rope_impl should be either "dail", or "hf".' ) if self.attn_config['rope'] and ( - self.attn_config['rope_imp'] + self.attn_config['rope_impl'] == 'hf') and self.attn_config['rope_hf_config']['type'] not in [ 'no_scaling', 'linear', 'dynamic' ]: @@ -247,7 +225,7 @@ def _validate_config(self) -> None: 'If using hf implementation of rope, the type should be one of "no_scaling", "linear" or "dynamic".' ) if self.attn_config['rope'] and ( - self.attn_config['rope_imp'] + self.attn_config['rope_impl'] == 'dail') and (self.attn_config['rope_dail_config']['type'] not in ['original', 'xpos']): raise ValueError( diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index bdd752fc41..0d52ee214d 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -86,45 +86,44 @@ log = logging.getLogger(__name__) -def _rotary_embedding(config: MPTConfig): - rope_head_dim = config.d_model // config.n_heads - if config.attn_config['rope_imp'] == 'dail': +def _rotary_embedding(rope_head_dim: int, rope_impl: str, rope_theta: int, + rope_dail_config: dict, rope_hf_config: dict, + max_seq_len: int): + if rope_impl == 'dail': return DAILRotaryEmbedding( dim=rope_head_dim, - base=config.attn_config['rope_theta'], + base=rope_theta, interleaved=False, - scale_base=config.attn_config['rope_dail_config']['xpos_scale_base'] - if (config.attn_config['rope_dail_config']['type'] - == 'xpos') else None, - pos_idx_in_fp32=config.attn_config['rope_dail_config'] - ['pos_idx_in_fp32'], + scale_base=rope_dail_config['xpos_scale_base'] if + (rope_dail_config['type'] == 'xpos') else None, + pos_idx_in_fp32=rope_dail_config['pos_idx_in_fp32'], device= 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu ) - elif config.attn_config['rope_imp'] == 'hf': - if config.attn_config['rope_hf_config']['type'] == 'no_scaling': + elif rope_impl == 'hf': + if rope_hf_config['type'] == 'no_scaling': return HFRotaryEmbedding( rope_head_dim, - max_position_embeddings=config.max_seq_len, - base=config.attn_config['rope_theta'], + max_position_embeddings=max_seq_len, + base=rope_theta, device= 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu ) - elif config.attn_config['rope_hf_config']['type'] == 'linear': + elif rope_hf_config['type'] == 'linear': return HFLinearScalingRotaryEmbedding( rope_head_dim, - max_position_embeddings=config.max_seq_len, - base=config.attn_config['rope_theta'], - scaling_factor=config.attn_config['rope_hf_config']['factor'], + max_position_embeddings=max_seq_len, + base=rope_theta, + scaling_factor=rope_hf_config['factor'], device= 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu ) - elif config.attn_config['rope_hf_config']['type'] == 'dynamic': + elif rope_hf_config['type'] == 'dynamic': return HFDynamicNTKScalingRotaryEmbedding( rope_head_dim, - max_position_embeddings=config.max_seq_len, - base=config.attn_config['rope_theta'], - scaling_factor=config.attn_config['rope_hf_config']['factor'], + max_position_embeddings=max_seq_len, + base=rope_theta, + scaling_factor=rope_hf_config['factor'], device= 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu ) @@ -184,10 +183,16 @@ def __init__(self, config: MPTConfig): self.norm_f = norm_class(config.d_model, device=config.init_device) self.rope = config.attn_config['rope'] - self.rope_imp = None + self.rope_impl = None if self.rope: - self.rope_imp = config.attn_config['rope_imp'] - self.rotary_embedding = _rotary_embedding(config) + self.rope_impl = config.attn_config['rope_impl'] + self.rotary_embedding = _rotary_embedding( + rope_head_dim=config.d_model // config.n_heads, + rope_impl=self.rope_impl, + rope_theta=config.attn_config['rope_theta'], + rope_dail_config=config.attn_config['rope_dail_config'], + rope_hf_config=config.attn_config['rope_hf_config'], + max_seq_len=self.config.max_seq_len) if config.init_device != 'meta': log.info( @@ -453,7 +458,7 @@ def forward( f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.' ) - if self.learned_pos_emb or (self.rope and self.rope_imp == 'hf'): + if self.learned_pos_emb or (self.rope and self.rope_impl == 'hf'): pos = torch.arange( past_position, S + past_position, @@ -469,16 +474,16 @@ def forward( ) if self.learned_pos_emb: x = x + self.wpe(pos) - elif self.rope and self.rope_imp == 'hf': + elif self.rope and self.rope_impl == 'hf': rotary_emb_w_meta_info = { - 'imp': self.rope_imp, + 'imp': self.rope_impl, 'rotary_emb': self.rotary_embedding, 'offset_info': pos, 'seq_len': S + past_position, } - elif self.rope and self.rope_imp == 'dail': + elif self.rope and self.rope_impl == 'dail': rotary_emb_w_meta_info = { - 'imp': self.rope_imp, + 'imp': self.rope_impl, 'rotary_emb': self.rotary_embedding, 'offset_info': past_position, 'seq_len': S + past_position, diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index 6900da633f..cae9490cc3 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -31,7 +31,7 @@ def allclose_helper(t0: torch.Tensor, 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_imp': 'dail', + 'rope_impl': 'dail', 'rope_dail_config': { 'type': 'original', 'pos_idx_in_fp32': True, @@ -41,7 +41,7 @@ def allclose_helper(t0: torch.Tensor, 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_imp': 'hf', + 'rope_impl': 'hf', 'rope_hf_config': { 'type': 'no_scaling', 'factor': 1.0, @@ -68,7 +68,7 @@ def test_attn_impl(attn_impl_0: str, if alibi and (attn_impl_0 == 'flash' or attn_impl_1 == 'flash'): pytest.xfail('flash attn does not support alibi') - if rope and (pos_emb_config['rope_imp'] + if rope and (pos_emb_config['rope_impl'] == 'dail') and (not is_flash_v2_installed()): pytest.skip('dail implementation of rope requires flash attention 2.') @@ -140,11 +140,11 @@ def gen_bias(attn_impl: str): ) rotary_emb_w_meta_info = { 'imp': - pos_emb_config['rope_imp'], + pos_emb_config['rope_impl'], 'rotary_emb': rotary_embedding, 'offset_info': - pos if (pos_emb_config['rope_imp'] == 'hf') else 0, + pos if (pos_emb_config['rope_impl'] == 'hf') else 0, 'seq_len': s, } diff --git a/tests/test_model.py b/tests/test_model.py index a85a3050c7..41b62f0ccf 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -528,7 +528,7 @@ def test_mpt_creation(norm_type: str, no_bias: bool): 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_imp': 'dail', + 'rope_impl': 'dail', 'rope_dail_config': { 'type': 'original', 'pos_idx_in_fp32': True, @@ -538,7 +538,7 @@ def test_mpt_creation(norm_type: str, no_bias: bool): 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_imp': 'hf', + 'rope_impl': 'hf', 'rope_hf_config': { 'type': 'no_scaling', 'factor': 1.0, @@ -556,7 +556,7 @@ def test_forward_with_padding(attention_impl: str, device: str, pytest.skip(f'alibi only implemented with torch and triton attention.') rope = pos_emb_config['rope'] - if rope and pos_emb_config['rope_imp'] == 'dail' and ( + if rope and pos_emb_config['rope_impl'] == 'dail' and ( device != 'gpu' or not is_flash_v2_installed()): pytest.skip( f'dail implementation of rope requires gpu and flash attention 2.') @@ -646,18 +646,18 @@ def test_forward_with_padding(attention_impl: str, device: str, attention_mask=batched_attention_mask).logits # check that right padding and left padding produce the same output - if rope and pos_emb_config[ - 'rope_imp'] == 'dail': # dail implementation of rope uses bf16 precision and hence the rotations have small numerical errors. This causes some differences between the outputs of padded and unpadded inputs. - assert torch.allclose(right_padding_output[0, :3], - left_padding_output[0, 3:], - rtol=1e-2, - atol=1e-2) - else: - assert torch.allclose( - right_padding_output[0, :3], - left_padding_output[0, 3:], - atol=1e-6 if attention_impl == 'torch' else 1e-8) - if not (alibi or (rope and pos_emb_config['rope_imp'] == 'dail')): + right_pad_v_left_pad_rtol = 1e-5 + right_pad_v_left_pad_atol = 1e-6 if attention_impl == 'torch' else 1e-8 + if rope and pos_emb_config['rope_impl'] == 'dail': + # dail implementation of rope uses bf16 precision and hence the rotations have small numerical errors. This causes some differences between the outputs of padded and unpadded inputs. + right_pad_v_left_pad_rtol = 1e-2 + right_pad_v_left_pad_atol = 1e-2 + assert torch.allclose(right_padding_output[0, :3], + left_padding_output[0, 3:], + rtol=right_pad_v_left_pad_rtol, + atol=right_pad_v_left_pad_atol) + + if not (alibi or (rope and pos_emb_config['rope_impl'] == 'dail')): # check that right padding and middle padding produce the same output # Note: alibi not implemented for middle padding. # Note: dail implementation of rope does not support middle padding. @@ -665,20 +665,13 @@ def test_forward_with_padding(attention_impl: str, device: str, right_padding_output[0, :3], middle_padding_output[0, [0, 1, 5]], atol=1e-6 if attention_impl == 'torch' else 1e-8) + # check that right padding and right padding in a batch produce the same output + assert torch.allclose(right_padding_output[0, :3], + batched_output[0, :3], + atol=1e-6 if attention_impl == 'torch' else 1e-8) - if rope and pos_emb_config[ - 'rope_imp'] == 'dail': # dail implementation of rope uses bf16 precision and hence the rotations have small numerical errors. This causes some differences between the outputs of padded and unpadded inputs. - assert torch.allclose(right_padding_output[0, :3], - left_padding_output[0, 3:], - rtol=1e-2, - atol=1e-2) - else: - assert torch.allclose( - right_padding_output[0, :3], - batched_output[0, :3], - atol=1e-6 if attention_impl == 'torch' else 1e-8) - if not (alibi or (rope and pos_emb_config['rope_imp'] == 'dail')): + if not (alibi or (rope and pos_emb_config['rope_impl'] == 'dail')): # check that middle padding and middle padding in a batch produce the same output # Note: alibi not implemented for middle padding. # Note: dail implementation of rope does not support middle padding. @@ -757,7 +750,7 @@ def test_advanced_mask_building(attention_impl: str): 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_imp': 'dail', + 'rope_impl': 'dail', 'rope_dail_config': { 'type': 'original', 'pos_idx_in_fp32': True, @@ -767,7 +760,7 @@ def test_advanced_mask_building(attention_impl: str): 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_imp': 'hf', + 'rope_impl': 'hf', 'rope_hf_config': { 'type': 'no_scaling', 'factor': 1.0, @@ -783,7 +776,7 @@ def test_generate(attention_impl: str, device: str, pos_emb_config: dict): if pos_emb_config['alibi'] and attention_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') - if pos_emb_config['rope'] and pos_emb_config['rope_imp'] == 'dail' and ( + if pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail' and ( device != 'gpu' or not is_flash_v2_installed()): pytest.skip( f'dail implementation of rope requires gpu and flash attention 2.') @@ -985,7 +978,7 @@ def test_save_from_pretrained(tmp_path: pathlib.Path): 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_imp': 'dail', + 'rope_impl': 'dail', 'rope_dail_config': { 'type': 'original', 'pos_idx_in_fp32': True, @@ -995,7 +988,7 @@ def test_save_from_pretrained(tmp_path: pathlib.Path): 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_imp': 'hf', + 'rope_impl': 'hf', 'rope_hf_config': { 'type': 'no_scaling', 'factor': 1.0, @@ -1010,7 +1003,7 @@ def test_forward_with_cache_and_padding(attn_impl: str, device: str, ) if pos_emb_config['alibi'] and attn_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') - if pos_emb_config['rope'] and pos_emb_config['rope_imp'] == 'dail' and ( + if pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail' and ( device != 'gpu' or not is_flash_v2_installed()): pytest.skip( f'dail implementation of rope requires gpu and flash attention 2.') @@ -1094,7 +1087,7 @@ def test_forward_with_cache_and_padding(attn_impl: str, device: str, # check that the outputs are the same with or without padding if pos_emb_config['rope'] and pos_emb_config[ - 'rope_imp'] == 'dail': # dail implementation of rope uses bf16 precision and hence the rotations have small numerical errors. This causes some differences between the outputs of padded and unpadded inputs. + 'rope_impl'] == 'dail': # dail implementation of rope uses bf16 precision and hence the rotations have small numerical errors. This causes some differences between the outputs of padded and unpadded inputs. torch.testing.assert_close( second_output_no_padding.logits, second_output_padding.logits[:, -1, :].unsqueeze(1), @@ -1124,7 +1117,7 @@ def test_forward_with_cache_and_padding(attn_impl: str, device: str, 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_imp': 'dail', + 'rope_impl': 'dail', 'rope_dail_config': { 'type': 'original', 'pos_idx_in_fp32': True, @@ -1134,7 +1127,7 @@ def test_forward_with_cache_and_padding(attn_impl: str, device: str, 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_imp': 'hf', + 'rope_impl': 'hf', 'rope_hf_config': { 'type': 'no_scaling', 'factor': 1.0, @@ -1150,7 +1143,7 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): if pos_emb_config['alibi'] and attn_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') - if pos_emb_config['rope'] and pos_emb_config['rope_imp'] == 'dail' and ( + if pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail' and ( device != 'gpu' or not is_flash_v2_installed()): pytest.skip( f'dail implementation of rope requires gpu and flash attention 2.') @@ -1265,7 +1258,7 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_imp': 'dail', + 'rope_impl': 'dail', 'rope_dail_config': { 'type': 'original', 'pos_idx_in_fp32': True, @@ -1275,7 +1268,7 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_imp': 'hf', + 'rope_impl': 'hf', 'rope_hf_config': { 'type': 'no_scaling', 'factor': 1.0, @@ -1289,7 +1282,7 @@ def test_generate_with_past_kv(attn_impl: str, device: str, ) if pos_emb_config['alibi'] and attn_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') - if pos_emb_config['rope'] and pos_emb_config['rope_imp'] == 'dail' and ( + if pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail' and ( device != 'gpu' or not is_flash_v2_installed()): pytest.skip( f'dail implementation of rope requires gpu and flash attention 2.') @@ -1377,7 +1370,7 @@ def test_generate_with_past_kv(attn_impl: str, device: str, 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_imp': 'dail', + 'rope_impl': 'dail', 'rope_dail_config': { 'type': 'original', 'pos_idx_in_fp32': True, @@ -1387,7 +1380,7 @@ def test_generate_with_past_kv(attn_impl: str, device: str, 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_imp': 'hf', + 'rope_impl': 'hf', 'rope_hf_config': { 'type': 'no_scaling', 'factor': 1.0, @@ -1403,7 +1396,7 @@ def test_generation_kwargs_dont_crash(attn_impl: str, device: str, if pos_emb_config['alibi'] and attn_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') - if pos_emb_config['rope'] and pos_emb_config['rope_imp'] == 'dail' and ( + if pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail' and ( device != 'gpu' or not is_flash_v2_installed()): pytest.skip( f'dail implementation of rope requires gpu and flash attention 2.') @@ -1458,7 +1451,7 @@ def test_generation_kwargs_dont_crash(attn_impl: str, device: str, 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_imp': 'dail', + 'rope_impl': 'dail', 'rope_dail_config': { 'type': 'original', 'pos_idx_in_fp32': True, @@ -1468,7 +1461,7 @@ def test_generation_kwargs_dont_crash(attn_impl: str, device: str, 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_imp': 'hf', + 'rope_impl': 'hf', 'rope_hf_config': { 'type': 'no_scaling', 'factor': 1.0, @@ -1484,7 +1477,7 @@ def test_model_to(attention_impl: str, pos_emb_config: dict): pytest.skip(f'alibi only implemented with torch and triton attention.') if pos_emb_config['rope'] and pos_emb_config[ - 'rope_imp'] == 'dail' and not is_flash_v2_installed(): + 'rope_impl'] == 'dail' and not is_flash_v2_installed(): pytest.skip(f'dail implementation of rope requires flash attention 2.') hf_config = MPTConfig( @@ -1523,8 +1516,8 @@ def test_model_to(attention_impl: str, pos_emb_config: dict): mpt = mpt.to('cpu') # verify the model still works - if attention_impl == 'torch' and not (pos_emb_config['rope'] and - pos_emb_config['rope_imp'] == 'dail'): + if attention_impl == 'torch' and not ( + pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail'): with torch.autocast('cpu', dtype=torch.bfloat16, enabled=True): _ = mpt(input_ids.to('cpu'), attention_mask=attention_mask.to('cpu')) @@ -1541,8 +1534,8 @@ def test_model_to(attention_impl: str, pos_emb_config: dict): mpt = mpt.float() # verify the model still works - if attention_impl == 'torch' and not (pos_emb_config['rope'] and - pos_emb_config['rope_imp'] == 'dail'): + if attention_impl == 'torch' and not ( + pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail'): _ = mpt(input_ids.to('cpu'), attention_mask=attention_mask.to('cpu')) mpt = mpt.half() @@ -1589,7 +1582,7 @@ def test_alibi_vs_hf(): 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_imp': 'dail', + 'rope_impl': 'dail', 'rope_dail_config': { 'type': 'original', 'pos_idx_in_fp32': True, @@ -1599,7 +1592,7 @@ def test_alibi_vs_hf(): 'alibi': False, 'rope': True, 'rope_theta': 10000, - 'rope_imp': 'hf', + 'rope_impl': 'hf', 'rope_hf_config': { 'type': 'no_scaling', 'factor': 1.0, @@ -1619,7 +1612,7 @@ def test_forward_with_output_attentions_and_output_hidden_states( pytest.skip(f'alibi only implemented with torch and triton attention.') if output_attentions and attn_impl in ['flash', 'triton']: pytest.skip(f'output_attentions only implemented with torch attention.') - if pos_emb_config['rope'] and pos_emb_config['rope_imp'] == 'dail' and ( + if pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail' and ( device != 'gpu' or not is_flash_v2_installed()): pytest.skip( f'dail implementation of rope requires gpu and flash attention 2.') diff --git a/tests/test_rope_dail_vs_hf.py b/tests/test_rope_dail_vs_hf.py index 76d5cd50b0..55c6536871 100644 --- a/tests/test_rope_dail_vs_hf.py +++ b/tests/test_rope_dail_vs_hf.py @@ -20,7 +20,7 @@ def gen_rotary_embedding(rope_head_dim: int, pos_emb_config: dict, max_seq_len: int): - if pos_emb_config['rope_imp'] == 'dail': + if pos_emb_config['rope_impl'] == 'dail': return DAILRotaryEmbedding( dim=rope_head_dim, base=pos_emb_config['rope_theta'], @@ -32,7 +32,7 @@ def gen_rotary_embedding(rope_head_dim: int, pos_emb_config: dict, device= 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu ) - elif pos_emb_config['rope_imp'] == 'hf': + elif pos_emb_config['rope_impl'] == 'hf': if pos_emb_config['rope_hf_config']['type'] == 'no_scaling': return HFRotaryEmbedding( rope_head_dim, @@ -64,7 +64,7 @@ def gen_rotary_embedding(rope_head_dim: int, pos_emb_config: dict, f'Invalid scaling type: {pos_emb_config["rope_hf_config"]["type"]}' ) else: - raise ValueError(f'Invalid rope_imp: {pos_emb_config["rope_imp"]}') + raise ValueError(f'Invalid rope_impl: {pos_emb_config["rope_impl"]}') @pytest.mark.gpu @@ -112,7 +112,7 @@ def test_rope_dail_vs_hf(clip_qkv: bool, with get_precision_context('amp_bf16'): dail_rope_config = { 'rope_theta': 10000, - 'rope_imp': 'dail', + 'rope_impl': 'dail', 'rope_dail_config': { 'type': 'original', 'pos_idx_in_fp32': True, @@ -121,7 +121,7 @@ def test_rope_dail_vs_hf(clip_qkv: bool, } hf_rope_config = { 'rope_theta': 10000, - 'rope_imp': 'hf', + 'rope_impl': 'hf', 'rope_hf_config': { 'type': 'no_scaling', 'factor': 1.0, From ac0fd40651a504ac4a902464e4009fdf9698402b Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 2 Nov 2023 22:36:36 +0000 Subject: [PATCH 099/106] fixed tests --- llmfoundry/models/layers/attention.py | 4 ++-- llmfoundry/models/mpt/configuration_mpt.py | 28 ++++++++++++++++------ 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 5bec1dbb70..1b02c41c29 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -17,12 +17,12 @@ from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY -def is_flash_v2_installed(): +def is_flash_v2_installed(v2_version: str = '2.0.0'): try: import flash_attn as flash_attn except: return False - return version.parse(flash_attn.__version__) >= version.parse('2.0.0') + return version.parse(flash_attn.__version__) >= version.parse(v2_version) def is_flash_v1_installed(): diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 9c6f4287f7..c4ca68d733 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -8,8 +8,17 @@ from transformers import PretrainedConfig +from llmfoundry.models.layers.attention import is_flash_v2_installed from llmfoundry.models.layers.blocks import attn_config_defaults +# NOTE: All utils are imported directly even if unused so that +# HuggingFace can detect all the needed files to copy into its modules folder. +# Otherwise, certain modules are missing. +# isort: off +from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY # type: ignore (see note) +from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note) +from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note) + ffn_config_defaults: Dict = { 'ffn_type': 'mptmlp', } @@ -224,13 +233,18 @@ def _validate_config(self) -> None: raise ValueError( 'If using hf implementation of rope, the type should be one of "no_scaling", "linear" or "dynamic".' ) - if self.attn_config['rope'] and ( - self.attn_config['rope_impl'] - == 'dail') and (self.attn_config['rope_dail_config']['type'] - not in ['original', 'xpos']): - raise ValueError( - 'If using the dail implementation of rope, the type should be one of "original" or "xpos".' - ) + if self.attn_config['rope'] and (self.attn_config['rope_impl'] + == 'dail'): + if self.attn_config['rope_dail_config']['type'] not in [ + 'original', 'xpos' + ]: + raise ValueError( + 'If using the dail implementation of rope, the type should be one of "original" or "xpos".' + ) + if not is_flash_v2_installed(v2_version='2.0.1'): + raise ImportError( + 'If using the dail implementation of rope, the flash_attn library v2.0.1 or higher must be installed. Please check the instructions at https://github.com/mosaicml/llm-foundry/blob/main/TUTORIAL.md#what-kinds-of-positional-embeddings-does-llm-foundry-support' + ) if self.embedding_fraction > 1 or self.embedding_fraction <= 0: raise ValueError( 'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!' From e59f78414cc5ee45aafbc2d3a7e2c801badc780b Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 3 Nov 2023 05:34:47 +0000 Subject: [PATCH 100/106] modified is_flash_v2_installed --- llmfoundry/models/layers/attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 1b02c41c29..a663875a2d 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -18,6 +18,7 @@ def is_flash_v2_installed(v2_version: str = '2.0.0'): + assert version.parse(v2_version) >= version.parse('2.0.0') try: import flash_attn as flash_attn except: @@ -597,7 +598,7 @@ def forward( seq_len = rotary_emb_w_meta_info['seq_len'] offset_info = rotary_emb_w_meta_info['offset_info'] assert query.shape[:2] == key.shape[:2] - assert query.shape[:2] == key.shape[:2] + assert query.shape[:2] == value.shape[:2] bsz, seqlen = query.shape[:2] query = query.view(bsz, seqlen, -1, self.head_dim) key = key.view(bsz, seqlen, -1, self.head_dim) From c602a06a8e88b8702ea7de6d9dbb112d184a9c54 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Fri, 3 Nov 2023 18:14:36 +0000 Subject: [PATCH 101/106] minor changes --- llmfoundry/models/mpt/modeling_mpt.py | 9 +-- tests/test_flash_triton_torch.py | 7 ++- tests/test_rope_dail_vs_hf.py | 82 ++++++--------------------- 3 files changed, 26 insertions(+), 72 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 0d52ee214d..059011a847 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -86,9 +86,9 @@ log = logging.getLogger(__name__) -def _rotary_embedding(rope_head_dim: int, rope_impl: str, rope_theta: int, - rope_dail_config: dict, rope_hf_config: dict, - max_seq_len: int): +def gen_rotary_embedding(rope_head_dim: int, rope_impl: str, rope_theta: int, + rope_dail_config: dict, rope_hf_config: dict, + max_seq_len: int): if rope_impl == 'dail': return DAILRotaryEmbedding( dim=rope_head_dim, @@ -127,6 +127,7 @@ def _rotary_embedding(rope_head_dim: int, rope_impl: str, rope_theta: int, device= 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu ) + raise ValueError('rope_impl needs to be either dail or hf') class MPTPreTrainedModel(PreTrainedModel): @@ -186,7 +187,7 @@ def __init__(self, config: MPTConfig): self.rope_impl = None if self.rope: self.rope_impl = config.attn_config['rope_impl'] - self.rotary_embedding = _rotary_embedding( + self.rotary_embedding = gen_rotary_embedding( rope_head_dim=config.d_model // config.n_heads, rope_impl=self.rope_impl, rope_theta=config.attn_config['rope_theta'], diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index cae9490cc3..3f79bf0c7e 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -6,7 +6,7 @@ from omegaconf import OmegaConf as om from llmfoundry.models.layers.attention import is_flash_v2_installed -from tests.test_rope_dail_vs_hf import gen_rotary_embedding +from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding def allclose_helper(t0: torch.Tensor, @@ -130,7 +130,10 @@ def gen_bias(attn_impl: str): if rope: rotary_embedding = gen_rotary_embedding( rope_head_dim=cfg.d_model // cfg.n_heads, - pos_emb_config=pos_emb_config, + rope_impl=pos_emb_config['rope_impl'], + rope_theta=pos_emb_config['rope_theta'], + rope_dail_config=pos_emb_config.get('rope_dail_config', {}), + rope_hf_config=pos_emb_config.get('rope_hf_config', {}), max_seq_len=s).to(device) pos = torch.arange(s).unsqueeze(0).to(device=device) # adjust the position indices to account for padding tokens diff --git a/tests/test_rope_dail_vs_hf.py b/tests/test_rope_dail_vs_hf.py index 55c6536871..9b2d471e19 100644 --- a/tests/test_rope_dail_vs_hf.py +++ b/tests/test_rope_dail_vs_hf.py @@ -4,67 +4,10 @@ import pytest import torch from composer.core.precision import get_precision_context +from omegaconf import OmegaConf as om from llmfoundry.models.layers.attention import is_flash_v2_installed - -if is_flash_v2_installed(): - from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding -from omegaconf import OmegaConf as om -from transformers.models.llama.modeling_llama import \ - LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding -from transformers.models.llama.modeling_llama import \ - LlamaLinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding -from transformers.models.llama.modeling_llama import \ - LlamaRotaryEmbedding as HFRotaryEmbedding - - -def gen_rotary_embedding(rope_head_dim: int, pos_emb_config: dict, - max_seq_len: int): - if pos_emb_config['rope_impl'] == 'dail': - return DAILRotaryEmbedding( - dim=rope_head_dim, - base=pos_emb_config['rope_theta'], - interleaved=False, - scale_base=pos_emb_config['rope_dail_config']['xpos_scale_base'] if - (pos_emb_config['rope_dail_config']['type'] == 'xpos') else None, - pos_idx_in_fp32=pos_emb_config['rope_dail_config'] - ['pos_idx_in_fp32'], - device= - 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - elif pos_emb_config['rope_impl'] == 'hf': - if pos_emb_config['rope_hf_config']['type'] == 'no_scaling': - return HFRotaryEmbedding( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=pos_emb_config['rope_theta'], - device= - 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - elif pos_emb_config['rope_hf_config']['type'] == 'linear': - return HFLinearScalingRotaryEmbedding( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=pos_emb_config['rope_theta'], - scaling_factor=pos_emb_config['rope_hf_config']['factor'], - device= - 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - elif pos_emb_config['rope_hf_config']['type'] == 'dynamic': - return HFDynamicNTKScalingRotaryEmbedding( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=pos_emb_config['rope_theta'], - scaling_factor=pos_emb_config['rope_hf_config']['factor'], - device= - 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - else: - raise ValueError( - f'Invalid scaling type: {pos_emb_config["rope_hf_config"]["type"]}' - ) - else: - raise ValueError(f'Invalid rope_impl: {pos_emb_config["rope_impl"]}') +from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding @pytest.mark.gpu @@ -128,10 +71,13 @@ def test_rope_dail_vs_hf(clip_qkv: bool, } } - dail_rope = gen_rotary_embedding(rope_head_dim=cfg.d_model // - cfg.n_heads, - pos_emb_config=dail_rope_config, - max_seq_len=seq_len).to('cuda') + dail_rope = gen_rotary_embedding( + rope_head_dim=cfg.d_model // cfg.n_heads, + rope_impl=dail_rope_config['rope_impl'], + rope_theta=dail_rope_config['rope_theta'], + rope_dail_config=dail_rope_config['rope_dail_config'], + rope_hf_config={}, + max_seq_len=seq_len).to('cuda') dail_rope_w_meta_info = { 'imp': 'dail', 'rotary_emb': dail_rope, @@ -139,9 +85,13 @@ def test_rope_dail_vs_hf(clip_qkv: bool, 'seq_len': seq_len, } - hf_rope = gen_rotary_embedding(rope_head_dim=cfg.d_model // cfg.n_heads, - pos_emb_config=hf_rope_config, - max_seq_len=seq_len).to('cuda') + hf_rope = gen_rotary_embedding( + rope_head_dim=cfg.d_model // cfg.n_heads, + rope_impl=hf_rope_config['rope_impl'], + rope_theta=hf_rope_config['rope_theta'], + rope_dail_config={}, + rope_hf_config=hf_rope_config['rope_hf_config'], + max_seq_len=seq_len).to('cuda') pos = torch.arange(seq_len).unsqueeze(0).to(device='cuda') # adjust the position indices to account for padding tokens pos = torch.clamp( From 5744a3e8a7a923ec8cbd64b847fa74192029cbf7 Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Sat, 4 Nov 2023 10:48:20 -0700 Subject: [PATCH 102/106] Update TUTORIAL.md Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- TUTORIAL.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TUTORIAL.md b/TUTORIAL.md index 117a61f617..a7dd7b4d37 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -347,7 +347,7 @@ The majority of our training setups use `triton`. --> - [FlashAttention-2](https://arxiv.org/pdf/2307.08691.pdf) improves upon FlashAttention to get even faster attention computation. LLM-Foundry now supports FlashAttention-2, please follow the instructions [here](https://github.com/mosaicml/llm-foundry/tree/main/scripts/train#flashattention). ### What kinds of positional embeddings does LLM Foundry support? -Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706.03762.pdf), [Attention with Linear Biases (ALiBi)](https://arxiv.org/pdf/2108.12409.pdf), and [Rotary Positional Embeddings (RoPE)](https://arxiv.org/pdf/2104.09864.pdf). There is also an option to switch off all of these embeddings to get (No Positional Embedding)[https://arxiv.org/pdf/2203.16634.pdf]. +Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706.03762.pdf), [Attention with Linear Biases (ALiBi)](https://arxiv.org/pdf/2108.12409.pdf), and [Rotary Positional Embeddings (RoPE)](https://arxiv.org/pdf/2104.09864.pdf). There is also an option to switch off all of these embeddings to get [No Positional Embedding](https://arxiv.org/pdf/2203.16634.pdf). | Name | YAML Config | Training MFU on MPT-7B trained on 8 A-100 80GB GPUs | Notes | |:-----------------------------------|:------------------------------------------------------------------|:---------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------| From 0036ce7130b96808840b7a38d5aedea3de365d0d Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Sat, 4 Nov 2023 10:48:33 -0700 Subject: [PATCH 103/106] Update TUTORIAL.md Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- TUTORIAL.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TUTORIAL.md b/TUTORIAL.md index a7dd7b4d37..085a5e4a40 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -353,7 +353,7 @@ Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706. |:-----------------------------------|:------------------------------------------------------------------|:---------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | Learned Positional Embeddings |
model:
learned_pos_emb: True
| 65.7 | | | ALiBi |
model:
attn_config:
alibi: True
| 64.5 | Requires Triton or Torch attention. | -| RoPE (Dao-AILab Implementation) |
model:
attn_config:
rope: True
rope_impl: dail
| 64.5 | Requires a CUDA GPU and [the flash-attn library](https://github.com/Dao-AILab/flash-attention) (v2.0.1 or higher) to be installed. Please see the instructions in the [paragraph above](#support-for-flashattention-2) on how to install flash-attn library v2. Note that attention implementation can still be torch, triton, or flash, just that this needs the the flash-attn library (v2.0.1 or higher) installed since we import their RotaryEmbedding class. | +| RoPE (Dao-AILab Implementation) |
model:
attn_config:
rope: True
rope_impl: dail
| 64.5 | Requires a CUDA GPU and the [flash-attn library](https://github.com/Dao-AILab/flash-attention) v2.0.1 or higher to be installed. Please see the instructions in the [paragraph above](#support-for-flashattention-2) on how to install flash-attn v2. Note that the attention implementation can still be `torch`, `triton`, or `flash`. | | RoPE (Hugging Face Implementation) |
model:
attn_config:
rope: True
rope_impl: hf
| 62.3 | | ### Can I finetune using PEFT / LoRA? From 9ebd2e7a3fddf316ea905f95b4bfc23d89ebfa7e Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Sat, 4 Nov 2023 10:48:42 -0700 Subject: [PATCH 104/106] Update TUTORIAL.md Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- TUTORIAL.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TUTORIAL.md b/TUTORIAL.md index 085a5e4a40..bbc720c0e7 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -349,7 +349,7 @@ The majority of our training setups use `triton`. --> ### What kinds of positional embeddings does LLM Foundry support? Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706.03762.pdf), [Attention with Linear Biases (ALiBi)](https://arxiv.org/pdf/2108.12409.pdf), and [Rotary Positional Embeddings (RoPE)](https://arxiv.org/pdf/2104.09864.pdf). There is also an option to switch off all of these embeddings to get [No Positional Embedding](https://arxiv.org/pdf/2203.16634.pdf). -| Name | YAML Config | Training MFU on MPT-7B trained on 8 A-100 80GB GPUs | Notes | +| Name | YAML Config | Training MFU on MPT-7B trained on 8 A100 80GB GPUs | Notes | |:-----------------------------------|:------------------------------------------------------------------|:---------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | Learned Positional Embeddings |
model:
learned_pos_emb: True
| 65.7 | | | ALiBi |
model:
attn_config:
alibi: True
| 64.5 | Requires Triton or Torch attention. | From 4874713212c381b36b7badd3a24ea7fd30b8a546 Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Sat, 4 Nov 2023 10:49:12 -0700 Subject: [PATCH 105/106] Update TUTORIAL.md Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- TUTORIAL.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TUTORIAL.md b/TUTORIAL.md index bbc720c0e7..86bd9829e9 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -344,7 +344,7 @@ The majority of our training setups use `triton`. --> What is the result of this? Although sm89+ is not **formally** supported until LLVM15, our testing on H100 GPUs shows that `attn_impl=triton` still works well and still runs fast. The only issue is that when the network is starting to run, LLVM might throw a warning like: `'sm_90' is not a recognized processor for this target (ignoring processor)`. This warning does not seem to affect performance. #### Support for FlashAttention-2 -- [FlashAttention-2](https://arxiv.org/pdf/2307.08691.pdf) improves upon FlashAttention to get even faster attention computation. LLM-Foundry now supports FlashAttention-2, please follow the instructions [here](https://github.com/mosaicml/llm-foundry/tree/main/scripts/train#flashattention). +- [FlashAttention-2](https://arxiv.org/pdf/2307.08691.pdf) improves upon FlashAttention to get even faster attention computation. LLM Foundry supports FlashAttention-2. Please follow the instructions [here](https://github.com/mosaicml/llm-foundry/tree/main/scripts/train#flashattention). ### What kinds of positional embeddings does LLM Foundry support? Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706.03762.pdf), [Attention with Linear Biases (ALiBi)](https://arxiv.org/pdf/2108.12409.pdf), and [Rotary Positional Embeddings (RoPE)](https://arxiv.org/pdf/2104.09864.pdf). There is also an option to switch off all of these embeddings to get [No Positional Embedding](https://arxiv.org/pdf/2203.16634.pdf). From e0d8b75c5243c4433542d282478f2b5ac671498f Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Sat, 4 Nov 2023 18:24:01 +0000 Subject: [PATCH 106/106] resolved PR comments --- llmfoundry/models/layers/attention.py | 6 ++---- llmfoundry/models/mpt/modeling_mpt.py | 4 ++-- tests/test_flash_triton_torch.py | 2 +- tests/test_rope_dail_vs_hf.py | 4 ++-- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index a663875a2d..0503d6d75a 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -597,13 +597,11 @@ def forward( rotary_emb = rotary_emb_w_meta_info['rotary_emb'] seq_len = rotary_emb_w_meta_info['seq_len'] offset_info = rotary_emb_w_meta_info['offset_info'] - assert query.shape[:2] == key.shape[:2] - assert query.shape[:2] == value.shape[:2] bsz, seqlen = query.shape[:2] query = query.view(bsz, seqlen, -1, self.head_dim) key = key.view(bsz, seqlen, -1, self.head_dim) - if rotary_emb_w_meta_info['imp'] == 'dail': + if rotary_emb_w_meta_info['impl'] == 'dail': value = value.view(bsz, seqlen, -1, self.head_dim) kv = torch.stack([key, value], dim=2) @@ -614,7 +612,7 @@ def forward( [key, value] = torch.unbind(kv, dim=2) value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim) - elif rotary_emb_w_meta_info['imp'] == 'hf': + elif rotary_emb_w_meta_info['impl'] == 'hf': (cos, sin) = rotary_emb(value, seq_len) # The following two transposes should be removed once the transformers library allows for the specification of the dimension for heads in the call to apply_rotary_pos_emb query = query.transpose(1, 2) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 059011a847..0cb3ebd56c 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -477,14 +477,14 @@ def forward( x = x + self.wpe(pos) elif self.rope and self.rope_impl == 'hf': rotary_emb_w_meta_info = { - 'imp': self.rope_impl, + 'impl': self.rope_impl, 'rotary_emb': self.rotary_embedding, 'offset_info': pos, 'seq_len': S + past_position, } elif self.rope and self.rope_impl == 'dail': rotary_emb_w_meta_info = { - 'imp': self.rope_impl, + 'impl': self.rope_impl, 'rotary_emb': self.rotary_embedding, 'offset_info': past_position, 'seq_len': S + past_position, diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index 3f79bf0c7e..3f2c229d6d 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -142,7 +142,7 @@ def gen_bias(attn_impl: str): min=0, ) rotary_emb_w_meta_info = { - 'imp': + 'impl': pos_emb_config['rope_impl'], 'rotary_emb': rotary_embedding, diff --git a/tests/test_rope_dail_vs_hf.py b/tests/test_rope_dail_vs_hf.py index 9b2d471e19..598e308546 100644 --- a/tests/test_rope_dail_vs_hf.py +++ b/tests/test_rope_dail_vs_hf.py @@ -79,7 +79,7 @@ def test_rope_dail_vs_hf(clip_qkv: bool, rope_hf_config={}, max_seq_len=seq_len).to('cuda') dail_rope_w_meta_info = { - 'imp': 'dail', + 'impl': 'dail', 'rotary_emb': dail_rope, 'offset_info': 0, 'seq_len': seq_len, @@ -99,7 +99,7 @@ def test_rope_dail_vs_hf(clip_qkv: bool, min=0, ) hf_rope_w_meta_info = { - 'imp': 'hf', + 'impl': 'hf', 'rotary_emb': hf_rope, 'offset_info': pos, 'seq_len': seq_len,