Skip to content

Commit

Permalink
temp comment out
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Dec 15, 2023
1 parent 3407d54 commit 6ab25b4
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,20 @@ def __init__(self, om_model_config: Union[DictConfig,
om_model_config.pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
attn_implementation=requested_attention_implementation,
# attn_implementation=requested_attention_implementation,
)

# This is not ideal, however Hugging Face's _autoset_attn_implementation function
# forces you to load the model in fp16/bf16 if you want to use flash attention. Rather than loading
# the model and then casting it back to fp32, we are monkeypatching their check.
# https://github.com/huggingface/transformers/issues/28052
def _autoset_attn_implementation_monkeypatch(
cls, config, *args, **kwargs): # type: ignore
config._attn_implementation = requested_attention_implementation
return config
# def _autoset_attn_implementation_monkeypatch(
# cls, config, *args, **kwargs): # type: ignore
# config._attn_implementation = requested_attention_implementation
# return config

PreTrainedModel._autoset_attn_implementation = classmethod(
_autoset_attn_implementation_monkeypatch)
# PreTrainedModel._autoset_attn_implementation = classmethod(
# _autoset_attn_implementation_monkeypatch)

# set config overrides
for k, v in om_model_config.get('config_overrides', {}).items():
Expand Down

0 comments on commit 6ab25b4

Please sign in to comment.