Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flash attention for gpt_bigcode #26479

Merged
merged 14 commits into from
Oct 31, 2023

Conversation

susnato
Copy link
Contributor

@susnato susnato commented Sep 28, 2023

What does this PR do?

Adds Flash Attention 2 for GPTBigCode (Starcoder) as discussed in in this issue - #26350 .

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

cc : @younesbelkada

@susnato
Copy link
Contributor Author

susnato commented Sep 28, 2023

All tests are all passing except the test_flash_attn_2_generate_use_cache tests(same as opt).
Screenshot from 2023-09-29 02-04-00

Error Traceback ========================================================== FAILURES =========================================================== __________________________________ GPTBigCodeModelTest.test_flash_attn_2_generate_use_cache ___________________________________

self = <tests.models.gpt_bigcode.test_modeling_gpt_bigcode.GPTBigCodeModelTest testMethod=test_flash_attn_2_generate_use_cache>

@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_generate_use_cache(self):
    import torch

    for model_class in self.all_generative_model_classes:
        if not model_class._supports_flash_attn_2:
            return

        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
        model = model_class(config)

        with tempfile.TemporaryDirectory() as tmpdirname:
            model.save_pretrained(tmpdirname)

            dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
            dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)

            model = model_class.from_pretrained(
                tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
            ).to(torch_device)

            # Just test that a large cache works as expected
          _ = model.generate(
                dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=30, do_sample=False
            )

tests/test_modeling_common.py:2936:


../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27: in decorate_context
return func(*args, **kwargs)
src/transformers/generation/utils.py:1606: in generate
return self.greedy_search(
src/transformers/generation/utils.py:2454: in greedy_search
outputs = self(
../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/nn/modules/module.py:1194: in _call_impl
return forward_call(*input, **kwargs)
src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py:1054: in forward
transformer_outputs = self.transformer(
../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/nn/modules/module.py:1194: in _call_impl
return forward_call(*input, **kwargs)
src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py:917: in forward
outputs = block(
../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/nn/modules/module.py:1194: in _call_impl
return forward_call(*input, **kwargs)
src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py:548: in forward
attn_outputs = self.attn(
../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/nn/modules/module.py:1194: in _call_impl
return forward_call(*input, **kwargs)
src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py:376: in forward
attn_output = self._flash_attention_forward(
src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py:443: in _flash_attention_forward
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
../../anaconda3/envs/transformers/lib/python3.9/site-packages/flash_attn/bert_padding.py:208: in pad_input
output = index_put_first_axis(hidden_states, indices, batch * seqlen)


ctx = <torch.autograd.function.IndexPutFirstAxisBackward object at 0x7fea8a486400>
values = tensor([[[-0.0145, 0.0878, -0.0264, 0.0285, 0.0206, 0.0969, 0.0740,
0.0528],
[-0.0148, 0.087...0304, 0.1120, -0.0429, 0.0525, -0.0332, 0.0903, 0.0680,
0.0412]]], device='cuda:0', dtype=torch.float16)
indices = tensor([0, 1], device='cuda:0', dtype=torch.int32), first_axis_dim = 2

@staticmethod
def forward(ctx, values, indices, first_axis_dim):
    ctx.save_for_backward(indices)
    assert indices.ndim == 1
    assert values.ndim >= 2
    output = torch.zeros(
        first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
    )
    # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
  output[indices] = values

E IndexError: tensors used as indices must be long, byte or bool tensors

../../anaconda3/envs/transformers/lib/python3.9/site-packages/flash_attn/bert_padding.py:51: IndexError
---------------------------------------------------- Captured stderr call -----------------------------------------------------
You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU after initializing it on CPU with model.to('cuda').
_________________________________ GPTBigCodeMHAModelTest.test_flash_attn_2_generate_use_cache _________________________________

self = <tests.models.gpt_bigcode.test_modeling_gpt_bigcode.GPTBigCodeMHAModelTest testMethod=test_flash_attn_2_generate_use_cache>

@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_generate_use_cache(self):
    import torch

    for model_class in self.all_generative_model_classes:
        if not model_class._supports_flash_attn_2:
            return

        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
        model = model_class(config)

        with tempfile.TemporaryDirectory() as tmpdirname:
            model.save_pretrained(tmpdirname)

            dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
            dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)

            model = model_class.from_pretrained(
                tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
            ).to(torch_device)

            # Just test that a large cache works as expected
          _ = model.generate(
                dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=30, do_sample=False
            )

tests/test_modeling_common.py:2936:


../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27: in decorate_context
return func(*args, **kwargs)
src/transformers/generation/utils.py:1606: in generate
return self.greedy_search(
src/transformers/generation/utils.py:2454: in greedy_search
outputs = self(
../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/nn/modules/module.py:1194: in _call_impl
return forward_call(*input, **kwargs)
src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py:1054: in forward
transformer_outputs = self.transformer(
../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/nn/modules/module.py:1194: in _call_impl
return forward_call(*input, **kwargs)
src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py:917: in forward
outputs = block(
../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/nn/modules/module.py:1194: in _call_impl
return forward_call(*input, **kwargs)
src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py:548: in forward
attn_outputs = self.attn(
../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/nn/modules/module.py:1194: in _call_impl
return forward_call(*input, **kwargs)
src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py:376: in forward
attn_output = self._flash_attention_forward(
src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py:443: in _flash_attention_forward
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
../../anaconda3/envs/transformers/lib/python3.9/site-packages/flash_attn/bert_padding.py:208: in pad_input
output = index_put_first_axis(hidden_states, indices, batch * seqlen)


ctx = <torch.autograd.function.IndexPutFirstAxisBackward object at 0x7fea872f76d0>
values = tensor([[[-0.0621, 0.0149, -0.0226, -0.0052, -0.0107, 0.0247, -0.0395,
-0.0323],
[ 0.0797, 0.024...0073, 0.0355, -0.0673, 0.0565, -0.0559, 0.0071, 0.0742,
-0.0018]]], device='cuda:0', dtype=torch.float16)
indices = tensor([0, 1], device='cuda:0', dtype=torch.int32), first_axis_dim = 2

@staticmethod
def forward(ctx, values, indices, first_axis_dim):
    ctx.save_for_backward(indices)
    assert indices.ndim == 1
    assert values.ndim >= 2
    output = torch.zeros(
        first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
    )
    # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
  output[indices] = values

E IndexError: tensors used as indices must be long, byte or bool tensors

../../anaconda3/envs/transformers/lib/python3.9/site-packages/flash_attn/bert_padding.py:51: IndexError

At this point, I think the failing of the tests might be caused due to some unknown error on my machine rather than the code since I have this test failing for all flash-attention supported models.

If you don't mind @younesbelkada , could you please checkout this branch and run the tests on your end and let me know the results?

@younesbelkada
Copy link
Contributor

Again, thanks a lot for your amazing contribution @susnato !
Yes OK will check that tomorrow and let you know

@susnato
Copy link
Contributor Author

susnato commented Sep 28, 2023

Thanks a lot @younesbelkada!
BTW if the tests pass on your machine then will it be ready to merge? Or we will need to investigate more why it fails on my end?

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM I can confirm FA tests pass and I have also tried a generation locally and it seems to work great! Thanks for your huge work ! 🙏

@@ -33,6 +33,7 @@ We natively support Flash Attention 2 for the following models:

- Llama
- Falcon
- GPTBigCode
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- GPTBigCode
- GPTBigCode (Starcoder)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 365 to 374
input_dtype = query.dtype
if input_dtype == torch.float32:
logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16."
)
query = query.to(torch.float16)
key = key.to(torch.float16)
value = value.to(torch.float16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be fixed in the global fix I want to apply in #26451 as a follow up PR that I will take care

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should I then remove this block? or are we keeping this block for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say we can keep it for now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's wait until your PR is merge @younesbelkada 😉

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Great work here, would be great going forward to have an estimatate of what kind of speedup we are expecting (and maybe add it to the readme here) or just benchmark that we indeed have a speedup!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@susnato
Copy link
Contributor Author

susnato commented Sep 29, 2023

Hey! Great work here, would be great going forward to have an estimatate of what kind of speedup we are expecting (and maybe add it to the readme here) or just benchmark that we indeed have a speedup!

I would love to do it but I don't really have access to a high quality GPU to show the full performance of flash attention 2 .

@younesbelkada
Copy link
Contributor

@susnato I will run benchmarks for you :)

@susnato
Copy link
Contributor Author

susnato commented Sep 29, 2023

Thanks @younesbelkada!

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again!
I ran the benchmarks and I am getting 15% - 40% speedup depending on the sequence length, given the fact that it makes training and inference more memory efficient I think this is great to have

Screenshot 2023-10-03 at 13 00 08

Screenshot 2023-10-03 at 13 01 45

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! If you can update the Readme following this would be great!

@@ -33,6 +33,7 @@ We natively support Flash Attention 2 for the following models:

- Llama
- Falcon
- GPTBigCode (Starcoder)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as for Llama, let's add a section in the readme of starcoder about expected gains!

@susnato
Copy link
Contributor Author

susnato commented Oct 5, 2023

Thanks for updating the readme of starcoder, @younesbelkada!
Hi @ArthurZucker, anything more required to make this ready for merge?

@younesbelkada
Copy link
Contributor

Hi @susnato for me the changes look great, I will let @ArthurZucker give a final pass and merge!

docs/source/en/perf_infer_gpu_one.md Outdated Show resolved Hide resolved
Comment on lines 365 to 374
input_dtype = query.dtype
if input_dtype == torch.float32:
logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16."
)
query = query.to(torch.float16)
key = key.to(torch.float16)
value = value.to(torch.float16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's wait until your PR is merge @younesbelkada 😉

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to block again!

Comment on lines 234 to 235
padding_mask: Optional[torch.LongTensor] = None,
encoder_padding_mask: Optional[torch.LongTensor] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should wait again for the padding mask refactor and not pass padding mask! #26792

@cmosguy
Copy link

cmosguy commented Oct 29, 2023

@younesbelkada @ArthurZucker i was trying to run the blog tutorial code here: https://github.com/pacman100/DHS-LLM-Workshop/blob/main/personal_copilot/training/train.py

I just pulled in the main line and getting this error.

This was working back in August, but now I am getting errors related to the Flash Attention 2.0 error message above in this issue. Why was this working back in August now it is broken? Is there a pinned version if transformers I should be using?

encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be no more padding_mask now since #26792 has been merged, let me know if you want to handle this, otherwise I can have a look!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey I am currently working to fix this.

@susnato
Copy link
Contributor Author

susnato commented Oct 30, 2023

Hello @younesbelkada, I have pushed the commit to remove padding_mask and all the tests are passing too!

Screenshot from 2023-10-30 17-02-34

Please let me know if any more changes are needed.

@susnato susnato requested a review from younesbelkada October 30, 2023 11:35
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks very much for your great work! LGTM with only 1 nit!

input_dtype = query.dtype
if input_dtype == torch.float32:
# Handle the case where the model is quantized
target_dtype = getattr(self.config, "_pre_quantization_dtype", self.c_attn.weight.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you replace that with a hasattr logic ? the reason behind it is presented here: #26916

@susnato
Copy link
Contributor Author

susnato commented Oct 30, 2023

Hi @younesbelkada, done!

@cmosguy
Copy link

cmosguy commented Oct 30, 2023

@susnato @younesbelkada , sorry to keep bugging you guys here in this PR, but can you please tell me how folks were able to run with Flash Attention back in August using this tutorial from @pacman100

When I run with the flash attention flag, I keep getting:

 File "/opt/lib/python3.10/site-packages/transformers/modeling_utils.py", line 1265, in _check_and_enable_flash_attn_2
        raise ValueError(raise ValueError(

ValueError: The current architecture does not support Flash Attention 2.0. Please open an issue on GitHub to request support for this architecture: https://github.com/huggingface/transformers/issues/newValueError
: The current architecture does not support Flash Attention 2.0. Please open an issue on GitHub to request support for this architecture: https://github.com/huggingface/transformers/issues/new
    raise ValueError(
ValueError: The current architecture does not support Flash Attention 2.0. Please open an issue on GitHub to request support for this architecture: https://github.com/huggingface/transformers/issues/new
    raise ValueError(
ValueError: The current architecture does not support Flash Attention 2.0. Please open an issue on GitHub to request support for this architecture: https://github.com/huggingface/transformers/issues/new

@susnato
Copy link
Contributor Author

susnato commented Oct 30, 2023

Hi @cmosguy, could you please try and checkout this branch (susnato:flash_attn_starcoder) then install the library from this branch and re-run your tutorial and let us know if this error is solved or not?

Right now, StarCoder does not support flash-attention, this PR is adds flash attention feature to the model. So you can wait for the PR to get merged or if it is urgent then you can checkout my branch and install from it(like I said above).

@cmosguy
Copy link

cmosguy commented Oct 30, 2023

@susnato

I just checked out your branch and ran pip install -e . in the transformers library. After installing I get the following output (sorry for long text):

  File "/opt/../DHS-LLM-Workshop/personal_copilot/training/train.py", line 275, in create_and_prepare_model
        model = AutoModelForCausalLM.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(  File "/opt/ds_research/transformers/src/transformers/models/auto/auto_factory.py", line 566, in from_pretrained

  File "/opt/ds_research/transformers/src/transformers/models/auto/auto_factory.py", line 566, in from_pretrained
    model = AutoModelForCausalLM.from_pretrained(
  File "/opt/ds_research/transformers/src/transformers/models/auto/auto_factory.py", line 566, in from_pretrained
    model = AutoModelForCausalLM.from_pretrained(
  File "/opt/ds_research/transformers/src/transformers/models/auto/auto_factory.py", line 566, in from_pretrained
        return model_class.from_pretrained(
  File "/opt/ds_research/transformers/src/transformers/modeling_utils.py", line 3175, in from_pretrained
return model_class.from_pretrained(
  File "/opt/ds_research/transformers/src/transformers/modeling_utils.py", line 3175, in from_pretrained
    return model_class.from_pretrained(
  File "/opt/ds_research/transformers/src/transformers/modeling_utils.py", line 3175, in from_pretrained
    return model_class.from_pretrained(
  File "/opt/ds_research/transformers/src/transformers/modeling_utils.py", line 3175, in from_pretrained
    config = cls._check_and_enable_flash_attn_2(config, torch_dtype=torch_dtype, device_map=device_map)
  File "/opt/ds_research/transformers/src/transformers/modeling_utils.py", line 1275, in _check_and_enable_flash_attn_2
        config = cls._check_and_enable_flash_attn_2(config, torch_dtype=torch_dtype, device_map=device_map)config = cls._check_and_enable_flash_attn_2(config, torch_dtype=torch_dtype, device_map=device_map)
  File "/opt/ds_research/transformers/src/transformers/modeling_utils.py", line 1275, in _check_and_enable_flash_attn_2

  File "/opt/ds_research/transformers/src/transformers/modeling_utils.py", line 1275, in _check_and_enable_flash_attn_2
    config = cls._check_and_enable_flash_attn_2(config, torch_dtype=torch_dtype, device_map=device_map)
  File "/opt/ds_research/transformers/src/transformers/modeling_utils.py", line 1275, in _check_and_enable_flash_attn_2
    raise ImportError(
ImportError: Flash Attention 2 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for installing it. Make sure to have at least the version 2.1.0
            raise ImportError(raise ImportError(

raise ImportError(
ImportErrorImportErrorImportError: : : Flash Attention 2 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for installing it. Make sure to have at least the version 2.1.0Flash Attention 2 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for installing it. Make sure to have at least the version 2.1.0Flash Attention 2 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for installing it. Make sure to have at least the version 2.1.0


[2023-10-30 07:55:00,169] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 42849) of binary: /opt/bin/python
Traceback (most recent call last):
  File "/opt/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/opt/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 47, in main
    args.func(args)
  File "/opt/lib/python3.10/site-packages/accelerate/commands/launch.py", line 985, in launch_command
    multi_gpu_launcher(args)
  File "/opt/lib/python3.10/site-packages/accelerate/commands/launch.py", line 654, in multi_gpu_launcher
    distrib_run.run(args)
  File "/opt/lib/python3.10/site-packages/torch/distributed/run.py", line 797, in run
    elastic_launch(
  File "/opt/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/opt/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
/opt/../DHS-LLM-Workshop/personal_copilot/training/train.py FAILED
------------------------------------------------------------

@susnato
Copy link
Contributor Author

susnato commented Oct 30, 2023

Hi @cmosguy, could you please run pip install flash-attn --no-build-isolation and then re-run the script?

Or check the flash attention version on your system? pip freeze | grep flash (it should show something like - flash-attn==2.3.0) ?

@cmosguy
Copy link

cmosguy commented Oct 30, 2023

Hey @susnato -

I had no idea that I had to install that additonal package. Shouldn't that have been installed with the setup.py from transformers library?

Anyways, I repeated the commands you mentioned:

pip freeze | grep flash
flash-attn==2.3.3

So it looks like it is loading:

You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make 

I am assuming at this point the flash-attn is kicking in.

But then it does this:

AttributeError: 'GPTBigCodeFlashAttention2' object has no attribute 'dropout'    
trainer.train()
  File "/ds_research/transformers/src/transformers/trainer.py", line 1511, in train
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'GPTBigCodeFlashAttention2' object has no attribute 'dropout'   
return model_forward(*args, **kwargs)
  File "/opt/lib/python3.10/site-packages/accelerate/utils/operations.py", line 647, in __call__  

@susnato
Copy link
Contributor Author

susnato commented Oct 30, 2023

Hey @cmosguy, I will checkout the script Wednesday and get back to you if I find a solution :)

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all of the work adding this!

Just a few small nits

docs/source/en/model_doc/gpt_bigcode.md Outdated Show resolved Hide resolved
@@ -34,6 +34,7 @@ We natively support Flash Attention 2 for the following models:
- Llama
- Mistral
- Falcon
- [GPTBigCode (Starcoder)](model_doc/gpt_bigcode#)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we link to the model page here but not for e.g. falcon or llama?

Copy link
Contributor Author

@susnato susnato Oct 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, if I remember it correctly, it was asked by @ArthurZucker , here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK - it doesn't really matter so I'm happy to leave it and find out when Arthur's back what that reasoning is :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok then leaving it as it is.

@susnato
Copy link
Contributor Author

susnato commented Oct 30, 2023

Hi @amyeroberts, I have pushed the changes.

@amyeroberts amyeroberts merged commit b5db8ca into huggingface:main Oct 31, 2023
3 checks passed
@susnato
Copy link
Contributor Author

susnato commented Nov 1, 2023

Hello @cmosguy, the dropout error is fixed now on the main branch.

Could you please install again from the main branch and re-run your script? It should work now.

@susnato
Copy link
Contributor Author

susnato commented Nov 1, 2023

Hi @cmosguy, I ran the script with some minor hyper-parameter changes (to suit my GPU) and it's working!

Screenshot from 2023-11-02 04-25-17
Screenshot from 2023-11-02 04-25-42

Make sure to override (use_flash_attention_2=True) in this line if you feel that Flash Attention is not being used here.

Also don't forget to re-install from the main :).

@cmosguy
Copy link

cmosguy commented Nov 2, 2023

Hey @susnato Thanks for your efforts here, yes I was able to see the training start with the lines:

You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.

Thanks for your help here.

If you do not mind me asking, what happened back in August? The tutorial I mentioned with script had things working before, but you made a lot of edits that indicates this is just now being added. Was it there before then removed? I guess I am trying to understand, because I may be interested in swapping in other models and I cannot fully comprehend when flash attention can be used or not on which model.

@susnato
Copy link
Contributor Author

susnato commented Nov 2, 2023

Hello @cmosguy, I am glad that you were able to start training :).

If you do not mind me asking, what happened back in August? The tutorial I mentioned with script had things working before, but you made a lot of edits that indicates this is just now being added. Was it there before then removed?

As far as I know Flash Attention is just added for this model.
Actually if we look at this commit from the same tutorial you provided, back in August whenever you would load any of the Falcon, llama or starcoder models, your model's attention forward code would be replaced by the custom flash attention code so you would be able to use Flash Attention without any error.(replace_starcoder_attn_with_flash_attn, replace_llama_attn_with_flash_attn, replace_falcon_attn_with_flash_attn).

But then Flash Attention were added to Flacon and Llama models (in transformer main), and @pacman100 removed this block which used to modify the attention code. It was okay for falcon and llama but since starcoder didn't get that feature yet, you were getting the errors whenever you tried to use starcoder. But now starcoder has FlashAttention so it's all good.

Also please note that the flash attention code from the tutorial didn't handle attention_mask but it is properly handled by the transformers main.

I hope this explanation helps, otherwise feel free to tag me if you need any help. :)

I may be interested in swapping in other models and I cannot fully comprehend when flash attention can be used or not on which model.

For every model that supports Flash Attention, will have _supports_flash_attn_2 as True. You can just load the model and check that, for example -

from transformers import AutoModel

model = AutoModel.from_pretrained("bigcode/starcoderbase-1b")
model._supports_flash_attn_2 # will give output as True since it has support for Flash Attention
from transformers import AutoModel

model = AutoModel.from_pretrained("gpt2")
model._supports_flash_attn_2 # will give output as False since it does not has support for Flash Attention (For now).

@cmosguy
Copy link

cmosguy commented Nov 2, 2023

@susnato wow I am learning so much hanging out here. Thank you for walking me through what happened this really explains a lot. I appreciate you taking the time to investigate the issue and writing a coherent explanation. I totally did not see the commit from before, thank you for bringing this to my attention (pun intended). OK, so I am off to the races here with the training in the meantime, cheers!

@susnato
Copy link
Contributor Author

susnato commented Nov 2, 2023

Hey @cmosguy, it was fun for me too!

cheers!

EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
* added flash attention of gpt_bigcode

* changed docs

* Update src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py

* add FA-2 docs

* oops

* Update docs/source/en/perf_infer_gpu_one.md Last Nit

Co-authored-by: Arthur <[email protected]>

* fix

* oops

* remove padding_mask

* change getattr->hasattr logic

* changed .md file

---------

Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: younesbelkada <[email protected]>
Co-authored-by: Arthur <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants