Skip to content

Commit

Permalink
Fix token counting to use attention mask instead of ids (#802)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Dec 14, 2023
1 parent 5fdcc43 commit 5388dc0
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 29 deletions.
1 change: 0 additions & 1 deletion llmfoundry/data/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,6 @@ def build_text_denoising_dataloader(
)

token_counting_func = get_tokens_per_batch_func(
pad_token_id=tokenizer.pad_token_id,
decoder_only=cfg.mixture_of_denoisers.decoder_only_format)

return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)
Expand Down
3 changes: 1 addition & 2 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
timeout=cfg.get('timeout', 0),
)

token_counting_func = get_tokens_per_batch_func(
pad_token_id=tokenizer.pad_token_id)
token_counting_func = get_tokens_per_batch_func()

return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)

Expand Down
21 changes: 9 additions & 12 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,15 +306,13 @@ def build_text_dataloader(
# and if tokenizing on the fly, we require that the tokenizer has a pad token.
token_counting_func = None
if tokenizer.pad_token_id is not None:
token_counting_func = get_tokens_per_batch_func(
pad_token_id=tokenizer.pad_token_id)
token_counting_func = get_tokens_per_batch_func()

return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)


def get_tokens_per_batch_func(pad_token_id: int,
decoder_only: bool = True
) -> Callable[[Batch], int]:
def get_tokens_per_batch_func(
decoder_only: bool = True) -> Callable[[Batch], int]:
"""Returns a callable that counts the number of tokens in a batch.
Args:
Expand All @@ -327,25 +325,24 @@ def get_tokens_per_batch_func(pad_token_id: int,
"""

def get_num_samples_in_batch(batch: Batch) -> int:
if not isinstance(batch, Mapping) or 'input_ids' not in batch:
if not isinstance(batch, Mapping) or 'attention_mask' not in batch:
raise ValueError(
'get_tokens_per_batch_func() requires a batch with an input_ids key'
'get_tokens_per_batch_func() requires a batch with an attention_mask key'
)

if not decoder_only and 'decoder_input_ids' not in batch:
if not decoder_only and 'decoder_attention_mask' not in batch:
raise ValueError(
'get_tokens_per_batch_func() for encoder decoder requires a batch with a decoder_input_ids key'
'get_tokens_per_batch_func() for encoder decoder requires a batch with a decoder_attention_mask key'
)

# Count number of non padding tokens in batch
input_ids_tokens = int(
torch.sum(batch['input_ids'] != pad_token_id).item())
input_ids_tokens = int(torch.sum(batch['attention_mask']).item())

# For encoder decoder models only
decoder_input_ids_tokens = 0
if not decoder_only:
decoder_input_ids_tokens = int(
torch.sum(batch['decoder_input_ids'] != pad_token_id).item())
torch.sum(batch['decoder_attention_mask']).item())

return input_ids_tokens + decoder_input_ids_tokens

Expand Down
37 changes: 23 additions & 14 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,12 +630,12 @@ def test_token_counting_func(pad_token_id: int, batch_size: int,
decoder_batch_strings.append(' '.join(['hello'] * sample_length))
decoder_expected_token_count += sample_length
expected_token_count += sample_length
batch_tokenized['decoder_input_ids'] = gptt(
batch_tokenized['decoder_attention_mask'] = gptt(
decoder_batch_strings, padding=True,
return_tensors='pt')['input_ids']
return_tensors='pt')['attention_mask']

token_counting_func = get_tokens_per_batch_func(
pad_token_id, decoder_only=not add_decoder_input_ids)
decoder_only=not add_decoder_input_ids)

actual_token_count = token_counting_func(batch_tokenized)

Expand All @@ -654,27 +654,33 @@ def test_token_counting_func_dataloader_setting(
model_max_length: int, padding_side: str,
monkeypatch: pytest.MonkeyPatch):
gptt = transformers.AutoTokenizer.from_pretrained('gpt2')
gptt.pad_token_id = pad_token_id
gptt.pad_token_id = pad_token_id if pad_token_id is not None else gptt.eos_token_id
gptt.model_max_length = model_max_length
gptt.padding_side = padding_side

batch_strings = []
expected_token_count = 0
for _ in range(batch_size):
sample_length = random.randint(
1,
model_max_length) if pad_token_id is not None else model_max_length
1, model_max_length //
4) if pad_token_id is not None else model_max_length // 4
batch_strings.append(' '.join(['hello'] * sample_length))
expected_token_count += sample_length

batch_tokenized = gptt(batch_strings,
padding=True if pad_token_id is not None else False,
return_tensors='pt')
batch_tokenized = [
gptt(b, padding=True if pad_token_id is not None else False)
for b in batch_strings
]

if dataloader_type == 'denoising':
batch_tokenized['decoder_input_ids'] = batch_tokenized[
'input_ids'].clone()
expected_token_count += 2 * batch_size # for the two eos tokens
expected_token_count += 5 * batch_size # for the corruption prefix tokens

if dataloader_type in {'finetuning-hf', 'finetuning-streaming'}:
for b in batch_tokenized:
b['labels'] = b['input_ids'].copy()
expected_token_count *= 2
expected_token_count += 1 * batch_size # for the eos token

common_args = {
'drop_last': False,
Expand Down Expand Up @@ -735,8 +741,10 @@ def test_token_counting_func_dataloader_setting(
},
**common_args
})
ds_mock = MagicMock()
ds_mock.tokenizer = gptt
monkeypatch.setattr('llmfoundry.data.text_data.StreamingTextDataset',
lambda *args, **kwargs: MagicMock())
lambda *args, **kwargs: ds_mock)
dl = build_text_dataloader(cfg, gptt, batch_size)
elif dataloader_type == 'denoising':
cfg = DictConfig({
Expand All @@ -754,7 +762,7 @@ def test_token_counting_func_dataloader_setting(
},
'mixture_of_denoisers': {
'decoder_only_format': False,
'span_mean_lengths_and_ratios': [[3, .15], [8, .5]],
'span_mean_lengths_and_ratios': None,
'sequence_mask_ratios': 0.25,
},
**common_args
Expand All @@ -767,7 +775,8 @@ def test_token_counting_func_dataloader_setting(

cfg = om.create(cfg)

actual_token_count = dl.get_num_tokens_in_batch(batch_tokenized)
batch_collated = dl.dataloader.collate_fn(batch_tokenized) # type: ignore
actual_token_count = dl.get_num_tokens_in_batch(batch_collated)

assert actual_token_count == expected_token_count

Expand Down

0 comments on commit 5388dc0

Please sign in to comment.