-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add gradient checkpointing for the llama's final layernorm module #23990
Conversation
Without this, when tuning with LoRA + gradient checkpointing, the last transformer layer's LoRA weights won't be updated! For example, if we use this callback to log the weight change of LoRA weights in each layer, we will find that no weight update for the last layer in TensorBoard. ``` class ParamsTensorBoardCallback(TensorBoardCallback): def __init__(self, tb_writer=None, params=None, process_name=lambda x:x): super().__init__(tb_writer) self.params = params self._process_name = process_name def on_step_end(self, args, state, control, **kwargs): if state.global_step % args.logging_steps == 0: dict_ = {} model = kwargs["model"] for name in self.params: param = model.get_parameter(name) param = param.flatten() name_p = self._process_name(name) dict_tmp = { f"{name_p}_mean": param.mean().item(), f"{name_p}_max": param.max().item(), f"{name_p}_q75": param.quantile(0.75).item(), f"{name_p}_q25": param.quantile(0.25).item(), f"{name_p}_min": param.min().item(), f"{name_p}_median": param.median().item(), f"{name_p}_std": param.std().item(), } dict_.update(dict_tmp) self.on_log(args, state, control, logs=dict_, **kwargs) def get_params_for_logging(model): ls_params = [] for name, param in model.named_parameters(): if param.requires_grad: ls_params.append(name) return ls_params ls_params = get_params_for_logging(model) tb_cb = ParamsTensorBoardCallback( None, ls_params, process_name=lambda x: x[30:] ) trainer = Trainer( model=model, train_dataset=train_data, eval_dataset=val_data, args=args, data_collator=data_collator, callbacks=[tb_cb] ) ```
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @zhaoqf123
Thanks for bringing this up!
Sadly I couldn't reproduce the issue, here is the snippet I used:
import torch
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training
model_id = "huggyllama/llama-7b"
config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_8bit=True)
# this should activate gradient checkpointing
model = prepare_model_for_int8_training(model)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
model = get_peft_model(model, config)
assert model.training and model.is_gradient_checkpointing
dummy_input = torch.LongTensor([[0, 1, 0, 1]]).to(0)
logits = model(dummy_input).logits
loss = logits.sum()
loss.backward()
optimizer.step()
for n, param in model.named_parameters():
if "lora" in n:
assert param.grad is not None
And as you can see the gradients are always non-None
. Per my understanding as long as the weight have an associated gradient its value will be updated.
@younesbelkada Thank you for your reply. I modify your script based on my training setup with V100 GPU as follows, and it can be reproduced. import torch
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training
# 1. load pretrained model
# model_id = "huggyllama/llama-7b"
model_id = "decapoda-research/llama-7b-hf"
cache_dir = "/mnt/workspace/kgg/hf_models"
model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=cache_dir, device_map="auto", load_in_8bit=True)
# this should activate gradient checkpointing
model = prepare_model_for_int8_training(model)
# 2. config peft model
config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
# target_modules=["layers.31.self_attn.q_proj"]
)
model = get_peft_model(model, config)
assert model.training and model.is_gradient_checkpointing
# 3. set up optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# 4. train
with torch.autocast("cuda"):
dummy_input = torch.LongTensor([[0, 1, 0, 1]]).to(0)
model.train()
logits = model(dummy_input).logits
loss = logits.sum()
loss.backward()
optimizer.step()
for n, param in model.named_parameters():
if "lora" in n:
print(n)
assert param.grad is not None You can see that the params of the last-layer (layer31) has None grad. The main differences of the codes from yours is 3 parts:
By the way, my torch version is 2.1.0a0+fe05266 |
Indeed I also managed to reproduce, this time with the latest stable version of torch, also note that this bug also occurs with any other model, for instance OPT. import torch
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training
# 1. load pretrained model
# model_id = "huggyllama/llama-7b"
model_id = "facebook/opt-350m"
# model_id = "decapoda-research/llama-7b-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_8bit=True)
# this should activate gradient checkpointing
model = prepare_model_for_int8_training(model)
# 2. config peft model
config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
# target_modules=["layers.31.self_attn.q_proj"]
)
model = get_peft_model(model, config)
assert model.training and model.is_gradient_checkpointing
# 3. set up optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# 4. train
with torch.autocast("cuda"):
dummy_input = torch.LongTensor([[0, 1, 0, 1]]).to(0)
model.train()
logits = model(dummy_input).logits
loss = logits.sum()
loss.backward()
optimizer.step()
for n, param in model.named_parameters():
if "lora" in n:
print(n)
assert param.grad is not None However, it seems that the bug disappears when the import torch
from transformers import AutoModelForCausalLM
model_id = "facebook/opt-350m"
model = AutoModelForCausalLM.from_pretrained(model_id).to(0)
model.gradient_checkpointing_enable()
model.train()
assert model.training and model.is_gradient_checkpointing
# 3. set up optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# 4. train
with torch.cuda.amp.autocast(True, dtype=torch.float16):
dummy_input = torch.LongTensor([[0, 1, 0, 1]]).to(0)
model.train()
logits = model(dummy_input).logits
loss = logits.sum()
loss.backward()
optimizer.step()
for n, param in model.named_parameters():
if param.grad is None:
print(n) And this gives: model.decoder.layers.23.self_attn.k_proj.weight
model.decoder.layers.23.self_attn.k_proj.bias
model.decoder.layers.23.self_attn.v_proj.weight
model.decoder.layers.23.self_attn.v_proj.bias
model.decoder.layers.23.self_attn.q_proj.weight
model.decoder.layers.23.self_attn.q_proj.bias
model.decoder.layers.23.self_attn.out_proj.weight
model.decoder.layers.23.self_attn.out_proj.bias
model.decoder.layers.23.fc1.weight
model.decoder.layers.23.fc1.bias
model.decoder.layers.23.fc2.weight
model.decoder.layers.23.fc2.bias Meaning the entire last layer doesn't get updated. From what I can see in the trainer, currently we support mixed precision autocast ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure we should merge this as it is a workaround for a bug in PyTorch and could have unintended consequences outside of the bug it fixes. We can definitely leave this open to show anyone how to work around the exact issue, but I feel this should be fixed in PyTorch.
In line with what @sgugger said, also not sure it even makes a lot of sense to checkpoint something as small as the layer norm grads? Thanks for flagging the issue and proposing a fix! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @zhaoqf123
Thanks a lot for helping us finding this important issue. After some digging and internal discussion, we have found a broader fix that includes most models that supports gradient checkpointing: #24247 . To credit you from your help, I have added you as a co-author in that PR and we will close this PR once #24247 will get merged
Thanks a lot !
@younesbelkada Thank you for your acknowledgement. Although I have several years of experiences in machine learning (with tf), I just start using Thank you very much for the |
Closing the PR as #24247 being merged |
hi @zhaoqf123 For reference, check: #24420 (comment) |
Hi @zhaoqf123 |
@younesbelkada Sorry for the late reply. Just got vocation last 3 days. Yes, I also noticed that the memory consumption increased a lot when making the last layer updatable. For llama 7B, when using V100-32GB, the VRAM increases from 34% to 42%, which is not proportional to the increase of updatable params. |
@younesbelkada May I know how should I try out? For example, re-install transformer: |
@zhaoqf123 thanks for the reply! |
@younesbelkada After re-install transformers from the source, in my V100, if I remove In my 3090 GPU, it works by removing |
Without this, when tuning with LoRA + gradient checkpointing, the last transformer layer's LoRA weights won't be updated!
For example, if we use this callback to log the weight change of LoRA weights in each layer, we will find that no weight update for the last layer in TensorBoard.
What does this PR do?
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.