Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deepcopy fails after accelerate==0.23.0 #2248

Closed
4 tasks
prathikr opened this issue Dec 14, 2023 · 13 comments · Fixed by bitsandbytes-foundation/bitsandbytes#1060
Closed
4 tasks

deepcopy fails after accelerate==0.23.0 #2248

prathikr opened this issue Dec 14, 2023 · 13 comments · Fixed by bitsandbytes-foundation/bitsandbytes#1060
Assignees
Labels
wip Work in progress

Comments

@prathikr
Copy link
Contributor

prathikr commented Dec 14, 2023

System Info

- `Accelerate` version: 0.25.0
- Platform: Linux-5.15.0-1050-azure-x86_64-with-glibc2.17
- Python version: 3.8.17
- Numpy version: 1.24.3
- PyTorch version (GPU?): 2.1.1+cu121 (True)
- GPU type: Tesla V100-PCIE-16GB

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

import copy
from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification, BitsAndBytesConfig
import torch

model_name_or_path = "meta-llama/Llama-2-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(
    model_name_or_path,
    max_seq_length=512,
    pad_to_max_length=True,
)
tokenizer.pad_token = tokenizer.eos_token

config = AutoConfig.from_pretrained(
    model_name_or_path,
    num_labels=1,
    finetuning_task="text-classification",
)
config.pad_token_id = config.eos_token_id

nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=False,
    bnb_4bit_quant_type="fp4",
    bnb_4bit_compute_dtype=torch.float16
)

model = AutoModelForSequenceClassification.from_pretrained(
    model_name_or_path,
    config=config,
    load_in_4bit=True,
    quantization_config=nf4_config,
    torch_dtype=torch.float16,
)
model.config.pad_token_id = model.config.eos_token_id

data = tokenizer(["test string #1", "test string #2"])
data["input_ids"] = torch.as_tensor(data["input_ids"])
data["attention_mask"] = torch.as_tensor(data["attention_mask"])

output = model(**data) # succeeds

model_copy = copy.deepcopy(model)

output_copy = model_copy(**data) # fails w/ AssertionError: assert quant_state is not None

Expected behavior

For accelerate versions beyond 0.23.0, model copies via copy.deepcopy() are failing to copy over the quant_state model parameter during QLoRA-enabled training. The copy.deepcopy() operation is used for ONNX export when finetuning meta-llama/Llama-2-7b-hf using ONNX Runtime Training. Provided is a stand-alone script to reproduce the error.

@BenjaminBossan
Copy link
Member

Thanks for reporting. I can confirm that this fails with the current accelerate but works with accelerate==0.23. As accelerate is not directly invoked, it probably has to do with the transformers integration, but I don't know how exactly. If others don't have an idea what change caused it, we could run a git bisect to identify the commit, but it would take quite some time.

(btw I tried removing all the PEFT-related code and got exactly the same error, so the snippet could be simplified for debugging purposes)

@prathikr
Copy link
Contributor Author

Thanks @BenjaminBossan, I simplified the repro as you suggested. Please let me know if there is anything else I can do.

@prathikr
Copy link
Contributor Author

@BenjaminBossan I found this related issue with an accelerate patch to workaround the issue. Does this help in identifying the root cause?

bitsandbytes-foundation/bitsandbytes#825 (comment)

@prathikr
Copy link
Contributor Author

Another related issue: huggingface/transformers#26801

@SunMarc
Copy link
Member

SunMarc commented Dec 19, 2023

Hi @prathikr, thanks for reporting. The breaking change is due to #1971. Before this PR, the hooks were not copied properly, meaning that they were still referencing to the forward of the original model. Hence, it looked like the model was copied properly. To make it work, we need to implement a __deepcopy__ method in Params4bit class since the self.quant_state is not copied properly (We get None). cc @Titus-von-Koeller

import copy
from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification, BitsAndBytesConfig
import torch


model_name_or_path = "meta-llama/Llama-2-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(
    model_name_or_path,
    max_seq_length=512,
    pad_to_max_length=True,
)
tokenizer.pad_token = tokenizer.eos_token

config = AutoConfig.from_pretrained(
    model_name_or_path,
    num_labels=1,
    finetuning_task="text-classification",
)
config.pad_token_id = config.eos_token_id

nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=False,
    bnb_4bit_quant_type="fp4",
    bnb_4bit_compute_dtype=torch.float16
)

model = AutoModelForSequenceClassification.from_pretrained(
    model_name_or_path,
    config=config,
    load_in_4bit=True,
    quantization_config=nf4_config,
    torch_dtype=torch.float16,
)

print(model.model.layers[0].self_attn.q_proj.weight.quant_state)
copied_module = copy.deepcopy(model.model.layers[0].self_attn.q_proj.weight)
print(copied_module.quant_state)

@Titus-von-Koeller
Copy link

Hey!

Thanks for ccing me on this. Sure, I can put implementing a __deepcopy__ method in Params4bit on my list.

How critical is this fix?

@Titus-von-Koeller
Copy link

I had a quick chat with @SunMarc since this is the first issue about this in 2 months, we said that it's not critical (let me know if anyone disagrees).

I'll put it on my list and try to add deepcopy on the bnb side in the next weeks.

@prathikr
Copy link
Contributor Author

@Titus-von-Koeller within the next few weeks sounds good, I will check back in 2-3 weeks.

@muellerzr muellerzr added the wip Work in progress label Dec 20, 2023
@prathikr
Copy link
Contributor Author

prathikr commented Jan 8, 2024

@Titus-von-Koeller any updates on this bug?

@Titus-von-Koeller
Copy link

Hi @prathikr,

I was mostly off the last two weeks, partly due to illness. This was among the things that got delayed and right now I'm catching up on stuff and doing some high impact work around FSDP that takes prio.

Have this on my todo for next week. It's on my list, so I won't miss it.

@prathikr
Copy link
Contributor Author

@Titus-von-Koeller no problem, thank you for the update.

@prathikr
Copy link
Contributor Author

@Titus-von-Koeller a few others on my team have encountered this issue, any updates on a resolution?

@prathikr
Copy link
Contributor Author

prathikr commented Feb 5, 2024

@SunMarc @Titus-von-Koeller any updates on this bug?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
wip Work in progress
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants