Skip to content

Commit

Permalink
Chunk file reads and tokenization
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed May 25, 2024
1 parent 2e10d95 commit 4ada4b5
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 32 deletions.
65 changes: 43 additions & 22 deletions llmfoundry/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""Datasets for converting to MDS Shards."""
import os
import warnings
from abc import ABC, abstractmethod
from typing import Dict, Iterable, Union

import datasets as hf_datasets
Expand Down Expand Up @@ -35,39 +36,20 @@ def __iter__(self) -> Iterable[Dict[str, bytes]]:
yield {'text': sample['text'].encode('utf-8')}


class ConcatTokensDataset(IterableDataset):
"""An IterableDataset that returns token samples for MDSWriter.
Returns dicts of {'tokens': bytes}
To use data created by this class and written to MDS format:
```python
import torch
from streaming.base import StreamingDataset
from transformers import AutoTokenizer
class AbstractConcatTokensDataset(ABC, IterableDataset):
"""Abstract class for defining an IterableDataset that tokenizes and.
tokenizer = AutoTokenizer.from_pretrained('your/tokenizer')
ds = StreamingDataset(local='mds-data-folder', split='val')
# note, you need to copy the numpy array because the original is non-writeable
# and torch does not support non-writeable tensors, so you get a scary warning and
# if you do try to write to the tensor you get undefined behavior
tokens = torch.from_numpy(np.frombuffer(ds[0]['tokens'], dtype=np.int64).copy())
print(tokenizer.decode(tokens))
```
concatenates text samples on the fly.
"""

def __init__(
self,
hf_dataset: Union[hf_datasets.IterableDataset, hf_datasets.Dataset],
tokenizer: PreTrainedTokenizerBase,
max_length: int,
bos_text: str,
eos_text: str,
no_wrap: bool,
):
self.hf_dataset = hf_dataset
self.tokenizer = tokenizer
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
self.max_length = max_length
Expand Down Expand Up @@ -114,8 +96,47 @@ def __init__(
'in duplicated special tokens. Please be sure this is what you intend.',
)

@abstractmethod
def __iter__(self) -> Iterable[Dict[str, bytes]]:
pass


class ConcatTokensDataset(AbstractConcatTokensDataset):
"""An IterableDataset that returns token samples for MDSWriter.
Returns dicts of {'tokens': bytes}
To use data created by this class and written to MDS format:
```python
import torch
from streaming.base import StreamingDataset
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('your/tokenizer')
ds = StreamingDataset(local='mds-data-folder', split='val')
# note, you need to copy the numpy array because the original is non-writeable
# and torch does not support non-writeable tensors, so you get a scary warning and
# if you do try to write to the tensor you get undefined behavior
tokens = torch.from_numpy(np.frombuffer(ds[0]['tokens'], dtype=np.int64).copy())
print(tokenizer.decode(tokens))
```
"""

def __init__(
self,
hf_dataset: Union[hf_datasets.IterableDataset, hf_datasets.Dataset],
tokenizer: PreTrainedTokenizerBase,
max_length: int,
bos_text: str,
eos_text: str,
no_wrap: bool,
):
self.hf_dataset = hf_dataset
super().__init__(tokenizer, max_length, bos_text, eos_text, no_wrap)

def __iter__(self) -> Iterable[Dict[str, bytes]]:
buffer = []
for sample in self.hf_dataset:
encoded = self.tokenizer(
Expand Down
7 changes: 2 additions & 5 deletions llmfoundry/utils/data_prep_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(
output_folder: str,
object_store: Optional[ObjectStore],
):
"""Iterable that downloads files from an object store before yielding.
"""Iterable that downloads files before yielding the local filename.
If object_store is None, input_folder_prefix is treated as a local path.
Expand Down Expand Up @@ -138,7 +138,4 @@ def __iter__(self):
object_name=object_name,
output_filename=output_filename,
)

with open(output_filename) as _txt_file:
txt = _txt_file.read()
yield {'text': txt}
yield output_filename
68 changes: 63 additions & 5 deletions scripts/data_prep/convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import tempfile
from argparse import ArgumentParser, Namespace
from concurrent.futures import ProcessPoolExecutor
from functools import partial
from glob import glob
from typing import Iterable, List, Tuple, cast
from typing import Dict, Iterable, List, Tuple, cast

import numpy as np
import psutil
from composer.utils import (
ObjectStore,
Expand All @@ -18,9 +20,9 @@
)
from streaming import MDSWriter
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from llmfoundry.data import ConcatTokensDataset
from llmfoundry.data.data import AbstractConcatTokensDataset
from llmfoundry.utils import maybe_create_mosaicml_logger
from llmfoundry.utils.data_prep_utils import (
DownloadingIterable,
Expand All @@ -37,6 +39,62 @@
DONE_FILENAME = '.text_to_mds_conversion_done'


class ConcatTokensFromFilesDataset(AbstractConcatTokensDataset):
"""An IterableDataset that returns token samples for MDSWriter from files.
Returns dicts of {'tokens': bytes}
Each file is considered a sequence.
"""

def __init__(
self,
files: Iterable[str],
tokenizer: PreTrainedTokenizerBase,
max_length: int,
bos_text: str,
eos_text: str,
no_wrap: bool,
):
self.files = files
super().__init__(tokenizer, max_length, bos_text, eos_text, no_wrap)

def __iter__(self) -> Iterable[Dict[str, bytes]]:

buffer = []
for file in self.files:
with open(file, 'r') as f:
buffer += self.bos_tokens
first_chunk = True
# Read the file in 1MB chunks to avoid memory issues
for chunk in iter(partial(f.read, 1000000), ''):
# Tokenize the chunk
encoded = self.tokenizer(
chunk,
truncation=False,
padding=False,
)
iids = encoded['input_ids']

# If this is not the first chunk, remove the BOS token
if not first_chunk:
if iids[0] == self.tokenizer.bos_token_id:
iids = iids[1:]

# Add the tokens to the buffer
buffer += iids
while len(buffer) >= self.max_length:
concat_sample = buffer[:self.max_length]
buffer = buffer[self.
max_length:] if self.should_wrap else []
yield {'tokens': np.asarray(concat_sample).tobytes()}

first_chunk = False

# Add the EOS token to the buffer to separate files.
buffer += self.eos_tokens


def parse_args() -> Namespace:
"""Parse commandline arguments."""
parser = ArgumentParser(
Expand Down Expand Up @@ -277,8 +335,8 @@ def download_and_convert(

# Use the ConcatTokensDataset from LLM-foundry to concatenate sequences of tokens up
# to the maximum sequence length
dataset = ConcatTokensDataset(
hf_dataset=downloading_iter,
dataset = ConcatTokensFromFilesDataset(
files=downloading_iter,
max_length=concat_tokens,
tokenizer=tokenizer,
eos_text=eos_text,
Expand Down

0 comments on commit 4ada4b5

Please sign in to comment.