From 866bfe6041a1af069d5c3e63d282842cc8c014d7 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Wed, 5 Jun 2024 10:49:13 -0700 Subject: [PATCH] add more token encoing types --- llmfoundry/data/text_data.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index 78dadeac3b..bfd1b73812 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -41,6 +41,7 @@ class StreamingTextDataset(StreamingDataset): tokenizer (Tokenizer): HuggingFace tokenizer to tokenize samples. max_seq_len (int): The max sequence length of each sample. + token_encoding_type (str): The type of the token encoding, either 'int16', 'int32', or 'int64'. streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from, which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. @@ -106,7 +107,7 @@ def __init__( self, tokenizer: PreTrainedTokenizerBase, max_seq_len: int, - token_encoding_type: str = 'int32', + token_encoding_type: str = 'int64', streams: Optional[Sequence[Stream]] = None, remote: Optional[str] = None, local: Optional[str] = None, @@ -206,7 +207,8 @@ def _read_binary_tokenized_sample( ) -> torch.Tensor: return torch.from_numpy( np.frombuffer( - sample['tokens'], dtype=getattr(np, self.token_encoding_type) + sample['tokens'], + dtype=getattr(np, self.token_encoding_type), )[:self.max_seq_len].copy(), )