From f398db10af1c6591779cd80cb93db2a55b6a3bc3 Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Mon, 1 Jul 2024 11:37:38 -0700 Subject: [PATCH] fixing sequence_id =-1 bug, adding tests (#1324) --- .../callbacks/loss_perp_v_len_callback.py | 12 +- .../test_loss_perp_v_len_callback.py | 258 ++++++++++++++++++ 2 files changed, 267 insertions(+), 3 deletions(-) diff --git a/llmfoundry/callbacks/loss_perp_v_len_callback.py b/llmfoundry/callbacks/loss_perp_v_len_callback.py index aa9519c255..ebb9583224 100644 --- a/llmfoundry/callbacks/loss_perp_v_len_callback.py +++ b/llmfoundry/callbacks/loss_perp_v_len_callback.py @@ -262,9 +262,15 @@ def update( self.sum_length += valid_labels_mask.sum(dim=0) if sequence_id is not None: - seq_id_expanded = torch.nn.functional.one_hot( - sequence_id, - ).transpose(-1, -2) + seq_id_mask = (sequence_id != -1) + sequence_id = torch.where(seq_id_mask, sequence_id, 0) + seq_id_expanded = torch.nn.functional.one_hot(sequence_id,) + seq_id_expanded = torch.where( + torch.unsqueeze(seq_id_mask, dim=-1), + seq_id_expanded, + 0, + ) + seq_id_expanded = seq_id_expanded.transpose(-1, -2) seq_lens = seq_id_expanded.sum(dim=-1) max_num_seq = seq_lens.shape[1] seq_tok_ids = torch.arange(seq_len, device=sequence_id.device)[ diff --git a/tests/callbacks/test_loss_perp_v_len_callback.py b/tests/callbacks/test_loss_perp_v_len_callback.py index 46bde1c2f1..4c487560d2 100644 --- a/tests/callbacks/test_loss_perp_v_len_callback.py +++ b/tests/callbacks/test_loss_perp_v_len_callback.py @@ -1,5 +1,6 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +from typing import Any from unittest.mock import MagicMock import pytest @@ -14,6 +15,7 @@ from omegaconf import OmegaConf as om from llmfoundry import registry +from llmfoundry.callbacks.loss_perp_v_len_callback import LossPerpVLen from llmfoundry.data.text_data import ( StreamingTextDataset, build_text_dataloader, @@ -172,3 +174,259 @@ def test_loss_perp_v_len_callback( ) / torch.sum(current_metric_dict['sum_length']) assert torch.allclose(loss, mean_loss_seq_id) assert torch.allclose(loss, mean_loss) + + +def test_metric(): + batch_size = 2 + seq_len = 100 + labels = torch.tensor([[ + 1, + ] * seq_len] * batch_size) + logits = torch.tensor([[ + 1, + ] * seq_len] * batch_size) + sequence_id = torch.tensor([[ + 0, + ] * 10 + [ + 1, + ] * 90, [ + 0, + ] * 50 + [ + 1, + ] * 50]) + loss = torch.rand([batch_size, seq_len]) + perplexity = torch.exp(loss) + + def mock_loss_fn(input_logits: Any, input_labels: Any): + del input_logits, input_labels + return loss + + loss_v_len_metric = LossPerpVLen(ignore_index=-100) + loss_v_len_metric.update( + labels=labels, + logits=logits, + sequence_id=sequence_id, + loss_fn=mock_loss_fn, + ) + metric_dict = loss_v_len_metric.compute() + + assert torch.all(metric_dict['sum_length'] == 2 * torch.ones([100])) + assert torch.all( + metric_dict['sum_length_seq_id'] == torch.tensor([ + 4, + ] * 10 + [ + 3, + ] * 40 + [ + 1, + ] * 40 + [ + 0, + ] * 10), + ) + assert torch.all(metric_dict['mean_loss_v_len'] == torch.mean(loss, dim=0)) + assert torch.all( + metric_dict['mean_perplexity_v_len'] == torch.mean(perplexity, dim=0), + ) + + expected_mean_loss_seq_id_v_len_0 = ( + loss[0][:10] + loss[0][10:20] + loss[1][0:10] + loss[1][50:60] + ) / 4 + expected_mean_loss_seq_id_v_len_1 = ( + loss[0][20:60] + loss[1][10:50] + loss[1][60:100] + ) / 3 + expected_mean_loss_seq_id_v_len_2 = loss[0][60:100] + expected_mean_loss_seq_id_v_len_3 = -1 + + assert torch.all( + metric_dict['mean_loss_seq_id_v_len'][0:10] == + expected_mean_loss_seq_id_v_len_0, + ) + assert torch.all( + metric_dict['mean_loss_seq_id_v_len'][10:50] == + expected_mean_loss_seq_id_v_len_1, + ) + assert torch.all( + metric_dict['mean_loss_seq_id_v_len'][50:90] == + expected_mean_loss_seq_id_v_len_2, + ) + assert torch.all( + metric_dict['mean_loss_seq_id_v_len'][90:100] == + expected_mean_loss_seq_id_v_len_3, + ) + + expected_mean_perplexity_seq_id_v_len_0 = ( + perplexity[0][:10] + perplexity[0][10:20] + perplexity[1][0:10] + + perplexity[1][50:60] + ) / 4 + expected_mean_perplexity_seq_id_v_len_1 = ( + perplexity[0][20:60] + perplexity[1][10:50] + perplexity[1][60:100] + ) / 3 + expected_mean_perplexity_seq_id_v_len_2 = perplexity[0][60:100] + expected_mean_perplexity_seq_id_v_len_3 = -1 + + assert torch.all( + metric_dict['mean_perplexity_seq_id_v_len'][0:10] == + expected_mean_perplexity_seq_id_v_len_0, + ) + assert torch.all( + metric_dict['mean_perplexity_seq_id_v_len'][10:50] == + expected_mean_perplexity_seq_id_v_len_1, + ) + assert torch.all( + metric_dict['mean_perplexity_seq_id_v_len'][50:90] == + expected_mean_perplexity_seq_id_v_len_2, + ) + assert torch.all( + metric_dict['mean_perplexity_seq_id_v_len'][90:100] == + expected_mean_perplexity_seq_id_v_len_3, + ) + + +def test_valid_labels(): + batch_size = 1 + seq_len = 100 + ignore_labels_len = 10 + labels = torch.tensor([[ + 1, + ] * (seq_len - ignore_labels_len) + [ + -100, + ] * ignore_labels_len] * batch_size) + logits = torch.tensor([[ + 1, + ] * seq_len] * batch_size) + sequence_id = torch.tensor([[ + 0, + ] * seq_len]) + loss = torch.rand([batch_size, seq_len]) + + def mock_loss_fn(input_logits: Any, input_labels: Any): + del input_logits, input_labels + return loss + + loss_v_len_metric = LossPerpVLen(ignore_index=-100) + loss_v_len_metric.update( + labels=labels, + logits=logits, + sequence_id=sequence_id, + loss_fn=mock_loss_fn, + ) + metric_dict = loss_v_len_metric.compute() + assert torch.all(metric_dict['sum_length'][-ignore_labels_len:] == 0) + assert torch.all(metric_dict['sum_length_seq_id'][-ignore_labels_len:] == 0) + assert torch.all(metric_dict['mean_loss_v_len'][-ignore_labels_len:] == -1) + assert torch.all( + metric_dict['mean_perplexity_v_len'][-ignore_labels_len:] == -1, + ) + assert torch.all( + metric_dict['mean_loss_seq_id_v_len'][-ignore_labels_len:] == -1, + ) + assert torch.all( + metric_dict['mean_perplexity_seq_id_v_len'][-ignore_labels_len:] == -1, + ) + + +def test_padding(): + batch_size = 2 + seq_len = 100 + + labels_no_pad = torch.tensor([[ + 1, + ] * seq_len] * batch_size) + logits_no_pad = torch.tensor([[ + 1, + ] * seq_len] * batch_size) + sequence_id_no_pad = torch.tensor([[ + 0, + ] * 10 + [ + 1, + ] * 90, [ + 0, + ] * 50 + [ + 1, + ] * 50]) + loss_no_pad = torch.rand([batch_size, seq_len]) + + def mock_loss_fn_no_pad(input_logits: Any, input_labels: Any): + del input_logits, input_labels + return loss_no_pad + + loss_v_len_metric_no_pad = LossPerpVLen(ignore_index=-100) + loss_v_len_metric_no_pad.update( + labels=labels_no_pad, + logits=logits_no_pad, + sequence_id=sequence_id_no_pad, + loss_fn=mock_loss_fn_no_pad, + ) + metric_dict_no_pad = loss_v_len_metric_no_pad.compute() + + pad_len = 10 + labels_pad = torch.tensor([[ + 1, + ] * seq_len + [ + -100, + ] * pad_len] * batch_size) + logits_pad = torch.tensor([[ + 1, + ] * (seq_len + pad_len)] * batch_size) + sequence_id_pad = torch.tensor([[ + 0, + ] * 10 + [ + 1, + ] * 90 + [ + -1, + ] * pad_len, [ + 0, + ] * 50 + [ + 1, + ] * 50 + [ + -1, + ] * pad_len]) + loss_pad = torch.cat([loss_no_pad, + torch.rand([batch_size, pad_len])], + dim=-1) + + def mock_loss_fn_pad(input_logits: Any, input_labels: Any): + del input_logits, input_labels + return loss_pad + + loss_v_len_metric_pad = LossPerpVLen(ignore_index=-100) + loss_v_len_metric_pad.update( + labels=labels_pad, + logits=logits_pad, + sequence_id=sequence_id_pad, + loss_fn=mock_loss_fn_pad, + ) + metric_dict_pad = loss_v_len_metric_pad.compute() + + assert torch.all(metric_dict_pad['sum_length'][-pad_len:] == 0) + assert torch.all(metric_dict_pad['sum_length_seq_id'][-pad_len:] == 0) + assert torch.all(metric_dict_pad['mean_loss_v_len'][-pad_len:] == -1) + assert torch.all(metric_dict_pad['mean_perplexity_v_len'][-pad_len:] == -1) + assert torch.all(metric_dict_pad['mean_loss_seq_id_v_len'][-pad_len:] == -1) + assert torch.all( + metric_dict_pad['mean_perplexity_seq_id_v_len'][-pad_len:] == -1, + ) + + assert torch.all( + metric_dict_pad['sum_length'][:-pad_len] == + metric_dict_no_pad['sum_length'], + ) + assert torch.all( + metric_dict_pad['sum_length_seq_id'][:-pad_len] == + metric_dict_no_pad['sum_length_seq_id'], + ) + assert torch.all( + metric_dict_pad['mean_loss_v_len'][:-pad_len] == + metric_dict_no_pad['mean_loss_v_len'], + ) + assert torch.all( + metric_dict_pad['mean_perplexity_v_len'][:-pad_len] == + metric_dict_no_pad['mean_perplexity_v_len'], + ) + assert torch.all( + metric_dict_pad['mean_loss_seq_id_v_len'][:-pad_len] == + metric_dict_no_pad['mean_loss_seq_id_v_len'], + ) + assert torch.all( + metric_dict_pad['mean_perplexity_seq_id_v_len'][:-pad_len] == + metric_dict_no_pad['mean_perplexity_seq_id_v_len'], + )