Skip to content

Commit

Permalink
Fix ORPO multi gpu (#1433)
Browse files Browse the repository at this point in the history
* don't drop attention_mask for orpo

* handle multi-gpu cases better for orpo

* revert change to not drop the attention_mask from inputs for orpo
  • Loading branch information
winglian authored Mar 22, 2024
1 parent 4e69aa4 commit 34ba634
Showing 1 changed file with 78 additions and 23 deletions.
101 changes: 78 additions & 23 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer
from trl.trainer.utils import pad_to_length

from axolotl.loraplus import create_loraplus_optimizer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
Expand Down Expand Up @@ -472,6 +473,58 @@ def compute_loss(self, model, inputs, return_outputs=False):
return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
return super().compute_loss(model, inputs, return_outputs=return_outputs)

@staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
concatenated_batch = {}

max_length = max(
inputs["input_ids"].shape[1], inputs["rejected_input_ids"].shape[1]
)
# Concatenate positive and negative inputs
concatenated_batch["input_ids"] = pad_to_length(
inputs["input_ids"], max_length, pad_token
)
concatenated_batch["rejected_input_ids"] = pad_to_length(
inputs["rejected_input_ids"], max_length, pad_token
)
concatenated_batch["labels"] = pad_to_length(
inputs["labels"], max_length, label_pad_token
)
concatenated_batch["rejected_labels"] = pad_to_length(
inputs["rejected_labels"], max_length, label_pad_token
)
concatenated_batch["attention_mask"] = pad_to_length(
inputs["attention_mask"], max_length, 0
)
concatenated_batch["rejected_attention_mask"] = pad_to_length(
inputs["rejected_attention_mask"], max_length, 0
)
concatenated_batch["prompt_attention_mask"] = pad_to_length(
inputs["prompt_attention_mask"], max_length, 0
).to(device=device)

input_ids = torch.cat(
[concatenated_batch["input_ids"], concatenated_batch["rejected_input_ids"]],
dim=0,
).to(device=device)
attention_mask = torch.cat(
[
concatenated_batch["attention_mask"],
concatenated_batch["rejected_attention_mask"],
],
dim=0,
).to(device=device)
labels = torch.cat(
[concatenated_batch["labels"], concatenated_batch["rejected_labels"]], dim=0
).to(device=device)

return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"prompt_attention_mask": concatenated_batch["prompt_attention_mask"],
}

def orpo_compute_custom_loss(self, logits, labels):
logits = logits.contiguous()
loss = 0.0
Expand Down Expand Up @@ -512,45 +565,46 @@ def orpo_compute_logps(
dim=2,
index=(mask * chosen_inputs[:, 1:]).unsqueeze(2),
).squeeze(2)
return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to(
dtype=torch.float64
) / mask.sum(dim=1).to(dtype=torch.float64)
return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1)

def orpo_compute_loss(self, model, inputs, return_outputs=False):
outputs_neg = model(
**{
"input_ids": inputs["rejected_input_ids"],
"attention_mask": inputs["rejected_attention_mask"],
"labels": inputs["rejected_labels"],
},
output_hidden_states=True,
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
inputs,
label_pad_token=-100,
pad_token=self.tokenizer.pad_token_id,
device=self.accelerator.device,
)
outputs_pos = model(

# Perform a single forward pass
outputs = model(
**{
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"labels": inputs["labels"],
"input_ids": concat_inputs["input_ids"],
"attention_mask": concat_inputs["attention_mask"],
"labels": concat_inputs["labels"],
},
output_hidden_states=True,
)

# Split the outputs for positive and negative examples
outputs_pos, outputs_neg = outputs.logits.chunk(2)

# Calculate NLL loss
pos_loss = self.orpo_compute_custom_loss(
logits=outputs_pos.logits, labels=inputs["input_ids"]
logits=outputs_pos, labels=concat_inputs["input_ids"].chunk(2)[0]
)

# Calculate Log Probability
pos_prob = self.orpo_compute_logps(
prompt_attention_mask=inputs["prompt_attention_mask"],
chosen_inputs=inputs["input_ids"],
chosen_attention_mask=inputs["attention_mask"],
logits=outputs_pos.logits,
prompt_attention_mask=concat_inputs["prompt_attention_mask"],
chosen_inputs=concat_inputs["input_ids"].chunk(2)[0],
chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[0],
logits=outputs_pos,
)
neg_prob = self.orpo_compute_logps(
prompt_attention_mask=inputs["prompt_attention_mask"],
chosen_inputs=inputs["rejected_input_ids"],
chosen_attention_mask=inputs["rejected_attention_mask"],
logits=outputs_neg.logits,
prompt_attention_mask=concat_inputs["prompt_attention_mask"],
chosen_inputs=concat_inputs["input_ids"].chunk(2)[1],
chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[1],
logits=outputs_neg,
)

# Calculate log odds
Expand Down Expand Up @@ -1247,6 +1301,7 @@ def build(self, total_num_steps):
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
tokenizer=self.tokenizer,
data_collator=self.build_collator(training_args, **data_collator_kwargs),
eval_data_collator=self.build_collator(
training_args, is_eval=True, **data_collator_kwargs
Expand Down

0 comments on commit 34ba634

Please sign in to comment.