From d7f06f5fc898eb700a9e89f08793b2735d97889c Mon Sep 17 00:00:00 2001 From: Abhishek Maurya <124327945+Abhishek-TAMU@users.noreply.github.com> Date: Thu, 19 Dec 2024 15:59:21 -0500 Subject: [PATCH] feat: Handle passing of multiple files, multiple folders, path with patterns, HF Dataset and combination (#424) Signed-off-by: Abhishek --- .../tokenize_and_apply_input_masking.yaml | 4 +- tests/artifacts/testdata/__init__.py | 2 + .../twitter_complaints_input_output_1.json | 152 +++++++ .../twitter_complaints_input_output_2.json | 152 +++++++ tests/data/test_data_preprocessing_utils.py | 419 ++++++++++++++---- tests/test_sft_trainer.py | 27 +- tests/utils/test_config_utils.py | 30 +- tuning/data/data_config.py | 15 +- tuning/data/data_processors.py | 148 +++++-- tuning/data/setup_dataprocessor.py | 6 +- tuning/utils/utils.py | 42 +- 11 files changed, 859 insertions(+), 138 deletions(-) create mode 100644 tests/artifacts/testdata/datafolder/twitter_complaints_input_output_1.json create mode 100644 tests/artifacts/testdata/datafolder/twitter_complaints_input_output_2.json diff --git a/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml b/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml index b66b01d55..ac7e07030 100644 --- a/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml +++ b/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml @@ -10,5 +10,5 @@ datasets: remove_columns: all batched: false fn_kwargs: - input_field_name: "INPUT" - output_field_name: "OUTPUT" \ No newline at end of file + input_field_name: input + output_field_name: output \ No newline at end of file diff --git a/tests/artifacts/testdata/__init__.py b/tests/artifacts/testdata/__init__.py index b6d8c0fff..c76f065b7 100644 --- a/tests/artifacts/testdata/__init__.py +++ b/tests/artifacts/testdata/__init__.py @@ -24,6 +24,8 @@ ARROW_DATA_DIR = os.path.join(os.path.dirname(__file__), "arrow") PARQUET_DATA_DIR = os.path.join(os.path.dirname(__file__), "parquet") +TWITTER_COMPLAINTS_DATA_DIR_JSON = os.path.join(DATA_DIR, "datafolder") + TWITTER_COMPLAINTS_DATA_JSON = os.path.join( JSON_DATA_DIR, "twitter_complaints_small.json" ) diff --git a/tests/artifacts/testdata/datafolder/twitter_complaints_input_output_1.json b/tests/artifacts/testdata/datafolder/twitter_complaints_input_output_1.json new file mode 100644 index 000000000..2668241f8 --- /dev/null +++ b/tests/artifacts/testdata/datafolder/twitter_complaints_input_output_1.json @@ -0,0 +1,152 @@ +[ + { + "ID": 0, + "Label": 2, + "input": "@HMRCcustomers No this is my first job", + "output": "no complaint" + }, + { + "ID": 1, + "Label": 2, + "input": "@KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.", + "output": "no complaint" + }, + { + "ID": 2, + "Label": 1, + "input": "If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService", + "output": "complaint" + }, + { + "ID": 3, + "Label": 1, + "input": "@EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.", + "output": "complaint" + }, + { + "ID": 4, + "Label": 2, + "input": "Couples wallpaper, so cute. :) #BrothersAtHome", + "output": "no complaint" + }, + { + "ID": 5, + "Label": 2, + "input": "@mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https://t.co/WRtNsokblG", + "output": "no complaint" + }, + { + "ID": 6, + "Label": 2, + "input": "@Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?", + "output": "no complaint" + }, + { + "ID": 7, + "Label": 1, + "input": "@nationalgridus I have no water and the bill is current and paid. Can you do something about this?", + "output": "complaint" + }, + { + "ID": 8, + "Label": 1, + "input": "Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude/condescending. I'll take my $$ to @Sephora", + "output": "complaint" + }, + { + "ID": 9, + "Label": 2, + "input": "@JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd", + "output": "no complaint" + }, + { + "ID": 10, + "Label": 2, + "input": "@NortonSupport Thanks much.", + "output": "no complaint" + }, + { + "ID": 11, + "Label": 2, + "input": "@VerizonSupport all of a sudden I can't connect to my primary wireless network but guest one works", + "output": "no complaint" + }, + { + "ID": 12, + "Label": 2, + "input": "Aaaahhhhh!!!! My @Razer @PlayOverwatch d.va meka headset came in!!! I didn't even know it had shipped!!! So excited\u2026 https://t.co/4gXy9xED8d", + "output": "no complaint" + }, + { + "ID": 13, + "Label": 2, + "input": "@Lin_Manuel @jmessinaphoto @VAMNit Omg a little squish!!!!! Enjoy and congrats!!!! I miss mine being so young! \ufffd\ufffd\ufffd\ufffd\ufffd\ufffd", + "output": "no complaint" + }, + { + "ID": 14, + "Label": 2, + "input": "@IanJamesPoulter What's your secret to poaching eggs? Mine NEVER look that good.", + "output": "no complaint" + }, + { + "ID": 15, + "Label": 2, + "input": "@AWSSupport When will be able Kinesis Firehose compatible with Elasticsearch 6.0? Thank you!", + "output": "no complaint" + }, + { + "ID": 16, + "Label": 2, + "input": "@NCIS_CBS https://t.co/eeVL9Eu3bE", + "output": "no complaint" + }, + { + "ID": 17, + "Label": 2, + "input": "@msetchell Via the settings? That\u2019s how I do it on master T\u2019s", + "output": "no complaint" + }, + { + "ID": 18, + "Label": 2, + "input": "Today at work there was a low flying duck heading toward a crowd of people, and I yelled \"watch out! and I'm very disappointed with myself.", + "output": "no complaint" + }, + { + "ID": 19, + "Label": 1, + "input": "@NortonSupport @NortonOnline What the hell is a dm 5-10 days to get money back bank account now overdrawn thanks guys", + "output": "complaint" + }, + { + "ID": 20, + "Label": 1, + "input": "@united not happy with this delay from Newark to Manchester tonight :( only 30 mins free Wi-fi sucks ...", + "output": "complaint" + }, + { + "ID": 21, + "Label": 1, + "input": "@ZARA_Care I've been waiting on a reply to my tweets and DMs for days now?", + "output": "complaint" + }, + { + "ID": 22, + "Label": 2, + "input": "New Listing! Large 2 Family Home for Sale in #Passaic Park, #NJ #realestate #homesforsale Great Location!\u2026 https://t.co/IV4OrLXkMk", + "output": "no complaint" + }, + { + "ID": 23, + "Label": 1, + "input": "@SouthwestAir I love you but when sending me flight changes please don't use military time #ignoranceisbliss", + "output": "complaint" + }, + { + "ID": 24, + "Label": 2, + "input": "@JetBlue Completely understand but would prefer being on time to filling out forms....", + "output": "no complaint" + } +] \ No newline at end of file diff --git a/tests/artifacts/testdata/datafolder/twitter_complaints_input_output_2.json b/tests/artifacts/testdata/datafolder/twitter_complaints_input_output_2.json new file mode 100644 index 000000000..e93fed8e4 --- /dev/null +++ b/tests/artifacts/testdata/datafolder/twitter_complaints_input_output_2.json @@ -0,0 +1,152 @@ +[ + { + "ID": 25, + "Label": 2, + "input": "@nvidiacc I own two gtx 460 in sli. I want to try windows 8 dev preview. Which driver should I use. Can I use the windows 7 one.", + "output": "no complaint" + }, + { + "ID": 26, + "Label": 2, + "input": "Just posted a photo https://t.co/RShFwCjPHu", + "output": "no complaint" + }, + { + "ID": 27, + "Label": 2, + "input": "Love crescent rolls? Try adding pesto @PerdueChicken to them and you\u2019re going to love it! #Promotion #PerdueCrew -\u2026 https://t.co/KBHOfqCukH", + "output": "no complaint" + }, + { + "ID": 28, + "Label": 1, + "input": "@TopmanAskUs please just give me my money back.", + "output": "complaint" + }, + { + "ID": 29, + "Label": 2, + "input": "I just gave 5 stars to Tracee at @neimanmarcus for the great service I received!", + "output": "no complaint" + }, + { + "ID": 30, + "Label": 2, + "input": "@FitbitSupport when are you launching new clock faces for Indian market", + "output": "no complaint" + }, + { + "ID": 31, + "Label": 1, + "input": "@HPSupport my printer will not allow me to choose color instead it only prints monochrome #hppsdr #ijkhelp", + "output": "complaint" + }, + { + "ID": 32, + "Label": 1, + "input": "@DIRECTV can I get a monthly charge double refund when it sprinkles outside and we lose reception? #IamEmbarrasedForYou", + "output": "complaint" + }, + { + "ID": 33, + "Label": 1, + "input": "@AlfaRomeoCares Hi thanks for replying, could be my internet but link doesn't seem to be working", + "output": "complaint" + }, + { + "ID": 34, + "Label": 2, + "input": "Looks tasty! Going to share with everyone I know #FebrezeONE #sponsored https://t.co/4AQI53npei", + "output": "no complaint" + }, + { + "ID": 35, + "Label": 2, + "input": "@OnePlus_IN can OnePlus 5T do front camera portrait?", + "output": "no complaint" + }, + { + "ID": 36, + "Label": 1, + "input": "@sho_help @showtime your arrive is terrible streaming is stop and start every couple mins. Get it together it's xmas", + "output": "complaint" + }, + { + "ID": 37, + "Label": 2, + "input": "@KandraKPTV I just witnessed a huge building fire in Santa Monica California", + "output": "no complaint" + }, + { + "ID": 38, + "Label": 2, + "input": "@fernrocks most definitely the latter for me", + "output": "no complaint" + }, + { + "ID": 39, + "Label": 1, + "input": "@greateranglia Could I ask why the Area in front of BIC Station was not gritted withh all the snow.", + "output": "complaint" + }, + { + "ID": 40, + "Label": 2, + "input": "I'm earning points with #CricketRewards https://t.co/GfpGhqqnhE", + "output": "no complaint" + }, + { + "ID": 41, + "Label": 2, + "input": "@Schrapnel @comcast RIP me", + "output": "no complaint" + }, + { + "ID": 42, + "Label": 2, + "input": "The wait is finally over, just joined @SquareUK, hope to get started real soon!", + "output": "no complaint" + }, + { + "ID": 43, + "Label": 2, + "input": "@WholeFoods what's the best way to give feedback on a particular store to the regional/national office?", + "output": "no complaint" + }, + { + "ID": 44, + "Label": 2, + "input": "@DanielNewman I honestly would believe anything. People are...too much sometimes.", + "output": "no complaint" + }, + { + "ID": 45, + "Label": 2, + "input": "@asblough Yep! It should send you a notification with your driver\u2019s name and what time they\u2019ll be showing up!", + "output": "no complaint" + }, + { + "ID": 46, + "Label": 2, + "input": "@Wavy2Timez for real", + "output": "no complaint" + }, + { + "ID": 47, + "Label": 1, + "input": "@KenyaPower_Care no power in south b area... is it scheduled.", + "output": "complaint" + }, + { + "ID": 48, + "Label": 1, + "input": "Honda won't do anything about water leaking in brand new car. Frustrated! @HondaCustSvc @AmericanHonda", + "output": "complaint" + }, + { + "ID": 49, + "Label": 1, + "input": "@CBSNews @Dodge @ChryslerCares My driver side air bag has been recalled and replaced, but what about the passenger side?", + "output": "complaint" + } +] \ No newline at end of file diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index fed73f0e3..578daffbf 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -13,7 +13,9 @@ # limitations under the License. # Standard +import glob import json +import os import tempfile # Third Party @@ -21,6 +23,7 @@ from transformers import AutoTokenizer, DataCollatorForSeq2Seq from trl import DataCollatorForCompletionOnlyLM import datasets +import pyarrow import pytest import yaml @@ -34,6 +37,7 @@ from tests.artifacts.testdata import ( MODEL_NAME, TWITTER_COMPLAINTS_DATA_ARROW, + TWITTER_COMPLAINTS_DATA_DIR_JSON, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, @@ -141,23 +145,41 @@ def test_load_dataset_with_datafile(datafile, column_names): assert set(load_dataset.column_names) == column_names +@pytest.mark.parametrize("hf_dataset, splitName", [("squad", "validation")]) +def test_load_dataset_with_hf_dataset(hf_dataset, splitName): + """Ensure that hf dataset could be loaded.""" + datasetconfig = DataSetConfig( + name="text_dataset_input_output_masking", data_paths=[hf_dataset] + ) + processor = get_datapreprocessor( + processor_config=DataPreProcessorConfig(), tokenizer=None + ) + load_dataset = processor.load_dataset( + datasetconfig=datasetconfig, splitName=splitName, datafile=None + ) + assert isinstance(load_dataset, Dataset) + + @pytest.mark.parametrize( - "datafile, column_names, datasetconfigname", + "datafile, column_names, datasetconfigname, builder", [ ( TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, set(["ID", "Label", "input", "output"]), "text_dataset_input_output_masking", + None, ), ( TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, set(["ID", "Label", "input", "output", "sequence"]), "text_dataset_input_output_masking", + None, ), ( TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, set(["ID", "Label", "input", "output"]), "text_dataset_input_output_masking", + None, ), ( TWITTER_COMPLAINTS_TOKENIZED_JSONL, @@ -173,6 +195,7 @@ def test_load_dataset_with_datafile(datafile, column_names): ] ), "pretokenized_dataset", + None, ), ( TWITTER_COMPLAINTS_TOKENIZED_PARQUET, @@ -188,27 +211,41 @@ def test_load_dataset_with_datafile(datafile, column_names): ] ), "pretokenized_dataset", + None, ), ( TWITTER_COMPLAINTS_DATA_JSONL, set(["Tweet text", "ID", "Label", "text_label", "output"]), "apply_custom_data_template", + None, ), ( TWITTER_COMPLAINTS_DATA_ARROW, set(["Tweet text", "ID", "Label", "text_label", "output"]), "apply_custom_data_template", + None, ), ( TWITTER_COMPLAINTS_DATA_PARQUET, set(["Tweet text", "ID", "Label", "text_label", "output"]), "apply_custom_data_template", + None, + ), + ( + TWITTER_COMPLAINTS_DATA_PARQUET, + set(["Tweet text", "ID", "Label", "text_label", "output"]), + "apply_custom_data_template", + "parquet", ), ], ) -def test_load_dataset_with_datasetconfig(datafile, column_names, datasetconfigname): +def test_load_dataset_with_datasetconfig( + datafile, column_names, datasetconfigname, builder +): """Ensure that both dataset is loaded with datafile.""" - datasetconfig = DataSetConfig(name=datasetconfigname, data_paths=[datafile]) + datasetconfig = DataSetConfig( + name=datasetconfigname, data_paths=[datafile], builder=builder + ) processor = get_datapreprocessor( processor_config=DataPreProcessorConfig(), tokenizer=None ) @@ -218,6 +255,57 @@ def test_load_dataset_with_datasetconfig(datafile, column_names, datasetconfigna assert set(load_dataset.column_names) == column_names +@pytest.mark.parametrize( + "data_paths, datasetconfigname", + [ + ( + ["fake/path"], + "apply_custom_data_template", + ), + ( + [ + TWITTER_COMPLAINTS_DATA_PARQUET.replace( + "twitter_complaints_small.parquet", "not_exist.parquet" + ) + ], + "apply_custom_data_template", + ), + ], +) +def test_load_dataset_with_non_exist_path(data_paths, datasetconfigname): + """Ensure that load_dataset raises error for non-exist paths.""" + datasetconfig = DataSetConfig(name=datasetconfigname, data_paths=data_paths) + processor = get_datapreprocessor( + processor_config=DataPreProcessorConfig(), tokenizer=None + ) + with pytest.raises((datasets.exceptions.DatasetNotFoundError, ValueError)): + processor.load_dataset( + datasetconfig=datasetconfig, splitName="train", datafile=None + ) + + +@pytest.mark.parametrize( + "datafile, datasetconfigname, builder", + [ + (TWITTER_COMPLAINTS_DATA_PARQUET, "apply_custom_data_template", "arrow"), + ], +) +def test_load_dataset_with_datasetconfig_incorrect_builder( + datafile, datasetconfigname, builder +): + """Ensure that directory with incorrect builder cannot be passed in datasetconfig.""" + datasetconfig = DataSetConfig( + name=datasetconfigname, data_paths=[datafile], builder=builder + ) + processor = get_datapreprocessor( + processor_config=DataPreProcessorConfig(), tokenizer=None + ) + with pytest.raises(pyarrow.lib.ArrowInvalid): + processor.load_dataset( + datasetconfig=datasetconfig, splitName="train", datafile=None + ) + + @pytest.mark.parametrize( "datafile, datasetconfigname", [ @@ -247,6 +335,58 @@ def test_load_dataset_with_dataconfig_and_datafile(datafile, datasetconfigname): ) +@pytest.mark.parametrize( + "datasetconfig, column_names", + [ + ( + DataSetConfig( + name="text_dataset_input_output_masking", + data_paths=[TWITTER_COMPLAINTS_DATA_DIR_JSON], + ), + set(["ID", "Label", "input", "output"]), + ), + ( + DataSetConfig( + name="text_dataset_input_output_masking", + data_paths=[TWITTER_COMPLAINTS_DATA_DIR_JSON], + builder="json", + ), + set(["ID", "Label", "input", "output"]), + ), + ], +) +def test_load_dataset_with_dataconfig_and_datafolder(datasetconfig, column_names): + """Ensure that directory can be passed in datasetconfig with/without builder.""" + processor = get_datapreprocessor( + processor_config=DataPreProcessorConfig(), tokenizer=None + ) + load_dataset = processor.load_dataset( + datasetconfig=datasetconfig, splitName="train", datafile=None + ) + assert set(load_dataset.column_names) == column_names + + +@pytest.mark.parametrize( + "datasetconfig", + [ + DataSetConfig( + name="text_dataset_input_output_masking", + data_paths=[TWITTER_COMPLAINTS_DATA_DIR_JSON], + builder="arrow", + ), + ], +) +def test_load_dataset_with_dataconfig_and_datafolder_incorrect_builder(datasetconfig): + """Ensure that directory with incorrect builder cannot be passed in datasetconfig.""" + processor = get_datapreprocessor( + processor_config=DataPreProcessorConfig(), tokenizer=None + ) + with pytest.raises(pyarrow.lib.ArrowInvalid): + processor.load_dataset( + datasetconfig=datasetconfig, splitName="train", datafile=None + ) + + def test_load_dataset_without_dataconfig_and_datafile(): """Ensure that both datasetconfig and datafile cannot be None.""" processor = get_datapreprocessor( @@ -256,6 +396,74 @@ def test_load_dataset_without_dataconfig_and_datafile(): processor.load_dataset(datasetconfig=None, splitName="train", datafile=None) +@pytest.mark.parametrize( + "data_paths, column_names, datasetconfigname, builder", + [ + ( + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + TWITTER_COMPLAINTS_DATA_DIR_JSON, + ], + set(["ID", "Label", "input", "output"]), + "text_dataset_input_output_masking", + None, + ), + ( + [ + TWITTER_COMPLAINTS_DATA_DIR_JSON, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + ], + set(["ID", "Label", "input", "output"]), + "text_dataset_input_output_masking", + None, + ), + ], +) +def test_load_dataset_with_datasetconfig_files_folders( + data_paths, column_names, datasetconfigname, builder +): + """Ensure that load_dataset works with passing combination of files and folders.""" + datasetconfig = DataSetConfig( + name=datasetconfigname, data_paths=data_paths, builder=builder + ) + processor = get_datapreprocessor( + processor_config=DataPreProcessorConfig(), tokenizer=None + ) + load_dataset = processor.load_dataset( + datasetconfig=datasetconfig, splitName="train", datafile=None + ) + assert set(load_dataset.column_names) == column_names + + +@pytest.mark.parametrize( + "data_paths, datasetconfigname, builder", + [ + ( + [ + TWITTER_COMPLAINTS_DATA_DIR_JSON, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + ], + "text_dataset_input_output_masking", + "arrow", + ), + ], +) +def test_load_dataset_with_datasetconfig_files_folders_incorrect_builder( + data_paths, datasetconfigname, builder +): + """Ensure that load_dataset with passing combination of files and folders does support mismatch in format""" + datasetconfig = DataSetConfig( + name=datasetconfigname, data_paths=data_paths, builder=builder + ) + processor = get_datapreprocessor( + processor_config=DataPreProcessorConfig(), tokenizer=None + ) + with pytest.raises(ValueError): + processor.load_dataset( + datasetconfig=datasetconfig, splitName="train", datafile=None + ) + + @pytest.mark.parametrize( "data, result", [ @@ -553,6 +761,10 @@ def test_process_dataconfig_file(data_config_path, data_path): DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, [TWITTER_COMPLAINTS_DATA_ARROW, TWITTER_COMPLAINTS_DATA_ARROW], ), + ( + DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, + [TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_PARQUET], + ), ( DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, [TWITTER_COMPLAINTS_TOKENIZED_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSON], @@ -602,6 +814,13 @@ def test_process_dataconfig_file(data_config_path, data_path): TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, ], ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + ], + ), ], ) def test_process_dataconfig_multiple_files(data_config_path, data_path_list): @@ -646,6 +865,122 @@ def test_process_dataconfig_multiple_files(data_config_path, data_path_list): assert formatted_dataset_field in set(train_set.column_names) +@pytest.mark.parametrize( + "data_config_path, data_paths, builder", + [ + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + [os.path.join(TWITTER_COMPLAINTS_DATA_DIR_JSON, "*.json")], + None, + ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + [os.path.join(TWITTER_COMPLAINTS_DATA_DIR_JSON, "*.json")], + "json", + ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + [os.path.join(TWITTER_COMPLAINTS_DATA_DIR_JSON, "*")], + "json", + ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + [os.path.join(TWITTER_COMPLAINTS_DATA_DIR_JSON, "*complaints*")], + "json", + ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + [TWITTER_COMPLAINTS_DATA_DIR_JSON], + None, + ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + [TWITTER_COMPLAINTS_DATA_DIR_JSON], + "json", + ), + ], +) +def test_process_dataconfig_multiple_files_folders_with_globbing( + data_config_path, data_paths, builder +): + """Ensure that datasets files matching globbing pattern are formatted and validated correctly based on the arguments passed in config file.""" + with open(data_config_path, "r") as f: + yaml_content = yaml.safe_load(f) + + yaml_content["datasets"][0]["data_paths"] = data_paths + yaml_content["datasets"][0]["builder"] = builder + + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".yaml" + ) as temp_yaml_file: + yaml.dump(yaml_content, temp_yaml_file) + temp_yaml_file_path = temp_yaml_file.name + data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + (train_set, _, _) = _process_dataconfig_file(data_args, tokenizer) + assert isinstance(train_set, Dataset) + assert set(["input_ids", "attention_mask", "labels"]).issubset( + set(train_set.column_names) + ) + + path_or_pattern = data_paths[0] + if os.path.isdir(path_or_pattern): + # Construct a pattern for JSON files in this directory + pattern = os.path.join(path_or_pattern, "*.json") + else: + # Assume path_or_pattern is already a pattern + pattern = path_or_pattern + + data_len = sum(len(json.load(open(file, "r"))) for file in glob.glob(pattern)) + assert len(train_set) == data_len + + +@pytest.mark.parametrize( + "data_paths, datasetconfigname, builder", + [ + ( + [os.path.join(TWITTER_COMPLAINTS_DATA_DIR_JSON, "*")], + "tokenize_and_apply_input_masking", + None, + ), + ( + [os.path.join(TWITTER_COMPLAINTS_DATA_DIR_JSON, "*complaints*")], + "tokenize_and_apply_input_masking", + None, + ), + (["*squad"], "tokenize_and_apply_input_masking", None), + ( + [TWITTER_COMPLAINTS_DATA_DIR_JSON.replace("datafolder", "dataf*")], + "tokenize_and_apply_input_masking", + None, + ), + ( + [TWITTER_COMPLAINTS_DATA_DIR_JSON], + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + "parquet", + ), + ], +) +def test_process_dataconfig_multiple_files_folders_without_builder( + data_paths, datasetconfigname, builder +): + """Ensure that datasets folders / files without ext and builder + OR HF datasets passed via globbing pattern raises error.""" + datasetconfig = DataSetConfig( + name=datasetconfigname, data_paths=data_paths, builder=builder + ) + processor = get_datapreprocessor( + processor_config=DataPreProcessorConfig(), tokenizer=None + ) + with pytest.raises( + (datasets.exceptions.DatasetNotFoundError, ValueError, pyarrow.lib.ArrowInvalid) + ): + processor.load_dataset( + datasetconfig=datasetconfig, splitName="train", datafile=None + ) + + @pytest.mark.parametrize( "datafiles, datasetconfigname", [ @@ -708,84 +1043,6 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling( ) -@pytest.mark.parametrize( - "data_config_path, data_path_list", - [ - ( - DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, - [TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_PARQUET], - ), - ( - DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, - [TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSON], - ), - ( - DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, - [ - TWITTER_COMPLAINTS_TOKENIZED_JSONL, - TWITTER_COMPLAINTS_TOKENIZED_ARROW, - TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, - ], - ), - ( - DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, - [ - TWITTER_COMPLAINTS_TOKENIZED_JSON, - TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, - ], - ), - ( - DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, - [ - TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, - TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, - ], - ), - ( - DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, - [TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, TWITTER_COMPLAINTS_DATA_JSON], - ), - ], -) -def test_process_dataconfig_multiple_files_varied_data_formats( - data_config_path, data_path_list -): - """Ensure that datasets with multiple files with different formats raise assertion error when passed in config file.""" - with open(data_config_path, "r") as f: - yaml_content = yaml.safe_load(f) - yaml_content["datasets"][0]["data_paths"] = data_path_list - datasets_name = yaml_content["datasets"][0]["name"] - - # Modify input_field_name and output_field_name according to dataset - if datasets_name == "text_dataset_input_output_masking": - yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { - "input_field_name": "input", - "output_field_name": "output", - } - - # Modify dataset_text_field and template according to dataset - formatted_dataset_field = "formatted_data_field" - if datasets_name == "apply_custom_data_template": - template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" - yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { - "dataset_text_field": formatted_dataset_field, - "template": template, - } - - with tempfile.NamedTemporaryFile( - "w", delete=False, suffix=".yaml" - ) as temp_yaml_file: - yaml.dump(yaml_content, temp_yaml_file) - temp_yaml_file_path = temp_yaml_file.name - data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) - - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - with pytest.raises( - (AssertionError, datasets.exceptions.DatasetGenerationCastError) - ): - (_, _, _) = _process_dataconfig_file(data_args, tokenizer) - - @pytest.mark.parametrize( "data_args", [ diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 1cee6a3c9..529f21b66 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -37,6 +37,7 @@ from scripts.run_inference import TunedCausalLM from tests.artifacts.predefined_data_configs import ( DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, ) from tests.artifacts.testdata import ( CHAT_DATA_MULTI_TURN, @@ -46,6 +47,7 @@ MALFORMATTED_DATA, MODEL_NAME, TWITTER_COMPLAINTS_DATA_ARROW, + TWITTER_COMPLAINTS_DATA_DIR_JSON, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, @@ -790,6 +792,10 @@ def test_run_causallm_ft_pretokenized(dataset_path): @pytest.mark.parametrize( "datafiles, datasetconfigname", [ + ( + [TWITTER_COMPLAINTS_DATA_DIR_JSON], + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + ), ( [ TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, @@ -797,6 +803,13 @@ def test_run_causallm_ft_pretokenized(dataset_path): ], DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, ), + ( + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + TWITTER_COMPLAINTS_DATA_DIR_JSON, + ], + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + ), ( [ TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, @@ -1123,7 +1136,7 @@ def test_malformatted_data(): data_args = copy.deepcopy(DATA_ARGS) data_args.training_data_path = MALFORMATTED_DATA - with pytest.raises(DatasetGenerationError): + with pytest.raises((DatasetGenerationError, ValueError)): sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS) @@ -1132,20 +1145,10 @@ def test_empty_data(): data_args = copy.deepcopy(DATA_ARGS) data_args.training_data_path = EMPTY_DATA - with pytest.raises(DatasetGenerationError): + with pytest.raises((DatasetGenerationError, ValueError)): sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS) -def test_data_path_is_a_directory(): - """Ensure that we get ValueError if we point the data path at a dir, not a file.""" - with tempfile.TemporaryDirectory() as tempdir: - data_args = copy.deepcopy(DATA_ARGS) - data_args.training_data_path = tempdir - - with pytest.raises(ValueError): - sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS) - - ### Tests for bad tuning module configurations def test_run_causallm_lora_with_invalid_modules(): """Check that we throw a value error if the target modules for lora don't exist.""" diff --git a/tests/utils/test_config_utils.py b/tests/utils/test_config_utils.py index 1cbbbaaa0..8e29750fb 100644 --- a/tests/utils/test_config_utils.py +++ b/tests/utils/test_config_utils.py @@ -17,10 +17,12 @@ # Standard import base64 +import logging import os import pickle # Third Party +from datasets import Dataset, Features, Value from peft import LoraConfig, PromptTuningConfig import pytest @@ -29,7 +31,7 @@ # Local from tuning.config import peft_config -from tuning.utils import config_utils +from tuning.utils import config_utils, utils def test_get_hf_peft_config_returns_None_for_tuning_config_None(): @@ -232,3 +234,29 @@ def test_get_json_config_can_load_from_envvar(): job_config = config_utils.get_json_config() assert job_config is not None assert job_config["model_name_or_path"] == "foobar" + + +def test_validate_datasets_logs_warnings_on_mismatch(caplog): + """Test that `validate_mergeable_datasets` logs warnings when + datasets have different columns or dtypes.""" + # Create a reference dataset with columns col1:int64 and col2:string + ds1 = Dataset.from_dict( + {"col1": [1, 2], "col2": ["hello", "world"]}, + features=Features({"col1": Value("int64"), "col2": Value("string")}), + ) + + # Create a second dataset with a different column set and a different dtype for col1 + ds2 = Dataset.from_dict( + {"col1": [0.1, 0.2], "col3": ["hi", "there"]}, + features=Features({"col1": Value("float64"), "col3": Value("string")}), + ) + + with caplog.at_level(logging.WARNING): + utils.validate_mergeable_datasets([ds1, ds2]) + + assert ( + "different columns" in caplog.text + ), "Expected a warning about differing columns." + assert ( + "expected int64" in caplog.text + ), "Expected a warning about mismatching column dtypes." diff --git a/tuning/data/data_config.py b/tuning/data/data_config.py index 4da83d720..0c5521baf 100644 --- a/tuning/data/data_config.py +++ b/tuning/data/data_config.py @@ -21,6 +21,8 @@ # Local from tuning.utils.utils import load_yaml_or_json +logger = logging.getLogger(__name__) + @dataclass class DataHandlerConfig: @@ -32,6 +34,7 @@ class DataHandlerConfig: class DataSetConfig: name: str data_paths: List[str] + builder: Optional[str] = None # Referring to Hugging Face dataset builder sampling: Optional[float] = None data_handlers: Optional[List[DataHandlerConfig]] = None @@ -79,14 +82,18 @@ def _validate_dataset_config(dataset_config) -> DataSetConfig: c.data_paths = [] for p in data_paths: assert isinstance(p, str), f"path {p} should be of the type string" - assert os.path.exists(p), f"data_paths {p} does not exist" if not os.path.isabs(p): _p = os.path.abspath(p) - logging.warning( - " Provided path %s is not absolute changing it to %s", p, _p - ) + logger.warning(" Provided path %s is not absolute changing it to %s", p, _p) p = _p c.data_paths.append(p) + if "builder" in kwargs and kwargs["builder"] is not None: + builder = kwargs["builder"] + assert isinstance( + builder, str + ), f"builder should be a string representing a supported \ + Hugging Face dataset builder, but got: {builder}" + c.builder = builder if "sampling" in kwargs and kwargs["sampling"] is not None: ratio = kwargs["sampling"] assert isinstance(ratio, float) and ( diff --git a/tuning/data/data_processors.py b/tuning/data/data_processors.py index 33a368314..170bc2a81 100644 --- a/tuning/data/data_processors.py +++ b/tuning/data/data_processors.py @@ -27,7 +27,9 @@ # Local from tuning.data.data_config import DataConfig, DataPreProcessorConfig, DataSetConfig from tuning.data.data_handlers import AVAILABLE_DATA_HANDLERS -from tuning.utils.utils import get_extension, get_loader_for_filepath +from tuning.utils.utils import get_loader_for_filepath, validate_mergeable_datasets + +logger = logging.getLogger(__name__) class DataPreProcessor: @@ -54,11 +56,11 @@ def register_data_handler(self, name: str, func: Callable): if not isinstance(name, str) or not callable(func): raise ValueError("Handlers should be of type Dict, str to callable") if name in self.registered_handlers: - logging.warning( + logger.warning( "Handler name '%s' already exists and will be overwritten", name ) self.registered_handlers[name] = func - logging.info("Registered new handler %s", name) + logger.info("Registered new handler %s", name) def register_data_handlers(self, handlers: Dict[str, Callable]): if handlers is None: @@ -81,33 +83,111 @@ def load_dataset( if (not datafile) and (not datasetconfig): raise ValueError("Either datafile or datasetconfig must be set") + def _load_dataset(data_path=None, builder=None, data_files=None, data_dir=None): + """ + Helper function to load a dataset using datasets.load_dataset + with standardized exception handling. + + Args: + data_path: The path argument for load_dataset (directory, file, pattern, dataset_id) + builder: Optional builder to use if provided. + data_files: Optional data_files list if loading from files. + data_dir: Optional data_dir if loading from a directory with a builder. + Returns: dataset + """ + + load_kwargs = {**kwargs, "split": splitName} + if data_dir is not None: + load_kwargs["data_dir"] = data_dir + if data_files is not None: + load_kwargs["data_files"] = data_files + + # Determine the `path` parameter for load_dataset + load_path = builder if builder else data_path + + try: + return datasets.load_dataset(path=load_path, **load_kwargs) + except DatasetNotFoundError as e: + # Reraise with a more context-specific message if needed + raise e + except FileNotFoundError as e: + # Handle file/directory not found + context = ( + f"path {data_path} with builder {builder}" + if builder + else f"path {data_path}" + ) + raise ValueError(f"Data loading failed: invalid {context}.") from e + except datasets.exceptions.DatasetGenerationError as e: + context = ( + f"builder {builder} and data_dir {data_dir}" + if builder and data_dir + else f"builder {builder}" + if builder + else f"path {data_path}" + ) + raise ValueError( + f"Failed to generate the dataset from the provided {context}." + ) from e + if datafile: - files = [datafile] loader = get_loader_for_filepath(file_path=datafile) - elif datasetconfig: - files = datasetconfig.data_paths - name = datasetconfig.name - # simple check to make sure all files are of same type. - extns = [get_extension(f) for f in files] - assert extns.count(extns[0]) == len( - extns - ), f"All files in the dataset {name} should have the same extension" - loader = get_loader_for_filepath(file_path=files[0]) - - if loader in (None, ""): - raise ValueError(f"data path is invalid [{', '.join(files)}]") + if loader in (None, ""): + raise ValueError(f"data path is invalid [{datafile}]") + return _load_dataset(builder=loader, data_files=[datafile]) + + data_paths = datasetconfig.data_paths + builder = datasetconfig.builder + all_datasets = [] + + for data_path in data_paths: + # CASE 1: User passes directory + if os.path.isdir(data_path): # Checks if path exists and isdirectory + # Directory case + if builder: + # Load using a builder with a data_dir + dataset = _load_dataset(builder=builder, data_dir=data_path) + else: + # Load directly from the directory + dataset = _load_dataset(data_path=data_path) + else: + # Non-directory (file, pattern, HF dataset name) + # If no builder provided, attempt to infer one + effective_builder = ( + builder if builder else get_loader_for_filepath(data_path) + ) + + if effective_builder: + # CASE 2: Files passed with builder. Load using the builder and specific files + dataset = _load_dataset( + builder=effective_builder, data_files=[data_path] + ) + else: + # CASE 3: User passes files/folder/pattern/HF_dataset which has no builder + dataset = _load_dataset(data_path=data_path) + all_datasets.append(dataset) + + # Logs warning if datasets have different columns + validate_mergeable_datasets(all_datasets) + + # Concatenate all datasets try: - return datasets.load_dataset( - loader, - data_files=files, - split=splitName, - **kwargs, + if len(all_datasets) == 1: + return all_datasets[0] + + raw_datasets = datasets.concatenate_datasets(all_datasets) + logger.info( + "Datasets concatenated from %s .Concatenated dataset columns: %s", + datasetconfig.name, + list(raw_datasets.features.keys()), ) - except DatasetNotFoundError as e: - raise e - except FileNotFoundError as e: - raise ValueError(f"data path is invalid [{', '.join(files)}]") from e + return raw_datasets + + except Exception as e: + raise ValueError( + f"An error occurred while concatenating datasets from {datasetconfig.name}: {e}" + ) from e def _process_dataset_configs( self, dataset_configs: List[DataSetConfig], **extra_kwargs @@ -129,25 +209,25 @@ def _process_dataset_configs( if sum(p for p in sampling_probabilities) != 1: raise ValueError("Sampling probabilities don't sum to 1") sample_datasets = True - logging.info( + logger.info( "Sampling ratios are specified; given datasets will be interleaved." ) else: - logging.info( + logger.info( "Sampling is not specified; if multiple datasets are provided," " the given datasets will be concatenated." ) sample_datasets = False - logging.info("Starting DataPreProcessor...") + logger.info("Starting DataPreProcessor...") # Now Iterate over the multiple datasets provided to us to process for d in dataset_configs: - logging.info("Loading %s", d.name) + logger.info("Loading %s", d.name) # In future the streaming etc go as kwargs of this function raw_dataset = self.load_dataset(d, splitName) - logging.info("Loaded raw dataset : %s", str(raw_dataset)) + logger.info("Loaded raw dataset : %s", str(raw_dataset)) raw_datasets = DatasetDict() @@ -188,7 +268,7 @@ def _process_dataset_configs( kwargs["fn_kwargs"] = dict(kwargs["fn_kwargs"], **extra_kwargs) - logging.info("Applying Handler: %s Args: %s", data_handler, kwargs) + logger.info("Applying Handler: %s Args: %s", data_handler, kwargs) raw_datasets = raw_datasets.map(handler, **kwargs) @@ -207,7 +287,7 @@ def _process_dataset_configs( if sample_datasets: strategy = self.processor_config.sampling_stopping_strategy seed = self.processor_config.sampling_seed - logging.info( + logger.info( "Interleaving datasets: strategy[%s] seed[%d] probabilities[%s]", strategy, seed, @@ -238,7 +318,7 @@ def process_dataset_configs( if torch.distributed.is_available() and torch.distributed.is_initialized(): if torch.distributed.get_rank() == 0: - logging.info("Processing data on rank 0...") + logger.info("Processing data on rank 0...") train_dataset = self._process_dataset_configs(dataset_configs, **kwargs) else: train_dataset = None @@ -251,7 +331,7 @@ def process_dataset_configs( torch.distributed.broadcast_object_list(to_share, src=0) train_dataset = to_share[0] else: - logging.info("Processing data...") + logger.info("Processing data...") train_dataset = self._process_dataset_configs(dataset_configs, **kwargs) return train_dataset diff --git a/tuning/data/setup_dataprocessor.py b/tuning/data/setup_dataprocessor.py index 2bae18bb7..b6f09c323 100644 --- a/tuning/data/setup_dataprocessor.py +++ b/tuning/data/setup_dataprocessor.py @@ -33,6 +33,8 @@ from tuning.data.data_preprocessing_utils import get_data_collator from tuning.data.data_processors import get_datapreprocessor +logger = logging.getLogger(__name__) + # In future we may make the fields configurable DEFAULT_INPUT_COLUMN = "input" DEFAULT_OUTPUT_COLUMN = "output" @@ -320,9 +322,9 @@ def process_dataargs( """ max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length) - logging.info("Max sequence length is %s", max_seq_length) + logger.info("Max sequence length is %s", max_seq_length) if train_args.max_seq_length > tokenizer.model_max_length: - logging.warning( + logger.warning( "max_seq_length %s exceeds tokenizer.model_max_length \ %s, using tokenizer.model_max_length %s", train_args.max_seq_length, diff --git a/tuning/utils/utils.py b/tuning/utils/utils.py index 6eef6b2cf..b840f6176 100644 --- a/tuning/utils/utils.py +++ b/tuning/utils/utils.py @@ -14,11 +14,14 @@ # Standard import json +import logging import os # Third Party import yaml +logger = logging.getLogger(__name__) + def get_extension(file_path: str) -> str: _, ext = os.path.splitext(file_path) @@ -31,9 +34,9 @@ def get_loader_for_filepath(file_path: str) -> str: return "text" if ext in (".json", ".jsonl"): return "json" - if ext in (".arrow"): + if ext in (".arrow",): return "arrow" - if ext in (".parquet"): + if ext in (".parquet",): return "parquet" return ext @@ -46,3 +49,38 @@ def load_yaml_or_json(file_path: str) -> dict: if ext == ".json": return json.load(f) return None + + +def validate_mergeable_datasets(datasets): + """Given list of datasets, validate if all datasets have same type and number of columns.""" + if len(datasets) > 1: + ref_columns = datasets[0].features + ref_column_names = list(ref_columns.keys()) + ref_column_types = {col: feat.dtype for col, feat in ref_columns.items()} + + # Check all other datasets + for i, ds in enumerate(datasets[1:], start=2): + ds_column_names = list(ds.features.keys()) + ds_column_types = {col: feat.dtype for col, feat in ds.features.items()} + + # Check same set of columns + if set(ds_column_names) != set(ref_column_names): + logger.warning( + "Dataset %d has different columns: %s. Columns in Dataset 1: %s", + i, + ds_column_names, + ref_column_names, + ) + + # Check column data types + for col in ref_column_names: + if (col in ds_column_types) and ( + ds_column_types[col] != ref_column_types[col] + ): + logger.warning( + "Column '%s' in dataset %d has type %s, expected %s", + col, + i, + ds_column_types[col], + ref_column_types[col], + )