From 9025b833d06127eda230f745b2c706e98da2f0a1 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 3 Oct 2023 11:45:13 -0400 Subject: [PATCH 1/3] Add flag to disable train metrics (#642) * free mem * lint * lint --- llmfoundry/models/hf/hf_causal_lm.py | 8 ++++---- llmfoundry/models/mpt/modeling_mpt.py | 4 +++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index d5ef2435f9..ce398a8b2d 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -65,10 +65,7 @@ def __init__(self, om_model_config: Union[DictConfig, nn.Module], tokenizer: PreTrainedTokenizerBase): # set up training and eval metrics - train_metrics = [ - LanguageCrossEntropy(), - LanguagePerplexity(), - ] + train_metrics = [LanguageCrossEntropy(), LanguagePerplexity()] eval_metrics = [ LanguageCrossEntropy(), LanguagePerplexity(), @@ -92,6 +89,9 @@ def __init__(self, om_model_config: Union[DictConfig, 'which is not significantly slower and not compatible with the LLM foundry training code, rather than the code release by MosaicML.' ) + if not om_model_config.get('use_train_metrics', True): + train_metrics = [] + # load the model config trust_remote_code = om_model_config.get('trust_remote_code', True) use_auth_token = om_model_config.get('use_auth_token', False) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index b1dff15398..cd162195b6 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -694,7 +694,9 @@ def __init__( hf_config = MPTConfig.from_dict(resolved_om_model_config) model = MPTForCausalLM(hf_config) - train_metrics = [LanguageCrossEntropy(), LanguagePerplexity()] + use_train_metrics = om_model_config.get('use_train_metrics', True) + train_metrics = [LanguageCrossEntropy(), + LanguagePerplexity()] if use_train_metrics else [] eval_metrics = [ LanguageCrossEntropy(), LanguagePerplexity(), From cf015dd6a418abea854a85e5670522ee146cfd0a Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Tue, 3 Oct 2023 09:59:13 -0700 Subject: [PATCH 2/3] fix pins (#646) --- mcli/mcli-llama2-finetune.yaml | 1 + mcli/mcli-openai-eval.yaml | 4 ++-- mcli/mcli-pretokenize-oci-upload.yaml | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mcli/mcli-llama2-finetune.yaml b/mcli/mcli-llama2-finetune.yaml index 89c9c0cd9c..ae8f57abb6 100644 --- a/mcli/mcli-llama2-finetune.yaml +++ b/mcli/mcli-llama2-finetune.yaml @@ -2,6 +2,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry git_branch: v0.3.0 + # git_commit: # OR use your commit hash pip_install: -e .[gpu] ssh_clone: false # Should be true if using a private repo diff --git a/mcli/mcli-openai-eval.yaml b/mcli/mcli-openai-eval.yaml index 6275d9d578..0b770626b9 100644 --- a/mcli/mcli-openai-eval.yaml +++ b/mcli/mcli-openai-eval.yaml @@ -1,8 +1,8 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: # use your branch - # git_commit: 29d65cc26853c09f6de7542978056ddb0b07e98c # OR use your commit hash + git_branch: v0.3.0 + # git_commit: # OR use your commit hash pip_install: -e ".[gpu,openai]" ssh_clone: false # Should be true if using a private repo diff --git a/mcli/mcli-pretokenize-oci-upload.yaml b/mcli/mcli-pretokenize-oci-upload.yaml index 8163d8c3bd..b585b5f5f2 100644 --- a/mcli/mcli-pretokenize-oci-upload.yaml +++ b/mcli/mcli-pretokenize-oci-upload.yaml @@ -14,7 +14,7 @@ integrations: - oci-cli==3.23.2 - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.2.0 + git_branch: v0.3.0 # git_commit: # OR use your commit hash pip_install: '.' ssh_clone: false # Should be true if using a private repo From cb1d94aed6e0aa4ef7c27d98eb15824db8935b86 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Tue, 3 Oct 2023 14:56:02 -0700 Subject: [PATCH 3/3] Fix overriding of rope_scaling config (#644) --- llmfoundry/models/hf/hf_causal_lm.py | 5 +++++ tests/test_hf_config.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index ce398a8b2d..13857e9bb9 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -109,6 +109,7 @@ def __init__(self, om_model_config: Union[DictConfig, ) attr = getattr(config, k) + # attempt to disallow typos in nested configs if isinstance(attr, Mapping): extra_keys = [ _k for _k in v.keys() if _k not in attr.keys() @@ -120,6 +121,10 @@ def __init__(self, om_model_config: Union[DictConfig, f'Expected (a subset of) keys: {list(attr.keys())}.' ) getattr(config, k).update(v) + # necessary case to allow for rope_scaling to be overriden in llama config + elif attr is None and isinstance(v, Mapping): + setattr(config, k, {}) + getattr(config, k).update(v) else: setattr(config, k, v) diff --git a/tests/test_hf_config.py b/tests/test_hf_config.py index 99d01f309f..5b3bb3d150 100644 --- a/tests/test_hf_config.py +++ b/tests/test_hf_config.py @@ -1,6 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import os import tempfile from copy import deepcopy from pathlib import Path @@ -139,3 +140,30 @@ def test_hf_config_override( assert getattr(hf_model.config, k)[_k] == _v else: assert getattr(hf_model.config, k) == v + + +@pytest.mark.skipif('HUGGING_FACE_HUB_TOKEN' not in os.environ, + reason='CI does not have access to llama2') +def test_rope_scaling_override(): + model_cfg = { + 'name': 'hf_causal_lm', + 'pretrained_model_name_or_path': 'meta-llama/Llama-2-7b-hf', + 'config_overrides': { + 'num_hidden_layers': 2, + 'hidden_size': 32, + 'intermediate_size': 64, + 'rope_scaling': { + 'type': 'dynamic', + 'factor': 0.5 + } + }, + 'use_auth_token': True, + 'pretrained': False, + 'init_device': 'cpu', + } + model_cfg = om.create(model_cfg) + + model = COMPOSER_MODEL_REGISTRY[model_cfg.name](model_cfg, tokenizer=None) + # This would error if the config isn't parsed into a proper dictionary + model.get_metadata() + assert model.config.rope_scaling == {'type': 'dynamic', 'factor': 0.5}