From 8fed63a244f901a8a241cd624d7444a0da9fc910 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Tue, 23 Jul 2024 21:04:39 +0000 Subject: [PATCH 01/18] support rope scaling --- llmfoundry/models/mpt/modeling_mpt.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 3b2744f867..0dedc0eac5 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -134,6 +134,8 @@ def gen_rotary_embedding( device= 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu ) + elif rope_hf_config['type'] == 'llama3': + raise NotImplementedError() raise ValueError('rope_impl needs to be either dail or hf') From aa4736f60d0408308161f7445d1d9f0257483c44 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Tue, 23 Jul 2024 21:34:20 +0000 Subject: [PATCH 02/18] use rope scaling --- llmfoundry/models/mpt/modeling_mpt.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 0dedc0eac5..9fff8dcc7e 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -45,6 +45,7 @@ import logging from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -135,7 +136,14 @@ def gen_rotary_embedding( 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu ) elif rope_hf_config['type'] == 'llama3': - raise NotImplementedError() + return LlamaRotaryEmbedding( + dim=rope_head_dim, + max_position_embeddings=max_seq_len, + base=rope_theta, + scaling_factor=rope_hf_config['factor'], + device= + 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu + ) raise ValueError('rope_impl needs to be either dail or hf') From 1e7d8d354437420572ec0dea359377a596da3c31 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Wed, 24 Jul 2024 03:39:55 +0000 Subject: [PATCH 03/18] update to use rope config --- llmfoundry/models/mpt/modeling_mpt.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 9fff8dcc7e..9abc2aa281 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -45,15 +45,16 @@ import logging from transformers import PreTrainedModel, PreTrainedTokenizerBase -from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) +from transformers.models.llama.modeling_llama import LlamaConfig from transformers.models.llama.modeling_llama import \ LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding from transformers.models.llama.modeling_llama import \ LlamaLinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding from transformers.models.llama.modeling_llama import \ LlamaRotaryEmbedding as HFRotaryEmbedding @@ -137,12 +138,10 @@ def gen_rotary_embedding( ) elif rope_hf_config['type'] == 'llama3': return LlamaRotaryEmbedding( - dim=rope_head_dim, - max_position_embeddings=max_seq_len, - base=rope_theta, - scaling_factor=rope_hf_config['factor'], - device= - 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu + config=LlamaConfig( + rope_scaling=rope_hf_config, + rope_theta=rope_theta, + ), ) raise ValueError('rope_impl needs to be either dail or hf') From a032676f293803869999fa78c53c76f74ef9e95e Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Wed, 24 Jul 2024 03:56:41 +0000 Subject: [PATCH 04/18] update config args --- llmfoundry/models/mpt/modeling_mpt.py | 7 +++++++ tests/models/layers/test_flash_torch.py | 4 ++++ tests/models/test_rope_dail_vs_hf.py | 4 ++++ 3 files changed, 15 insertions(+) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index aad7429d72..ad57eda372 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -97,6 +97,8 @@ def gen_rotary_embedding( rope_dail_config: dict, rope_hf_config: dict, max_seq_len: int, + d_model: int, + n_heads: int, ): if rope_impl == 'dail': return DAILRotaryEmbedding( @@ -141,6 +143,9 @@ def gen_rotary_embedding( config=LlamaConfig( rope_scaling=rope_hf_config, rope_theta=rope_theta, + max_position_embeddings=max_seq_len, + hidden_size=d_model, + num_attention_heads=n_heads, ), ) raise ValueError('rope_impl needs to be either dail or hf') @@ -414,6 +419,8 @@ def __init__(self, config: MPTConfig): rope_dail_config=config.attn_config['rope_dail_config'], rope_hf_config=config.attn_config['rope_hf_config'], max_seq_len=self.config.max_seq_len, + d_model=config.d_model, + n_heads=config.n_heads, ) if config.init_device != 'meta': diff --git a/tests/models/layers/test_flash_torch.py b/tests/models/layers/test_flash_torch.py index 01d982052f..82b221ae95 100644 --- a/tests/models/layers/test_flash_torch.py +++ b/tests/models/layers/test_flash_torch.py @@ -257,6 +257,8 @@ def gen_bias(attn_impl: str): rope_dail_config=pos_emb_config.get('rope_dail_config', {}), rope_hf_config=pos_emb_config.get('rope_hf_config', {}), max_seq_len=s, + d_model=cfg.d_model, + n_heads=cfg.n_heads, ).to(device) pos = torch.arange(s).unsqueeze(0).to(device=device) # adjust the position indices to account for padding tokens @@ -670,6 +672,8 @@ def gen_bias(attn_impl: str): rope_dail_config=pos_emb_config.get('rope_dail_config', {}), rope_hf_config=pos_emb_config.get('rope_hf_config', {}), max_seq_len=s, + d_model=cfg['d_model'], + n_heads=cfg['n_heads'], ).to(device) pos = torch.arange(s).unsqueeze(0).to(device=device) # adjust the position indices to account for padding tokens diff --git a/tests/models/test_rope_dail_vs_hf.py b/tests/models/test_rope_dail_vs_hf.py index 6a41e64f48..00592caa77 100644 --- a/tests/models/test_rope_dail_vs_hf.py +++ b/tests/models/test_rope_dail_vs_hf.py @@ -83,6 +83,8 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): rope_dail_config=dail_rope_config['rope_dail_config'], rope_hf_config={}, max_seq_len=seq_len, + d_model=cfg.d_model, + n_heads=cfg.n_heads, ).to('cuda') dail_rope_w_meta_info = { 'impl': 'dail', @@ -98,6 +100,8 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): rope_dail_config={}, rope_hf_config=hf_rope_config['rope_hf_config'], max_seq_len=seq_len, + d_model=cfg.d_model, + n_heads=cfg.n_heads, ).to('cuda') pos = torch.arange(seq_len).unsqueeze(0).to(device='cuda') # adjust the position indices to account for padding tokens From 0d190a2f900d3a84bfd35eeb2fa2eef0ea917d29 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Wed, 24 Jul 2024 04:27:04 +0000 Subject: [PATCH 05/18] use allowlist for config to enforce hygeine --- llmfoundry/models/mpt/modeling_mpt.py | 28 ++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index ad57eda372..a7ce93d236 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -90,6 +90,32 @@ log = logging.getLogger(__name__) +class InvalidConfigAccessError(KeyError): + pass + + +class PartialLlamaConfig(LlamaConfig): + _ALLOWED_KEYS = { + 'rope_scaling', + 'rope_theta', + 'max_position_embeddings', + 'hidden_size', + 'num_attention_heads', + } + + def __getattribute__(self, key: str): + if key not in self._ALLOWED_KEYS: + raise InvalidConfigAccessError(key) + + return super().__getattribute__(key) + + def __getitem__(self, key: str): + if key not in self._ALLOWED_KEYS: + raise InvalidConfigAccessError(key) + + return super().__getitem__(key) + + def gen_rotary_embedding( rope_head_dim: int, rope_impl: str, @@ -140,7 +166,7 @@ def gen_rotary_embedding( ) elif rope_hf_config['type'] == 'llama3': return LlamaRotaryEmbedding( - config=LlamaConfig( + config=PartialLlamaConfig( rope_scaling=rope_hf_config, rope_theta=rope_theta, max_position_embeddings=max_seq_len, From 604f0b9321f4f340a7e1958a5fbd56eeb4336f52 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Wed, 24 Jul 2024 06:05:24 +0000 Subject: [PATCH 06/18] allow llama3 rope config --- llmfoundry/models/mpt/configuration_mpt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 3de3744745..8ac5a8ac49 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -306,6 +306,7 @@ def _validate_config(self) -> None: 'no_scaling', 'linear', 'dynamic', + 'llama3', ]: raise ValueError( 'If using hf implementation of rope, the type should be one of "no_scaling", "linear" or "dynamic".', From 0682283fc66468bc97358cf5567f9f10276d9f1e Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Wed, 24 Jul 2024 06:49:16 +0000 Subject: [PATCH 07/18] add unit test --- llmfoundry/models/mpt/modeling_mpt.py | 37 +++++++++++++++++-------- tests/models/layers/test_flash_torch.py | 2 -- tests/models/test_rope_dail_vs_hf.py | 2 -- tests/models/test_rope_scaling.py | 36 ++++++++++++++++++++++++ 4 files changed, 61 insertions(+), 16 deletions(-) create mode 100644 tests/models/test_rope_scaling.py diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index a7ce93d236..b3f4636fb6 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -94,30 +94,41 @@ class InvalidConfigAccessError(KeyError): pass +_ALLOWED_LLAMA_CONFIG_KEYS = { + 'rope_scaling', + 'rope_theta', + 'max_position_embeddings', + 'hidden_size', + 'num_attention_heads', + '_get_generation_defaults', + 'label2id', + 'id2label', + 'torch_dtype', + 'problem_type', + '__class__', + 'partial_rotary_factor', +} + + class PartialLlamaConfig(LlamaConfig): - _ALLOWED_KEYS = { - 'rope_scaling', - 'rope_theta', - 'max_position_embeddings', - 'hidden_size', - 'num_attention_heads', - } def __getattribute__(self, key: str): - if key not in self._ALLOWED_KEYS: + if key not in _ALLOWED_LLAMA_CONFIG_KEYS: raise InvalidConfigAccessError(key) return super().__getattribute__(key) def __getitem__(self, key: str): - if key not in self._ALLOWED_KEYS: + if key not in _ALLOWED_LLAMA_CONFIG_KEYS: raise InvalidConfigAccessError(key) return super().__getitem__(key) + def _get_generation_defaults(self): + return {} + def gen_rotary_embedding( - rope_head_dim: int, rope_impl: str, rope_theta: int, rope_dail_config: dict, @@ -126,6 +137,7 @@ def gen_rotary_embedding( d_model: int, n_heads: int, ): + rope_head_dim = d_model // n_heads if rope_impl == 'dail': return DAILRotaryEmbedding( dim=rope_head_dim, @@ -165,9 +177,11 @@ def gen_rotary_embedding( 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu ) elif rope_hf_config['type'] == 'llama3': + llama_rope_config = {**rope_hf_config} + llama_rope_config['rope_type'] = rope_hf_config.pop('type') return LlamaRotaryEmbedding( config=PartialLlamaConfig( - rope_scaling=rope_hf_config, + rope_scaling=llama_rope_config, rope_theta=rope_theta, max_position_embeddings=max_seq_len, hidden_size=d_model, @@ -439,7 +453,6 @@ def __init__(self, config: MPTConfig): if self.rope: self.rope_impl = config.attn_config['rope_impl'] self.rotary_embedding = gen_rotary_embedding( - rope_head_dim=config.d_model // config.n_heads, rope_impl=self.rope_impl, rope_theta=config.attn_config['rope_theta'], rope_dail_config=config.attn_config['rope_dail_config'], diff --git a/tests/models/layers/test_flash_torch.py b/tests/models/layers/test_flash_torch.py index 82b221ae95..4bfdfb84dc 100644 --- a/tests/models/layers/test_flash_torch.py +++ b/tests/models/layers/test_flash_torch.py @@ -251,7 +251,6 @@ def gen_bias(attn_impl: str): rotary_emb_w_meta_info = None if rope: rotary_embedding = gen_rotary_embedding( - rope_head_dim=cfg.d_model // cfg.n_heads, rope_impl=pos_emb_config['rope_impl'], rope_theta=pos_emb_config['rope_theta'], rope_dail_config=pos_emb_config.get('rope_dail_config', {}), @@ -666,7 +665,6 @@ def gen_bias(attn_impl: str): rotary_emb_w_meta_info = None if rope: rotary_embedding = gen_rotary_embedding( - rope_head_dim=cfg['d_model'] // cfg['n_heads'], rope_impl=pos_emb_config['rope_impl'], rope_theta=pos_emb_config['rope_theta'], rope_dail_config=pos_emb_config.get('rope_dail_config', {}), diff --git a/tests/models/test_rope_dail_vs_hf.py b/tests/models/test_rope_dail_vs_hf.py index 00592caa77..34fb23f670 100644 --- a/tests/models/test_rope_dail_vs_hf.py +++ b/tests/models/test_rope_dail_vs_hf.py @@ -77,7 +77,6 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): } dail_rope = gen_rotary_embedding( - rope_head_dim=cfg.d_model // cfg.n_heads, rope_impl=dail_rope_config['rope_impl'], rope_theta=dail_rope_config['rope_theta'], rope_dail_config=dail_rope_config['rope_dail_config'], @@ -94,7 +93,6 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): } hf_rope = gen_rotary_embedding( - rope_head_dim=cfg.d_model // cfg.n_heads, rope_impl=hf_rope_config['rope_impl'], rope_theta=hf_rope_config['rope_theta'], rope_dail_config={}, diff --git a/tests/models/test_rope_scaling.py b/tests/models/test_rope_scaling.py new file mode 100644 index 0000000000..076f0396b5 --- /dev/null +++ b/tests/models/test_rope_scaling.py @@ -0,0 +1,36 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding + +from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding + +rope_config = { + 'rope_theta': 500000.0, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'factor': 8.0, + 'low_freq_factor': 1.0, + 'high_freq_factor': 4.0, + 'original_max_position_embeddings': 8192, + 'type': 'llama3', + }, +} + +rope_dail_config = {} + + +def test_rope_scaling(): + d_model = 128 + n_heads = 32 + max_seq_len = 131_000 + + embedding = gen_rotary_embedding( + d_model=d_model, + n_heads=n_heads, + rope_dail_config=rope_dail_config, + max_seq_len=max_seq_len, + **rope_config, + ) + + assert isinstance(embedding, LlamaRotaryEmbedding) From ff0d3dece1cfa957db155631153dae78137caf5e Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Wed, 24 Jul 2024 06:54:33 +0000 Subject: [PATCH 08/18] documented allowed llama config keys --- llmfoundry/models/mpt/modeling_mpt.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index b3f4636fb6..1115ddc14b 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -95,18 +95,23 @@ class InvalidConfigAccessError(KeyError): _ALLOWED_LLAMA_CONFIG_KEYS = { + # these are the only config keys that are set and are safe to read from 'rope_scaling', 'rope_theta', 'max_position_embeddings', 'hidden_size', 'num_attention_heads', + + # not set but llama modeling code tries to read this attribute + 'partial_rotary_factor', + + # benign transformers attributes needed for __init__ '_get_generation_defaults', 'label2id', 'id2label', 'torch_dtype', 'problem_type', '__class__', - 'partial_rotary_factor', } From ef6c8c25cf81d27741c58c17584f060887c0ca7b Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Wed, 24 Jul 2024 04:44:31 -0400 Subject: [PATCH 09/18] Update llmfoundry/models/mpt/modeling_mpt.py --- llmfoundry/models/mpt/modeling_mpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 1115ddc14b..b3f0b361d8 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -183,7 +183,7 @@ def gen_rotary_embedding( ) elif rope_hf_config['type'] == 'llama3': llama_rope_config = {**rope_hf_config} - llama_rope_config['rope_type'] = rope_hf_config.pop('type') + llama_rope_config['rope_type'] = rope_hf_config.get('type') return LlamaRotaryEmbedding( config=PartialLlamaConfig( rope_scaling=llama_rope_config, From dd1de377c2630cdf4db92d7eb573ce976cd4f67c Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Wed, 24 Jul 2024 11:46:46 -0400 Subject: [PATCH 10/18] Address comments 1 --- llmfoundry/models/mpt/modeling_mpt.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index b3f0b361d8..2de2f6b325 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -116,6 +116,13 @@ class InvalidConfigAccessError(KeyError): class PartialLlamaConfig(LlamaConfig): + """Holds the rope config for Llama models and throws + + an `InvalidConfigAccessError` if any other config elements + are read. This class is necessary because the + `LlamaRotaryEmbedding` class takes a full `LlamaConfig` now + instead of the old keyword arguments. + """ def __getattribute__(self, key: str): if key not in _ALLOWED_LLAMA_CONFIG_KEYS: @@ -129,9 +136,6 @@ def __getitem__(self, key: str): return super().__getitem__(key) - def _get_generation_defaults(self): - return {} - def gen_rotary_embedding( rope_impl: str, From 151570842a5369d291f3f3c57d9fae4702446615 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Wed, 24 Jul 2024 11:47:39 -0400 Subject: [PATCH 11/18] Apply suggestions from code review Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/models/mpt/modeling_mpt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 2de2f6b325..d8696d6600 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -102,10 +102,10 @@ class InvalidConfigAccessError(KeyError): 'hidden_size', 'num_attention_heads', - # not set but llama modeling code tries to read this attribute + # Not set but llama modeling code tries to read this attribute 'partial_rotary_factor', - # benign transformers attributes needed for __init__ + # Benign transformers attributes needed for __init__ '_get_generation_defaults', 'label2id', 'id2label', From 8da6165c44c757678a4547727e8730c8275f49e6 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Wed, 24 Jul 2024 11:47:48 -0400 Subject: [PATCH 12/18] Apply suggestions from code review Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/models/mpt/modeling_mpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index d8696d6600..a118808ffe 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -95,7 +95,7 @@ class InvalidConfigAccessError(KeyError): _ALLOWED_LLAMA_CONFIG_KEYS = { - # these are the only config keys that are set and are safe to read from + # These are the only config keys that are set and are safe to read from 'rope_scaling', 'rope_theta', 'max_position_embeddings', From b0700c9f828b56d4989fd618de90e694147550f2 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Wed, 24 Jul 2024 16:02:32 +0000 Subject: [PATCH 13/18] use same codepath for all the hf rotary embeddings --- llmfoundry/models/mpt/modeling_mpt.py | 53 ++++++--------------------- 1 file changed, 12 insertions(+), 41 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index a118808ffe..a7b07ecbc0 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -50,10 +50,6 @@ CausalLMOutputWithPast, ) from transformers.models.llama.modeling_llama import LlamaConfig -from transformers.models.llama.modeling_llama import \ - LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding -from transformers.models.llama.modeling_llama import \ - LlamaLinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding from transformers.models.llama.modeling_llama import \ LlamaRotaryEmbedding as HFRotaryEmbedding @@ -159,44 +155,19 @@ def gen_rotary_embedding( 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu ) elif rope_impl == 'hf': + llama_rope_config = {**rope_hf_config} + llama_rope_config['rope_type'] = rope_hf_config.get('type') + partial_llama_config = PartialLlamaConfig( + rope_scaling=llama_rope_config, + rope_theta=rope_theta, + max_position_embeddings=max_seq_len, + hidden_size=d_model, + num_attention_heads=n_heads, + ) if rope_hf_config['type'] == 'no_scaling': - return HFRotaryEmbeddingFoundry( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=rope_theta, - device= - 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - elif rope_hf_config['type'] == 'linear': - return HFLinearScalingRotaryEmbedding( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=rope_theta, - scaling_factor=rope_hf_config['factor'], - device= - 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - elif rope_hf_config['type'] == 'dynamic': - return HFDynamicNTKScalingRotaryEmbedding( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=rope_theta, - scaling_factor=rope_hf_config['factor'], - device= - 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - elif rope_hf_config['type'] == 'llama3': - llama_rope_config = {**rope_hf_config} - llama_rope_config['rope_type'] = rope_hf_config.get('type') - return LlamaRotaryEmbedding( - config=PartialLlamaConfig( - rope_scaling=llama_rope_config, - rope_theta=rope_theta, - max_position_embeddings=max_seq_len, - hidden_size=d_model, - num_attention_heads=n_heads, - ), - ) + return HFRotaryEmbeddingFoundry(config=partial_llama_config) + elif rope_hf_config['type'] in {'llama3', 'linear', 'dynamic'}: + return LlamaRotaryEmbedding(config=partial_llama_config) raise ValueError('rope_impl needs to be either dail or hf') From 518a3a1bdf15f8a082ed7a88971565332daefc77 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Wed, 24 Jul 2024 16:29:31 +0000 Subject: [PATCH 14/18] fix --- llmfoundry/models/mpt/modeling_mpt.py | 10 +++++----- tests/models/test_rope_scaling.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index a7b07ecbc0..9d34c2cfc3 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -112,12 +112,11 @@ class InvalidConfigAccessError(KeyError): class PartialLlamaConfig(LlamaConfig): - """Holds the rope config for Llama models and throws + """Holds the rope config for Llama models and throws. - an `InvalidConfigAccessError` if any other config elements - are read. This class is necessary because the - `LlamaRotaryEmbedding` class takes a full `LlamaConfig` now - instead of the old keyword arguments. + an `InvalidConfigAccessError` if any other config elements are read. This + class is necessary because the `LlamaRotaryEmbedding` class takes a full + `LlamaConfig` now instead of the old keyword arguments. """ def __getattribute__(self, key: str): @@ -165,6 +164,7 @@ def gen_rotary_embedding( num_attention_heads=n_heads, ) if rope_hf_config['type'] == 'no_scaling': + llama_rope_config['rope_type'] = 'default' return HFRotaryEmbeddingFoundry(config=partial_llama_config) elif rope_hf_config['type'] in {'llama3', 'linear', 'dynamic'}: return LlamaRotaryEmbedding(config=partial_llama_config) diff --git a/tests/models/test_rope_scaling.py b/tests/models/test_rope_scaling.py index 076f0396b5..c68efa1184 100644 --- a/tests/models/test_rope_scaling.py +++ b/tests/models/test_rope_scaling.py @@ -23,7 +23,7 @@ def test_rope_scaling(): d_model = 128 n_heads = 32 - max_seq_len = 131_000 + max_seq_len = 65536 embedding = gen_rotary_embedding( d_model=d_model, From 44ce115cca17ff12f18b495d04ba67b42470788e Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Wed, 24 Jul 2024 16:44:11 +0000 Subject: [PATCH 15/18] update --- llmfoundry/models/mpt/modeling_mpt.py | 3 +- tests/models/test_rope_scaling.py | 77 +++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 9d34c2cfc3..94460a65d2 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -156,6 +156,8 @@ def gen_rotary_embedding( elif rope_impl == 'hf': llama_rope_config = {**rope_hf_config} llama_rope_config['rope_type'] = rope_hf_config.get('type') + if llama_rope_config['rope_type'] == 'no_scaling': + llama_rope_config['rope_type'] = 'default' partial_llama_config = PartialLlamaConfig( rope_scaling=llama_rope_config, rope_theta=rope_theta, @@ -164,7 +166,6 @@ def gen_rotary_embedding( num_attention_heads=n_heads, ) if rope_hf_config['type'] == 'no_scaling': - llama_rope_config['rope_type'] = 'default' return HFRotaryEmbeddingFoundry(config=partial_llama_config) elif rope_hf_config['type'] in {'llama3', 'linear', 'dynamic'}: return LlamaRotaryEmbedding(config=partial_llama_config) diff --git a/tests/models/test_rope_scaling.py b/tests/models/test_rope_scaling.py index c68efa1184..94e37c0243 100644 --- a/tests/models/test_rope_scaling.py +++ b/tests/models/test_rope_scaling.py @@ -1,6 +1,10 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import math +from typing import Tuple + +import torch from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding @@ -20,6 +24,66 @@ rope_dail_config = {} +def apply_scaling(freqs: torch.Tensor): + # Values obtained from grid search + scale_factor = 8 + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 # original llama3 length + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scale_factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - + low_freq_factor) / (high_freq_factor - low_freq_factor) + new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + + +def precompute_freqs_cis( + dim: int, + end: int, + theta: float = 10000.0, + use_scaled: bool = False, +): + freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + if use_scaled: + freqs = apply_scaling(freqs) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + def test_rope_scaling(): d_model = 128 n_heads = 32 @@ -34,3 +98,16 @@ def test_rope_scaling(): ) assert isinstance(embedding, LlamaRotaryEmbedding) + + x = torch.randn(1, max_seq_len, d_model) + position_ids = torch.arange(max_seq_len).unsqueeze(0) + + freqs_cis = precompute_freqs_cis( + d_model, + max_seq_len, + rope_config['rope_theta'], + ) + + rope_embeddings = embedding.forward(x, position_ids) + + # ??? WIP From d29739556d452b0e34035b31b1974dfec57703bb Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Wed, 24 Jul 2024 17:25:45 +0000 Subject: [PATCH 16/18] test WIP but fix get/pop --- llmfoundry/models/mpt/modeling_mpt.py | 2 +- tests/models/test_rope_scaling.py | 26 ++++++++++++++++++++++---- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 94460a65d2..35b864c2d5 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -155,7 +155,7 @@ def gen_rotary_embedding( ) elif rope_impl == 'hf': llama_rope_config = {**rope_hf_config} - llama_rope_config['rope_type'] = rope_hf_config.get('type') + llama_rope_config['rope_type'] = rope_hf_config.pop('type') if llama_rope_config['rope_type'] == 'no_scaling': llama_rope_config['rope_type'] = 'default' partial_llama_config = PartialLlamaConfig( diff --git a/tests/models/test_rope_scaling.py b/tests/models/test_rope_scaling.py index 94e37c0243..33f7f0b383 100644 --- a/tests/models/test_rope_scaling.py +++ b/tests/models/test_rope_scaling.py @@ -99,15 +99,33 @@ def test_rope_scaling(): assert isinstance(embedding, LlamaRotaryEmbedding) - x = torch.randn(1, max_seq_len, d_model) + xq = torch.randn(1, max_seq_len, n_heads, d_model // n_heads) + xk = torch.randn(1, max_seq_len, n_heads, d_model // n_heads) position_ids = torch.arange(max_seq_len).unsqueeze(0) freqs_cis = precompute_freqs_cis( - d_model, + d_model // n_heads, max_seq_len, rope_config['rope_theta'], + use_scaled=True, ) - rope_embeddings = embedding.forward(x, position_ids) + rope_embeddings_q, rope_embeddings_k = embedding.forward( + xq, + position_ids, + ), embedding.forward(xk, position_ids) - # ??? WIP + rope_embeddings_q, rope_embeddings_k = torch.stack( + rope_embeddings_q, + dim=-1, + ), torch.stack(rope_embeddings_k, dim=-1) + + rope_embeddings_q, rope_embeddings_k = rope_embeddings_q.reshape( + *rope_embeddings_q.shape[:-2], + -1, + ), rope_embeddings_k.reshape(*rope_embeddings_k.shape[:-2], -1) + + expected_q, expected_k = apply_rotary_emb(xq, xk, freqs_cis) + + torch.testing.assert_close(rope_embeddings_q, expected_q) + torch.testing.assert_close(rope_embeddings_k, expected_k) From cf460ec5b75ea94e058c637a5b4622b2506bfdbc Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Wed, 24 Jul 2024 17:28:57 +0000 Subject: [PATCH 17/18] change the thing being popped --- llmfoundry/models/mpt/modeling_mpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 35b864c2d5..7dfaf8562b 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -155,7 +155,7 @@ def gen_rotary_embedding( ) elif rope_impl == 'hf': llama_rope_config = {**rope_hf_config} - llama_rope_config['rope_type'] = rope_hf_config.pop('type') + llama_rope_config['rope_type'] = llama_rope_config.pop('type') if llama_rope_config['rope_type'] == 'no_scaling': llama_rope_config['rope_type'] = 'default' partial_llama_config = PartialLlamaConfig( From d48bcb4608cc799050e50c7f2490d212ad5fce38 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Wed, 24 Jul 2024 17:32:16 +0000 Subject: [PATCH 18/18] give up on testing hf --- tests/models/test_rope_scaling.py | 96 ------------------------------- 1 file changed, 96 deletions(-) diff --git a/tests/models/test_rope_scaling.py b/tests/models/test_rope_scaling.py index 33f7f0b383..484ac2b23a 100644 --- a/tests/models/test_rope_scaling.py +++ b/tests/models/test_rope_scaling.py @@ -1,10 +1,5 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 - -import math -from typing import Tuple - -import torch from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding @@ -24,66 +19,6 @@ rope_dail_config = {} -def apply_scaling(freqs: torch.Tensor): - # Values obtained from grid search - scale_factor = 8 - low_freq_factor = 1 - high_freq_factor = 4 - old_context_len = 8192 # original llama3 length - - low_freq_wavelen = old_context_len / low_freq_factor - high_freq_wavelen = old_context_len / high_freq_factor - new_freqs = [] - for freq in freqs: - wavelen = 2 * math.pi / freq - if wavelen < high_freq_wavelen: - new_freqs.append(freq) - elif wavelen > low_freq_wavelen: - new_freqs.append(freq / scale_factor) - else: - assert low_freq_wavelen != high_freq_wavelen - smooth = (old_context_len / wavelen - - low_freq_factor) / (high_freq_factor - low_freq_factor) - new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) - return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) - - -def precompute_freqs_cis( - dim: int, - end: int, - theta: float = 10000.0, - use_scaled: bool = False, -): - freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device, dtype=torch.float32) - if use_scaled: - freqs = apply_scaling(freqs) - freqs = torch.outer(t, freqs) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis - - -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): - ndim = x.ndim - assert 0 <= 1 < ndim - assert freqs_cis.shape == (x.shape[1], x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - def test_rope_scaling(): d_model = 128 n_heads = 32 @@ -98,34 +33,3 @@ def test_rope_scaling(): ) assert isinstance(embedding, LlamaRotaryEmbedding) - - xq = torch.randn(1, max_seq_len, n_heads, d_model // n_heads) - xk = torch.randn(1, max_seq_len, n_heads, d_model // n_heads) - position_ids = torch.arange(max_seq_len).unsqueeze(0) - - freqs_cis = precompute_freqs_cis( - d_model // n_heads, - max_seq_len, - rope_config['rope_theta'], - use_scaled=True, - ) - - rope_embeddings_q, rope_embeddings_k = embedding.forward( - xq, - position_ids, - ), embedding.forward(xk, position_ids) - - rope_embeddings_q, rope_embeddings_k = torch.stack( - rope_embeddings_q, - dim=-1, - ), torch.stack(rope_embeddings_k, dim=-1) - - rope_embeddings_q, rope_embeddings_k = rope_embeddings_q.reshape( - *rope_embeddings_q.shape[:-2], - -1, - ), rope_embeddings_k.reshape(*rope_embeddings_k.shape[:-2], -1) - - expected_q, expected_k = apply_rotary_emb(xq, xk, freqs_cis) - - torch.testing.assert_close(rope_embeddings_q, expected_q) - torch.testing.assert_close(rope_embeddings_k, expected_k)