-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
8-bit precision error with fine tuning of gemma #1355
Comments
Hi @smreddy05 |
@younesbelkada thanks for your suggestion and i am hitting new issue |
@smreddy05 |
Hey @younesbelkada , i was using flashattention from the moment I have faced 8-bit precision error and I tried reduing batch_size, still I am hitting same issue and the same code works for llama2. Not sure whats wrong with this. Will give it a try with previous versions of trl and accelerate. |
@younesbelkada sorry for not being clear, i was referring to llama2-70B model and as of now I am on accelerate 0.27.2, trl=0.7.10 |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
Hi @smreddy05! Were you able to find a solution to fix the OutOfMemoryError error? |
@VIS-WA, sorry, i haven't spent time on this. But, if we set do_eval=False then we cannot run any evaluation on validation set and due to this it might be tricky to judge how good fine tuned model is |
I am trying to fine tune gemma7-b with 4 A100 80 GB gpus using 4-bit qunatization
model_id = "google/gemma-7b"
BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
print("initiating model download")
model = AutoModelForCausalLM.from_pretrained(model_id,
quantization_config=bnb_config,
use_cache=False,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
device_map="auto", token=access_token)
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
target_modules=["q_proj", "v_proj"],
r=64,
bias="none",
task_type="CAUSAL_LM",
)
prepare model for training
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)
from transformers import TrainingArguments
args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=15,
per_device_train_batch_size=8,
gradient_accumulation_steps=2,
gradient_checkpointing=True,
)
from trl import SFTTrainer
max_seq_length = 2048 # max sequence length for model and packing of the dataset
trainer = SFTTrainer(
model=model,
peft_config=peft_config,
max_seq_length=max_seq_length,
tokenizer=tokenizer,
packing=True,
formatting_func=generate_prompt, # this will aplly the create_prompt mapping to all training and test dataset
args=args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"]
)
trainer.train()
This is throwing ""ValueError: You can't train a model that has been loaded in 8-bit precision on a different device than the one you're training on. Make sure you loaded the model on the correct device using for example `device_map={'':torch.cuda.current_device() or device_map={'':torch.xpu.current_device()}""
the same script works for other models like llama2
versions used :
transformers:4.38.1
trl:0.7.11
The text was updated successfully, but these errors were encountered: