Skip to content

Commit

Permalink
add more token encoing types
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Jun 5, 2024
1 parent 71b5bf4 commit 866bfe6
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class StreamingTextDataset(StreamingDataset):
tokenizer (Tokenizer): HuggingFace tokenizer to
tokenize samples.
max_seq_len (int): The max sequence length of each sample.
token_encoding_type (str): The type of the token encoding, either 'int16', 'int32', or 'int64'.
streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from,
which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or
``remote``/``local``. Defaults to ``None``.
Expand Down Expand Up @@ -106,7 +107,7 @@ def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
max_seq_len: int,
token_encoding_type: str = 'int32',
token_encoding_type: str = 'int64',
streams: Optional[Sequence[Stream]] = None,
remote: Optional[str] = None,
local: Optional[str] = None,
Expand Down Expand Up @@ -206,7 +207,8 @@ def _read_binary_tokenized_sample(
) -> torch.Tensor:
return torch.from_numpy(
np.frombuffer(
sample['tokens'], dtype=getattr(np, self.token_encoding_type)
sample['tokens'],
dtype=getattr(np, self.token_encoding_type),
)[:self.max_seq_len].copy(),
)

Expand Down

0 comments on commit 866bfe6

Please sign in to comment.