diff --git a/llmfoundry/__init__.py b/llmfoundry/__init__.py index bbbbe6bff7..23ffa4e6b4 100644 --- a/llmfoundry/__init__.py +++ b/llmfoundry/__init__.py @@ -5,6 +5,8 @@ try: import warnings + # bitsandbytes is a very noisy library. A lot of it is print statements that we can't easily suppress, + # but we can at least suppress a bunch of spurious warnings. warnings.filterwarnings('ignore', category=UserWarning, module='bitsandbytes') @@ -13,13 +15,12 @@ from llmfoundry.utils.logging_utils import SpecificWarningFilter - # Filter out Hugging Face warning + # Filter out Hugging Face warning for not using a pinned revision of the model hf_dynamic_modules_logger = logging.getLogger( 'transformers.dynamic_module_utils') new_files_warning_filter = SpecificWarningFilter( 'A new version of the following files was downloaded from') - # We will trim examples later in the collate_fn, so we want to silence this warning from Hugging Face hf_dynamic_modules_logger.addFilter(new_files_warning_filter) # Before importing any transformers models, we need to disable transformers flash attention if diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 469b1161a4..dd1bb72697 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -108,11 +108,9 @@ def __init__(self, om_model_config: Union[DictConfig, trust_remote_code=trust_remote_code, use_auth_token=use_auth_token, attn_implementation=requested_attention_implementation, - use_cache=False, + use_cache=False, # Necessary due to https://github.com/huggingface/transformers/issues/28056 ) - # config._flash_attn_2_enabled = use_flash_attention_2 - # This is not ideal, however Hugging Face's _autoset_attn_implementation function # forces you to load the model in fp16/bf16 if you want to use flash attention. Rather than loading # the model and then casting it back to fp32, we are monkeypatching their check.