Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove graph breaks for torch.compile() in padding free branch in DataCollatorForCompletionOnlyLM #2158

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4472501
feat: Add info to batch in DataCollatorForCompletionOnlyLM
Abhishek-TAMU Oct 2, 2024
6cfa171
fix: formatting
Abhishek-TAMU Oct 2, 2024
a821ce0
feat: Add info to batch in DataCollatorForCompletionOnlyLM
Abhishek-TAMU Oct 2, 2024
fb669b6
fix: formatting
Abhishek-TAMU Oct 2, 2024
f4b1955
Merge branch 'huggingface:main' into collator_batch
Abhishek-TAMU Oct 14, 2024
1b7c060
Merge branch 'collator_batch' of github.com:Abhishek-TAMU/trl into co…
Abhishek-TAMU Oct 21, 2024
c3578f8
Merge branch 'main' into collator_batch
Abhishek-TAMU Oct 21, 2024
e83fc8a
fix: max_length_k to int
Abhishek-TAMU Oct 21, 2024
68554b1
fix:Added comments
Abhishek-TAMU Oct 21, 2024
2a7dd47
Merge remote-tracking branch 'trl/main' into collator_batch
Abhishek-TAMU Oct 30, 2024
b0a52e2
test cases
Abhishek-TAMU Oct 30, 2024
054a6ef
test cases
Abhishek-TAMU Oct 30, 2024
376ad21
test cases
Abhishek-TAMU Oct 30, 2024
9a08ea3
Merge remote-tracking branch 'trl/main' into collator_batch
Abhishek-TAMU Nov 12, 2024
a97045b
feat: Add info to batch in DataCollatorForCompletionOnlyLM
Abhishek-TAMU Oct 2, 2024
f31a780
fix: formatting
Abhishek-TAMU Oct 2, 2024
29ba8a3
feat: Add info to batch in DataCollatorForCompletionOnlyLM
Abhishek-TAMU Oct 2, 2024
d1441e1
test cases
Abhishek-TAMU Oct 30, 2024
d55a6e2
test cases
Abhishek-TAMU Oct 30, 2024
7dccc2d
test cases
Abhishek-TAMU Oct 30, 2024
5e5224e
unit test changes
Abhishek-TAMU Nov 12, 2024
1b434b0
unit test changes
Abhishek-TAMU Nov 12, 2024
ef1e304
Merge remote-tracking branch 'trl/main' into collator_batch
Abhishek-TAMU Nov 18, 2024
77894b1
style
qgallouedec Nov 19, 2024
911f60c
Merge branch 'main' into collator_batch
qgallouedec Nov 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,50 @@ def test_data_collator_completion_lm_with_multiple_text(self):
result_text = tokenizer.decode(batch["input_ids"][i, last_pad_idx + 1 :])
self.assertEqual(result_text, "I have not been masked correctly.")

def test_data_collator_completion_lm_without_padding(self):
os.environ["CUDA_VISIBLE_DEVICES"]="0"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the issue only occur with cuda device? In other words can we reproduce on cpu?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to usage of flash_attention_2 it would work only on GPU.

model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM"
torch_dtype = getattr(torch, "bfloat16", None)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, attn_implementation="flash_attention_2")
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)

formatted_dataset = lambda example: {
"output": f"### prompt:\n{example['prompt'].strip()}\n\n### completion:\n{example['completion'].strip()}{tokenizer.eos_token}"
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is dataset formatting required here, or can we drop it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataset formatting is required because the SFTTrainer and DataCollatorForCompletionOnlyLM expect the dataset to have a specific format—a single text field that combines both the prompt and the completion in a way the model can understand. This function includes both the prompt and completion, ensuring the data collator can correctly identify where the completion starts using the response_template.


train_dataset = self.standard_prompt_completion_dataset["train"].map(formatted_dataset)

response_template = "### completion:\n"
data_collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer, padding_free=True)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
max_steps=2,
per_device_train_batch_size=2,
gradient_accumulation_steps=1,
save_steps=2,
learning_rate=1e-5,
dataset_text_field="output",
torch_compile=True,
torch_compile_backend="inductor",
torch_compile_mode="default"
)

trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
data_collator=data_collator,
args=training_args,
)

trainer.train()
assert trainer.state.log_history[(-1)]["train_loss"] is not None
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
del os.environ["CUDA_VISIBLE_DEVICES"]

def test_data_collator_chat_completion_lm(self):
instruction_template = "### Human:"
assistant_template = "### Assistant:"
Expand Down
19 changes: 19 additions & 0 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,25 @@ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> D
batch["labels"] = batch["labels"][attn_mask.bool()].unsqueeze(0)
batch["labels"][batch["position_ids"] == 0] = self.ignore_index

# Calculate cumulative sequence lengths for queries and keys to prevent graph breaks during further computations.
flattened_position_ids = batch["position_ids"].flatten()
indices_q = torch.arange(
flattened_position_ids.size(0), device=flattened_position_ids.device, dtype=torch.int32
)
batch["cu_seq_lens_q"] = torch.cat(
(
indices_q[flattened_position_ids == 0],
torch.tensor(
flattened_position_ids.size(), device=flattened_position_ids.device, dtype=torch.int32
),
)
)
batch["cu_seq_lens_k"] = batch["cu_seq_lens_q"]

# Determine maximum sequence lengths to prevent graph breaks during further computations.
batch["max_length_k"] = flattened_position_ids.max().item() + 1
batch["max_length_q"] = batch["max_length_k"]

return batch


Expand Down
Loading