From aeb4d0ed088c3b8e61784ba7f8215e370756af45 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 21 Mar 2024 10:07:37 -0700 Subject: [PATCH 1/3] don't drop attention_mask for orpo --- src/axolotl/utils/trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 380264a7ac..8b8c901ff4 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -116,8 +116,9 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True) if ( - cfg.is_mistral_derived_model and cfg.flash_attention - ) or cfg.model_config_type == "mamba": + (cfg.is_mistral_derived_model and cfg.flash_attention) + or cfg.model_config_type == "mamba" + ) and cfg.rl != "orpo": LOG.info("dropping attention_mask column") train_dataset = train_dataset.remove_columns("attention_mask") if eval_dataset: From 8cf8d5f21ba77d7efb93d9dc0cc2b7c387cec70d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 22 Mar 2024 12:14:42 -0700 Subject: [PATCH 2/3] handle multi-gpu cases better for orpo --- src/axolotl/core/trainer_builder.py | 101 +++++++++++++++++++++------- 1 file changed, 78 insertions(+), 23 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index c2d622ceec..4d85b40dee 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -30,6 +30,7 @@ from transformers.trainer_utils import seed_worker from transformers.utils import is_sagemaker_mp_enabled from trl import DPOTrainer +from trl.trainer.utils import pad_to_length from axolotl.loraplus import create_loraplus_optimizer from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES @@ -472,6 +473,58 @@ def compute_loss(self, model, inputs, return_outputs=False): return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs) return super().compute_loss(model, inputs, return_outputs=return_outputs) + @staticmethod + def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None): + concatenated_batch = {} + + max_length = max( + inputs["input_ids"].shape[1], inputs["rejected_input_ids"].shape[1] + ) + # Concatenate positive and negative inputs + concatenated_batch["input_ids"] = pad_to_length( + inputs["input_ids"], max_length, pad_token + ) + concatenated_batch["rejected_input_ids"] = pad_to_length( + inputs["rejected_input_ids"], max_length, pad_token + ) + concatenated_batch["labels"] = pad_to_length( + inputs["labels"], max_length, label_pad_token + ) + concatenated_batch["rejected_labels"] = pad_to_length( + inputs["rejected_labels"], max_length, label_pad_token + ) + concatenated_batch["attention_mask"] = pad_to_length( + inputs["attention_mask"], max_length, 0 + ) + concatenated_batch["rejected_attention_mask"] = pad_to_length( + inputs["rejected_attention_mask"], max_length, 0 + ) + concatenated_batch["prompt_attention_mask"] = pad_to_length( + inputs["prompt_attention_mask"], max_length, 0 + ).to(device=device) + + input_ids = torch.cat( + [concatenated_batch["input_ids"], concatenated_batch["rejected_input_ids"]], + dim=0, + ).to(device=device) + attention_mask = torch.cat( + [ + concatenated_batch["attention_mask"], + concatenated_batch["rejected_attention_mask"], + ], + dim=0, + ).to(device=device) + labels = torch.cat( + [concatenated_batch["labels"], concatenated_batch["rejected_labels"]], dim=0 + ).to(device=device) + + return { + "input_ids": input_ids, + "labels": labels, + "attention_mask": attention_mask, + "prompt_attention_mask": concatenated_batch["prompt_attention_mask"], + } + def orpo_compute_custom_loss(self, logits, labels): logits = logits.contiguous() loss = 0.0 @@ -512,45 +565,46 @@ def orpo_compute_logps( dim=2, index=(mask * chosen_inputs[:, 1:]).unsqueeze(2), ).squeeze(2) - return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to( - dtype=torch.float64 - ) / mask.sum(dim=1).to(dtype=torch.float64) + return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1) def orpo_compute_loss(self, model, inputs, return_outputs=False): - outputs_neg = model( - **{ - "input_ids": inputs["rejected_input_ids"], - "attention_mask": inputs["rejected_attention_mask"], - "labels": inputs["rejected_labels"], - }, - output_hidden_states=True, + concat_inputs = AxolotlTrainer.orpo_concatenate_inputs( + inputs, + label_pad_token=-100, + pad_token=self.tokenizer.pad_token_id, + device=self.accelerator.device, ) - outputs_pos = model( + + # Perform a single forward pass + outputs = model( **{ - "input_ids": inputs["input_ids"], - "attention_mask": inputs["attention_mask"], - "labels": inputs["labels"], + "input_ids": concat_inputs["input_ids"], + "attention_mask": concat_inputs["attention_mask"], + "labels": concat_inputs["labels"], }, output_hidden_states=True, ) + # Split the outputs for positive and negative examples + outputs_pos, outputs_neg = outputs.logits.chunk(2) + # Calculate NLL loss pos_loss = self.orpo_compute_custom_loss( - logits=outputs_pos.logits, labels=inputs["input_ids"] + logits=outputs_pos, labels=concat_inputs["input_ids"].chunk(2)[0] ) # Calculate Log Probability pos_prob = self.orpo_compute_logps( - prompt_attention_mask=inputs["prompt_attention_mask"], - chosen_inputs=inputs["input_ids"], - chosen_attention_mask=inputs["attention_mask"], - logits=outputs_pos.logits, + prompt_attention_mask=concat_inputs["prompt_attention_mask"], + chosen_inputs=concat_inputs["input_ids"].chunk(2)[0], + chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[0], + logits=outputs_pos, ) neg_prob = self.orpo_compute_logps( - prompt_attention_mask=inputs["prompt_attention_mask"], - chosen_inputs=inputs["rejected_input_ids"], - chosen_attention_mask=inputs["rejected_attention_mask"], - logits=outputs_neg.logits, + prompt_attention_mask=concat_inputs["prompt_attention_mask"], + chosen_inputs=concat_inputs["input_ids"].chunk(2)[1], + chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[1], + logits=outputs_neg, ) # Calculate log odds @@ -1247,6 +1301,7 @@ def build(self, total_num_steps): train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, args=training_args, + tokenizer=self.tokenizer, data_collator=self.build_collator(training_args, **data_collator_kwargs), eval_data_collator=self.build_collator( training_args, is_eval=True, **data_collator_kwargs From f391cc3a74a7c1522baab8bcabed226cfc2eff61 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 22 Mar 2024 12:20:21 -0700 Subject: [PATCH 3/3] revert change to not drop the attention_mask from inputs for orpo --- src/axolotl/utils/trainer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 8b8c901ff4..380264a7ac 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -116,9 +116,8 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True) if ( - (cfg.is_mistral_derived_model and cfg.flash_attention) - or cfg.model_config_type == "mamba" - ) and cfg.rl != "orpo": + cfg.is_mistral_derived_model and cfg.flash_attention + ) or cfg.model_config_type == "mamba": LOG.info("dropping attention_mask column") train_dataset = train_dataset.remove_columns("attention_mask") if eval_dataset: