Skip to content

Commit

Permalink
Fix token counting to allow there to be no attention mask (#818)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Dec 22, 2023
1 parent bbf5cc7 commit 836ab95
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 10 deletions.
10 changes: 7 additions & 3 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,10 @@ def get_tokens_per_batch_func(
"""

def get_num_samples_in_batch(batch: Batch) -> int:
if not isinstance(batch, Mapping) or 'attention_mask' not in batch:
if not isinstance(batch, Mapping) or ('attention_mask' not in batch and
'input_ids' not in batch):
raise ValueError(
'get_tokens_per_batch_func() requires a batch with an attention_mask key'
'get_tokens_per_batch_func() requires a batch with an attention_mask key or an input_ids key'
)

if not decoder_only and 'decoder_attention_mask' not in batch:
Expand All @@ -336,7 +337,10 @@ def get_num_samples_in_batch(batch: Batch) -> int:
)

# Count number of non padding tokens in batch
input_ids_tokens = int(torch.sum(batch['attention_mask']).item())
if 'attention_mask' in batch:
input_ids_tokens = int(torch.sum(batch['attention_mask']).item())
else:
input_ids_tokens = batch['input_ids'].numel()

# For encoder decoder models only
decoder_input_ids_tokens = 0
Expand Down
22 changes: 15 additions & 7 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,16 +642,17 @@ def test_token_counting_func(pad_token_id: int, batch_size: int,
assert actual_token_count == expected_token_count


@pytest.mark.parametrize(
'dataloader_type',
['finetuning-hf', 'finetuning-streaming', 'denoising', 'text'])
@pytest.mark.parametrize('dataloader_type,tensor_input',
[('finetuning-hf', False),
('finetuning-streaming', False), ('denoising', False),
('text', True), ('text', False)])
@pytest.mark.parametrize('pad_token_id', [100, None])
@pytest.mark.parametrize('batch_size', [1, 8])
@pytest.mark.parametrize('model_max_length', [1024])
@pytest.mark.parametrize('padding_side', ['left'])
def test_token_counting_func_dataloader_setting(
dataloader_type: str, pad_token_id: Optional[int], batch_size: int,
model_max_length: int, padding_side: str,
dataloader_type: str, tensor_input: bool, pad_token_id: Optional[int],
batch_size: int, model_max_length: int, padding_side: str,
monkeypatch: pytest.MonkeyPatch):
gptt = transformers.AutoTokenizer.from_pretrained('gpt2')
gptt.pad_token_id = pad_token_id if pad_token_id is not None else gptt.eos_token_id
Expand All @@ -661,9 +662,11 @@ def test_token_counting_func_dataloader_setting(
batch_strings = []
expected_token_count = 0
for _ in range(batch_size):
# Get randomly different lengths if we are going to add padding
sample_length = random.randint(
1, model_max_length //
4) if pad_token_id is not None else model_max_length // 4
4) if (pad_token_id is not None and
not tensor_input) else model_max_length // 4
batch_strings.append(' '.join(['hello'] * sample_length))
expected_token_count += sample_length

Expand All @@ -672,13 +675,18 @@ def test_token_counting_func_dataloader_setting(
for b in batch_strings
]

if tensor_input:
batch_tokenized = [
torch.tensor(b['input_ids']) for b in batch_tokenized
]

if dataloader_type == 'denoising':
expected_token_count += 2 * batch_size # for the two eos tokens
expected_token_count += 5 * batch_size # for the corruption prefix tokens

if dataloader_type in {'finetuning-hf', 'finetuning-streaming'}:
for b in batch_tokenized:
b['labels'] = b['input_ids'].copy()
b['labels'] = b['input_ids'].copy() # type: ignore
expected_token_count *= 2
expected_token_count += 1 * batch_size # for the eos token

Expand Down

0 comments on commit 836ab95

Please sign in to comment.