diff --git a/llmfoundry/data/data.py b/llmfoundry/data/data.py index 482c296fa5..04eb6d345d 100644 --- a/llmfoundry/data/data.py +++ b/llmfoundry/data/data.py @@ -4,6 +4,7 @@ """Datasets for converting to MDS Shards.""" import os import warnings +from abc import ABC, abstractmethod from typing import Dict, Iterable, Union import datasets as hf_datasets @@ -35,39 +36,20 @@ def __iter__(self) -> Iterable[Dict[str, bytes]]: yield {'text': sample['text'].encode('utf-8')} -class ConcatTokensDataset(IterableDataset): - """An IterableDataset that returns token samples for MDSWriter. - - Returns dicts of {'tokens': bytes} - - To use data created by this class and written to MDS format: - - ```python - import torch - from streaming.base import StreamingDataset - from transformers import AutoTokenizer +class AbstractConcatTokensDataset(ABC, IterableDataset): + """Abstract class for defining an IterableDataset that tokenizes and. - tokenizer = AutoTokenizer.from_pretrained('your/tokenizer') - ds = StreamingDataset(local='mds-data-folder', split='val') - - # note, you need to copy the numpy array because the original is non-writeable - # and torch does not support non-writeable tensors, so you get a scary warning and - # if you do try to write to the tensor you get undefined behavior - tokens = torch.from_numpy(np.frombuffer(ds[0]['tokens'], dtype=np.int64).copy()) - print(tokenizer.decode(tokens)) - ``` + concatenates text samples on the fly. """ def __init__( self, - hf_dataset: Union[hf_datasets.IterableDataset, hf_datasets.Dataset], tokenizer: PreTrainedTokenizerBase, max_length: int, bos_text: str, eos_text: str, no_wrap: bool, ): - self.hf_dataset = hf_dataset self.tokenizer = tokenizer os.environ['TOKENIZERS_PARALLELISM'] = 'false' self.max_length = max_length @@ -114,8 +96,47 @@ def __init__( 'in duplicated special tokens. Please be sure this is what you intend.', ) + @abstractmethod def __iter__(self) -> Iterable[Dict[str, bytes]]: + pass + + +class ConcatTokensDataset(AbstractConcatTokensDataset): + """An IterableDataset that returns token samples for MDSWriter. + + Returns dicts of {'tokens': bytes} + + To use data created by this class and written to MDS format: + ```python + import torch + from streaming.base import StreamingDataset + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained('your/tokenizer') + ds = StreamingDataset(local='mds-data-folder', split='val') + + # note, you need to copy the numpy array because the original is non-writeable + # and torch does not support non-writeable tensors, so you get a scary warning and + # if you do try to write to the tensor you get undefined behavior + tokens = torch.from_numpy(np.frombuffer(ds[0]['tokens'], dtype=np.int64).copy()) + print(tokenizer.decode(tokens)) + ``` + """ + + def __init__( + self, + hf_dataset: Union[hf_datasets.IterableDataset, hf_datasets.Dataset], + tokenizer: PreTrainedTokenizerBase, + max_length: int, + bos_text: str, + eos_text: str, + no_wrap: bool, + ): + self.hf_dataset = hf_dataset + super().__init__(tokenizer, max_length, bos_text, eos_text, no_wrap) + + def __iter__(self) -> Iterable[Dict[str, bytes]]: buffer = [] for sample in self.hf_dataset: encoded = self.tokenizer( diff --git a/llmfoundry/utils/data_prep_utils.py b/llmfoundry/utils/data_prep_utils.py index 9601e8618e..1662ab74c2 100644 --- a/llmfoundry/utils/data_prep_utils.py +++ b/llmfoundry/utils/data_prep_utils.py @@ -108,7 +108,7 @@ def __init__( output_folder: str, object_store: Optional[ObjectStore], ): - """Iterable that downloads files from an object store before yielding. + """Iterable that downloads files before yielding the local filename. If object_store is None, input_folder_prefix is treated as a local path. @@ -138,7 +138,4 @@ def __iter__(self): object_name=object_name, output_filename=output_filename, ) - - with open(output_filename) as _txt_file: - txt = _txt_file.read() - yield {'text': txt} + yield output_filename diff --git a/scripts/data_prep/convert_text_to_mds.py b/scripts/data_prep/convert_text_to_mds.py index 98e9843570..6c94798682 100644 --- a/scripts/data_prep/convert_text_to_mds.py +++ b/scripts/data_prep/convert_text_to_mds.py @@ -7,9 +7,11 @@ import tempfile from argparse import ArgumentParser, Namespace from concurrent.futures import ProcessPoolExecutor +from functools import partial from glob import glob -from typing import Iterable, List, Tuple, cast +from typing import Dict, Iterable, List, Tuple, cast +import numpy as np import psutil from composer.utils import ( ObjectStore, @@ -18,9 +20,9 @@ ) from streaming import MDSWriter from tqdm import tqdm -from transformers import AutoTokenizer +from transformers import AutoTokenizer, PreTrainedTokenizerBase -from llmfoundry.data import ConcatTokensDataset +from llmfoundry.data.data import AbstractConcatTokensDataset from llmfoundry.utils import maybe_create_mosaicml_logger from llmfoundry.utils.data_prep_utils import ( DownloadingIterable, @@ -37,6 +39,62 @@ DONE_FILENAME = '.text_to_mds_conversion_done' +class ConcatTokensFromFilesDataset(AbstractConcatTokensDataset): + """An IterableDataset that returns token samples for MDSWriter from files. + + Returns dicts of {'tokens': bytes} + + Each file is considered a sequence. + """ + + def __init__( + self, + files: Iterable[str], + tokenizer: PreTrainedTokenizerBase, + max_length: int, + bos_text: str, + eos_text: str, + no_wrap: bool, + ): + self.files = files + super().__init__(tokenizer, max_length, bos_text, eos_text, no_wrap) + + def __iter__(self) -> Iterable[Dict[str, bytes]]: + + buffer = [] + for file in self.files: + with open(file, 'r') as f: + buffer += self.bos_tokens + first_chunk = True + # Read the file in 1MB chunks to avoid memory issues + for chunk in iter(partial(f.read, 1000000), ''): + # Tokenize the chunk + encoded = self.tokenizer( + chunk, + truncation=False, + padding=False, + ) + iids = encoded['input_ids'] + + # If this is not the first chunk, remove the BOS token + if not first_chunk: + if iids[0] == self.tokenizer.bos_token_id: + iids = iids[1:] + + # Add the tokens to the buffer + buffer += iids + while len(buffer) >= self.max_length: + concat_sample = buffer[:self.max_length] + buffer = buffer[self. + max_length:] if self.should_wrap else [] + yield {'tokens': np.asarray(concat_sample).tobytes()} + + first_chunk = False + + # Add the EOS token to the buffer to separate files. + buffer += self.eos_tokens + + def parse_args() -> Namespace: """Parse commandline arguments.""" parser = ArgumentParser( @@ -277,8 +335,8 @@ def download_and_convert( # Use the ConcatTokensDataset from LLM-foundry to concatenate sequences of tokens up # to the maximum sequence length - dataset = ConcatTokensDataset( - hf_dataset=downloading_iter, + dataset = ConcatTokensFromFilesDataset( + files=downloading_iter, max_length=concat_tokens, tokenizer=tokenizer, eos_text=eos_text,