Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add eos_text and bos_text defaults for convert_text_to_mds.py #826

Closed
wants to merge 8 commits into from
Closed
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions scripts/data_prep/convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from argparse import ArgumentParser, Namespace
from concurrent.futures import ProcessPoolExecutor
from glob import glob
from typing import Iterable, List, Tuple, cast
from typing import Iterable, List, Optional, Tuple, cast

import psutil
from composer.utils import (ObjectStore, maybe_create_object_store_from_uri,
Expand Down Expand Up @@ -109,11 +109,6 @@ def parse_args() -> Namespace:
parser.error(
'When setting --concat_tokens, you must specify a --tokenizer')

# now that we have validated them, change BOS/EOS to strings
if parsed.bos_text is None:
parsed.bos_text = ''
if parsed.eos_text is None:
parsed.eos_text = ''
return parsed


Expand Down Expand Up @@ -226,7 +221,9 @@ def download_and_convert(
downloading_iter = DownloadingIterable(object_names=file_names,
output_folder=tmp_dir,
object_store=object_store)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name,
add_bos_token=False,
add_eos_token=False)
tokenizer.model_max_length = 5000000000 # Hack to prevent warnings from HuggingFace
Comment on lines +225 to 227
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this because if a user specifies a particular bos_text or eos_text the tokenizer should not automatically also add bos or eos token


# Use the ConcatTokensDataset from LLM-foundry to concatenate sequences of tokens up
Expand Down Expand Up @@ -328,13 +325,13 @@ def convert_text_to_mds(
output_folder: str,
input_folder: str,
concat_tokens: int,
eos_text: str,
bos_text: str,
no_wrap: bool,
compression: str,
processes: int,
args_str: str,
reprocess: bool,
bos_text: Optional[str] = None,
eos_text: Optional[str] = None,
):
"""Convert a folder of text files to MDS format.

Expand All @@ -343,14 +340,31 @@ def convert_text_to_mds(
output_folder (str): Folder to write MDS shards to
input_folder (str): Folder of text files to process
concat_tokens (int): Concantenate up to this many tokens
eos_text (str): Textend to append to each example to separate concatenated samples
bos_text (str): Text to prepend to each example to separate concatenated samples
no_wrap: (bool): Whether to let text examples wrap across multiple training examples
compression (str): The compression algorithm to use for MDS writing
processes (int): The number of processes to use.
args_str (str): String representation of the arguments
reprocess (bool): Whether to always reprocess the given folder of text files
bos_text (Optional[str]): Text to prepend to each example to separate concatenated samples
If None and tokenizer.add_bos_token is True, use the tokenizer's bos_token, otherwise use an empty string.
eos_text (Optional[str]): Text end to append to each example to separate concatenated samples
If None, use the tokenizer's eos_token.
"""
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

if bos_text is None:
if hasattr(tokenizer, 'add_bos_token') and tokenizer.add_bos_token:
tokenizer_bos = tokenizer.bos_token
bos_text = tokenizer_bos if tokenizer_bos is not None else ''
else:
bos_text = ''

if eos_text is None:
tokenizer_eos = tokenizer.eos_token
eos_text = tokenizer_eos if tokenizer_eos is not None else ''

assert bos_text is not None and eos_text is not None # for pyright

is_remote_output = is_remote_path(output_folder)

object_names = get_object_names(input_folder)
Expand Down
Loading