diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 6d51f20db4..baa3052cad 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -48,16 +48,15 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: __all__ = ['dataset_constructor'] - def _read_binary_tokenized_sample( sample: Dict[str, Any]) -> Dict[str, torch.Tensor]: example = { 'input_ids': torch.from_numpy( - np.frombuffer(sample['tokens'], dtype=np.int64).copy()), + np.frombuffer(sample['prompt'], dtype=np.int64).copy()), 'labels': torch.from_numpy( - np.frombuffer(sample['labels'], dtype=np.int64).copy()), + np.frombuffer(sample['response'], dtype=np.int64).copy()), } example['attention_mask'] = torch.ones(example['input_ids'].size()) return example diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 7580f88ecf..85ae62e303 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -66,13 +66,13 @@ def build_mock_ft_streaming_dataset(data_path: str, split: str): def build_mock_tokenized_ft_streaming_dataset(data_path: str, split: str): columns = {'prompt': 'bytes', 'response': 'bytes'} - dataset = [{ - 'prompt': numpy.asarray([1, 2, 3, 4]).tobytes(), - 'response': numpy.asarray([2, 3, 4, 5]).tobytes() - }, { - 'prompt': numpy.asarray([2, 3, 4, 5]).tobytes(), - 'response': numpy.asarray([3, 4, 5, 6]).tobytes() - }] + dataset = [] + + for i in range(0, 64): + dataset.append({ + 'prompt': numpy.asarray([i, i, i, i]).tobytes(), + 'response': numpy.asarray([i + 1, i + 1, i + 1, i + 1]).tobytes() + }) output_path = os.path.join(data_path, split) @@ -527,13 +527,37 @@ def test_finetuning_dataloader_streaming_tokenized(tmp_path: pathlib.Path): tokenizer_kwargs={'model_max_length': 2048}, ) - ft_dataloader = build_finetuning_dataloader(cfg, tokenizer, 4) + ft_dataloader = build_finetuning_dataloader(cfg, tokenizer, 32) expected_keys = ['input_ids', 'attention_mask', 'labels'] + batch_idx = 0 for batch in ft_dataloader: for k in expected_keys: assert k in batch + t = batch[k] + if batch_idx == 0: + if k == 'input_ids': + for i in range(0, 32): + bi = batch_idx * 32 + i + # Only check the first four elements. The rest will be + # padding functions up to the maximum sequence length + # introduced by the collator + assert torch.equal(t[i][:4], + torch.tensor([bi, bi, bi, bi])) + if k == 'labels': + for i in range(0, 32): + bi = batch_idx * 32 + i + 1 + # Look at indicies 4-8 as the collator pads the labels + # and the actual labels end up in these positions. + assert torch.equal(t[i][4:8], + torch.tensor([bi, bi, bi, bi])) + if k == 'attention_mask': + for i in range(0, 32): + # We only have four tokens per batch, so the attention + # mask should have 1s in the first four positions. + assert torch.equal(t[i][:4], torch.ones(4)) + batch_idx += 1 @pytest.mark.parametrize('add_bad_data_dropped', [True, False])