Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding more token encoding types #1254

Merged
merged 16 commits into from
Jun 6, 2024
7 changes: 6 additions & 1 deletion llmfoundry/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.data.data import ConcatTokensDataset, NoConcatDataset
from llmfoundry.data.data import (
ConcatTokensDataset,
NoConcatDataset,
stream_remote_local_validate,
)
from llmfoundry.data.dataloader import build_dataloader
from llmfoundry.data.finetuning import (
Seq2SeqFinetuningCollator,
Expand Down Expand Up @@ -55,4 +59,5 @@
'auto_packing_ratio',
'profile_packing',
'ConcatenatedSequenceCollatorWrapper',
'stream_remote_local_validate',
]
16 changes: 15 additions & 1 deletion llmfoundry/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import warnings
from abc import ABC, abstractmethod
from typing import Dict, Iterable, Union
from typing import Dict, Iterable, Optional, Union

import datasets as hf_datasets
import numpy as np
Expand Down Expand Up @@ -153,3 +153,17 @@ def __iter__(self) -> Iterable[Dict[str, bytes]]:
# convert to bytes to store in MDS binary format
'tokens': np.asarray(concat_sample).tobytes(),
}


def stream_remote_local_validate(
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
remote: Optional[str],
local: Optional[str],
split: Optional[str],
):
if remote is None or (local == remote):
if local is not None and os.path.isdir(local):
contents = set(os.listdir(local))
if split is not None and split not in contents:
raise ValueError(
f'Local directory {local} does not contain split {split}',
)
41 changes: 23 additions & 18 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
from streaming import Stream, StreamingDataset
from transformers import PreTrainedTokenizerBase

from llmfoundry.data import stream_remote_local_validate
from llmfoundry.data.finetuning.collator import (
_HF_IGNORE_INDEX,
stitch_turns_decoder_only,
Expand Down Expand Up @@ -494,26 +495,14 @@ def is_valid_ift_example(
return True


def _stream_remote_local_validate(
remote: Optional[str],
local: Optional[str],
split: Optional[str],
):
if remote is None or (local == remote):
if local is not None and os.path.isdir(local):
contents = set(os.listdir(local))
if split is not None and split not in contents:
raise ValueError(
f'Local directory {local} does not contain split {split}',
)


class StreamingFinetuningDataset(StreamingDataset):
"""Finetuning dataset with flexible tokenization using StreamingDataset.

Args:
tokenizer (Tokenizer): The name of the HuggingFace tokenizer to use to
tokenize samples.
token_encoding_type (str): The encoding type of the tokenized samples. Can be one of
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
['int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64'].
streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from,
which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or
``remote``/``local``. Defaults to ``None``.
Expand Down Expand Up @@ -574,6 +563,7 @@ class StreamingFinetuningDataset(StreamingDataset):
def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
token_encoding_type: str = 'int64',
streams: Optional[Sequence[Stream]] = None,
local: Optional[str] = None,
remote: Optional[str] = None,
Expand Down Expand Up @@ -606,11 +596,26 @@ def __init__(
f'StreamingFinetuningDataset() got an unexpected keyword argument: {kwargs}',
)

if token_encoding_type not in [
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
'int8',
'int16',
'int32',
'int64',
'uint8',
'uint16',
'uint32',
'uint64',
]:
raise ValueError(
f'The token_encoding_type must be one of [\'int8\', \'int16\', \'int32\', \'int64\' \'uint8\', \'uint16\', \'uint32\', \'uint64\'], but got {token_encoding_type}',
)
self.token_encoding_type = token_encoding_type

if streams is None:
_stream_remote_local_validate(remote, local, split)
stream_remote_local_validate(remote, local, split)
else:
for stream in streams:
_stream_remote_local_validate(
stream_remote_local_validate(
stream.remote,
stream.local,
split,
Expand Down Expand Up @@ -656,11 +661,11 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
if isinstance(sample['input_ids'], bytes):
sample['input_ids'] = np.frombuffer(
sample['input_ids'],
dtype=np.int64,
dtype=getattr(np, self.token_encoding_type),
)[:self.max_seq_len].tolist().copy()
sample['labels'] = np.frombuffer(
sample['labels'],
dtype=np.int64,
dtype=getattr(np, self.token_encoding_type),
)[:self.max_seq_len].tolist().copy()
elif isinstance(sample['input_ids'], np.ndarray):
sample['input_ids'] = sample['input_ids'][:self.max_seq_len
Expand Down
49 changes: 37 additions & 12 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""Build a StreamingTextDataset dataset and dataloader for training."""

import inspect
import os
from itertools import islice
from typing import (
Any,
Expand All @@ -25,6 +24,7 @@
from transformers import PreTrainedTokenizerBase

from llmfoundry import registry
from llmfoundry.data import stream_remote_local_validate
from llmfoundry.utils.registry_utils import construct_from_registry

__all__ = [
Expand All @@ -41,6 +41,8 @@ class StreamingTextDataset(StreamingDataset):
tokenizer (Tokenizer): HuggingFace tokenizer to
tokenize samples.
max_seq_len (int): The max sequence length of each sample.
token_encoding_type (str): The encoding type of the tokenized samples. Can be one of
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
['int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64'].
streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from,
which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or
``remote``/``local``. Defaults to ``None``.
Expand Down Expand Up @@ -106,6 +108,7 @@ def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
max_seq_len: int,
token_encoding_type: str = 'int64',
streams: Optional[Sequence[Stream]] = None,
remote: Optional[str] = None,
local: Optional[str] = None,
Expand Down Expand Up @@ -137,13 +140,30 @@ def __init__(
f'StreamingTextDataset() got an unexpected keyword argument: {kwargs}',
)

if local is not None and (remote is None or (local == remote)):
if os.path.isdir(local):
contents = set(os.listdir(local))
if split not in contents:
raise ValueError(
f'local directory {local} does not contain split {split}',
)
if token_encoding_type not in [
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
'int8',
'int16',
'int32',
'int64',
'uint8',
'uint16',
'uint32',
'uint64',
]:
raise ValueError(
f'The token_encoding_type must be one of [\'int8\', \'int16\', \'int32\', \'int64\' \'uint8\', \'uint16\', \'uint32\', \'uint64\'], but got {token_encoding_type}',
)
self.token_encoding_type = token_encoding_type

if streams is None:
stream_remote_local_validate(remote, local, split)
else:
for stream in streams:
stream_remote_local_validate(
stream.remote,
stream.local,
split,
)

# TODO: discover where yamls are being converted incorrect, but temporary workaround
if isinstance(shuffle_block_size, float):
Expand Down Expand Up @@ -197,10 +217,15 @@ def _read_binary_tokenized_sample(
self,
sample: Dict[str, Any],
) -> torch.Tensor:
return torch.from_numpy(
np.frombuffer(sample['tokens'],
dtype=np.int64)[:self.max_seq_len].copy(),
)
if isinstance(sample['tokens'], np.ndarray):
return torch.from_numpy(sample['tokens'][:self.max_seq_len].copy())
else:
return torch.from_numpy(
np.frombuffer(
sample['tokens'],
dtype=getattr(np, self.token_encoding_type),
)[:self.max_seq_len].copy(),
)

# How to process a sample
def __getitem__(self,
Expand Down
2 changes: 1 addition & 1 deletion scripts/data_prep/convert_dataset_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def main(args: Namespace) -> None:
tokenizer = build_tokenizer(args.tokenizer, args.tokenizer_kwargs)
# we will enforce length, so suppress warnings about sequences too long for the model
tokenizer.model_max_length = int(1e30)
columns = {'tokens': 'bytes'}
columns = {'tokens': 'ndarray'}
else:
mode = ConcatMode.NO_CONCAT
tokenizer = None
Expand Down
2 changes: 1 addition & 1 deletion scripts/data_prep/convert_dataset_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def main(args: Namespace) -> None:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
# we will enforce length, so suppress warnings about sequences too long for the model
tokenizer.model_max_length = int(1e30)
columns = {'tokens': 'bytes'}
columns = {'tokens': 'ndarray'}
else:
mode = ConcatMode.NO_CONCAT
tokenizer = None
Expand Down
2 changes: 1 addition & 1 deletion scripts/data_prep/convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def download_and_convert(
no_wrap=no_wrap,
)

columns = {'tokens': 'bytes'}
columns = {'tokens': 'ndarray'}

log.info('Converting to MDS format...')
with MDSWriter(
Expand Down
Loading
Loading