diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 81a27321f3..d52633a09b 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -5,16 +5,18 @@ import logging import os +import warnings from typing import Mapping, Union # required for loading a python model into composer import transformers -from composer.metrics.nlp import ( - InContextLearningCodeEvalAccuracy, - InContextLearningLMAccuracy, InContextLearningLMExpectedCalibrationError, - InContextLearningMCExpectedCalibrationError, - InContextLearningMultipleChoiceAccuracy, InContextLearningQAAccuracy, - LanguageCrossEntropy, LanguagePerplexity) +from composer.metrics.nlp import (InContextLearningCodeEvalAccuracy, + InContextLearningLMAccuracy, + InContextLearningLMExpectedCalibrationError, + InContextLearningMCExpectedCalibrationError, + InContextLearningMultipleChoiceAccuracy, + InContextLearningQAAccuracy, + LanguageCrossEntropy, LanguagePerplexity) from composer.utils import dist from omegaconf import DictConfig from torch import nn @@ -156,6 +158,24 @@ def __init__(self, om_model_config: Union[DictConfig, if dist.get_local_rank() != 0 and init_device == 'mixed': om_model_config.pretrained = False + # If the HuggingFace model is coming from a local folder, Hugging Face copies the modules into the + # transformers modules cache. On particular systems, this operation seems to cause contention between + # the different processes. To avoid this contention, we first create the model (on meta device) on local rank + # zero. This will set up the transformers model cache and avoid the future contention. + if dist.get_local_rank() == 0 and os.path.isdir( + om_model_config.pretrained_model_name_or_path): + with init_empty_weights(include_buffers=False): + with warnings.catch_warnings(): + warnings.simplefilter('ignore', UserWarning) + AutoModelForCausalLM.from_pretrained( + om_model_config.pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + use_auth_token=use_auth_token, + config=config, + ) + + dist.barrier() + # initialize the model on the correct device if resolved_init_device == 'cpu': if om_model_config.pretrained: diff --git a/mcli/mcli-hf-eval.yaml b/mcli/mcli-hf-eval.yaml index 8ab4cec481..accff7d5c0 100644 --- a/mcli/mcli-hf-eval.yaml +++ b/mcli/mcli-hf-eval.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: execution_prediction # v0.2.0 + git_branch: v0.3.0 # git_commit: # OR use your commit hash pip_install: -e ".[gpu]" ssh_clone: false # Should be true if using a private repo @@ -11,10 +11,10 @@ command: | composer eval/eval.py /mnt/config/parameters.yaml # Mosaic Cloud will use run_name (with a unique suffix) to populate the env var $RUN_NAME -run_name: mpt-gauntlet-v0.1 +run_name: mpt-eval gpu_num: 8 -gpu_type: a100_80gb -cluster: r1z1 # replace with your cluster here! +# gpu_type: +# cluster: # replace with your cluster here! image: mosaicml/llm-foundry:2.0.1_cu118-latest @@ -31,7 +31,7 @@ parameters: model_name: mosaicml/mpt-7b-instruct # Tokenizer tokenizer: - name: mosaicml/mpt-7b-instruct + name: EleutherAI/gpt-neox-20b kwargs: model_max_length: ${max_seq_len} @@ -41,14 +41,14 @@ parameters: init_device: mixed pretrained: true use_auth_token: false - # FSDP config for model sharding fsdp_config: sharding_strategy: FULL_SHARD - mixed_precision: PURE + mixed_precision: FULL forward_prefetch: True limit_all_gathers: True + icl_tasks: 'eval/yamls/tasks.yaml' eval_gauntlet: 'eval/yamls/eval_gauntlet.yaml' diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index 069ffb8dc4..02a5d1f862 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -117,6 +117,7 @@ def evaluate_model( tokenizer_name = tokenizer_cfg['name'] tokenizer_kwargs = tokenizer_cfg.get('kwargs', {}) tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) + evaluators, logger_keys, eval_gauntlet_callback = build_icl_data_and_gauntlet( icl_tasks, eval_gauntlet_config, tokenizer, device_eval_batch_size, max_seq_len, icl_subset_num_batches) @@ -173,6 +174,7 @@ def evaluate_model( dist_timeout=dist_timeout, python_log_level=python_log_level, ) + if torch.cuda.is_available(): torch.cuda.synchronize() a = time.time() @@ -315,7 +317,6 @@ def main(cfg: DictConfig): row = {'model_name': model_cfg['model_name']} row.update( {k.split('/')[-1]: v for k, v in composite_scores.items()}) - eval_gauntlet_df = pd.concat( [eval_gauntlet_df, pd.DataFrame([row])], ignore_index=True)