diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 155f5d376d..42180f32b3 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -227,6 +227,9 @@ def __init__( self.eval_data_collator = eval_data_collator super().__init__(*_args, **kwargs) self.train_data_collator = self.data_collator + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + if self.args.orpo_alpha: + self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") def create_optimizer(self): if self.args.loraplus_lr_ratio is None: @@ -469,8 +472,112 @@ def compute_loss(self, model, inputs, return_outputs=False): # outputs = model(**inputs) # loss = trainer_weighted_loss(outputs, labels, shift_labels=True) # return (loss, outputs) if return_outputs else loss + if self.args.orpo_alpha: + return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs) return super().compute_loss(model, inputs, return_outputs=return_outputs) + def orpo_compute_custom_loss(self, logits, labels): + logits = logits.contiguous() + loss = 0.0 + + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # Flatten the tokens + loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean( + dim=-1 + ) + + return loss + + def orpo_compute_logps( + self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits + ): + # Get the shape of chosen_attention_mask[:, :-1] + chosen_shape = chosen_attention_mask[:, :-1].shape + + # Calculate the padding size + pad_length = chosen_shape[1] - (prompt_attention_mask.shape[1] - 1) + + # Pad prompt_attention_mask with zeros to match the desired shape + prompt_attention_mask_padded = torch.nn.functional.pad( + prompt_attention_mask[:, 1:], (0, pad_length), mode="constant", value=0 + ) + + # Perform the subtraction operation + mask = chosen_attention_mask[:, :-1] > prompt_attention_mask_padded + + per_token_logps = torch.gather( + logits[:, :-1, :].log_softmax(-1), + dim=2, + index=(mask * chosen_inputs[:, 1:]).unsqueeze(2), + ).squeeze(2) + return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to( + dtype=torch.float64 + ) / mask.sum(dim=1).to(dtype=torch.float64) + + def orpo_compute_loss(self, model, inputs, return_outputs=False): + outputs_neg = model( + **{ + "input_ids": inputs["rejected_input_ids"], + "attention_mask": inputs["rejected_attention_mask"], + "labels": inputs["rejected_labels"], + }, + output_hidden_states=True, + ) + outputs_pos = model( + **{ + "input_ids": inputs["input_ids"], + "attention_mask": inputs["attention_mask"], + "labels": inputs["labels"], + }, + output_hidden_states=True, + ) + + # Calculate NLL loss + pos_loss = self.orpo_compute_custom_loss( + logits=outputs_pos.logits, labels=inputs["input_ids"] + ) + + # Calculate Log Probability + pos_prob = self.orpo_compute_logps( + prompt_attention_mask=inputs["prompt_attention_mask"], + chosen_inputs=inputs["input_ids"], + chosen_attention_mask=inputs["attention_mask"], + logits=outputs_pos.logits, + ) + neg_prob = self.orpo_compute_logps( + prompt_attention_mask=inputs["prompt_attention_mask"], + chosen_inputs=inputs["rejected_input_ids"], + chosen_attention_mask=inputs["rejected_attention_mask"], + logits=outputs_neg.logits, + ) + + # Calculate log odds + log_odds = (pos_prob - neg_prob) - ( + torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob)) + ) + sig_ratio = torch.nn.functional.sigmoid(log_odds) + ratio = torch.log(sig_ratio) + + # Calculate the Final Loss + loss = torch.mean(pos_loss - self.args.orpo_alpha * ratio).to( + dtype=torch.bfloat16 + ) + + metrics = {} + metrics["chosen_geometric_mean"] = torch.mean(pos_prob).cpu().item() + metrics["rejected_geometric_mean"] = torch.mean(neg_prob).cpu().item() + metrics["log_odds_ratio"] = torch.mean(ratio).cpu().item() + metrics["log_odds"] = torch.mean(log_odds).cpu().item() + self.store_metrics(metrics, train_eval="train") + + return (loss, outputs_pos) if return_outputs else loss + @wraps(Trainer.push_to_hub) def push_to_hub(self, *args, **kwargs) -> str: """ @@ -531,6 +638,28 @@ def create_accelerator_and_postprocess(self): return res + def log(self, logs: Dict[str, float]) -> None: + """ + Log `logs` on the various objects watching training, including stored metrics. + + Args: + logs (`Dict[str, float]`): + The values to log. + """ + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super().log(logs) + + def store_metrics( + self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" + ) -> None: + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + class AxolotlMambaTrainer(AxolotlTrainer): """ @@ -630,139 +759,6 @@ def create_scheduler( return self.lr_scheduler -class AxolotlORPOTrainer(AxolotlTrainer): - """Axolotl trainer for ORPO""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._stored_metrics = defaultdict(lambda: defaultdict(list)) - self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") - - def compute_custom_loss(self, logits, labels): - logits = logits.contiguous() - loss = 0.0 - - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(logits.device) - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - - # Flatten the tokens - loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean( - dim=-1 - ) - - return loss - - def compute_logps( - self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits - ): - # Get the shape of chosen_attention_mask[:, :-1] - chosen_shape = chosen_attention_mask[:, :-1].shape - - # Calculate the padding size - pad_length = chosen_shape[1] - (prompt_attention_mask.shape[1] - 1) - - # Pad prompt_attention_mask with zeros to match the desired shape - prompt_attention_mask_padded = torch.nn.functional.pad( - prompt_attention_mask[:, 1:], (0, pad_length), mode="constant", value=0 - ) - - # Perform the subtraction operation - mask = chosen_attention_mask[:, :-1] > prompt_attention_mask_padded - - per_token_logps = torch.gather( - logits[:, :-1, :].log_softmax(-1), - dim=2, - index=(mask * chosen_inputs[:, 1:]).unsqueeze(2), - ).squeeze(2) - return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to( - dtype=torch.float64 - ) / mask.sum(dim=1).to(dtype=torch.float64) - - def compute_loss(self, model, inputs, return_outputs=False): - outputs_neg = model( - **{ - "input_ids": inputs["rejected_input_ids"], - "attention_mask": inputs["rejected_attention_mask"], - "labels": inputs["rejected_labels"], - }, - output_hidden_states=True, - ) - outputs_pos = model( - **{ - "input_ids": inputs["input_ids"], - "attention_mask": inputs["attention_mask"], - "labels": inputs["labels"], - }, - output_hidden_states=True, - ) - - # Calculate NLL loss - pos_loss = self.compute_custom_loss( - logits=outputs_pos.logits, labels=inputs["input_ids"] - ) - - # Calculate Log Probability - pos_prob = self.compute_logps( - prompt_attention_mask=inputs["prompt_attention_mask"], - chosen_inputs=inputs["input_ids"], - chosen_attention_mask=inputs["attention_mask"], - logits=outputs_pos.logits, - ) - neg_prob = self.compute_logps( - prompt_attention_mask=inputs["prompt_attention_mask"], - chosen_inputs=inputs["rejected_input_ids"], - chosen_attention_mask=inputs["rejected_attention_mask"], - logits=outputs_neg.logits, - ) - - # Calculate log odds - log_odds = (pos_prob - neg_prob) - ( - torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob)) - ) - sig_ratio = torch.nn.functional.sigmoid(log_odds) - ratio = torch.log(sig_ratio) - - # Calculate the Final Loss - loss = torch.mean(pos_loss - self.args.orpo_alpha * ratio).to( - dtype=torch.bfloat16 - ) - - metrics = {} - metrics["chosen_geometric_mean"] = torch.mean(pos_prob).cpu().item() - metrics["rejected_geometric_mean"] = torch.mean(neg_prob).cpu().item() - metrics["log_odds_ratio"] = torch.mean(ratio).cpu().item() - metrics["log_odds"] = torch.mean(log_odds).cpu().item() - self.store_metrics(metrics, train_eval="train") - - return (loss, outputs_pos) if return_outputs else loss - - def log(self, logs: Dict[str, float]) -> None: - """ - Log `logs` on the various objects watching training, including stored metrics. - - Args: - logs (`Dict[str, float]`): - The values to log. - """ - # logs either has 'loss' or 'eval_loss' - train_eval = "train" if "loss" in logs else "eval" - # Add averaged stored metrics to logs - for key, metrics in self._stored_metrics[train_eval].items(): - logs[key] = torch.tensor(metrics).mean().item() - del self._stored_metrics[train_eval] - return super().log(logs) - - def store_metrics( - self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" - ) -> None: - for key, value in metrics.items(): - self._stored_metrics[train_eval][key].append(value) - - class AxolotlDPOTrainer(DPOTrainer): """ Extend the base DPOTrainer for axolotl helpers @@ -934,8 +930,6 @@ def _get_trainer_cls(self): return ReLoRATrainer if self.cfg.model_config_type == "mamba": return AxolotlMambaTrainer - if self.cfg.rl == "orpo" and self.cfg.orpo_alpha: - return AxolotlORPOTrainer return AxolotlTrainer def build(self, total_num_steps): @@ -1214,7 +1208,8 @@ def build(self, total_num_steps): training_arguments_kwargs["model_type"] = self.cfg.model_config_type training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) - training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha + if self.cfg.rl == "orpo": + training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha if self.cfg.neftune_noise_alpha is not None: training_arguments_kwargs[