diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index d9778e0b43..2a2ac21950 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -108,6 +108,7 @@ def __init__(self, om_model_config: Union[DictConfig, trust_remote_code=trust_remote_code, use_auth_token=use_auth_token, attn_implementation=requested_attention_implementation, + use_cache=False, ) # config._flash_attn_2_enabled = use_flash_attention_2