From 836ab9537d790f25c95d42c8da4c722bd967dc96 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Thu, 21 Dec 2023 17:04:14 -0800 Subject: [PATCH] Fix token counting to allow there to be no attention mask (#818) --- llmfoundry/data/text_data.py | 10 +++++++--- tests/data/test_dataloader.py | 22 +++++++++++++++------- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index 083cd48069..1c0894a451 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -325,9 +325,10 @@ def get_tokens_per_batch_func( """ def get_num_samples_in_batch(batch: Batch) -> int: - if not isinstance(batch, Mapping) or 'attention_mask' not in batch: + if not isinstance(batch, Mapping) or ('attention_mask' not in batch and + 'input_ids' not in batch): raise ValueError( - 'get_tokens_per_batch_func() requires a batch with an attention_mask key' + 'get_tokens_per_batch_func() requires a batch with an attention_mask key or an input_ids key' ) if not decoder_only and 'decoder_attention_mask' not in batch: @@ -336,7 +337,10 @@ def get_num_samples_in_batch(batch: Batch) -> int: ) # Count number of non padding tokens in batch - input_ids_tokens = int(torch.sum(batch['attention_mask']).item()) + if 'attention_mask' in batch: + input_ids_tokens = int(torch.sum(batch['attention_mask']).item()) + else: + input_ids_tokens = batch['input_ids'].numel() # For encoder decoder models only decoder_input_ids_tokens = 0 diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 2cf4c51a72..bf818347a0 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -642,16 +642,17 @@ def test_token_counting_func(pad_token_id: int, batch_size: int, assert actual_token_count == expected_token_count -@pytest.mark.parametrize( - 'dataloader_type', - ['finetuning-hf', 'finetuning-streaming', 'denoising', 'text']) +@pytest.mark.parametrize('dataloader_type,tensor_input', + [('finetuning-hf', False), + ('finetuning-streaming', False), ('denoising', False), + ('text', True), ('text', False)]) @pytest.mark.parametrize('pad_token_id', [100, None]) @pytest.mark.parametrize('batch_size', [1, 8]) @pytest.mark.parametrize('model_max_length', [1024]) @pytest.mark.parametrize('padding_side', ['left']) def test_token_counting_func_dataloader_setting( - dataloader_type: str, pad_token_id: Optional[int], batch_size: int, - model_max_length: int, padding_side: str, + dataloader_type: str, tensor_input: bool, pad_token_id: Optional[int], + batch_size: int, model_max_length: int, padding_side: str, monkeypatch: pytest.MonkeyPatch): gptt = transformers.AutoTokenizer.from_pretrained('gpt2') gptt.pad_token_id = pad_token_id if pad_token_id is not None else gptt.eos_token_id @@ -661,9 +662,11 @@ def test_token_counting_func_dataloader_setting( batch_strings = [] expected_token_count = 0 for _ in range(batch_size): + # Get randomly different lengths if we are going to add padding sample_length = random.randint( 1, model_max_length // - 4) if pad_token_id is not None else model_max_length // 4 + 4) if (pad_token_id is not None and + not tensor_input) else model_max_length // 4 batch_strings.append(' '.join(['hello'] * sample_length)) expected_token_count += sample_length @@ -672,13 +675,18 @@ def test_token_counting_func_dataloader_setting( for b in batch_strings ] + if tensor_input: + batch_tokenized = [ + torch.tensor(b['input_ids']) for b in batch_tokenized + ] + if dataloader_type == 'denoising': expected_token_count += 2 * batch_size # for the two eos tokens expected_token_count += 5 * batch_size # for the corruption prefix tokens if dataloader_type in {'finetuning-hf', 'finetuning-streaming'}: for b in batch_tokenized: - b['labels'] = b['input_ids'].copy() + b['labels'] = b['input_ids'].copy() # type: ignore expected_token_count *= 2 expected_token_count += 1 * batch_size # for the eos token