diff --git a/llmfoundry/__init__.py b/llmfoundry/__init__.py index 85f96aadb9..87504d26b3 100644 --- a/llmfoundry/__init__.py +++ b/llmfoundry/__init__.py @@ -4,6 +4,26 @@ import torch try: + import warnings + + # bitsandbytes is a very noisy library. A lot of it is print statements that we can't easily suppress, + # but we can at least suppress a bunch of spurious warnings. + warnings.filterwarnings('ignore', + category=UserWarning, + module='bitsandbytes') + + import logging + + from llmfoundry.utils.logging_utils import SpecificWarningFilter + + # Filter out Hugging Face warning for not using a pinned revision of the model + hf_dynamic_modules_logger = logging.getLogger( + 'transformers.dynamic_module_utils') + new_files_warning_filter = SpecificWarningFilter( + 'A new version of the following files was downloaded from') + + hf_dynamic_modules_logger.addFilter(new_files_warning_filter) + # Before importing any transformers models, we need to disable transformers flash attention if # we are in an environment with flash attention version <2. Transformers hard errors on a not properly # gated import otherwise. diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 4b80ffef54..21c3558b2d 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -43,6 +43,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: from streaming import StreamingDataset from transformers import PreTrainedTokenizerBase +from llmfoundry.utils.logging_utils import SpecificWarningFilter + log = logging.getLogger(__name__) __all__ = ['dataset_constructor'] @@ -236,7 +238,7 @@ def wrapper(func: Callable) -> Callable: def print_registered_tasks(self) -> None: tasks = sorted(self._task_preprocessing_registry.keys()) - print('\n'.join(tasks)) + log.info('\n'.join(tasks)) def get_preprocessing_fn_from_dict( self, mapping: Union[Dict, DictConfig] @@ -365,6 +367,15 @@ def build_from_hf( with dist.local_rank_zero_download_and_wait(signal_file_path): pass + hf_tokenization_logger = logging.getLogger( + 'transformers.tokenization_utils_base') + sequence_length_warning_filter = SpecificWarningFilter( + 'Token indices sequence length is longer than the specified maximum sequence length' + ) + + # We will trim examples later in the collate_fn, so we want to silence this warning from Hugging Face + hf_tokenization_logger.addFilter(sequence_length_warning_filter) + error: Optional[Exception] = None filtered_dataset = None try: @@ -433,6 +444,9 @@ def filter_long_or_empty_examples(example: Dict) -> bool: log.error('Error during data prep') raise error log.debug('All ranks finished data prep') + + hf_tokenization_logger.removeFilter(sequence_length_warning_filter) + assert filtered_dataset is not None return filtered_dataset diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index d52633a09b..fcac57d817 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -20,7 +20,7 @@ from composer.utils import dist from omegaconf import DictConfig from torch import nn -from transformers import (AutoConfig, AutoModelForCausalLM, +from transformers import (AutoConfig, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase) from llmfoundry.models.hf.hf_fsdp import hf_get_init_device @@ -102,20 +102,27 @@ def __init__(self, om_model_config: Union[DictConfig, 'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. ' + 'Please install flash_attn==2.3.2`.') + requested_attention_implementation = 'flash_attention_2' if use_flash_attention_2 else 'eager' config = AutoConfig.from_pretrained( om_model_config.pretrained_model_name_or_path, trust_remote_code=trust_remote_code, use_auth_token=use_auth_token, + attn_implementation=requested_attention_implementation, + use_cache= + False, # Necessary due to https://github.com/huggingface/transformers/issues/28056 ) - # This is not how you are supposed to set this, but transformers currently only - # supports enabling flash attention 2 when using the from_pretrained API. - # We need to support it for both from_pretrained and from_config, so we have to - # set the private attribute here. This will just skip all of transformers' - # validation logic that it is ok to use flash attention 2, so we check - # whether it is installed above, and whether the chosen config supports it here. - # https://github.com/huggingface/transformers/issues/26878 - config._flash_attn_2_enabled = use_flash_attention_2 + # This is not ideal, however Hugging Face's _autoset_attn_implementation function + # forces you to load the model in fp16/bf16 if you want to use flash attention. Rather than loading + # the model and then casting it back to fp32, we are monkeypatching their check. + # https://github.com/huggingface/transformers/issues/28052 + def _autoset_attn_implementation_monkeypatch( + cls, config, *args, **kwargs): # type: ignore + config._attn_implementation = requested_attention_implementation + return config + + PreTrainedModel._autoset_attn_implementation = classmethod( + _autoset_attn_implementation_monkeypatch) # set config overrides for k, v in om_model_config.get('config_overrides', {}).items(): @@ -184,7 +191,8 @@ def __init__(self, om_model_config: Union[DictConfig, trust_remote_code=trust_remote_code, use_auth_token=use_auth_token, load_in_8bit=load_in_8bit, - config=config) + config=config, + ) else: model = AutoModelForCausalLM.from_config( config, diff --git a/llmfoundry/utils/logging_utils.py b/llmfoundry/utils/logging_utils.py new file mode 100644 index 0000000000..081a06fefb --- /dev/null +++ b/llmfoundry/utils/logging_utils.py @@ -0,0 +1,21 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import logging + + +class SpecificWarningFilter(logging.Filter): + + def __init__(self, message_to_suppress: str): + """Filter out a specific warning message based on its content. + + This can be useful for filtering out specific warning messages from third party packages. + + Args: + message_to_suppress (str): The warning message to suppress. + """ + super().__init__() + self.message_to_suppress = message_to_suppress + + def filter(self, record: logging.LogRecord) -> bool: + return self.message_to_suppress not in record.getMessage() diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index 369a894720..5c74b9fd8f 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -26,6 +26,8 @@ build_tokenizer) from llmfoundry.utils.config_utils import pop_config, process_init_device +log = logging.getLogger(__name__) + def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, num_retries: int) -> ComposerModel: @@ -65,7 +67,7 @@ def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, if retries >= num_retries: raise e else: - print( + log.info( f'Got exception {str(e)} while loading model {model_cfg.name}. {num_retries-retries} retries remaining' ) @@ -89,7 +91,7 @@ def load_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, if retries >= num_retries: raise e else: - print( + log.info( f'Got exception {str(e)} while loading model {model_cfg.name}. {num_retries-retries} retries remaining' ) @@ -116,7 +118,7 @@ def evaluate_model( icl_subset_num_batches: Optional[int], ): - print(f'Evaluating model: {model_cfg.model_name}', flush=True) + log.info(f'Evaluating model: {model_cfg.model_name}') # Build tokenizer and model tokenizer_cfg: Dict[str, Any] = om.to_container(model_cfg.tokenizer, @@ -200,7 +202,7 @@ def evaluate_model( if torch.cuda.is_available(): torch.cuda.synchronize() b = time.time() - print(f'Ran {model_cfg.model_name} eval in: {b-a} seconds') + log.info(f'Ran {model_cfg.model_name} eval in: {b-a} seconds') return (trainer, logger_keys, eval_gauntlet_callback, eval_gauntlet_df) @@ -215,7 +217,7 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: must_exist=False, default_value=None) if eval_gauntlet_config: - print( + warnings.warn( 'Use of the key `model_gauntlet` is deprecated, please use the key `eval_gauntlet`' ) diff --git a/scripts/train/train.py b/scripts/train/train.py index 809f2fb09c..2c1099ff00 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -34,6 +34,8 @@ process_init_device, update_batch_size_info) +log = logging.getLogger(__name__) + def validate_config(cfg: DictConfig): """Validates compatible model and dataloader selection.""" @@ -138,17 +140,17 @@ def build_composer_peft_model( + f'Error encountered: {e}') # 1) loads a hf model, 2) adds peft modules, 3) wraps it in a ComposerHFCausalLM. - print('Building Lora config...') + log.info('Building Lora config...') lora_cfg = LoraConfig(**lora_args) - print('Building model from HuggingFace checkpoint...') + log.info('Building model from HuggingFace checkpoint...') model = MPTForCausalLM.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) - print('Model built!') + log.info('Model built!') - print('Adding Lora modules...') + log.info('Adding Lora modules...') model = get_peft_model(model, lora_cfg) - print('Lora modules added!') + log.info('Lora modules added!') model = ComposerHFCausalLM(model, tokenizer) @@ -163,7 +165,7 @@ def print_trainable_parameters(model: torch.nn.Module) -> None: all_param += param.numel() if param.requires_grad: trainable_params += param.numel() - print( + log.info( f'trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}' ) @@ -260,9 +262,9 @@ def main(cfg: DictConfig) -> Trainer: must_exist=False, default_value=None) if eval_gauntlet_config is not None: - print( - 'Use of the key `model_gauntlet` is deprecated, please use the key `eval_gauntlet`' - ) + warnings.warn( + 'Use of the key `model_gauntlet` is deprecated, please use the key `eval_gauntlet`', + DeprecationWarning) icl_subset_num_batches: Optional[int] = pop_config(cfg, 'icl_subset_num_batches', must_exist=False, @@ -398,7 +400,7 @@ def main(cfg: DictConfig) -> Trainer: autoresume_default = True if cfg.get('autoresume') is None and autoresume_default: - print('As run_name, save_folder, and save_latest_filename are set, \ + log.info('As run_name, save_folder, and save_latest_filename are set, \ changing autoresume default to True...') autoresume: bool = pop_config(cfg, @@ -514,7 +516,7 @@ def main(cfg: DictConfig) -> Trainer: ] if algorithm_configs else None # Dataloaders - print('Building train loader...') + log.info('Building train loader...') train_loader = build_dataloader( train_loader_config, tokenizer, @@ -525,7 +527,7 @@ def main(cfg: DictConfig) -> Trainer: mosaicml_logger.log_metrics({'data_validated': time.time()}) ## Evaluation - print('Building eval loader...') + log.info('Building eval loader...') eval_icl_seq_len: int = icl_seq_len if icl_seq_len else max_seq_len evaluators, _, eval_gauntlet_callback = build_evaluators( eval_loader_config, @@ -541,7 +543,7 @@ def main(cfg: DictConfig) -> Trainer: callbacks.append(eval_gauntlet_callback) # Build Model - print('Initializing model...') + log.info('Initializing model...') with init_context: if lora_config is not None: # frozen model + trainable lora modules model: ComposerHFCausalLM = build_composer_peft_model( @@ -570,7 +572,7 @@ def main(cfg: DictConfig) -> Trainer: evaluators = add_metrics_to_eval_loaders(evaluators, train_metrics) # Build the Trainer - print('Building trainer...') + log.info('Building trainer...') trainer = Trainer( run_name=run_name, seed=seed, @@ -609,7 +611,7 @@ def main(cfg: DictConfig) -> Trainer: compile_config=compile_config, ) - print('Logging config') + log.info('Logging config') log_config(logged_cfg) torch.cuda.empty_cache() gc.collect() @@ -618,10 +620,10 @@ def main(cfg: DictConfig) -> Trainer: if eval_first and trainer.state.timestamp.batch.value == 0: trainer.eval() - print('Starting training...') + log.info('Starting training...') trainer.fit() - print('Done.') + log.info('Done.') return trainer diff --git a/setup.py b/setup.py index 9853aa17bf..2283e60d9c 100644 --- a/setup.py +++ b/setup.py @@ -48,11 +48,11 @@ install_requires = [ 'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.17.1,<0.18', - 'accelerate>=0.20,<0.21', # for HF inference `device_map` - 'transformers>=4.34.1,<4.35', + 'accelerate>=0.25,<0.26', # for HF inference `device_map` + 'transformers>=4.36,<4.37', 'mosaicml-streaming>=0.7.1,<0.8', 'torch>=2.1,<2.1.1', - 'datasets>=2.14.5,<2.15', + 'datasets==2.15.0', 'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data 'sentencepiece==0.1.97', 'einops==0.5.0', diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 94a2d66c6e..28fb9219f8 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -258,7 +258,7 @@ def test_callback_inits(): @pytest.mark.parametrize( 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', [('3ba', '2ba', '4ba', 2, 2), ('1dur', '2ba', '1ep', 1, 2)]) -@patch('os.cpu_count', MagicMock(return_value=None)) +@patch('os.cpu_count', MagicMock(return_value=1)) def test_huggingface_conversion_callback_interval( tmp_path: pathlib.Path, log_to_mlflow: bool, hf_save_interval: str, save_interval: str, max_duration: str, expected_hf_checkpoints: int, @@ -381,7 +381,7 @@ def test_huggingface_conversion_callback_interval( @pytest.mark.parametrize( 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', [('1ba', '1ba', '1ba', 1, 1)]) -@patch('os.cpu_count', MagicMock(return_value=None)) +@patch('os.cpu_count', MagicMock(return_value=1)) def test_huggingface_conversion_callback( model: str, tmp_path: pathlib.Path, diff --git a/tests/fixtures/data.py b/tests/fixtures/data.py index 16dd01347d..9ba053ffe8 100644 --- a/tests/fixtures/data.py +++ b/tests/fixtures/data.py @@ -26,7 +26,7 @@ def tiny_ft_dataset_path(tmp_path: Path, dataset_size: int = 4) -> Path: @fixture -@patch('os.cpu_count', MagicMock(return_value=None)) +@patch('os.cpu_count', MagicMock(return_value=1)) def tiny_ft_dataloader(tiny_ft_dataset_path: Path, mpt_tokenizer: PreTrainedTokenizerBase, max_seq_len: int = 128, diff --git a/tests/models/layers/test_huggingface_flash.py b/tests/models/layers/test_huggingface_flash.py index 70c08c4eb1..411aab77a2 100644 --- a/tests/models/layers/test_huggingface_flash.py +++ b/tests/models/layers/test_huggingface_flash.py @@ -159,9 +159,11 @@ def test_attn_patch_integration(patch: str): @pytest.mark.gpu +@pytest.mark.world_size(2) @pytest.mark.parametrize('model_name', ['llama2', 'mistral']) @pytest.mark.parametrize('use_flash_attention_2', [True, False]) -def test_flash2(model_name: str, use_flash_attention_2: bool): +@pytest.mark.parametrize('init_device', ['cpu', 'mixed', 'meta']) +def test_flash2(model_name: str, use_flash_attention_2: bool, init_device: str): if model_name == 'llama2': if 'HUGGING_FACE_HUB_TOKEN' not in os.environ: pytest.skip( @@ -177,7 +179,7 @@ def test_flash2(model_name: str, use_flash_attention_2: bool): }, 'use_auth_token': True, 'pretrained': False, - 'init_device': 'cpu', + 'init_device': init_device, } tokenizer_name = 'meta-llama/Llama-2-7b-hf' @@ -228,21 +230,27 @@ def test_flash2(model_name: str, use_flash_attention_2: bool): model = COMPOSER_MODEL_REGISTRY[model_cfg['name']](model_cfg, tokenizer) # check that it actually used flash attention 2 - assert model.model.config._flash_attn_2_enabled if use_flash_attention_2 else not model.model.config._flash_attn_2_enabled + assert model.model.config._attn_implementation == ( + 'flash_attention_2' if use_flash_attention_2 else 'eager') attention_layer = rgetattr( rgetattr(model, attention_layers_attr)[0], attention_attr) assert isinstance(attention_layer, flash_attn_class) - tokenized_input = tokenizer(['Hello world blah blah', 'Goodbye world'], - return_tensors='pt', - padding=True) - tokenized_input['labels'] = tokenized_input['input_ids'].clone() - - tokenized_input = {k: v.cuda() for k, v in tokenized_input.items()} - model.to('cuda') - - with get_precision_context('amp_bf16'): - # We're just testing that flash attention 2 runs okay - outputs = model(tokenized_input) - loss = outputs.loss - loss.backward() + # Skip attempting to run forward/backward when some devices have meta params + # because we are not instantiating a full Trainer here, which contains the logic + # to move params off of meta device. + if init_device == 'cpu': + tokenized_input = tokenizer( + ['Hello world blah blah', 'Goodbye world'], + return_tensors='pt', + padding=True) + tokenized_input['labels'] = tokenized_input['input_ids'].clone() + + tokenized_input = {k: v.cuda() for k, v in tokenized_input.items()} + model.to('cuda') + + with get_precision_context('amp_bf16'): + # We're just testing that flash attention 2 runs okay + outputs = model(tokenized_input) + loss = outputs.loss + loss.backward() diff --git a/tests/models/test_model.py b/tests/models/test_model.py index c61e963e55..3b2fc22ee3 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -1940,7 +1940,7 @@ def test_hf_init(tmp_path: pathlib.Path, precision = Precision('amp_bf16') hf_config = MPTConfig( - init_device=init_device, + init_device='cpu', d_model=32, n_heads=4, n_layers=1,