diff --git a/README.md b/README.md index 0299e43710..d9b75b7617 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ DBRX is a state-of-the-art open source LLM trained by Databricks Mosaic team. It | DBRX Base | 32768 | https://huggingface.co/databricks/dbrx-base | | DBRX Instruct | 32768 | https://huggingface.co/databricks/dbrx-instruct | -Our model weights and code are licensed for both researchers and commercial entities. The Databricks Open Source License can be found at [LICENSE](https://github.com/databricks/dbrx/LICENSE), and our Acceptable Use Policy can be found [here](https://www.databricks.com/legal/acceptable-use-policy-open-model). +Our model weights and code are licensed for both researchers and commercial entities. The Databricks Open Source License can be found at [LICENSE](https://github.com/databricks/dbrx/blob/main/LICENSE), and our Acceptable Use Policy can be found [here](https://www.databricks.com/legal/acceptable-use-policy-open-model). For more information about the DBRX models, see https://github.com/databricks/dbrx. diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 7127d37f40..a186f67f14 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -376,6 +376,17 @@ def transform_config( copied_config.ffn_config['moe_world_size'] = 1 return copied_config + def pre_register_edit(self, local_save_path: str): + """Edit the model before registering with MLflow. + + This allows a subclass to modify the model before registering with MLflow. The base class implementation will + make no modifications. + + Args: + local_save_path (str): The path to the model to be transformed. + """ + pass + def transform_model_pre_registration( self, model: PreTrainedModel, @@ -455,9 +466,9 @@ def tensor_hook( state_dict[fqn] = tensor else: state_dict[fqn] = None - # Convert the state dict to the requested precision - if isinstance(tensor, torch.Tensor): - state_dict[fqn] = tensor.to(dtype=self.dtype) + + if isinstance(state_dict[fqn], torch.Tensor): + state_dict[fqn] = state_dict[fqn].to(dtype=self.dtype) del tensor if dist.get_global_rank() != 0: state_dict = {} @@ -602,6 +613,12 @@ def tensor_hook( ) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext( ) with context_manager: + # Add the pip requirements directly to avoid mlflow + # attempting to run inference on the model + model_saving_kwargs['pip_requirements'] = [ + 'transformers', + 'torch', + ] mlflow_logger.save_model(**model_saving_kwargs) # Upload the license file generated by mlflow during the model saving. @@ -618,6 +635,8 @@ def tensor_hook( os.path.join(local_save_path, license_filename), ) + self.pre_register_edit(local_save_path,) + # Spawn a new process to register the model. process = SpawnProcess( target=_register_model_with_run_id_multiprocess, diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py index 77bb9dbcfe..c925e6e586 100644 --- a/llmfoundry/command_utils/train.py +++ b/llmfoundry/command_utils/train.py @@ -544,6 +544,7 @@ def train(cfg: DictConfig) -> Trainer: dist_timeout=train_cfg.dist_timeout, profiler=profiler, compile_config=compile_config, + spin_dataloaders=train_cfg.spin_dataloaders, ) # Optionally just save an HF checkpoint diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index 0c5cb1418b..a6fdf34953 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -424,6 +424,8 @@ def profile_packing( dataloader_cfg = copy.deepcopy(dataloader_cfg) dataloader_cfg.update({ 'drop_last': False, + 'num_workers': 0, + 'prefetch_factor': None, 'persistent_workers': False, }) dataloader_cfg['dataset']['packing_ratio'] = 1.0 diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 071310d69e..f7f372f5fa 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -284,7 +284,7 @@ def _autoset_attn_implementation_monkeypatch( # the different processes. To avoid this contention, we first create the model (on meta device) on local rank # zero. This will set up the transformers model cache and avoid the future contention. if dist.get_local_rank() == 0: - if os.path.isdir(pretrained_model_name_or_path): + if pretrained and os.path.isdir(pretrained_model_name_or_path): with init_empty_weights(include_buffers=False): with warnings.catch_warnings(): warnings.simplefilter('ignore', UserWarning) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 8e740be2b3..c7fdb5b987 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -415,6 +415,7 @@ def __init__( softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, fc_type: Optional[dict[str, Any]] = None, device: Optional[str] = None, bias: bool = True, @@ -520,6 +521,7 @@ def __init__( self.q_ln = build_norm( name=norm_type.lower(), normalized_shape=norm_size, + eps=norm_eps, device=device, ) if self.reuse_kv_layer_idx is None: @@ -528,6 +530,7 @@ def __init__( self.k_ln = build_norm( name=norm_type.lower(), normalized_shape=norm_size, + eps=norm_eps, device=device, ) @@ -796,6 +799,7 @@ def __init__( softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, fc_type: Optional[dict[str, Any]] = None, device: Optional[str] = None, bias: bool = True, @@ -814,6 +818,7 @@ def __init__( softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, + norm_eps=norm_eps, fc_type=fc_type, device=device, bias=bias, @@ -841,6 +846,7 @@ def __init__( softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, fc_type: Optional[dict[str, Any]] = None, device: Optional[str] = None, bias: bool = True, @@ -859,6 +865,7 @@ def __init__( softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, + norm_eps=norm_eps, fc_type=fc_type, device=device, bias=bias, diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index c6988b7bd7..92735cc489 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -42,6 +42,7 @@ def __init__( ffn_config: Optional[Dict] = None, resid_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, fc_type: Optional[dict[str, Any]] = None, device: Optional[str] = None, no_bias: bool = False, @@ -84,6 +85,7 @@ def __init__( fc_type=fc_type, resid_pdrop=resid_pdrop, norm_type=norm_type, + norm_eps=norm_eps, device=device, no_bias=no_bias, ) @@ -99,6 +101,7 @@ def __init__( self.norm_1 = build_norm( name=norm_type.lower(), normalized_shape=d_model, + eps=norm_eps, device=device, ) self.attn = build_attention_layer( @@ -117,6 +120,7 @@ def __init__( self.norm_2 = build_norm( name=norm_type.lower(), normalized_shape=d_model, + eps=norm_eps, device=device, ) @@ -260,6 +264,7 @@ def __init__( fc_type: Optional[dict[str, Any]] = None, resid_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, device: Optional[str] = None, no_bias: bool = False, **kwargs: Any, @@ -283,6 +288,7 @@ def __init__( self.norm_1 = build_norm( name=norm_type.lower(), normalized_shape=d_model, + eps=norm_eps, device=device, ) self.attn = build_attention_layer( @@ -302,6 +308,7 @@ def __init__( self.norm_2 = build_norm( name=norm_type.lower(), normalized_shape=d_model, + eps=norm_eps, device=device, ) self.resid_attn_dropout = nn.Dropout(resid_pdrop) diff --git a/llmfoundry/models/layers/layer_builders.py b/llmfoundry/models/layers/layer_builders.py index 69d2059bad..d5fd1d37d4 100644 --- a/llmfoundry/models/layers/layer_builders.py +++ b/llmfoundry/models/layers/layer_builders.py @@ -26,10 +26,12 @@ def build_norm( name: str, normalized_shape: Union[int, List[int], torch.Size], + eps: Optional[float] = 1e-5, device: Optional[str] = None, ): kwargs = { 'normalized_shape': normalized_shape, + 'eps': eps, 'device': device, } diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 3de3744745..86cc3519ba 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -44,6 +44,7 @@ def __init__( no_bias: bool = False, embedding_fraction: float = 1.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, use_cache: bool = False, init_config: Optional[Dict] = None, fc_type: Union[str, Dict] = 'torch', @@ -101,6 +102,7 @@ def __init__( no_bias (bool): Whether to use bias in all layers. embedding_fraction (float): The fraction to scale the gradients of the embedding layer by. norm_type (str): choose type of norm to use + norm_eps (float): epsilon value for norm layer use_cache (bool): Whether or not the model should return the last key/values attentions init_config (Dict): A dictionary used to configure the model initialization: init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_', @@ -168,6 +170,7 @@ def __init__( self.no_bias = no_bias self.embedding_fraction = embedding_fraction self.norm_type = norm_type + self.norm_eps = norm_eps self.use_cache = use_cache self.init_config = init_config if init_config is not None else copy.deepcopy( init_config_defaults, @@ -306,6 +309,7 @@ def _validate_config(self) -> None: 'no_scaling', 'linear', 'dynamic', + 'llama3', ]: raise ValueError( 'If using hf implementation of rope, the type should be one of "no_scaling", "linear" or "dynamic".', diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 40b3aaa6ee..6f9b6bf806 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -49,12 +49,10 @@ BaseModelOutputWithPast, CausalLMOutputWithPast, ) -from transformers.models.llama.modeling_llama import \ - LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding -from transformers.models.llama.modeling_llama import \ - LlamaLinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding -from transformers.models.llama.modeling_llama import \ - LlamaRotaryEmbedding as HFRotaryEmbedding +from transformers.models.llama.modeling_llama import ( + LlamaConfig, + LlamaRotaryEmbedding, +) from llmfoundry.layers_registry import norms, param_init_fns from llmfoundry.models.layers.attention import ( @@ -88,14 +86,62 @@ log = logging.getLogger(__name__) +class InvalidConfigAccessError(KeyError): + pass + + +_ALLOWED_LLAMA_CONFIG_KEYS = { + # These are the only config keys that are set and are safe to read from + 'rope_scaling', + 'rope_theta', + 'max_position_embeddings', + 'hidden_size', + 'num_attention_heads', + + # Not set but llama modeling code tries to read this attribute + 'partial_rotary_factor', + + # Benign transformers attributes needed for __init__ + '_get_generation_defaults', + 'label2id', + 'id2label', + 'torch_dtype', + 'problem_type', + '__class__', +} + + +class PartialLlamaConfig(LlamaConfig): + """Holds the rope config for Llama models and throws. + + an `InvalidConfigAccessError` if any other config elements are read. This + class is necessary because the `LlamaRotaryEmbedding` class takes a full + `LlamaConfig` now instead of the old keyword arguments. + """ + + def __getattribute__(self, key: str): + if key not in _ALLOWED_LLAMA_CONFIG_KEYS: + raise InvalidConfigAccessError(key) + + return super().__getattribute__(key) + + def __getitem__(self, key: str): + if key not in _ALLOWED_LLAMA_CONFIG_KEYS: + raise InvalidConfigAccessError(key) + + return super().__getitem__(key) + + def gen_rotary_embedding( - rope_head_dim: int, rope_impl: str, rope_theta: int, rope_dail_config: dict, rope_hf_config: dict, max_seq_len: int, + d_model: int, + n_heads: int, ): + rope_head_dim = d_model // n_heads if rope_impl == 'dail': return DAILRotaryEmbedding( dim=rope_head_dim, @@ -108,32 +154,21 @@ def gen_rotary_embedding( 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu ) elif rope_impl == 'hf': + llama_rope_config = {**rope_hf_config} + llama_rope_config['rope_type'] = llama_rope_config.pop('type') + if llama_rope_config['rope_type'] == 'no_scaling': + llama_rope_config['rope_type'] = 'default' + partial_llama_config = PartialLlamaConfig( + rope_scaling=llama_rope_config, + rope_theta=rope_theta, + max_position_embeddings=max_seq_len, + hidden_size=d_model, + num_attention_heads=n_heads, + ) if rope_hf_config['type'] == 'no_scaling': - return HFRotaryEmbeddingFoundry( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=rope_theta, - device= - 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - elif rope_hf_config['type'] == 'linear': - return HFLinearScalingRotaryEmbedding( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=rope_theta, - scaling_factor=rope_hf_config['factor'], - device= - 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - elif rope_hf_config['type'] == 'dynamic': - return HFDynamicNTKScalingRotaryEmbedding( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=rope_theta, - scaling_factor=rope_hf_config['factor'], - device= - 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) + return LlamaRotaryEmbeddingFoundry(config=partial_llama_config) + elif rope_hf_config['type'] in {'llama3', 'linear', 'dynamic'}: + return LlamaRotaryEmbedding(config=partial_llama_config) raise ValueError('rope_impl needs to be either dail or hf') @@ -306,7 +341,7 @@ def apply_sequence_id( return attn_bias -class HFRotaryEmbeddingFoundry(HFRotaryEmbedding): +class LlamaRotaryEmbeddingFoundry(LlamaRotaryEmbedding): @torch.no_grad() def forward( @@ -391,6 +426,7 @@ def __init__(self, config: MPTConfig): self.norm_f = build_norm( name=config.norm_type.lower(), normalized_shape=config.d_model, + eps=config.norm_eps, device=config.init_device, ) @@ -399,12 +435,13 @@ def __init__(self, config: MPTConfig): if self.rope: self.rope_impl = config.attn_config['rope_impl'] self.rotary_embedding = gen_rotary_embedding( - rope_head_dim=config.d_model // config.n_heads, rope_impl=self.rope_impl, rope_theta=config.attn_config['rope_theta'], rope_dail_config=config.attn_config['rope_dail_config'], rope_hf_config=config.attn_config['rope_hf_config'], max_seq_len=self.config.max_seq_len, + d_model=config.d_model, + n_heads=config.n_heads, ) if config.init_device != 'meta': diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 48290bd7c5..84a3376718 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -167,6 +167,7 @@ class TrainConfig: # Dataloader device_train_microbatch_size: Union[str, int, float] = 'auto' global_train_batch_size: Optional[int] = None + spin_dataloaders: bool = True # Eval dataloader eval_subset_num_batches: int = -1 @@ -531,7 +532,6 @@ def process_init_device(model_cfg: Dict[str, Any], fsdp_config: Optional[Dict]): fsdp_config['sync_module_states'] = True # Set defaults for mixed initialization - fsdp_config.setdefault('use_orig_params', False) fsdp_config.setdefault('load_monolith_rank0_only', True) # Set ffn_config.device_mesh to fsdp_config.device_mesh diff --git a/setup.py b/setup.py index 97b8962069..52eb9e5578 100644 --- a/setup.py +++ b/setup.py @@ -57,7 +57,7 @@ 'mosaicml[libcloud,wandb,oci,gcs,mlflow]>=0.23.4,<0.24', 'mlflow>=2.14.1,<2.15', 'accelerate>=0.25,<0.33', # for HF inference `device_map` - 'transformers>=4.43.1,<4.44', + 'transformers>=4.43.2,<4.44', 'mosaicml-streaming>=0.7.6,<0.8', 'torch>=2.3.0,<2.4', 'datasets>=2.19,<2.20', diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index ffdb09ca98..cd47b2df7c 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -388,6 +388,9 @@ def test_huggingface_conversion_callback_interval( checkpointer_callback.transform_model_pre_registration = MagicMock( wraps=checkpointer_callback.transform_model_pre_registration, ) + checkpointer_callback.pre_register_edit = MagicMock( + wraps=checkpointer_callback.pre_register_edit, + ) trainer = Trainer( model=original_model, device='gpu', @@ -411,11 +414,14 @@ def test_huggingface_conversion_callback_interval( task='llm/v1/completions', input_example=ANY, metadata={}, + pip_requirements=ANY, ) assert checkpointer_callback.transform_model_pre_registration.call_count == 1 + assert checkpointer_callback.pre_register_edit.call_count == 1 assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 else: assert checkpointer_callback.transform_model_pre_registration.call_count == 0 + assert checkpointer_callback.pre_register_edit.call_count == 0 assert mlflow_logger_mock.save_model.call_count == 0 assert mlflow_logger_mock.register_model_with_run_id.call_count == 0 @@ -589,6 +595,7 @@ def _assert_mlflow_logger_calls( 'task': 'llm/v1/completions', 'input_example': default_input_example, 'metadata': {}, + 'pip_requirements': ANY, } mlflow_logger_mock.save_model.assert_called_with(**expectation) assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 diff --git a/tests/models/layers/test_flash_torch.py b/tests/models/layers/test_flash_torch.py index 01d982052f..4bfdfb84dc 100644 --- a/tests/models/layers/test_flash_torch.py +++ b/tests/models/layers/test_flash_torch.py @@ -251,12 +251,13 @@ def gen_bias(attn_impl: str): rotary_emb_w_meta_info = None if rope: rotary_embedding = gen_rotary_embedding( - rope_head_dim=cfg.d_model // cfg.n_heads, rope_impl=pos_emb_config['rope_impl'], rope_theta=pos_emb_config['rope_theta'], rope_dail_config=pos_emb_config.get('rope_dail_config', {}), rope_hf_config=pos_emb_config.get('rope_hf_config', {}), max_seq_len=s, + d_model=cfg.d_model, + n_heads=cfg.n_heads, ).to(device) pos = torch.arange(s).unsqueeze(0).to(device=device) # adjust the position indices to account for padding tokens @@ -664,12 +665,13 @@ def gen_bias(attn_impl: str): rotary_emb_w_meta_info = None if rope: rotary_embedding = gen_rotary_embedding( - rope_head_dim=cfg['d_model'] // cfg['n_heads'], rope_impl=pos_emb_config['rope_impl'], rope_theta=pos_emb_config['rope_theta'], rope_dail_config=pos_emb_config.get('rope_dail_config', {}), rope_hf_config=pos_emb_config.get('rope_hf_config', {}), max_seq_len=s, + d_model=cfg['d_model'], + n_heads=cfg['n_heads'], ).to(device) pos = torch.arange(s).unsqueeze(0).to(device=device) # adjust the position indices to account for padding tokens diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 45378e42bd..ed40e7a88a 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -35,8 +35,7 @@ ) from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.bloom.modeling_bloom import build_alibi_tensor -from transformers.models.llama.modeling_llama import \ - LlamaRotaryEmbedding as HFRotaryEmbedding +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding from llmfoundry import ComposerHFCausalLM from llmfoundry.layers_registry import norms @@ -48,7 +47,7 @@ ) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM, MPTModel -from llmfoundry.models.mpt.modeling_mpt import HFRotaryEmbeddingFoundry +from llmfoundry.models.mpt.modeling_mpt import LlamaRotaryEmbeddingFoundry from llmfoundry.utils import build_tokenizer from llmfoundry.utils.builders import build_composer_model from llmfoundry.utils.config_utils import to_dict_container @@ -2924,7 +2923,7 @@ def test_hf_rotary_child_class_builds(): list(range(max_seq_len)), ] * bsz) - rot_emb_mp = HFRotaryEmbeddingFoundry( + rot_emb_mp = LlamaRotaryEmbeddingFoundry( rope_head_dim, max_seq_len, rope_theta, @@ -2932,7 +2931,7 @@ def test_hf_rotary_child_class_builds(): ) cos_mp, sin_mp = rot_emb_mp(value, position_ids) - rot_emb = HFRotaryEmbedding( + rot_emb = LlamaRotaryEmbedding( rope_head_dim, max_seq_len, rope_theta, diff --git a/tests/models/test_rope_dail_vs_hf.py b/tests/models/test_rope_dail_vs_hf.py index 6a41e64f48..34fb23f670 100644 --- a/tests/models/test_rope_dail_vs_hf.py +++ b/tests/models/test_rope_dail_vs_hf.py @@ -77,12 +77,13 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): } dail_rope = gen_rotary_embedding( - rope_head_dim=cfg.d_model // cfg.n_heads, rope_impl=dail_rope_config['rope_impl'], rope_theta=dail_rope_config['rope_theta'], rope_dail_config=dail_rope_config['rope_dail_config'], rope_hf_config={}, max_seq_len=seq_len, + d_model=cfg.d_model, + n_heads=cfg.n_heads, ).to('cuda') dail_rope_w_meta_info = { 'impl': 'dail', @@ -92,12 +93,13 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): } hf_rope = gen_rotary_embedding( - rope_head_dim=cfg.d_model // cfg.n_heads, rope_impl=hf_rope_config['rope_impl'], rope_theta=hf_rope_config['rope_theta'], rope_dail_config={}, rope_hf_config=hf_rope_config['rope_hf_config'], max_seq_len=seq_len, + d_model=cfg.d_model, + n_heads=cfg.n_heads, ).to('cuda') pos = torch.arange(seq_len).unsqueeze(0).to(device='cuda') # adjust the position indices to account for padding tokens diff --git a/tests/models/test_rope_scaling.py b/tests/models/test_rope_scaling.py new file mode 100644 index 0000000000..484ac2b23a --- /dev/null +++ b/tests/models/test_rope_scaling.py @@ -0,0 +1,35 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding + +from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding + +rope_config = { + 'rope_theta': 500000.0, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'factor': 8.0, + 'low_freq_factor': 1.0, + 'high_freq_factor': 4.0, + 'original_max_position_embeddings': 8192, + 'type': 'llama3', + }, +} + +rope_dail_config = {} + + +def test_rope_scaling(): + d_model = 128 + n_heads = 32 + max_seq_len = 65536 + + embedding = gen_rotary_embedding( + d_model=d_model, + n_heads=n_heads, + rope_dail_config=rope_dail_config, + max_seq_len=max_seq_len, + **rope_config, + ) + + assert isinstance(embedding, LlamaRotaryEmbedding)