diff --git a/scripts/data_prep/convert_text_to_mds.py b/scripts/data_prep/convert_text_to_mds.py index dc7c514d75..2218e575b2 100644 --- a/scripts/data_prep/convert_text_to_mds.py +++ b/scripts/data_prep/convert_text_to_mds.py @@ -47,18 +47,21 @@ def parse_args() -> Namespace: '--compression', type=str, default='zstd', + required=False, help='The compression algorithm to use for MDS writing', ) parser.add_argument( '--concat_tokens', type=int, + required=True, help='Convert text to tokens and concatenate up to this many tokens', ) parser.add_argument( '--tokenizer', type=str, + required=True, help='The name of the tokenizer to use', ) parser.add_argument( @@ -77,6 +80,13 @@ def parse_args() -> Namespace: help= 'The text to append to each example to separate concatenated examples', ) + parser.add_argument( + '--use_tokenizer_eos', + required=False, + action='store_true', + default=False, + help='Use the EOS text from the tokenizer.', + ) parser.add_argument( '--no_wrap', default=False, @@ -103,11 +113,15 @@ def parse_args() -> Namespace: parsed = parser.parse_args() - # Make sure we have needed concat options - if (parsed.concat_tokens is not None and - isinstance(parsed.concat_tokens, int) and parsed.tokenizer is None): - parser.error( - 'When setting --concat_tokens, you must specify a --tokenizer') + # Set eos token. + if parsed.use_tokenizer_eos: + # Ensure that eos text is not specified twice. + if parsed.eos_text is not None: + parser.error( + 'Cannot set --eos_text with --use_tokenizer_eos. Please specify one.' + ) + tokenizer = AutoTokenizer.from_pretrained(parsed.tokenizer) + parsed.eos_text = tokenizer.eos_token # now that we have validated them, change BOS/EOS to strings if parsed.bos_text is None: