Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add gradient checkpointing for the llama's final layernorm module #23990

Closed
wants to merge 1 commit into from
Closed

Conversation

zhaoqf123
Copy link
Contributor

@zhaoqf123 zhaoqf123 commented Jun 4, 2023

Without this, when tuning with LoRA + gradient checkpointing, the last transformer layer's LoRA weights won't be updated!

For example, if we use this callback to log the weight change of LoRA weights in each layer, we will find that no weight update for the last layer in TensorBoard.

class ParamsTensorBoardCallback(TensorBoardCallback):
    def __init__(self, tb_writer=None, params=None, process_name=lambda x:x):
        super().__init__(tb_writer)
        self.params = params
        self._process_name = process_name

    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % args.logging_steps == 0:
            dict_ = {}
            model = kwargs["model"]
            for name in self.params:
                param = model.get_parameter(name)
                param = param.flatten()
                name_p = self._process_name(name)
                dict_tmp = {
                    f"{name_p}_mean": param.mean().item(),
                    f"{name_p}_max": param.max().item(),
                    f"{name_p}_q75": param.quantile(0.75).item(),
                    f"{name_p}_q25": param.quantile(0.25).item(),
                    f"{name_p}_min": param.min().item(),
                    f"{name_p}_median": param.median().item(),
                    f"{name_p}_std": param.std().item(),
                }
                dict_.update(dict_tmp)
            self.on_log(args, state, control, logs=dict_, **kwargs)

def get_params_for_logging(model):
    ls_params = []
    for name, param in model.named_parameters():
        if param.requires_grad:
            ls_params.append(name)
    return ls_params

ls_params = get_params_for_logging(model)
tb_cb = ParamsTensorBoardCallback(
    None, ls_params, process_name=lambda x: x[30:]
)

trainer = Trainer(
        model=model,
        train_dataset=train_data,
        eval_dataset=val_data,
        args=args,
        data_collator=data_collator,
        callbacks=[tb_cb]
    )

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Without this, when tuning with LoRA + gradient checkpointing, the last transformer layer's LoRA weights won't be updated! 

For example, if we use this callback to log the weight change of LoRA weights in each layer, we will find that no weight update for the last layer in TensorBoard.

```
class ParamsTensorBoardCallback(TensorBoardCallback):
    def __init__(self, tb_writer=None, params=None, process_name=lambda x:x):
        super().__init__(tb_writer)
        self.params = params
        self._process_name = process_name

    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % args.logging_steps == 0:
            dict_ = {}
            model = kwargs["model"]
            for name in self.params:
                param = model.get_parameter(name)
                param = param.flatten()
                name_p = self._process_name(name)
                dict_tmp = {
                    f"{name_p}_mean": param.mean().item(),
                    f"{name_p}_max": param.max().item(),
                    f"{name_p}_q75": param.quantile(0.75).item(),
                    f"{name_p}_q25": param.quantile(0.25).item(),
                    f"{name_p}_min": param.min().item(),
                    f"{name_p}_median": param.median().item(),
                    f"{name_p}_std": param.std().item(),
                }
                dict_.update(dict_tmp)
            self.on_log(args, state, control, logs=dict_, **kwargs)

def get_params_for_logging(model):
    ls_params = []
    for name, param in model.named_parameters():
        if param.requires_grad:
            ls_params.append(name)
    return ls_params

ls_params = get_params_for_logging(model)
tb_cb = ParamsTensorBoardCallback(
    None, ls_params, process_name=lambda x: x[30:]
)

trainer = Trainer(
        model=model,
        train_dataset=train_data,
        eval_dataset=val_data,
        args=args,
        data_collator=data_collator,
        callbacks=[tb_cb]
    )
```
@sgugger
Copy link
Collaborator

sgugger commented Jun 5, 2023

cc @younesbelkada

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 5, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @zhaoqf123
Thanks for bringing this up!
Sadly I couldn't reproduce the issue, here is the snippet I used:

import torch
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training

model_id = "huggyllama/llama-7b"

config = LoraConfig(
    r=16, 
    lora_alpha=32, 
    lora_dropout=0.05, 
    bias="none", 
    task_type="CAUSAL_LM"
)

model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_8bit=True)

# this should activate gradient checkpointing
model = prepare_model_for_int8_training(model)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

model = get_peft_model(model, config)

assert model.training and model.is_gradient_checkpointing

dummy_input = torch.LongTensor([[0, 1, 0, 1]]).to(0)
logits = model(dummy_input).logits
loss = logits.sum()
loss.backward()
optimizer.step()

for n, param in model.named_parameters():
    if "lora" in n:
        assert param.grad is not None

And as you can see the gradients are always non-None. Per my understanding as long as the weight have an associated gradient its value will be updated.

@zhaoqf123
Copy link
Contributor Author

zhaoqf123 commented Jun 6, 2023

Hi @zhaoqf123 Thanks for bringing this up! Sadly I couldn't reproduce the issue, here is the snippet I used:

import torch
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training

model_id = "huggyllama/llama-7b"

config = LoraConfig(
    r=16, 
    lora_alpha=32, 
    lora_dropout=0.05, 
    bias="none", 
    task_type="CAUSAL_LM"
)

model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_8bit=True)

# this should activate gradient checkpointing
model = prepare_model_for_int8_training(model)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

model = get_peft_model(model, config)

assert model.training and model.is_gradient_checkpointing

dummy_input = torch.LongTensor([[0, 1, 0, 1]]).to(0)
logits = model(dummy_input).logits
loss = logits.sum()
loss.backward()
optimizer.step()

for n, param in model.named_parameters():
    if "lora" in n:
        assert param.grad is not None

And as you can see the gradients are always non-None. Per my understanding as long as the weight have an associated gradient its value will be updated.

@younesbelkada Thank you for your reply. I modify your script based on my training setup with V100 GPU as follows, and it can be reproduced.

import torch
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training

# 1. load pretrained model
# model_id = "huggyllama/llama-7b"
model_id = "decapoda-research/llama-7b-hf"
cache_dir = "/mnt/workspace/kgg/hf_models"

model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=cache_dir, device_map="auto", load_in_8bit=True)

# this should activate gradient checkpointing
model = prepare_model_for_int8_training(model)

# 2. config peft model
config = LoraConfig(
    r=16, 
    lora_alpha=32, 
    lora_dropout=0.05, 
    bias="none", 
    task_type="CAUSAL_LM",
    # target_modules=["layers.31.self_attn.q_proj"]
)
model = get_peft_model(model, config)

assert model.training and model.is_gradient_checkpointing

# 3. set up optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# 4. train
with torch.autocast("cuda"):
    dummy_input = torch.LongTensor([[0, 1, 0, 1]]).to(0)
    model.train()
    logits = model(dummy_input).logits
    loss = logits.sum()

    loss.backward()
    optimizer.step()

    for n, param in model.named_parameters():
        if "lora" in n:
            print(n)
            assert param.grad is not None

You can see that the params of the last-layer (layer31) has None grad.

The main differences of the codes from yours is 3 parts:

  1. The optimizer setup is after get_peft_model
  2. with torch.autocast("cuda")
  3. model.train() as in the trainsformers/trainer.py script

By the way, my torch version is 2.1.0a0+fe05266

@younesbelkada
Copy link
Contributor

younesbelkada commented Jun 6, 2023

Indeed I also managed to reproduce, this time with the latest stable version of torch, also note that this bug also occurs with any other model, for instance OPT.

import torch
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training

# 1. load pretrained model
# model_id = "huggyllama/llama-7b"
model_id = "facebook/opt-350m"
# model_id = "decapoda-research/llama-7b-hf"

model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_8bit=True)

# this should activate gradient checkpointing
model = prepare_model_for_int8_training(model)

# 2. config peft model
config = LoraConfig(
    r=16, 
    lora_alpha=32, 
    lora_dropout=0.05, 
    bias="none", 
    task_type="CAUSAL_LM",
    # target_modules=["layers.31.self_attn.q_proj"]
)
model = get_peft_model(model, config)

assert model.training and model.is_gradient_checkpointing

# 3. set up optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# 4. train
with torch.autocast("cuda"):
    dummy_input = torch.LongTensor([[0, 1, 0, 1]]).to(0)
    model.train()
    logits = model(dummy_input).logits
    loss = logits.sum()

    loss.backward()
    optimizer.step()

    for n, param in model.named_parameters():
        if "lora" in n:
            print(n)
            assert param.grad is not None

However, it seems that the bug disappears when the torch.autocast("cuda") context manager is removed.
It appears the issue can be reproduced even without PEFT:

import torch
from transformers import AutoModelForCausalLM

model_id = "facebook/opt-350m"

model = AutoModelForCausalLM.from_pretrained(model_id).to(0)
model.gradient_checkpointing_enable()
model.train()

assert model.training and model.is_gradient_checkpointing

# 3. set up optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# 4. train
with torch.cuda.amp.autocast(True, dtype=torch.float16):
    dummy_input = torch.LongTensor([[0, 1, 0, 1]]).to(0)
    model.train()
    logits = model(dummy_input).logits
    loss = logits.sum()

    loss.backward()
    optimizer.step()

    for n, param in model.named_parameters():
        if param.grad is None:
            print(n)

And this gives:

model.decoder.layers.23.self_attn.k_proj.weight
model.decoder.layers.23.self_attn.k_proj.bias
model.decoder.layers.23.self_attn.v_proj.weight
model.decoder.layers.23.self_attn.v_proj.bias
model.decoder.layers.23.self_attn.q_proj.weight
model.decoder.layers.23.self_attn.q_proj.bias
model.decoder.layers.23.self_attn.out_proj.weight
model.decoder.layers.23.self_attn.out_proj.bias
model.decoder.layers.23.fc1.weight
model.decoder.layers.23.fc1.bias
model.decoder.layers.23.fc2.weight
model.decoder.layers.23.fc2.bias

Meaning the entire last layer doesn't get updated.

From what I can see in the trainer, currently we support mixed precision autocast (torch.xxx.amp) context managers: https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L2688-L2705 - and replacing the context manager you have put to torch.cuda.amp.autocast(True, dtype=torch.float16) reproduces also the bug.
I am not sure if this is a bug on transformers side or torch but I would say OK to merge this and apply this patch to other common architectures (by opening a good first issue maybe?).

Wdyt @sgugger @amyeroberts @ArthurZucker

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we should merge this as it is a workaround for a bug in PyTorch and could have unintended consequences outside of the bug it fixes. We can definitely leave this open to show anyone how to work around the exact issue, but I feel this should be fixed in PyTorch.

@ArthurZucker
Copy link
Collaborator

In line with what @sgugger said, also not sure it even makes a lot of sense to checkpoint something as small as the layer norm grads? Thanks for flagging the issue and proposing a fix!

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @zhaoqf123
Thanks a lot for helping us finding this important issue. After some digging and internal discussion, we have found a broader fix that includes most models that supports gradient checkpointing: #24247 . To credit you from your help, I have added you as a co-author in that PR and we will close this PR once #24247 will get merged
Thanks a lot !

@zhaoqf123
Copy link
Contributor Author

Hi @zhaoqf123 Thanks a lot for helping us finding this important issue. After some digging and internal discussion, we have found a broader fix that includes most models that supports gradient checkpointing: #24247 . To credit you from your help, I have added you as a co-author in that PR and we will close this PR once #24247 will get merged Thanks a lot !

@younesbelkada Thank you for your acknowledgement. Although I have several years of experiences in machine learning (with tf), I just start using transformers and pytorch for couple of months. It really took me 4 days and nights to figure out where the bug occurs and a workaround solution.

Thank you very much for the transformers and peft project. They are really very helpful.

@younesbelkada
Copy link
Contributor

Closing the PR as #24247 being merged
Again thanks so much @zhaoqf123 for all your help on this and your great investigation!

@younesbelkada
Copy link
Contributor

younesbelkada commented Jun 22, 2023

hi @zhaoqf123
Some training setups that were running fine in a single T4 (with 7GB peak memory) now OOM with that PR, I wanted to double check if you observe the same behaviour in your case as well?

For reference, check: #24420 (comment)

@younesbelkada
Copy link
Contributor

younesbelkada commented Jun 22, 2023

Hi @zhaoqf123
@pacman100 has found the rootcause of your original issue and we found out that the recent accelerate integration of Trainer silently fixed your bug. I can confirm I don't get any None grad with llama models using Trainer + autocast: #24420 (comment) | I believe 3 weeks ago the Trainer + accelerate integration was not released yet that could explain why you had the bug
Can you try out your script after we revert the PR and let us know?
Thanks !

@zhaoqf123
Copy link
Contributor Author

hi @zhaoqf123 Some training setups that were running fine in a single T4 (with 7GB peak memory) now OOM with that PR, I wanted to double check if you observe the same behaviour in your case as well?

For reference, check: #24420 (comment)

@younesbelkada Sorry for the late reply. Just got vocation last 3 days.

Yes, I also noticed that the memory consumption increased a lot when making the last layer updatable. For llama 7B, when using V100-32GB, the VRAM increases from 34% to 42%, which is not proportional to the increase of updatable params.

@zhaoqf123
Copy link
Contributor Author

Hi @zhaoqf123 @pacman100 has found the rootcause of your original issue and we found out that the recent accelerate integration of Trainer silently fixed your bug. I can confirm I don't get any None grad with llama models using Trainer + autocast: #24420 (comment) | I believe 3 weeks ago the Trainer + accelerate integration was not released yet that could explain why you had the bug Can you try out your script after we revert the PR and let us know? Thanks !

@younesbelkada May I know how should I try out? For example, re-install transformer: pip install --upgrade git+https://github.com/huggingface/transformers.git, and then run my script without with torch.autocast("cuda"):?

@younesbelkada
Copy link
Contributor

@zhaoqf123 thanks for the reply!
Yes you can try out that way, uninstall your current transformers lib, reinstall it from source and see if the original bug still persists

@zhaoqf123
Copy link
Contributor Author

@zhaoqf123 thanks for the reply! Yes you can try out that way, uninstall your current transformers lib, reinstall it from source and see if the original bug still persists

@younesbelkada After re-install transformers from the source, in my V100, if I remove with torch.autocast("cuda"), I encounter this issue. If I don't remove with torch.autocast("cuda"), the last layer still not updatable.

In my 3090 GPU, it works by removing with torch.autocast("cuda"). It could be due to the implementation of bitsandbytes for GPU computability < 7.5. Because GPU<7.5 does not have int8 core production, so bitsandbytes do int8 mutliplication using fp16.

Check also this issue and this issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants