Skip to content

Commit

Permalink
Fix FT dataset loading after key name change, make test more detailed
Browse files Browse the repository at this point in the history
  • Loading branch information
boomanaiden154 committed Sep 29, 2023
1 parent 46da826 commit 9cceb70
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 11 deletions.
5 changes: 2 additions & 3 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,15 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:

__all__ = ['dataset_constructor']


def _read_binary_tokenized_sample(
sample: Dict[str, Any]) -> Dict[str, torch.Tensor]:
example = {
'input_ids':
torch.from_numpy(
np.frombuffer(sample['tokens'], dtype=np.int64).copy()),
np.frombuffer(sample['prompt'], dtype=np.int64).copy()),
'labels':
torch.from_numpy(
np.frombuffer(sample['labels'], dtype=np.int64).copy()),
np.frombuffer(sample['response'], dtype=np.int64).copy()),
}
example['attention_mask'] = torch.ones(example['input_ids'].size())
return example
Expand Down
40 changes: 32 additions & 8 deletions tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,13 @@ def build_mock_ft_streaming_dataset(data_path: str, split: str):
def build_mock_tokenized_ft_streaming_dataset(data_path: str, split: str):
columns = {'prompt': 'bytes', 'response': 'bytes'}

dataset = [{
'prompt': numpy.asarray([1, 2, 3, 4]).tobytes(),
'response': numpy.asarray([2, 3, 4, 5]).tobytes()
}, {
'prompt': numpy.asarray([2, 3, 4, 5]).tobytes(),
'response': numpy.asarray([3, 4, 5, 6]).tobytes()
}]
dataset = []

for i in range(0, 64):
dataset.append({
'prompt': numpy.asarray([i, i, i, i]).tobytes(),
'response': numpy.asarray([i + 1, i + 1, i + 1, i + 1]).tobytes()
})

output_path = os.path.join(data_path, split)

Expand Down Expand Up @@ -527,13 +527,37 @@ def test_finetuning_dataloader_streaming_tokenized(tmp_path: pathlib.Path):
tokenizer_kwargs={'model_max_length': 2048},
)

ft_dataloader = build_finetuning_dataloader(cfg, tokenizer, 4)
ft_dataloader = build_finetuning_dataloader(cfg, tokenizer, 32)

expected_keys = ['input_ids', 'attention_mask', 'labels']

batch_idx = 0
for batch in ft_dataloader:
for k in expected_keys:
assert k in batch
t = batch[k]
if batch_idx == 0:
if k == 'input_ids':
for i in range(0, 32):
bi = batch_idx * 32 + i
# Only check the first four elements. The rest will be
# padding functions up to the maximum sequence length
# introduced by the collator
assert torch.equal(t[i][:4],
torch.tensor([bi, bi, bi, bi]))
if k == 'labels':
for i in range(0, 32):
bi = batch_idx * 32 + i + 1
# Look at indicies 4-8 as the collator pads the labels
# and the actual labels end up in these positions.
assert torch.equal(t[i][4:8],
torch.tensor([bi, bi, bi, bi]))
if k == 'attention_mask':
for i in range(0, 32):
# We only have four tokens per batch, so the attention
# mask should have 1s in the first four positions.
assert torch.equal(t[i][:4], torch.ones(4))
batch_idx += 1


@pytest.mark.parametrize('add_bad_data_dropped', [True, False])
Expand Down

0 comments on commit 9cceb70

Please sign in to comment.