Skip to content

Commit

Permalink
Add check to validate response length
Browse files Browse the repository at this point in the history
Signed-off-by: Olivier Delalleau <[email protected]>
  • Loading branch information
odelalleau committed Jan 10, 2024
1 parent 520e842 commit ada1fb2
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions nemo_aligner/models/nlp/gpt/megatron_gpt_ppo_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,15 @@ def infer(self, inference_batch):
response_lengths = strategy.get_lengths().to(torch.int64).view((-1, 1))
assert (response_lengths <= self.cfg.encoder_seq_length).all()
response_lengths = broadcast_2d_tensor_within_pp(response_lengths, dtype=torch.int64).flatten()
max_response_length = response_lengths.max().item()

response_tokens = torch.cuda.LongTensor(actor_output["token_ids"])
if max_response_length != response_tokens.size(1): # sanity check to validate response length
raise AssertionError(
f"max response length ({max_response_length}) does not match the size of "
f"`response_tokens` ({response_tokens.size(1)})"
)

# TODO(geshen): get nemo generate to return the unaltered log probs
log_probs = self.get_inference_log_probs(
response_tokens, forward_micro_batch_size=self.forward_micro_batch_size
Expand Down

0 comments on commit ada1fb2

Please sign in to comment.