From 88399e0001f09738d8d2fd98062b28211fd4691f Mon Sep 17 00:00:00 2001 From: root <23239305+b-chu@users.noreply.github.com> Date: Wed, 10 Apr 2024 14:15:25 +0000 Subject: [PATCH] Add remote code option to allow execution of DBRX tokenizer --- scripts/data_prep/convert_text_to_mds.py | 25 ++++++++++++++++--- .../data_prep/test_convert_text_to_mds.py | 3 +++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/scripts/data_prep/convert_text_to_mds.py b/scripts/data_prep/convert_text_to_mds.py index df39e38a90..be986fc24d 100644 --- a/scripts/data_prep/convert_text_to_mds.py +++ b/scripts/data_prep/convert_text_to_mds.py @@ -114,6 +114,13 @@ def parse_args() -> Namespace: help='If true, reprocess the input_folder to mds format. Otherwise, ' + 'only reprocess upon changes to the input folder or dataset creation parameters.', ) + parser.add_argument( + '--trust-remote-code', + type=bool, + required=False, + default=False, + help='If true, allows custom code to be executed to load the tokenizer', + ) parsed = parser.parse_args() @@ -124,7 +131,8 @@ def parse_args() -> Namespace: parser.error( 'Cannot set --eos_text with --use_tokenizer_eos. Please specify one.' ) - tokenizer = AutoTokenizer.from_pretrained(parsed.tokenizer) + tokenizer = AutoTokenizer.from_pretrained( + parsed.tokenizer, trust_remote_code=parsed.trust_remote_code) parsed.eos_text = tokenizer.eos_token # now that we have validated them, change BOS/EOS to strings @@ -171,6 +179,7 @@ def get_task_args( bos_text: str, no_wrap: bool, compression: str, + trust_remote_code: bool, ) -> Iterable: """Get download_and_convert arguments split across n_groups. @@ -187,6 +196,7 @@ def get_task_args( bos_text (str): Text to prepend to each example to separate concatenated samples no_wrap: (bool): Whether to let text examples wrap across multiple training examples compression (str): The compression algorithm to use for MDS writing + trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer """ num_objects = len(object_names) objs_per_group = math.ceil(num_objects / n_groups) @@ -202,6 +212,7 @@ def get_task_args( bos_text, no_wrap, compression, + trust_remote_code, ) @@ -223,6 +234,7 @@ def download_and_convert( bos_text: str, no_wrap: bool, compression: str, + trust_remote_code: bool, ): """Downloads and converts text fies to MDS format. @@ -236,6 +248,7 @@ def download_and_convert( bos_text (str): Text to prepend to each example to separate concatenated samples no_wrap: (bool): Whether to let text examples wrap across multiple training examples compression (str): The compression algorithm to use for MDS writing + trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer """ object_store = maybe_create_object_store_from_uri(input_folder) @@ -244,7 +257,8 @@ def download_and_convert( downloading_iter = DownloadingIterable(object_names=file_names, output_folder=tmp_dir, object_store=object_store) - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, trust_remote_code=trust_remote_code) tokenizer.model_max_length = 5000000000 # Hack to prevent warnings from HuggingFace # Use the ConcatTokensDataset from LLM-foundry to concatenate sequences of tokens up @@ -353,6 +367,7 @@ def convert_text_to_mds( processes: int, args_str: str, reprocess: bool, + trust_remote_code: bool, ): """Convert a folder of text files to MDS format. @@ -368,6 +383,7 @@ def convert_text_to_mds( processes (int): The number of processes to use. args_str (str): String representation of the arguments reprocess (bool): Whether to always reprocess the given folder of text files + trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer """ is_remote_output = is_remote_path(output_folder) @@ -396,7 +412,7 @@ def convert_text_to_mds( # Download and convert the text files in parallel args = get_task_args(object_names, local_output_folder, input_folder, processes, tokenizer_name, concat_tokens, eos_text, - bos_text, no_wrap, compression) + bos_text, no_wrap, compression, trust_remote_code) with ProcessPoolExecutor(max_workers=processes) as executor: list(executor.map(download_and_convert_starargs, args)) @@ -405,7 +421,7 @@ def convert_text_to_mds( else: download_and_convert(object_names, local_output_folder, input_folder, tokenizer_name, concat_tokens, eos_text, bos_text, - no_wrap, compression) + no_wrap, compression, trust_remote_code) # Write a done file with the args and object names write_done_file(local_output_folder, args_str, object_names) @@ -462,6 +478,7 @@ def _args_str(original_args: Namespace) -> str: compression=args.compression, processes=args.processes, reprocess=args.reprocess, + trust_remote_code=args.trust_remote_code, args_str=_args_str(args)) except Exception as e: if mosaicml_logger is not None: diff --git a/tests/a_scripts/data_prep/test_convert_text_to_mds.py b/tests/a_scripts/data_prep/test_convert_text_to_mds.py index e458cb1dfc..bd96de695c 100644 --- a/tests/a_scripts/data_prep/test_convert_text_to_mds.py +++ b/tests/a_scripts/data_prep/test_convert_text_to_mds.py @@ -106,6 +106,7 @@ def call_convert_text_to_mds() -> None: processes=processes, args_str='Namespace()', reprocess=False, + trust_remote_code=False, ) call_convert_text_to_mds() @@ -195,6 +196,7 @@ def call_convert_text_to_mds(reprocess: bool): processes=1, args_str='Namespace()', reprocess=reprocess, + trust_remote_code=False, ) # Create input text data @@ -234,6 +236,7 @@ def test_input_folder_not_exist(tmp_path: pathlib.Path): processes=1, args_str='Namespace()', reprocess=False, + trust_remote_code=False, )