diff --git a/docs/source/rloo_trainer.md b/docs/source/rloo_trainer.md index 127f297321..d141667218 100644 --- a/docs/source/rloo_trainer.md +++ b/docs/source/rloo_trainer.md @@ -269,6 +269,17 @@ python -m openrlbenchmark.rlops_multi_metrics \ --scan-history ``` +## Reinforce++ + +The [Reinforce++](https://hijkzzz.notion.site/reinforce-plus-plus) report by Jian Hu suggests several optimization tricks to enhance performance and stability of RLHF. They include: + +- Clipping rewards: limiting reward values within a specific range to mitigate the impact of extreme rewards on model updates, thus preventing gradient explosion +- Normalizing rewards: scaling rewards to have a mean of 0 and a standard deviation of 1, which helps in stabilizing the training process +- Normalizing advantages: scaling advantages to have a mean of 0 and a standard deviation of 1, which helps in stabilizing the training process +- Using token-level KL penalty (default) vs. sequence-level KL penalty + +These options are available via the appropriate arguments in the [`RLOOConfig`] class. + ## RLOOTrainer diff --git a/trl/trainer/rloo_config.py b/trl/trainer/rloo_config.py index df943ae264..ac02b95ec8 100644 --- a/trl/trainer/rloo_config.py +++ b/trl/trainer/rloo_config.py @@ -42,6 +42,14 @@ class RLOOConfig(OnPolicyConfig): Clip range. rloo_k (`int`, *optional*, defaults to `2`): REINFORCE Leave-One-Out (RLOO) number of online samples per prompt. + normalize_reward (`bool`, *optional*, defaults to `False`): + Whether to normalize rewards. + reward_clip_range (`float`, *optional*, defaults to `10.0`): + Clip range for rewards. + normalize_advantage (`bool`, *optional*, defaults to `False`): + Whether to normalize advantages. + token_level_kl (`bool`, *optional*, defaults to `True`): + Whether to use token-level KL penalty or sequence-level KL penalty. """ exp_name: str = field( @@ -72,3 +80,19 @@ class RLOOConfig(OnPolicyConfig): default=2, metadata={"help": "REINFORCE Leave-One-Out (RLOO) number of online samples per prompt."}, ) + normalize_reward: bool = field( + default=False, + metadata={"help": "Whether to normalize rewards"}, + ) + reward_clip_range: float = field( + default=10.0, + metadata={"help": "Clip range for rewards"}, + ) + normalize_advantage: bool = field( + default=False, + metadata={"help": "Whether to normalize advantages"}, + ) + token_level_kl: bool = field( + default=True, + metadata={"help": "Whether to use token-level KL penalty or sequence-level KL penalty"}, + ) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index cae2c7c494..2e16e6c7f8 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -308,6 +308,8 @@ def repeat_generator(): ref_logprobs = [] scores = [] sequence_lengths = [] + + # Generate responses and compute logprobs with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: query_responses, logitss = batch_generation( unwrapped_model, @@ -317,6 +319,7 @@ def repeat_generator(): generation_config, ) + # Process responses in batches for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): query = queries[i : i + args.local_rollout_forward_batch_size] query_response = query_responses[i : i + args.local_rollout_forward_batch_size] @@ -349,12 +352,15 @@ def repeat_generator(): reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length ) + # Store batch results responses.append(response) postprocessed_responses.append(postprocessed_response) logprobs.append(logprob) ref_logprobs.append(ref_logprob) sequence_lengths.append(sequence_length) scores.append(score) + + # Concatenate all batched results responses = torch.cat(responses, 0) postprocessed_responses = torch.cat(postprocessed_responses, 0) logprobs = torch.cat(logprobs, 0) @@ -380,8 +386,23 @@ def repeat_generator(): ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) # 4. compute rewards + # Compute KL divergence kl = logprobs - ref_logprobs - non_score_reward = (-args.kl_coef * kl).sum(1) + + # Normalize rewards + if args.normalize_reward: + scores = (scores - scores.mean()) / (scores.std() + 1e-8) + scores = torch.clamp(scores, -args.reward_clip_range, args.reward_clip_range) + + # Compute total reward with KL penalty + if args.token_level_kl: + # Token-level KL penalty: apply KL penalty per token + token_kl_penalty = -args.kl_coef * kl + non_score_reward = token_kl_penalty.sum(1) + else: + # Sequence-level KL penalty: sum KL across tokens first + sequence_kl = kl.sum(1) + non_score_reward = -args.kl_coef * sequence_kl rlhf_reward = scores + non_score_reward # vectorized RLOO advantages implementation @@ -389,6 +410,11 @@ def repeat_generator(): baseline = (rlhf_reward.sum(0) - rlhf_reward) / (args.rloo_k - 1) advantages = rlhf_reward - baseline advantages = advantages.flatten() + + # Normalize advantages + if args.normalize_advantage: + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + torch.cuda.empty_cache() # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch @@ -403,32 +429,46 @@ def repeat_generator(): with accelerator.accumulate(model): micro_batch_end = micro_batch_start + args.per_device_train_batch_size micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + + # Get batch data mb_advantage = advantages[micro_batch_inds] mb_responses = responses[micro_batch_inds] mb_query_responses = query_responses[micro_batch_inds] mb_logprobs = logprobs[micro_batch_inds] + # Forward pass output = forward(model, mb_query_responses, processing_class.pad_token_id) logits = output.logits[:, context_length - 1 : -1] logits /= args.temperature + 1e-7 + + # Compute new logprobs new_all_logprobs = F.log_softmax(logits, dim=-1) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) new_logprobs = torch.masked_fill( new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB ) + + # Compute probability ratios new_ratio = (new_logprobs - mb_logprobs).exp() new_logprobs = new_logprobs.sum(1) mb_logprobs = mb_logprobs.sum(1) logprobs_diff = new_logprobs - mb_logprobs ratio = torch.exp(logprobs_diff) + + # PPO clipped loss pg_losses = -mb_advantage * ratio pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange) pg_loss_max = torch.max(pg_losses, pg_losses2) pg_loss = pg_loss_max.mean() + + # Final loss loss = pg_loss + + # Optimization step accelerator.backward(loss) optimizer.step() optimizer.zero_grad() + with torch.no_grad(): pg_clipfrac = (pg_losses2 > pg_losses).float().mean() prob_dist = torch.nn.functional.softmax(logits, dim=-1) @@ -443,6 +483,7 @@ def repeat_generator(): ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = new_ratio.mean() gradient_accumulation_idx += 1 minibatch_idx += 1 + # del everything and empty cache # fmt: off del ( @@ -453,6 +494,8 @@ def repeat_generator(): ) # fmt: on torch.cuda.empty_cache() + + # Compute metrics with torch.no_grad(): mean_kl = kl.sum(1).mean() mean_entropy = (-logprobs).sum(1).mean()