From 42c2d9a003d697a060eae76c0bf54a0ffbf7722a Mon Sep 17 00:00:00 2001 From: Saaketh Narayan Date: Thu, 6 Jun 2024 11:52:32 -0700 Subject: [PATCH] Adding more token encoding types (#1254) * add more token encoing types * add more token encoing types * add tests * add tests * ft support, tests * linting is shortening my lifespan * linting is shortening my lifespan * long tensor * long tensor * long tensor * feedbacc * import * import --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/data/__init__.py | 9 +- llmfoundry/data/data.py | 50 ++++- llmfoundry/data/finetuning/tasks.py | 36 +-- llmfoundry/data/text_data.py | 47 +++- scripts/data_prep/README.md | 17 ++ scripts/data_prep/convert_dataset_hf.py | 12 +- scripts/data_prep/convert_dataset_json.py | 30 +-- scripts/data_prep/convert_text_to_mds.py | 13 +- .../data_prep/test_convert_text_to_mds.py | 3 +- tests/data/test_data_encodings.py | 205 ++++++++++++++++++ tests/data/test_dataloader.py | 6 +- 11 files changed, 350 insertions(+), 78 deletions(-) create mode 100644 tests/data/test_data_encodings.py diff --git a/llmfoundry/data/__init__.py b/llmfoundry/data/__init__.py index 966ca90c86..5710be0c55 100644 --- a/llmfoundry/data/__init__.py +++ b/llmfoundry/data/__init__.py @@ -1,7 +1,12 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from llmfoundry.data.data import ConcatTokensDataset, NoConcatDataset +from llmfoundry.data.data import ( + SUPPORTED_MDS_ENCODING_TYPES, + ConcatTokensDataset, + NoConcatDataset, + stream_remote_local_validate, +) from llmfoundry.data.dataloader import build_dataloader from llmfoundry.data.finetuning import ( Seq2SeqFinetuningCollator, @@ -55,4 +60,6 @@ 'auto_packing_ratio', 'profile_packing', 'ConcatenatedSequenceCollatorWrapper', + 'stream_remote_local_validate', + 'SUPPORTED_MDS_ENCODING_TYPES', ] diff --git a/llmfoundry/data/data.py b/llmfoundry/data/data.py index 04eb6d345d..bde68a6998 100644 --- a/llmfoundry/data/data.py +++ b/llmfoundry/data/data.py @@ -5,16 +5,31 @@ import os import warnings from abc import ABC, abstractmethod -from typing import Dict, Iterable, Union +from typing import Dict, Iterable, Optional, Union import datasets as hf_datasets import numpy as np +from numpy.typing import NDArray from torch.utils.data import IterableDataset from transformers import PreTrainedTokenizerBase __all__ = [ + 'AbstractConcatTokensDataset', 'ConcatTokensDataset', 'NoConcatDataset', + 'stream_remote_local_validate', + 'SUPPORTED_MDS_ENCODING_TYPES', +] + +SUPPORTED_MDS_ENCODING_TYPES = [ + 'int8', + 'int16', + 'int32', + 'int64', + 'uint8', + 'uint16', + 'uint32', + 'uint64', ] @@ -97,14 +112,14 @@ def __init__( ) @abstractmethod - def __iter__(self) -> Iterable[Dict[str, bytes]]: + def __iter__(self) -> Iterable[Union[Dict[str, bytes], Dict[str, NDArray]]]: pass class ConcatTokensDataset(AbstractConcatTokensDataset): """An IterableDataset that returns token samples for MDSWriter. - Returns dicts of {'tokens': bytes} + Returns dicts of {'tokens': ndarray:int32} To use data created by this class and written to MDS format: @@ -119,7 +134,7 @@ class ConcatTokensDataset(AbstractConcatTokensDataset): # note, you need to copy the numpy array because the original is non-writeable # and torch does not support non-writeable tensors, so you get a scary warning and # if you do try to write to the tensor you get undefined behavior - tokens = torch.from_numpy(np.frombuffer(ds[0]['tokens'], dtype=np.int64).copy()) + tokens = torch.from_numpy(np.frombuffer(ds[0]['tokens'], dtype=np.int32).copy()) print(tokenizer.decode(tokens)) ``` """ @@ -136,7 +151,7 @@ def __init__( self.hf_dataset = hf_dataset super().__init__(tokenizer, max_length, bos_text, eos_text, no_wrap) - def __iter__(self) -> Iterable[Dict[str, bytes]]: + def __iter__(self) -> Iterable[Dict[str, NDArray]]: buffer = [] for sample in self.hf_dataset: encoded = self.tokenizer( @@ -150,6 +165,27 @@ def __iter__(self) -> Iterable[Dict[str, bytes]]: concat_sample = buffer[:self.max_length] buffer = buffer[self.max_length:] if self.should_wrap else [] yield { - # convert to bytes to store in MDS binary format - 'tokens': np.asarray(concat_sample).tobytes(), + # convert to ndarray to store in MDS format + 'tokens': np.asarray(concat_sample, dtype=np.int32), } + + +def stream_remote_local_validate( + remote: Optional[str], + local: Optional[str], + split: Optional[str], +): + """Check that, if needed, the local/split directory exists. + + Args: + remote (Optional[str]): Remote path to the dataset. + local (Optional[str]): Local path to the dataset. + split (Optional[str]): Subdirectory specifying which dataset split to use, if any. + """ + if remote is None or (local == remote): + if local is not None and os.path.isdir(local): + contents = set(os.listdir(local)) + if split is not None and split not in contents: + raise ValueError( + f'Local directory {local} does not contain split {split}', + ) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index b7cce4d20a..40f178fb6e 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -59,6 +59,10 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: from streaming import Stream, StreamingDataset from transformers import PreTrainedTokenizerBase +from llmfoundry.data import ( + SUPPORTED_MDS_ENCODING_TYPES, + stream_remote_local_validate, +) from llmfoundry.data.finetuning.collator import ( _HF_IGNORE_INDEX, stitch_turns_decoder_only, @@ -494,26 +498,15 @@ def is_valid_ift_example( return True -def _stream_remote_local_validate( - remote: Optional[str], - local: Optional[str], - split: Optional[str], -): - if remote is None or (local == remote): - if local is not None and os.path.isdir(local): - contents = set(os.listdir(local)) - if split is not None and split not in contents: - raise ValueError( - f'Local directory {local} does not contain split {split}', - ) - - class StreamingFinetuningDataset(StreamingDataset): """Finetuning dataset with flexible tokenization using StreamingDataset. Args: tokenizer (Tokenizer): The name of the HuggingFace tokenizer to use to tokenize samples. + token_encoding_type (str): The encoding type of the tokenized samples. This is only used + for legacy datasets that have been written directly as 'bytes' instead of numpy + arrays. Types are auto-inferred for numpy arrays. Defaults to 'int64'. streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from, which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. @@ -574,6 +567,7 @@ class StreamingFinetuningDataset(StreamingDataset): def __init__( self, tokenizer: PreTrainedTokenizerBase, + token_encoding_type: str = 'int64', streams: Optional[Sequence[Stream]] = None, local: Optional[str] = None, remote: Optional[str] = None, @@ -606,11 +600,17 @@ def __init__( f'StreamingFinetuningDataset() got an unexpected keyword argument: {kwargs}', ) + if token_encoding_type not in SUPPORTED_MDS_ENCODING_TYPES: + raise ValueError( + f'The token_encoding_type must be one of {SUPPORTED_MDS_ENCODING_TYPES}, but got {token_encoding_type}', + ) + self.token_encoding_type = token_encoding_type + if streams is None: - _stream_remote_local_validate(remote, local, split) + stream_remote_local_validate(remote, local, split) else: for stream in streams: - _stream_remote_local_validate( + stream_remote_local_validate( stream.remote, stream.local, split, @@ -656,11 +656,11 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: if isinstance(sample['input_ids'], bytes): sample['input_ids'] = np.frombuffer( sample['input_ids'], - dtype=np.int64, + dtype=getattr(np, self.token_encoding_type), )[:self.max_seq_len].tolist().copy() sample['labels'] = np.frombuffer( sample['labels'], - dtype=np.int64, + dtype=getattr(np, self.token_encoding_type), )[:self.max_seq_len].tolist().copy() elif isinstance(sample['input_ids'], np.ndarray): sample['input_ids'] = sample['input_ids'][:self.max_seq_len diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index 60b81cd145..86d5edbaf4 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -4,7 +4,6 @@ """Build a StreamingTextDataset dataset and dataloader for training.""" import inspect -import os from itertools import islice from typing import ( Any, @@ -25,6 +24,10 @@ from transformers import PreTrainedTokenizerBase from llmfoundry import registry +from llmfoundry.data import ( + SUPPORTED_MDS_ENCODING_TYPES, + stream_remote_local_validate, +) from llmfoundry.utils.registry_utils import construct_from_registry __all__ = [ @@ -41,6 +44,9 @@ class StreamingTextDataset(StreamingDataset): tokenizer (Tokenizer): HuggingFace tokenizer to tokenize samples. max_seq_len (int): The max sequence length of each sample. + token_encoding_type (str): The encoding type of the tokenized samples. This is only used + for legacy datasets that have been written directly as 'bytes' instead of numpy + arrays. Types are auto-inferred for numpy arrays. Defaults to 'int64'. streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from, which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. @@ -106,6 +112,7 @@ def __init__( self, tokenizer: PreTrainedTokenizerBase, max_seq_len: int, + token_encoding_type: str = 'int64', streams: Optional[Sequence[Stream]] = None, remote: Optional[str] = None, local: Optional[str] = None, @@ -137,13 +144,21 @@ def __init__( f'StreamingTextDataset() got an unexpected keyword argument: {kwargs}', ) - if local is not None and (remote is None or (local == remote)): - if os.path.isdir(local): - contents = set(os.listdir(local)) - if split not in contents: - raise ValueError( - f'local directory {local} does not contain split {split}', - ) + if token_encoding_type not in SUPPORTED_MDS_ENCODING_TYPES: + raise ValueError( + f'The token_encoding_type must be one of {SUPPORTED_MDS_ENCODING_TYPES}, but got {token_encoding_type}', + ) + self.token_encoding_type = token_encoding_type + + if streams is None: + stream_remote_local_validate(remote, local, split) + else: + for stream in streams: + stream_remote_local_validate( + stream.remote, + stream.local, + split, + ) # TODO: discover where yamls are being converted incorrect, but temporary workaround if isinstance(shuffle_block_size, float): @@ -197,10 +212,18 @@ def _read_binary_tokenized_sample( self, sample: Dict[str, Any], ) -> torch.Tensor: - return torch.from_numpy( - np.frombuffer(sample['tokens'], - dtype=np.int64)[:self.max_seq_len].copy(), - ) + # Modeling code still expects int64 tensors. + if isinstance(sample['tokens'], np.ndarray): + return torch.from_numpy( + sample['tokens'][:self.max_seq_len].copy(), + ).to(torch.int64) + else: + return torch.from_numpy( + np.frombuffer( + sample['tokens'], + dtype=getattr(np, self.token_encoding_type), + )[:self.max_seq_len].copy(), + ).to(torch.int64) # How to process a sample def __getitem__(self, diff --git a/scripts/data_prep/README.md b/scripts/data_prep/README.md index 7881298b2f..3601cc865f 100644 --- a/scripts/data_prep/README.md +++ b/scripts/data_prep/README.md @@ -35,6 +35,23 @@ python convert_dataset_json.py \ Where `--path` can be a single json file, or a folder containing json files. `--split` denotes the intended split (hf defaults to `train`). +### Raw text files + +Using the `convert_text_to_mds.py` script, we convert a [text file](https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt) containing the complete works of William Shakespeare. + + +```bash +# Convert json dataset to StreamingDataset format +mkdir shakespeare && cd shakespeare +curl -O https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt +cd .. +python convert_text_to_mds.py \ + --output_folder my-copy-shakespeare \ + --input_folder shakespeare \ + --concat_tokens 2048 --tokenizer EleutherAI/gpt-neox-20b \ + --compression zstd +``` + ## Converting a finetuning dataset Using the `convert_finetuning_dataset.py` script you can run a command such as: diff --git a/scripts/data_prep/convert_dataset_hf.py b/scripts/data_prep/convert_dataset_hf.py index d7aaa52193..bf7f145610 100644 --- a/scripts/data_prep/convert_dataset_hf.py +++ b/scripts/data_prep/convert_dataset_hf.py @@ -12,6 +12,8 @@ import datasets as hf_datasets import psutil +import torch +from numpy.typing import NDArray from streaming import MDSWriter from torch.utils.data import DataLoader, Dataset, IterableDataset from tqdm import tqdm @@ -338,7 +340,7 @@ def build_dataloader( def generate_samples( loader: DataLoader, truncate_num_samples: Optional[int] = None, -) -> Iterable[Dict[str, bytes]]: +) -> Iterable[Union[Dict[str, bytes], Dict[str, NDArray]]]: """Generator over samples of a dataloader. Args: @@ -356,7 +358,11 @@ def generate_samples( if truncate_num_samples is not None and n_samples == truncate_num_samples: return n_samples += 1 - yield {k: v[idx] for k, v in batch.items()} + yield { + k: + v[idx].numpy() if isinstance(v[idx], torch.Tensor) else v[idx] + for k, v in batch.items() + } def main(args: Namespace) -> None: @@ -377,7 +383,7 @@ def main(args: Namespace) -> None: tokenizer = build_tokenizer(args.tokenizer, args.tokenizer_kwargs) # we will enforce length, so suppress warnings about sequences too long for the model tokenizer.model_max_length = int(1e30) - columns = {'tokens': 'bytes'} + columns = {'tokens': 'ndarray:int32'} else: mode = ConcatMode.NO_CONCAT tokenizer = None diff --git a/scripts/data_prep/convert_dataset_json.py b/scripts/data_prep/convert_dataset_json.py index fb117ddef3..37b0465692 100644 --- a/scripts/data_prep/convert_dataset_json.py +++ b/scripts/data_prep/convert_dataset_json.py @@ -6,11 +6,11 @@ from argparse import ArgumentParser, Namespace from enum import Enum from glob import glob -from typing import Dict, Iterable, Optional +from typing import Optional import datasets as hf_datasets from streaming import MDSWriter -from torch.utils.data import DataLoader, IterableDataset +from torch.utils.data import IterableDataset from tqdm import tqdm from transformers import AutoTokenizer, PreTrainedTokenizerBase @@ -140,30 +140,6 @@ def build_hf_dataset( return dataset -def generate_samples( - loader: DataLoader, - truncate_num_samples: Optional[int] = None, -) -> Iterable[Dict[str, bytes]]: - """Generator over samples of a dataloader. - - Args: - loader (DataLoader): A dataloader emitting batches like {key: [sample0_bytes, sample1_bytes, sample2_bytes, ...]} - truncate_num_samples (Optional[int]): An optional # of samples to stop at. - - Yields: - Sample dicts. - """ - n_samples = 0 - for batch in loader: - keys = list(batch.keys()) - current_bs = len(batch[keys[0]]) - for idx in range(current_bs): - if truncate_num_samples is not None and n_samples == truncate_num_samples: - return - n_samples += 1 - yield {k: v[idx] for k, v in batch.items()} - - def main(args: Namespace) -> None: """Main: create C4/pile streaming dataset. @@ -175,7 +151,7 @@ def main(args: Namespace) -> None: tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) # we will enforce length, so suppress warnings about sequences too long for the model tokenizer.model_max_length = int(1e30) - columns = {'tokens': 'bytes'} + columns = {'tokens': 'ndarray:int32'} else: mode = ConcatMode.NO_CONCAT tokenizer = None diff --git a/scripts/data_prep/convert_text_to_mds.py b/scripts/data_prep/convert_text_to_mds.py index 365cc9b71d..b2f0b0e7b4 100644 --- a/scripts/data_prep/convert_text_to_mds.py +++ b/scripts/data_prep/convert_text_to_mds.py @@ -18,6 +18,7 @@ maybe_create_object_store_from_uri, parse_uri, ) +from numpy.typing import NDArray from streaming import MDSWriter from tqdm import tqdm from transformers import AutoTokenizer, PreTrainedTokenizerBase @@ -42,7 +43,7 @@ class ConcatTokensFromFilesDataset(AbstractConcatTokensDataset): """An IterableDataset that returns token samples for MDSWriter from files. - Returns dicts of {'tokens': bytes} + Returns dicts of {'tokens': ndarray:int32} Each file is considered a sequence. """ @@ -59,7 +60,7 @@ def __init__( self.files = files super().__init__(tokenizer, max_length, bos_text, eos_text, no_wrap) - def __iter__(self) -> Iterable[Dict[str, bytes]]: + def __iter__(self) -> Iterable[Dict[str, NDArray]]: buffer = [] for file in self.files: @@ -87,7 +88,9 @@ def __iter__(self) -> Iterable[Dict[str, bytes]]: concat_sample = buffer[:self.max_length] buffer = buffer[self. max_length:] if self.should_wrap else [] - yield {'tokens': np.asarray(concat_sample).tobytes()} + yield { + 'tokens': np.asarray(concat_sample, dtype=np.int32), + } first_chunk = False @@ -98,7 +101,7 @@ def __iter__(self) -> Iterable[Dict[str, bytes]]: while len(buffer) >= self.max_length: concat_sample = buffer[:self.max_length] buffer = buffer[self.max_length:] if self.should_wrap else [] - yield {'tokens': np.asarray(concat_sample).tobytes()} + yield {'tokens': np.asarray(concat_sample, dtype=np.int32)} def parse_args() -> Namespace: @@ -356,7 +359,7 @@ def download_and_convert( no_wrap=no_wrap, ) - columns = {'tokens': 'bytes'} + columns = {'tokens': 'ndarray:int32'} log.info('Converting to MDS format...') with MDSWriter( diff --git a/tests/a_scripts/data_prep/test_convert_text_to_mds.py b/tests/a_scripts/data_prep/test_convert_text_to_mds.py index df4309e13d..8dac151f55 100644 --- a/tests/a_scripts/data_prep/test_convert_text_to_mds.py +++ b/tests/a_scripts/data_prep/test_convert_text_to_mds.py @@ -9,7 +9,6 @@ from typing import Callable, Iterable, List from unittest.mock import Mock, patch -import numpy as np import pytest from streaming import StreamingDataset from transformers import AutoTokenizer @@ -194,7 +193,7 @@ def call_convert_text_to_mds() -> None: n_tokens = 0 for i in range(dataset.num_samples): sample = dataset[i] - tokens = np.frombuffer(sample['tokens'], dtype=int) + tokens = sample['tokens'] if i == 0: # For the first sample, check that the decoded sample matches the text_content decoded = tokenizer.decode(tokens) assert decoded == text_content[:len(decoded)] diff --git a/tests/data/test_data_encodings.py b/tests/data/test_data_encodings.py new file mode 100644 index 0000000000..a45bfbcb88 --- /dev/null +++ b/tests/data/test_data_encodings.py @@ -0,0 +1,205 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 +import pathlib + +import numpy as np +import pytest +import torch +from streaming import MDSWriter + +from llmfoundry.data import SUPPORTED_MDS_ENCODING_TYPES, StreamingTextDataset +from llmfoundry.data.finetuning.tasks import StreamingFinetuningDataset + + +@pytest.mark.parametrize( + 'token_encoding_type', + SUPPORTED_MDS_ENCODING_TYPES + ['default'], +) +@pytest.mark.parametrize('use_bytes', [True, False]) +@pytest.mark.parametrize('samples', [10]) +@pytest.mark.parametrize('max_seq_len', [2048]) +def test_encoding_types_text( + tmp_path: pathlib.Path, + token_encoding_type: str, + use_bytes: bool, + samples: int, + max_seq_len: int, +): + dataset_local_path = str(tmp_path) + if token_encoding_type != 'default': + encoding_dtype = getattr(np, token_encoding_type) + else: + encoding_dtype = None + + if use_bytes: + columns = { + 'tokens': 'bytes', + } + else: + columns = { + 'tokens': + 'ndarray:' + token_encoding_type + if token_encoding_type != 'default' else 'ndarray', + } + + with MDSWriter(out=dataset_local_path, columns=columns) as writer: + for _ in range(samples): + if token_encoding_type != 'default': + tokens = np.random.randint( + 0, + np.iinfo(encoding_dtype).max, + max_seq_len, + dtype=encoding_dtype, + ) + else: + tokens = np.random.randint( + 0, + 200, + max_seq_len, + ) + if use_bytes: + tokens = tokens.tobytes() + writer.write({'tokens': tokens}) + + if use_bytes and token_encoding_type != 'default': + dataset = StreamingTextDataset( + tokenizer=None, + token_encoding_type=token_encoding_type, + max_seq_len=max_seq_len, + local=dataset_local_path, + batch_size=1, + ) + else: + # There should be no need to pass in the token encoding type if writing out ndarrays, + # or if using the default token encoding type. + dataset = StreamingTextDataset( + tokenizer=None, + max_seq_len=max_seq_len, + local=dataset_local_path, + batch_size=1, + ) + + for _, sample in enumerate(dataset): + # StreamingTextDataset should return an int64 torch Tensor + assert sample.dtype == torch.int64 + assert sample.shape == (max_seq_len,) + + +@pytest.mark.parametrize( + 'token_encoding_type', + SUPPORTED_MDS_ENCODING_TYPES + ['default'], +) +@pytest.mark.parametrize('use_bytes', [True, False]) +@pytest.mark.parametrize('samples', [10]) +@pytest.mark.parametrize('max_seq_len', [2048]) +def test_encoding_types_finetuning( + tmp_path: pathlib.Path, + token_encoding_type: str, + use_bytes: bool, + samples: int, + max_seq_len: int, +): + dataset_local_path = str(tmp_path) + if token_encoding_type != 'default': + encoding_dtype = getattr(np, token_encoding_type) + else: + encoding_dtype = None + + if use_bytes: + columns = { + 'input_ids': 'bytes', + 'labels': 'bytes', + } + else: + columns = { + 'input_ids': + 'ndarray:' + token_encoding_type + if token_encoding_type != 'default' else 'ndarray', + 'labels': + 'ndarray:' + token_encoding_type + if token_encoding_type != 'default' else 'ndarray', + } + + with MDSWriter(out=dataset_local_path, columns=columns) as writer: + for _ in range(samples): + if token_encoding_type != 'default': + input_ids = np.random.randint( + 0, + np.iinfo(encoding_dtype).max, + max_seq_len, + dtype=encoding_dtype, + ) + labels = np.random.randint( + 0, + np.iinfo(encoding_dtype).max, + max_seq_len, + dtype=encoding_dtype, + ) + else: + input_ids = np.random.randint( + 0, + 200, + max_seq_len, + ) + labels = np.random.randint( + 0, + 200, + max_seq_len, + ) + if use_bytes: + input_ids = input_ids.tobytes() + labels = labels.tobytes() + writer.write({'input_ids': input_ids, 'labels': labels}) + + if use_bytes and token_encoding_type != 'default': + dataset = StreamingFinetuningDataset( + tokenizer=None, + token_encoding_type=token_encoding_type, + local=dataset_local_path, + max_seq_len=max_seq_len, + batch_size=1, + ) + else: + # There should be no need to pass in the token encoding type if writing out ndarrays, + # or if using the default token encoding type. + dataset = StreamingFinetuningDataset( + tokenizer=None, + local=dataset_local_path, + max_seq_len=max_seq_len, + batch_size=1, + ) + + for _, sample in enumerate(dataset): + # StreamingFinetuningDataset puts samples in a list, and converts arrays to lists too. + assert isinstance(sample['turns'][0]['input_ids'][0], int) + assert len(sample['turns'][0]['input_ids']) == max_seq_len + assert isinstance(sample['turns'][0]['labels'][0], int) + assert len(sample['turns'][0]['labels']) == max_seq_len + + +@pytest.mark.parametrize( + 'token_encoding_type', + ['int17', 'float32', 'complex', 'int4'], +) +@pytest.mark.parametrize('use_finetuning', [True, False]) +def test_unsupported_encoding_type( + token_encoding_type: str, + use_finetuning: bool, +): + with pytest.raises(ValueError, match='The token_encoding_type*'): + if use_finetuning: + StreamingFinetuningDataset( + tokenizer=None, + token_encoding_type=token_encoding_type, + local='dataset/path', + max_seq_len=2048, + batch_size=1, + ) + else: + StreamingTextDataset( + tokenizer=None, + token_encoding_type=token_encoding_type, + max_seq_len=2048, + local='dataset/path', + batch_size=1, + ) diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 7c8e808bab..ec27df8121 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -114,8 +114,8 @@ def build_mock_ft_streaming_dataset( columns = {'input_ids': 'bytes', 'labels': 'bytes'} else: columns = { - 'input_ids': 'ndarray:uint32', - 'labels': 'ndarray:uint32', + 'input_ids': 'ndarray:int32', + 'labels': 'ndarray:int32', } else: columns = {'prompt': 'str', 'response': 'str'} @@ -142,7 +142,7 @@ def build_mock_ft_streaming_dataset( else: sample_to_write[key] = np.asarray( sample[key], - dtype=np.uint32, + dtype=np.int32, ) output_writer.write(sample_to_write) else: