From d8ea2c569acee49176401017503b820fec3a37ca Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Mon, 1 Apr 2024 11:47:53 -0700 Subject: [PATCH] Check the user provided eos / bos token id against the tokenizer eos / bos token id (#1039) * lint * lint * added warning and error message instead of setting the eos and bos token ids * Update text_data.py Adding info about the override flags in the error message. * Update llmfoundry/data/text_data.py Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> * Update llmfoundry/data/text_data.py Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> * adding warning if user does not provide eos or bos token id * adding warning if user does not provide eos or bos token id --------- Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> --- llmfoundry/data/text_data.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index 8d7ff5849d..e85968543c 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -3,6 +3,7 @@ """Build a StreamingTextDataset dataset and dataloader for training.""" +import logging import os from itertools import islice from typing import (Any, Callable, Dict, List, Mapping, Optional, Sequence, @@ -19,6 +20,8 @@ from torch.utils.data import DataLoader from transformers import PreTrainedTokenizerBase +log = logging.getLogger(__name__) + class StreamingTextDataset(StreamingDataset): """Generic text dataset using MosaicML's StreamingDataset. @@ -257,6 +260,34 @@ def build_text_dataloader( eos_token_id = cfg.dataset.pop('eos_token_id', None) bos_token_id = cfg.dataset.pop('bos_token_id', None) + if eos_token_id is None and bos_token_id is None and (hasattr( + tokenizer, 'eos_token_id') or hasattr(tokenizer, 'bos_token_id')): + log.warning( + 'The user has not provided an eos_token_id or bos_token_id, but the tokenizer has an eos_token_id or a bos_token_id.' + ) + + tokenizer_eos_token_id = getattr(tokenizer, 'eos_token_id', None) + if eos_token_id is not None and eos_token_id != tokenizer_eos_token_id: + eos_mismatch_str = f'Provided {eos_token_id=} does not match the eos_token_id of the tokenizer={tokenizer_eos_token_id}.' + if cfg.dataset.pop('override_eos_token_id_mismatch_error', False): + log.warning(eos_mismatch_str) + else: + raise ValueError( + eos_mismatch_str + + ' To override this error, set the override_eos_token_id_mismatch_error flag to True in the dataset config section of the YAML.' + ) + + tokenizer_bos_token_id = getattr(tokenizer, 'bos_token_id', None) + if bos_token_id is not None and bos_token_id != tokenizer_bos_token_id: + bos_mismatch_str = f'Provided {bos_token_id=} does not match the bos_token_id of the tokenizer={tokenizer_bos_token_id}.' + if cfg.dataset.pop('override_bos_token_id_mismatch_error', False): + log.warning(bos_mismatch_str) + else: + raise ValueError( + bos_mismatch_str + + ' To override this error, set the override_bos_token_id_mismatch_error flag to True in the dataset config section of the YAML.' + ) + streams = build_streams(cfg.dataset) # build dataset potentially with streams