-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove graph breaks for torch.compile() in padding free branch in DataCollatorForCompletionOnlyLM #2158
base: main
Are you sure you want to change the base?
Remove graph breaks for torch.compile() in padding free branch in DataCollatorForCompletionOnlyLM #2158
Changes from 23 commits
4472501
6cfa171
a821ce0
fb669b6
f4b1955
1b7c060
c3578f8
e83fc8a
68554b1
2a7dd47
b0a52e2
054a6ef
376ad21
9a08ea3
a97045b
f31a780
29ba8a3
d1441e1
d55a6e2
7dccc2d
5e5224e
1b434b0
ef1e304
77894b1
911f60c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -654,6 +654,50 @@ def test_data_collator_completion_lm_with_multiple_text(self): | |
result_text = tokenizer.decode(batch["input_ids"][i, last_pad_idx + 1 :]) | ||
self.assertEqual(result_text, "I have not been masked correctly.") | ||
|
||
def test_data_collator_completion_lm_without_padding(self): | ||
os.environ["CUDA_VISIBLE_DEVICES"]="0" | ||
model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM" | ||
torch_dtype = getattr(torch, "bfloat16", None) | ||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, attn_implementation="flash_attention_2") | ||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) | ||
|
||
formatted_dataset = lambda example: { | ||
"output": f"### prompt:\n{example['prompt'].strip()}\n\n### completion:\n{example['completion'].strip()}{tokenizer.eos_token}" | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is dataset formatting required here, or can we drop it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dataset formatting is required because the |
||
|
||
train_dataset = self.standard_prompt_completion_dataset["train"].map(formatted_dataset) | ||
|
||
response_template = "### completion:\n" | ||
data_collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer, padding_free=True) | ||
|
||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
training_args = SFTConfig( | ||
output_dir=tmp_dir, | ||
dataloader_drop_last=True, | ||
max_steps=2, | ||
per_device_train_batch_size=2, | ||
gradient_accumulation_steps=1, | ||
save_steps=2, | ||
learning_rate=1e-5, | ||
dataset_text_field="output", | ||
torch_compile=True, | ||
torch_compile_backend="inductor", | ||
torch_compile_mode="default" | ||
) | ||
|
||
trainer = SFTTrainer( | ||
model=model, | ||
tokenizer=tokenizer, | ||
train_dataset=train_dataset, | ||
data_collator=data_collator, | ||
args=training_args, | ||
) | ||
|
||
trainer.train() | ||
assert trainer.state.log_history[(-1)]["train_loss"] is not None | ||
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2") | ||
del os.environ["CUDA_VISIBLE_DEVICES"] | ||
|
||
def test_data_collator_chat_completion_lm(self): | ||
instruction_template = "### Human:" | ||
assistant_template = "### Assistant:" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the issue only occur with cuda device? In other words can we reproduce on cpu?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Due to usage of
flash_attention_2
it would work only on GPU.