Skip to content

Commit

Permalink
logic down, not tested
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Jan 8, 2024
1 parent 9bdf480 commit b5a89ba
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 76 deletions.
30 changes: 27 additions & 3 deletions llmfoundry/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""Datasets for converting to MDS Shards."""
import os
import warnings
from abc import ABC, abstractmethod
from typing import Dict, Iterable, Union

import datasets as hf_datasets
Expand All @@ -28,7 +29,7 @@ def __iter__(self) -> Iterable[Dict[str, bytes]]:
yield {'text': sample['text'].encode('utf-8')}


class ConcatTokensDataset(IterableDataset):
class AbstractConcatTokensDataset(ABC, IterableDataset):
"""An IterableDataset that returns token samples for MDSWriter.
Returns dicts of {'tokens': bytes}
Expand All @@ -53,14 +54,12 @@ class ConcatTokensDataset(IterableDataset):

def __init__(
self,
hf_dataset: Union[hf_datasets.IterableDataset, hf_datasets.Dataset],
tokenizer: PreTrainedTokenizerBase,
max_length: int,
bos_text: str,
eos_text: str,
no_wrap: bool,
):
self.hf_dataset = hf_dataset
self.tokenizer = tokenizer
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
self.max_length = max_length
Expand Down Expand Up @@ -99,6 +98,31 @@ def __init__(
'in duplicated special tokens. Please be sure this is what you intend.'
)

@abstractmethod
def __iter__(self) -> Iterable[Dict[str, bytes]]:
pass


class ConcatTokensDataset(AbstractConcatTokensDataset):
"""An IterableDataset that returns token samples for MDSWriter.
Samples are taken from a HuggingFace dataset.
Returns dicts of {'tokens': bytes}
"""

def __init__(
self,
hf_dataset: Union[hf_datasets.IterableDataset, hf_datasets.Dataset],
tokenizer: PreTrainedTokenizerBase,
max_length: int,
bos_text: str,
eos_text: str,
no_wrap: bool,
):
self.hf_dataset = hf_dataset
super().__init__(tokenizer, max_length, bos_text, eos_text, no_wrap)

def __iter__(self) -> Iterable[Dict[str, bytes]]:

buffer = []
Expand Down
88 changes: 19 additions & 69 deletions llmfoundry/utils/data_prep_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@

import json
import os
from functools import partial
from glob import glob
from typing import List, Optional
from typing import Dict, Iterable, List, Optional

import numpy as np
from composer.utils import ObjectStore
from torch.utils.data import IterableDataset
from transformers import PreTrainedTokenizerBase

from llmfoundry.data.data import AbstractConcatTokensDataset


def with_id(basename: str, shard_id: int) -> str:
"""Get a new basename with the given shard_id.
Expand Down Expand Up @@ -112,99 +115,46 @@ def __iter__(self):
yield output_filename


class ConcatTokensFromFilesDataset(IterableDataset):
"""An IterableDataset that returns token samples for MDSWriter.
class ConcatTokensFromFilesDataset(AbstractConcatTokensDataset):
"""An IterableDataset that returns token samples for MDSWriter from files.
Returns dicts of {'tokens': bytes}
To use data created by this class and written to MDS format:
```python
import torch
from streaming.base import StreamingDataset
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('your/tokenizer')
ds = StreamingDataset(local='mds-data-folder', split='val')
# note, you need to copy the numpy array because the original is non-writeable
# and torch does not support non-writeable tensors, so you get a scary warning and
# if you do try to write to the tensor you get undefined behavior
tokens = torch.from_numpy(np.frombuffer(ds[0]['tokens'], dtype=np.int64).copy())
print(tokenizer.decode(tokens))
```
Each file is considered a sequence.
"""

def __init__(
self,
files: list[str],
files: Iterable[str],
tokenizer: PreTrainedTokenizerBase,
max_length: int,
bos_text: str,
eos_text: str,
no_wrap: bool,
):
self.files = files
self.tokenizer = tokenizer
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
self.max_length = max_length
self.bos_text = bos_text
self.eos_text = eos_text
self.should_wrap = not no_wrap

self.bos_tokens = self.tokenizer(self.bos_text,
truncation=False,
padding=False,
add_special_tokens=False)['input_ids']
if len(self.bos_tokens) > 1:
warnings.warn(
f'You specified --concat_tokens with --bos_text, but your BOS text is not tokenizing to one token\
, instead we got {self.bos_tokens}. Quit if this was in error.')

self.eos_tokens = self.tokenizer(self.eos_text,
truncation=False,
padding=False,
add_special_tokens=False)['input_ids']
if len(self.eos_tokens) > 1:
warnings.warn(
f'You specified --concat_tokens with --eos_text, but your EOS text is not tokenizing to one token\
, instead we got {self.eos_tokens}. Quit if this was in error.')

eos_text_provided = self.eos_text != ''
bos_text_provided = self.bos_text != ''
test_text = self.tokenizer('')
if len(test_text['input_ids']) > 0 and (eos_text_provided or
bos_text_provided):
message = 'both eos and bos' if eos_text_provided and bos_text_provided else (
'eos_text' if eos_text_provided else 'bos_text')
warnings.warn(
f'The provided tokenizer adds special tokens, but you also specified {message}. This may result '
+
'in duplicated special tokens. Please be sure this is what you intend.'
)
super().__init__(tokenizer, max_length, bos_text, eos_text, no_wrap)

def __iter__(self) -> Iterable[Dict[str, bytes]]:

buffer = []
files = ['']
for file in files:
for file in self.files:
with open(file, 'r') as f:
buffer += self.bos_tokens
for chunk in iter(partial(f.read, 1000), ''):
encoded = self.tokenizer(chunk, truncation=False, padding=False)
encoded = self.tokenizer(chunk,
truncation=False,
padding=False)
iids = encoded['input_ids']
buffer += iids
while len(buffer) >= self.max_length:
concat_sample = buffer[:self.max_length]
buffer = buffer[self.max_length:] if self.should_wrap else []
yield {
'tokens': np.asarray(concat_sample).tobytes()
}
buffer = buffer[
self.max_length:] if self.should_wrap else []
yield {'tokens': np.asarray(concat_sample).tobytes()}
buffer += self.eos_tokens
# Finish up the last of the tokens.
while len(buffer) >= self.max_length:
while len(buffer) >= self.max_length:
concat_sample = buffer[:self.max_length]
buffer = buffer[self.max_length:] if self.should_wrap else []
yield {
'tokens': np.asarray(concat_sample).tobytes()
}
yield {'tokens': np.asarray(concat_sample).tobytes()}
8 changes: 4 additions & 4 deletions scripts/data_prep/convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from tqdm import tqdm
from transformers import AutoTokenizer

from llmfoundry.data import ConcatTokensDataset
from llmfoundry.utils.data_prep_utils import (DownloadingIterable,
from llmfoundry.utils.data_prep_utils import (ConcatTokensFromFilesDataset,
DownloadingIterable,
merge_shard_groups)

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -231,8 +231,8 @@ def download_and_convert(

# Use the ConcatTokensDataset from LLM-foundry to concatenate sequences of tokens up
# to the maximum sequence length
dataset = ConcatTokensDataset(
hf_dataset=downloading_iter,
dataset = ConcatTokensFromFilesDataset(
files=downloading_iter,
max_length=concat_tokens,
tokenizer=tokenizer,
eos_text=eos_text,
Expand Down

0 comments on commit b5a89ba

Please sign in to comment.