Skip to content

Commit

Permalink
Check the user provided eos / bos token id against the tokenizer eos …
Browse files Browse the repository at this point in the history
…/ bos token id (mosaicml#1039)

* lint

* lint

* added warning and error message instead of setting the eos and bos token ids

* Update text_data.py

Adding info about the override flags in the error message.

* Update llmfoundry/data/text_data.py

Co-authored-by: Vitaliy Chiley <[email protected]>

* Update llmfoundry/data/text_data.py

Co-authored-by: Vitaliy Chiley <[email protected]>

* adding warning if user does not provide eos or bos token id

* adding warning if user does not provide eos or bos token id

---------

Co-authored-by: Vitaliy Chiley <[email protected]>
  • Loading branch information
ShashankMosaicML and vchiley authored Apr 1, 2024
1 parent 349c2ff commit d8ea2c5
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Build a StreamingTextDataset dataset and dataloader for training."""

import logging
import os
from itertools import islice
from typing import (Any, Callable, Dict, List, Mapping, Optional, Sequence,
Expand All @@ -19,6 +20,8 @@
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerBase

log = logging.getLogger(__name__)


class StreamingTextDataset(StreamingDataset):
"""Generic text dataset using MosaicML's StreamingDataset.
Expand Down Expand Up @@ -257,6 +260,34 @@ def build_text_dataloader(
eos_token_id = cfg.dataset.pop('eos_token_id', None)
bos_token_id = cfg.dataset.pop('bos_token_id', None)

if eos_token_id is None and bos_token_id is None and (hasattr(
tokenizer, 'eos_token_id') or hasattr(tokenizer, 'bos_token_id')):
log.warning(
'The user has not provided an eos_token_id or bos_token_id, but the tokenizer has an eos_token_id or a bos_token_id.'
)

tokenizer_eos_token_id = getattr(tokenizer, 'eos_token_id', None)
if eos_token_id is not None and eos_token_id != tokenizer_eos_token_id:
eos_mismatch_str = f'Provided {eos_token_id=} does not match the eos_token_id of the tokenizer={tokenizer_eos_token_id}.'
if cfg.dataset.pop('override_eos_token_id_mismatch_error', False):
log.warning(eos_mismatch_str)
else:
raise ValueError(
eos_mismatch_str +
' To override this error, set the override_eos_token_id_mismatch_error flag to True in the dataset config section of the YAML.'
)

tokenizer_bos_token_id = getattr(tokenizer, 'bos_token_id', None)
if bos_token_id is not None and bos_token_id != tokenizer_bos_token_id:
bos_mismatch_str = f'Provided {bos_token_id=} does not match the bos_token_id of the tokenizer={tokenizer_bos_token_id}.'
if cfg.dataset.pop('override_bos_token_id_mismatch_error', False):
log.warning(bos_mismatch_str)
else:
raise ValueError(
bos_mismatch_str +
' To override this error, set the override_bos_token_id_mismatch_error flag to True in the dataset config section of the YAML.'
)

streams = build_streams(cfg.dataset)

# build dataset potentially with streams
Expand Down

0 comments on commit d8ea2c5

Please sign in to comment.