Skip to content

Commit

Permalink
add more token encoing types
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Jun 5, 2024
1 parent ac56dc5 commit 71b5bf4
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
max_seq_len: int,
token_encoding_type: str = 'int32',
streams: Optional[Sequence[Stream]] = None,
remote: Optional[str] = None,
local: Optional[str] = None,
Expand Down Expand Up @@ -137,6 +138,12 @@ def __init__(
f'StreamingTextDataset() got an unexpected keyword argument: {kwargs}',
)

if token_encoding_type not in ['int16', 'int32', 'int64']:
raise ValueError(
f'The token_encoding_type must be one of [\'int16\', \'int32\', \'int64\'], but got {token_encoding_type}',
)
self.token_encoding_type = token_encoding_type

if local is not None and (remote is None or (local == remote)):
if os.path.isdir(local):
contents = set(os.listdir(local))
Expand Down Expand Up @@ -198,8 +205,9 @@ def _read_binary_tokenized_sample(
sample: Dict[str, Any],
) -> torch.Tensor:
return torch.from_numpy(
np.frombuffer(sample['tokens'],
dtype=np.int64)[:self.max_seq_len].copy(),
np.frombuffer(
sample['tokens'], dtype=getattr(np, self.token_encoding_type)
)[:self.max_seq_len].copy(),
)

# How to process a sample
Expand Down

0 comments on commit 71b5bf4

Please sign in to comment.