Skip to content

Commit

Permalink
Fix computation of response length in the general case
Browse files Browse the repository at this point in the history
The previous logic only worked for EOS / <extra_id_1>

Signed-off-by: Olivier Delalleau <[email protected]>
  • Loading branch information
odelalleau committed Jan 10, 2024
1 parent c00bbda commit 520e842
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 72 deletions.
41 changes: 16 additions & 25 deletions nemo_aligner/models/nlp/gpt/megatron_gpt_ppo_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,19 @@
from nemo.collections.nlp.parts.utils_funcs import get_last_rank
from nemo_aligner.models.alignable_interface import AlignableGenerativeInterface
from nemo_aligner.utils.distributed import (
broadcast_2d_tensor,
broadcast_2d_tensor_within_pp,
calculate_distributed_entropy,
from_parallel_logits_to_logprobs,
)
from nemo_aligner.utils.text_generation_utils import TrackLengthGPTModelTextGenerationStrategy
from nemo_aligner.utils.train_utils import (
grad_reductions,
prepare_for_training_step,
set_eval,
set_sync_funcs,
set_train,
)
from nemo_aligner.utils.utils import (
calculate_dialogue_response_lengths,
configure_batch_sizes,
cpu_weight_swap,
masked_mean,
offload_distributed_adam,
)
from nemo_aligner.utils.utils import configure_batch_sizes, cpu_weight_swap, masked_mean, offload_distributed_adam


class MegatronGPTActorModel(MegatronGPTModel, AlignableGenerativeInterface):
Expand Down Expand Up @@ -256,13 +251,9 @@ def get_inference_log_probs(self, response_tokens, forward_micro_batch_size):
)

logprobs = torch.cat(logprobs_list) if len(logprobs_list) > 0 else None
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
# broadcast it from last PP stage to everything else
logprobs = broadcast_2d_tensor(
logprobs,
parallel_state.get_pipeline_model_parallel_last_rank(),
parallel_state.get_pipeline_model_parallel_group(),
)

# Broadcast it from last PP stage to everything else.
logprobs = broadcast_2d_tensor_within_pp(logprobs)

return logprobs

Expand All @@ -280,20 +271,20 @@ def infer(self, inference_batch):
prompt_lengths = inference_batch["length"].cuda(non_blocking=True)
inputs = (prompt_tokens, prompt_lengths)

strategy = TrackLengthGPTModelTextGenerationStrategy(
model=self, context_lengths=prompt_lengths, max_length=self._length_params["max_length"]
)
actor_output = self.generate(
inputs=inputs, length_params=self._length_params, sampling_params=self._sampling_params
inputs=inputs, length_params=self._length_params, sampling_params=self._sampling_params, strategy=strategy
)

response_tokens = torch.cuda.LongTensor(actor_output["token_ids"])
response_lengths = calculate_dialogue_response_lengths(
tokens=response_tokens,
prompt_lengths=prompt_lengths,
tokenizer=self.tokenizer,
end_strings=self._sampling_params["end_strings"],
max_generation_length=self._length_params["max_length"],
max_sequence_length=self.cfg.encoder_seq_length,
)
response_lengths = None
if parallel_state.is_pipeline_last_stage():
response_lengths = strategy.get_lengths().to(torch.int64).view((-1, 1))
assert (response_lengths <= self.cfg.encoder_seq_length).all()
response_lengths = broadcast_2d_tensor_within_pp(response_lengths, dtype=torch.int64).flatten()

response_tokens = torch.cuda.LongTensor(actor_output["token_ids"])
# TODO(geshen): get nemo generate to return the unaltered log probs
log_probs = self.get_inference_log_probs(
response_tokens, forward_micro_batch_size=self.forward_micro_batch_size
Expand Down
9 changes: 2 additions & 7 deletions nemo_aligner/models/nlp/gpt/megatron_gpt_reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from nemo.utils import AppState, logging
from nemo_aligner.models.alignable_interface import Inferrable, SupervisedInterface
from nemo_aligner.models.nlp.gpt.gpt_reward_model import GPTRewardModel
from nemo_aligner.utils.distributed import broadcast_2d_tensor, gather_tensor
from nemo_aligner.utils.distributed import broadcast_2d_tensor, broadcast_2d_tensor_within_pp, gather_tensor
from nemo_aligner.utils.text_generation_utils import tokenize_batch
from nemo_aligner.utils.train_utils import (
finish_validation_step,
Expand Down Expand Up @@ -415,12 +415,7 @@ def infer(
if self.enable_standardization:
rewards = (rewards - self.rew_mean) / self.rew_std

if parallel_state.get_pipeline_model_parallel_world_size() > 1:
rewards = broadcast_2d_tensor(
rewards,
parallel_state.get_pipeline_model_parallel_last_rank(),
parallel_state.get_pipeline_model_parallel_group(),
)
rewards = broadcast_2d_tensor_within_pp(rewards)

rewards_list = gather_tensor(
rewards, dst=parallel_state.get_data_parallel_src_rank(), group=parallel_state.get_data_parallel_group()
Expand Down
16 changes: 14 additions & 2 deletions nemo_aligner/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,25 @@ def broadcast_2d_tensor(tensor, src, group, dtype=torch.float32):
return tensor


def broadcast_2d_tensor_within_mp(tensor):
def broadcast_2d_tensor_within_mp(tensor, dtype=torch.float32):
"""helper function to broadcast within the model parallel group
"""
group = parallel_state.get_model_parallel_group()

if torch.distributed.get_world_size(group) > 1:
return broadcast_2d_tensor(tensor, get_model_parallel_src_rank(), group)
return broadcast_2d_tensor(tensor, get_model_parallel_src_rank(), group, dtype=dtype)

return tensor


def broadcast_2d_tensor_within_pp(tensor, dtype=torch.float32):
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
return broadcast_2d_tensor(
tensor,
parallel_state.get_pipeline_model_parallel_last_rank(),
parallel_state.get_pipeline_model_parallel_group(),
dtype=dtype,
)

return tensor

Expand Down
49 changes: 49 additions & 0 deletions nemo_aligner/utils/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,60 @@

"""Utilities for generating text."""

from typing import Any, List

import torch

from megatron.core import parallel_state
from nemo.collections.nlp.modules.common.text_generation_strategy import GPTModelTextGenerationStrategy
from nemo.utils import logging


class TrackLengthGPTModelTextGenerationStrategy(GPTModelTextGenerationStrategy):
"""
Text generation strategy that tracks the length of the generated text.
TODO This is a temporary workaround until NeMo's `generate()` function returns this information.
"""

def __init__(self, model: Any, context_lengths: torch.Tensor, max_length: int):
super().__init__(model)
self._context_lengths = context_lengths
self._max_length = max_length
self._end_idx = torch.full_like(context_lengths, fill_value=-1)

def end_of_generation_condition(
self, tokens: torch.Tensor, prev: torch.Tensor, eod_id: int, end_strings: List[str]
) -> torch.Tensor:
is_end = super().end_of_generation_condition(tokens=tokens, prev=prev, eod_id=eod_id, end_strings=end_strings)
assert len(is_end) == len(tokens)
if len(tokens) != len(self._context_lengths):
raise RuntimeError(
"Batch size mismatch: the `context_lengths` tensor provided in the constructor has batch size "
f"{len(self._context_lengths)}, while the generated tokens have batch size {len(tokens)}"
)
context_length = tokens.size(1) - 1 # the input tokens come from `tokens[:, : context_length + 1]`
started = self._context_lengths <= context_length
# The generation ends right now when three conditions hold:
# - it has started
# - the end generation is triggered now
# - it did *not* end before
self._end_idx = torch.where(started & is_end & (self._end_idx < 0), context_length, self._end_idx)
return is_end

def get_lengths(self) -> torch.Tensor:
"""
Return the total lengths of the generated sequences, in # of tokens.
The total length of a generated sequence counts both:
* the context tokens (i.e., the input prompt)
* the token(s) that ended generation, if any (e.g. the `EOS` token or the token(s) corresponding to
an element of `sampling_params.end_strings`)
"""
assert parallel_state.is_pipeline_last_stage(), "only the last pp stage can compute lengths"
return torch.where(self._end_idx >= 0, self._end_idx + 1, self._context_lengths + self._max_length)


def pad_batch(batch, pad_id):
"""batch each element of the batch to be the size of the longest sequence
"""
Expand Down
38 changes: 0 additions & 38 deletions nemo_aligner/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,44 +184,6 @@ def calculate_response_lengths(tokens, eos_id):
return (tokens != eos_id).sum(-1)


def calculate_dialogue_response_lengths(
tokens, prompt_lengths, tokenizer, end_strings, max_generation_length, max_sequence_length
):
# for EOS
eos_length = calculate_response_lengths(tokens, tokenizer.eos_id)

if "<extra_id_1>" in end_strings:
# for the extra_id_1
extra_id_1_idx = tokenizer.text_to_ids("<extra_id_1>")[-1]
mask = tokens == extra_id_1_idx

# take the last extra id token index(assumes we are not padding with extra_id_1)
length_with_extra_id_1 = torch.argmax(
mask * torch.arange(tokens.size(-1), device=torch.cuda.current_device()), dim=-1
)

# if it terminated on the extra token id, then it must have been generated by the model, otherwise it couldn't have
length_with_extra_id_1 = torch.where(
length_with_extra_id_1 >= prompt_lengths, length_with_extra_id_1, torch.iinfo(torch.int32).max
)

# either terminated using eos id or extra id 1
lengths = torch.minimum(eos_length, length_with_extra_id_1)
else:
lengths = eos_length

# we also want the model to learn EOS or extra id 1
lengths = lengths + 1
# Ensure we never go over `length_params.max_length`. Note that this means the response may not necessarily
# end with EOS / extra_id_1 (we should not enforce it as PPO training requires the real generated token).
max_lengths = prompt_lengths + max_generation_length
lengths = torch.minimum(lengths, max_lengths)

# Prompts' max size and `max_length` should be such that we never exceed the encoder input size.
assert (lengths <= max_sequence_length).all()
return lengths


def configure_batch_sizes(mbs, gbs, dp=1):
app_state = AppState()
_reconfigure_microbatch_calculator(
Expand Down

0 comments on commit 520e842

Please sign in to comment.