Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add loss generating token counts #1610

Merged
merged 12 commits into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions llmfoundry/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_data_spec(

def get_tokens_per_batch_func(
decoder_only: bool = True,
) -> Callable[[Batch], int]:
) -> Callable[[Batch], Union[int, dict[str, int]]]:
"""Returns a callable that counts the number of tokens in a batch.

Args:
Expand All @@ -95,7 +95,7 @@ def get_tokens_per_batch_func(
Callable[[Batch], int]: A callable that counts the number of tokens in a batch.
"""

def get_num_tokens_in_batch(batch: Batch) -> int:
def get_num_tokens_in_batch(batch: Batch) -> Union[int, dict[str, int]]:
if not isinstance(batch, Mapping) or (
'attention_mask' not in batch and 'input_ids' not in batch
):
Expand All @@ -114,13 +114,30 @@ def get_num_tokens_in_batch(batch: Batch) -> int:
else:
input_ids_tokens = batch['input_ids'].numel()

loss_generating_tokens = None
if 'labels' in batch:
loss_generating_tokens = int(
torch.sum(batch['labels'] != -100).item(),
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
)

# Subtract one for each example in the batch that starts with a non -100,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dakinggg I don't think this subtraction isn't necessary. Instead you can just do this:

loss_generating_tokens = int(
                torch.sum(batch['labels'][...,1:] != CROSS_ENTROPY_IGNORE_INDEX).item(),
            )

*I just came across this pr while looking into how mosaic's libs handle the gradient accumulation bug recently discussed on x.com

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yeah, that should work too :)

# because those will be shifted off
loss_generating_tokens -= int(
torch.sum(batch['labels'][:, 0] != -100).item(),
)

# For encoder decoder models only
decoder_input_ids_tokens = 0
if not decoder_only:
decoder_input_ids_tokens = int(
torch.sum(batch['decoder_attention_mask']).item(),
)

if loss_generating_tokens is not None:
return {
'total': input_ids_tokens + decoder_input_ids_tokens,
'loss_generating': loss_generating_tokens,
}
return input_ids_tokens + decoder_input_ids_tokens

return get_num_tokens_in_batch
Expand Down
17 changes: 15 additions & 2 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pathlib
import random
import shutil
from collections import Counter
from contextlib import nullcontext as does_not_raise
from pathlib import Path
from typing import Any, Callable, ContextManager, Literal, Optional, Union
Expand Down Expand Up @@ -1186,12 +1187,15 @@ def test_token_counting_func_dataloader_setting(

batch_strings = []
expected_token_count = 0
expected_loss_generating_token_count = 0
sample_lengths = []
for _ in range(batch_size):
# Get randomly different lengths if we are going to add padding
sample_length = random.randint(1, model_max_length // 4) if (
pad_token_id is not None and not tensor_input
) else model_max_length // 4
batch_strings.append(' '.join(['hello'] * sample_length))
sample_lengths.append(sample_length)
expected_token_count += sample_length

batch_tokenized = [
Expand All @@ -1208,8 +1212,15 @@ def test_token_counting_func_dataloader_setting(
for b in batch_tokenized:
b['labels'] = b['input_ids'].copy() # type: ignore
batch_tokenized = [{'turns': [b]} for b in batch_tokenized]
expected_loss_generating_token_count = expected_token_count
expected_token_count *= 2
expected_token_count += 1 * batch_size # for the eos token
expected_loss_generating_token_count += 1 * batch_size # for the eos token
else:
expected_loss_generating_token_count = expected_token_count

number_of_shifted_off_labels = Counter(sample_lengths)[max(sample_lengths)]
expected_loss_generating_token_count -= 1 * number_of_shifted_off_labels # because the labels will be shifted

common_args = {
'drop_last': False,
Expand Down Expand Up @@ -1311,9 +1322,11 @@ def build_from_hf(
raise NotImplementedError()

batch_collated = dl.dataloader.collate_fn(batch_tokenized) # type: ignore
actual_token_count = dl.get_num_tokens_in_batch(batch_collated)
actual_total_token_count = dl.get_num_tokens_in_batch(batch_collated, token_type='total')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i might be missing something, but how can we pass in token_type here when it's not in the function definition here? https://github.com/mosaicml/llm-foundry/pull/1610/files#diff-9568d89aed75ca69416abe2a592c6bb9732129049a62c34e4e9263c18495a236R99

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function being called here is actually defined on the DataSpec class in Composer (https://github.com/mosaicml/composer/blob/28756dd52e96371689b764cb72c336406460ad35/composer/core/data_spec.py#L301). The DataSpec takes in a function from the user and uses it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Part of the reason for doing it this way was to maintain backwards compatibility with any existing user defined get_num_tokens_in_batch functions out there.

actual_loss_generating_token_count = dl.get_num_tokens_in_batch(batch_collated, token_type='loss_generating')

assert actual_token_count == expected_token_count
assert actual_total_token_count == expected_token_count
assert actual_loss_generating_token_count == expected_loss_generating_token_count


def test_build_unknown_dataloader():
Expand Down
Loading