From 5388dc0d46e27f7e46e4cc744faaff54975c06cc Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Wed, 13 Dec 2023 19:17:04 -0800 Subject: [PATCH] Fix token counting to use attention mask instead of ids (#802) --- llmfoundry/data/denoising.py | 1 - llmfoundry/data/finetuning/dataloader.py | 3 +- llmfoundry/data/text_data.py | 21 ++++++-------- tests/data/test_dataloader.py | 37 +++++++++++++++--------- 4 files changed, 33 insertions(+), 29 deletions(-) diff --git a/llmfoundry/data/denoising.py b/llmfoundry/data/denoising.py index 8ccf7f25e9..9c14f21751 100644 --- a/llmfoundry/data/denoising.py +++ b/llmfoundry/data/denoising.py @@ -527,7 +527,6 @@ def build_text_denoising_dataloader( ) token_counting_func = get_tokens_per_batch_func( - pad_token_id=tokenizer.pad_token_id, decoder_only=cfg.mixture_of_denoisers.decoder_only_format) return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index b19cab841f..7a29d1dfed 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -216,8 +216,7 @@ def build_finetuning_dataloader(cfg: DictConfig, timeout=cfg.get('timeout', 0), ) - token_counting_func = get_tokens_per_batch_func( - pad_token_id=tokenizer.pad_token_id) + token_counting_func = get_tokens_per_batch_func() return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func) diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index 51fd6b38dc..083cd48069 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -306,15 +306,13 @@ def build_text_dataloader( # and if tokenizing on the fly, we require that the tokenizer has a pad token. token_counting_func = None if tokenizer.pad_token_id is not None: - token_counting_func = get_tokens_per_batch_func( - pad_token_id=tokenizer.pad_token_id) + token_counting_func = get_tokens_per_batch_func() return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func) -def get_tokens_per_batch_func(pad_token_id: int, - decoder_only: bool = True - ) -> Callable[[Batch], int]: +def get_tokens_per_batch_func( + decoder_only: bool = True) -> Callable[[Batch], int]: """Returns a callable that counts the number of tokens in a batch. Args: @@ -327,25 +325,24 @@ def get_tokens_per_batch_func(pad_token_id: int, """ def get_num_samples_in_batch(batch: Batch) -> int: - if not isinstance(batch, Mapping) or 'input_ids' not in batch: + if not isinstance(batch, Mapping) or 'attention_mask' not in batch: raise ValueError( - 'get_tokens_per_batch_func() requires a batch with an input_ids key' + 'get_tokens_per_batch_func() requires a batch with an attention_mask key' ) - if not decoder_only and 'decoder_input_ids' not in batch: + if not decoder_only and 'decoder_attention_mask' not in batch: raise ValueError( - 'get_tokens_per_batch_func() for encoder decoder requires a batch with a decoder_input_ids key' + 'get_tokens_per_batch_func() for encoder decoder requires a batch with a decoder_attention_mask key' ) # Count number of non padding tokens in batch - input_ids_tokens = int( - torch.sum(batch['input_ids'] != pad_token_id).item()) + input_ids_tokens = int(torch.sum(batch['attention_mask']).item()) # For encoder decoder models only decoder_input_ids_tokens = 0 if not decoder_only: decoder_input_ids_tokens = int( - torch.sum(batch['decoder_input_ids'] != pad_token_id).item()) + torch.sum(batch['decoder_attention_mask']).item()) return input_ids_tokens + decoder_input_ids_tokens diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 728376229b..2cf4c51a72 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -630,12 +630,12 @@ def test_token_counting_func(pad_token_id: int, batch_size: int, decoder_batch_strings.append(' '.join(['hello'] * sample_length)) decoder_expected_token_count += sample_length expected_token_count += sample_length - batch_tokenized['decoder_input_ids'] = gptt( + batch_tokenized['decoder_attention_mask'] = gptt( decoder_batch_strings, padding=True, - return_tensors='pt')['input_ids'] + return_tensors='pt')['attention_mask'] token_counting_func = get_tokens_per_batch_func( - pad_token_id, decoder_only=not add_decoder_input_ids) + decoder_only=not add_decoder_input_ids) actual_token_count = token_counting_func(batch_tokenized) @@ -654,7 +654,7 @@ def test_token_counting_func_dataloader_setting( model_max_length: int, padding_side: str, monkeypatch: pytest.MonkeyPatch): gptt = transformers.AutoTokenizer.from_pretrained('gpt2') - gptt.pad_token_id = pad_token_id + gptt.pad_token_id = pad_token_id if pad_token_id is not None else gptt.eos_token_id gptt.model_max_length = model_max_length gptt.padding_side = padding_side @@ -662,19 +662,25 @@ def test_token_counting_func_dataloader_setting( expected_token_count = 0 for _ in range(batch_size): sample_length = random.randint( - 1, - model_max_length) if pad_token_id is not None else model_max_length + 1, model_max_length // + 4) if pad_token_id is not None else model_max_length // 4 batch_strings.append(' '.join(['hello'] * sample_length)) expected_token_count += sample_length - batch_tokenized = gptt(batch_strings, - padding=True if pad_token_id is not None else False, - return_tensors='pt') + batch_tokenized = [ + gptt(b, padding=True if pad_token_id is not None else False) + for b in batch_strings + ] if dataloader_type == 'denoising': - batch_tokenized['decoder_input_ids'] = batch_tokenized[ - 'input_ids'].clone() + expected_token_count += 2 * batch_size # for the two eos tokens + expected_token_count += 5 * batch_size # for the corruption prefix tokens + + if dataloader_type in {'finetuning-hf', 'finetuning-streaming'}: + for b in batch_tokenized: + b['labels'] = b['input_ids'].copy() expected_token_count *= 2 + expected_token_count += 1 * batch_size # for the eos token common_args = { 'drop_last': False, @@ -735,8 +741,10 @@ def test_token_counting_func_dataloader_setting( }, **common_args }) + ds_mock = MagicMock() + ds_mock.tokenizer = gptt monkeypatch.setattr('llmfoundry.data.text_data.StreamingTextDataset', - lambda *args, **kwargs: MagicMock()) + lambda *args, **kwargs: ds_mock) dl = build_text_dataloader(cfg, gptt, batch_size) elif dataloader_type == 'denoising': cfg = DictConfig({ @@ -754,7 +762,7 @@ def test_token_counting_func_dataloader_setting( }, 'mixture_of_denoisers': { 'decoder_only_format': False, - 'span_mean_lengths_and_ratios': [[3, .15], [8, .5]], + 'span_mean_lengths_and_ratios': None, 'sequence_mask_ratios': 0.25, }, **common_args @@ -767,7 +775,8 @@ def test_token_counting_func_dataloader_setting( cfg = om.create(cfg) - actual_token_count = dl.get_num_tokens_in_batch(batch_tokenized) + batch_collated = dl.dataloader.collate_fn(batch_tokenized) # type: ignore + actual_token_count = dl.get_num_tokens_in_batch(batch_collated) assert actual_token_count == expected_token_count