Skip to content

Commit

Permalink
fixing sequence_id =-1 bug, adding tests (#1324)
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML authored Jul 1, 2024
1 parent 6117844 commit f398db1
Show file tree
Hide file tree
Showing 2 changed files with 267 additions and 3 deletions.
12 changes: 9 additions & 3 deletions llmfoundry/callbacks/loss_perp_v_len_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,15 @@ def update(
self.sum_length += valid_labels_mask.sum(dim=0)

if sequence_id is not None:
seq_id_expanded = torch.nn.functional.one_hot(
sequence_id,
).transpose(-1, -2)
seq_id_mask = (sequence_id != -1)
sequence_id = torch.where(seq_id_mask, sequence_id, 0)
seq_id_expanded = torch.nn.functional.one_hot(sequence_id,)
seq_id_expanded = torch.where(
torch.unsqueeze(seq_id_mask, dim=-1),
seq_id_expanded,
0,
)
seq_id_expanded = seq_id_expanded.transpose(-1, -2)
seq_lens = seq_id_expanded.sum(dim=-1)
max_num_seq = seq_lens.shape[1]
seq_tok_ids = torch.arange(seq_len, device=sequence_id.device)[
Expand Down
258 changes: 258 additions & 0 deletions tests/callbacks/test_loss_perp_v_len_callback.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
from typing import Any
from unittest.mock import MagicMock

import pytest
Expand All @@ -14,6 +15,7 @@
from omegaconf import OmegaConf as om

from llmfoundry import registry
from llmfoundry.callbacks.loss_perp_v_len_callback import LossPerpVLen
from llmfoundry.data.text_data import (
StreamingTextDataset,
build_text_dataloader,
Expand Down Expand Up @@ -172,3 +174,259 @@ def test_loss_perp_v_len_callback(
) / torch.sum(current_metric_dict['sum_length'])
assert torch.allclose(loss, mean_loss_seq_id)
assert torch.allclose(loss, mean_loss)


def test_metric():
batch_size = 2
seq_len = 100
labels = torch.tensor([[
1,
] * seq_len] * batch_size)
logits = torch.tensor([[
1,
] * seq_len] * batch_size)
sequence_id = torch.tensor([[
0,
] * 10 + [
1,
] * 90, [
0,
] * 50 + [
1,
] * 50])
loss = torch.rand([batch_size, seq_len])
perplexity = torch.exp(loss)

def mock_loss_fn(input_logits: Any, input_labels: Any):
del input_logits, input_labels
return loss

loss_v_len_metric = LossPerpVLen(ignore_index=-100)
loss_v_len_metric.update(
labels=labels,
logits=logits,
sequence_id=sequence_id,
loss_fn=mock_loss_fn,
)
metric_dict = loss_v_len_metric.compute()

assert torch.all(metric_dict['sum_length'] == 2 * torch.ones([100]))
assert torch.all(
metric_dict['sum_length_seq_id'] == torch.tensor([
4,
] * 10 + [
3,
] * 40 + [
1,
] * 40 + [
0,
] * 10),
)
assert torch.all(metric_dict['mean_loss_v_len'] == torch.mean(loss, dim=0))
assert torch.all(
metric_dict['mean_perplexity_v_len'] == torch.mean(perplexity, dim=0),
)

expected_mean_loss_seq_id_v_len_0 = (
loss[0][:10] + loss[0][10:20] + loss[1][0:10] + loss[1][50:60]
) / 4
expected_mean_loss_seq_id_v_len_1 = (
loss[0][20:60] + loss[1][10:50] + loss[1][60:100]
) / 3
expected_mean_loss_seq_id_v_len_2 = loss[0][60:100]
expected_mean_loss_seq_id_v_len_3 = -1

assert torch.all(
metric_dict['mean_loss_seq_id_v_len'][0:10] ==
expected_mean_loss_seq_id_v_len_0,
)
assert torch.all(
metric_dict['mean_loss_seq_id_v_len'][10:50] ==
expected_mean_loss_seq_id_v_len_1,
)
assert torch.all(
metric_dict['mean_loss_seq_id_v_len'][50:90] ==
expected_mean_loss_seq_id_v_len_2,
)
assert torch.all(
metric_dict['mean_loss_seq_id_v_len'][90:100] ==
expected_mean_loss_seq_id_v_len_3,
)

expected_mean_perplexity_seq_id_v_len_0 = (
perplexity[0][:10] + perplexity[0][10:20] + perplexity[1][0:10] +
perplexity[1][50:60]
) / 4
expected_mean_perplexity_seq_id_v_len_1 = (
perplexity[0][20:60] + perplexity[1][10:50] + perplexity[1][60:100]
) / 3
expected_mean_perplexity_seq_id_v_len_2 = perplexity[0][60:100]
expected_mean_perplexity_seq_id_v_len_3 = -1

assert torch.all(
metric_dict['mean_perplexity_seq_id_v_len'][0:10] ==
expected_mean_perplexity_seq_id_v_len_0,
)
assert torch.all(
metric_dict['mean_perplexity_seq_id_v_len'][10:50] ==
expected_mean_perplexity_seq_id_v_len_1,
)
assert torch.all(
metric_dict['mean_perplexity_seq_id_v_len'][50:90] ==
expected_mean_perplexity_seq_id_v_len_2,
)
assert torch.all(
metric_dict['mean_perplexity_seq_id_v_len'][90:100] ==
expected_mean_perplexity_seq_id_v_len_3,
)


def test_valid_labels():
batch_size = 1
seq_len = 100
ignore_labels_len = 10
labels = torch.tensor([[
1,
] * (seq_len - ignore_labels_len) + [
-100,
] * ignore_labels_len] * batch_size)
logits = torch.tensor([[
1,
] * seq_len] * batch_size)
sequence_id = torch.tensor([[
0,
] * seq_len])
loss = torch.rand([batch_size, seq_len])

def mock_loss_fn(input_logits: Any, input_labels: Any):
del input_logits, input_labels
return loss

loss_v_len_metric = LossPerpVLen(ignore_index=-100)
loss_v_len_metric.update(
labels=labels,
logits=logits,
sequence_id=sequence_id,
loss_fn=mock_loss_fn,
)
metric_dict = loss_v_len_metric.compute()
assert torch.all(metric_dict['sum_length'][-ignore_labels_len:] == 0)
assert torch.all(metric_dict['sum_length_seq_id'][-ignore_labels_len:] == 0)
assert torch.all(metric_dict['mean_loss_v_len'][-ignore_labels_len:] == -1)
assert torch.all(
metric_dict['mean_perplexity_v_len'][-ignore_labels_len:] == -1,
)
assert torch.all(
metric_dict['mean_loss_seq_id_v_len'][-ignore_labels_len:] == -1,
)
assert torch.all(
metric_dict['mean_perplexity_seq_id_v_len'][-ignore_labels_len:] == -1,
)


def test_padding():
batch_size = 2
seq_len = 100

labels_no_pad = torch.tensor([[
1,
] * seq_len] * batch_size)
logits_no_pad = torch.tensor([[
1,
] * seq_len] * batch_size)
sequence_id_no_pad = torch.tensor([[
0,
] * 10 + [
1,
] * 90, [
0,
] * 50 + [
1,
] * 50])
loss_no_pad = torch.rand([batch_size, seq_len])

def mock_loss_fn_no_pad(input_logits: Any, input_labels: Any):
del input_logits, input_labels
return loss_no_pad

loss_v_len_metric_no_pad = LossPerpVLen(ignore_index=-100)
loss_v_len_metric_no_pad.update(
labels=labels_no_pad,
logits=logits_no_pad,
sequence_id=sequence_id_no_pad,
loss_fn=mock_loss_fn_no_pad,
)
metric_dict_no_pad = loss_v_len_metric_no_pad.compute()

pad_len = 10
labels_pad = torch.tensor([[
1,
] * seq_len + [
-100,
] * pad_len] * batch_size)
logits_pad = torch.tensor([[
1,
] * (seq_len + pad_len)] * batch_size)
sequence_id_pad = torch.tensor([[
0,
] * 10 + [
1,
] * 90 + [
-1,
] * pad_len, [
0,
] * 50 + [
1,
] * 50 + [
-1,
] * pad_len])
loss_pad = torch.cat([loss_no_pad,
torch.rand([batch_size, pad_len])],
dim=-1)

def mock_loss_fn_pad(input_logits: Any, input_labels: Any):
del input_logits, input_labels
return loss_pad

loss_v_len_metric_pad = LossPerpVLen(ignore_index=-100)
loss_v_len_metric_pad.update(
labels=labels_pad,
logits=logits_pad,
sequence_id=sequence_id_pad,
loss_fn=mock_loss_fn_pad,
)
metric_dict_pad = loss_v_len_metric_pad.compute()

assert torch.all(metric_dict_pad['sum_length'][-pad_len:] == 0)
assert torch.all(metric_dict_pad['sum_length_seq_id'][-pad_len:] == 0)
assert torch.all(metric_dict_pad['mean_loss_v_len'][-pad_len:] == -1)
assert torch.all(metric_dict_pad['mean_perplexity_v_len'][-pad_len:] == -1)
assert torch.all(metric_dict_pad['mean_loss_seq_id_v_len'][-pad_len:] == -1)
assert torch.all(
metric_dict_pad['mean_perplexity_seq_id_v_len'][-pad_len:] == -1,
)

assert torch.all(
metric_dict_pad['sum_length'][:-pad_len] ==
metric_dict_no_pad['sum_length'],
)
assert torch.all(
metric_dict_pad['sum_length_seq_id'][:-pad_len] ==
metric_dict_no_pad['sum_length_seq_id'],
)
assert torch.all(
metric_dict_pad['mean_loss_v_len'][:-pad_len] ==
metric_dict_no_pad['mean_loss_v_len'],
)
assert torch.all(
metric_dict_pad['mean_perplexity_v_len'][:-pad_len] ==
metric_dict_no_pad['mean_perplexity_v_len'],
)
assert torch.all(
metric_dict_pad['mean_loss_seq_id_v_len'][:-pad_len] ==
metric_dict_no_pad['mean_loss_seq_id_v_len'],
)
assert torch.all(
metric_dict_pad['mean_perplexity_seq_id_v_len'][:-pad_len] ==
metric_dict_no_pad['mean_perplexity_seq_id_v_len'],
)

0 comments on commit f398db1

Please sign in to comment.