Skip to content

Commit

Permalink
feat: Allow hf dataset id to be passed via training_data_path (#431)
Browse files Browse the repository at this point in the history
* Allow hf dataset id to be loaded by training_data_path

Signed-off-by: Dushyant Behl <[email protected]>

* update README

Signed-off-by: Dushyant Behl <[email protected]>

* minor changes

Signed-off-by: Abhishek <[email protected]>

---------

Signed-off-by: Dushyant Behl <[email protected]>
Signed-off-by: Abhishek <[email protected]>
Co-authored-by: Abhishek <[email protected]>
  • Loading branch information
dushyantbehl and Abhishek-TAMU authored Jan 4, 2025
1 parent 6f0c61d commit 3dc8ef7
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 33 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ pip install fms-hf-tuning[aim]
For more details on how to enable and use the trackers, Please see, [the experiment tracking section below](#experiment-tracking).

## Data Support
Users can pass training data in a single file using the `--training_data_path` argument along with other arguments required for various [use cases](#use-cases-supported-with-training_data_path-argument) (see details below) and the file can be in any of the [supported formats](#supported-data-formats). Alternatively, you can use our powerful [data preprocessing backend](./docs/advanced-data-preprocessing.md) to preprocess datasets on the fly.
Users can pass training data as either a single file or a Hugging Face dataset ID using the `--training_data_path` argument along with other arguments required for various [use cases](#use-cases-supported-with-training_data_path-argument) (see details below). If user choose to pass a file, it can be in any of the [supported formats](#supported-data-formats). Alternatively, you can use our powerful [data preprocessing backend](./docs/advanced-data-preprocessing.md) to preprocess datasets on the fly.


Below, we mention the list of supported data usecases via `--training_data_path` argument. For details of our advanced data preprocessing see more details in [Advanced Data Preprocessing](./docs/advanced-data-preprocessing.md).

## Supported Data Formats
We support the following data formats via `--training_data_path` argument
We support the following file formats via `--training_data_path` argument

Data Format | Tested Support
------------|---------------
Expand All @@ -77,6 +77,8 @@ JSONL | ✅
PARQUET | ✅
ARROW | ✅

As iterated above, we also support passing a HF dataset ID directly via `--training_data_path` argument.

## Use cases supported with `training_data_path` argument

### 1. Data formats with a single sequence and a specified response_template to use for masking on completion.
Expand Down
34 changes: 32 additions & 2 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import tempfile

# Third Party
from datasets.exceptions import DatasetGenerationError
from datasets.exceptions import DatasetGenerationError, DatasetNotFoundError
from transformers.trainer_callback import TrainerCallback
import pytest
import torch
Expand Down Expand Up @@ -326,7 +326,7 @@ def test_run_train_fails_training_data_path_not_exist():
"""Check fails when data path not found."""
updated_data_path_args = copy.deepcopy(DATA_ARGS)
updated_data_path_args.training_data_path = "fake/path"
with pytest.raises(ValueError):
with pytest.raises(DatasetNotFoundError):
sft_trainer.train(MODEL_ARGS, updated_data_path_args, TRAIN_ARGS, None)


Expand Down Expand Up @@ -998,6 +998,36 @@ def test_run_chat_style_ft_using_dataconfig(datafiles, dataconfigfile):
assert 'Provide two rhyming words for the word "love"' in output_inference


@pytest.mark.parametrize(
"data_args",
[
(
# sample hugging face dataset id
configs.DataArguments(
training_data_path="lhoestq/demo1",
data_formatter_template="### Text:{{review}} \n\n### Stars: {{star}}",
response_template="\n### Stars:",
)
)
],
)
def test_run_e2e_with_hf_dataset_id(data_args):
"""
Check if we can run an e2e test with a hf dataset id as training_data_path.
"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

sft_trainer.train(MODEL_ARGS, data_args, train_args)

# validate ft tuning configs
_validate_training(tempdir)

# validate inference
_test_run_inference(checkpoint_path=_get_checkpoint_path(tempdir))


############################# Helper functions #############################
def _test_run_causallm_ft(training_args, model_args, data_args, tempdir):
train_args = copy.deepcopy(training_args)
Expand Down
2 changes: 1 addition & 1 deletion tuning/data/data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def replace_text(match_obj):
if index_object not in element:
raise KeyError("Requested template string is not a valid key in dict")

return element[index_object]
return str(element[index_object])

return {
dataset_text_field: re.sub(r"{{([\s0-9a-zA-Z_\-\.]+)}}", replace_text, template)
Expand Down
70 changes: 42 additions & 28 deletions tuning/data/data_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,42 +130,56 @@ def _load_dataset(data_path=None, builder=None, data_files=None, data_dir=None):
f"Failed to generate the dataset from the provided {context}."
) from e

if datafile:
loader = get_loader_for_filepath(file_path=datafile)
if loader in (None, ""):
raise ValueError(f"data path is invalid [{datafile}]")
return _load_dataset(builder=loader, data_files=[datafile])
def _try_load_dataset(dataset_path, dataset_builder):
"""
Helper function to call load dataset on case by case basis to ensure we handle
directories and files (with or without builders) and hf datasets.
data_paths = datasetconfig.data_paths
builder = datasetconfig.builder
all_datasets = []
Args:
dataset_path: Path of directory/file, pattern, or hf dataset id.
dataset_builder: Optional builder to use if provided.
Returns: dataset
"""
if not dataset_path:
raise ValueError("Invalid dataset path")

for data_path in data_paths:
# CASE 1: User passes directory
if os.path.isdir(data_path): # Checks if path exists and isdirectory
if os.path.isdir(dataset_path): # Checks if path exists and it is a dir
# Directory case
if builder:
if dataset_builder:
# Load using a builder with a data_dir
dataset = _load_dataset(builder=builder, data_dir=data_path)
else:
# Load directly from the directory
dataset = _load_dataset(data_path=data_path)
else:
# Non-directory (file, pattern, HF dataset name)
# If no builder provided, attempt to infer one
effective_builder = (
builder if builder else get_loader_for_filepath(data_path)
return _load_dataset(builder=dataset_builder, data_dir=dataset_path)

# If no builder then load directly from the directory
return _load_dataset(data_path=dataset_path)

# Non-directory (file, pattern, HF dataset name)
# If no builder provided, attempt to infer one
effective_builder = (
dataset_builder
if dataset_builder
else get_loader_for_filepath(dataset_path)
)

if effective_builder:
# CASE 2: Files passed with builder. Load using the builder and specific files
return _load_dataset(
builder=effective_builder, data_files=[dataset_path]
)

if effective_builder:
# CASE 2: Files passed with builder. Load using the builder and specific files
dataset = _load_dataset(
builder=effective_builder, data_files=[data_path]
)
else:
# CASE 3: User passes files/folder/pattern/HF_dataset which has no builder
dataset = _load_dataset(data_path=data_path)
# CASE 3: User passes files/folder/pattern/HF_dataset which has no builder
# Still no builder, try if this is a dataset id
return _load_dataset(data_path=dataset_path)

if datafile:
return _try_load_dataset(datafile, None)

data_paths = datasetconfig.data_paths
builder = datasetconfig.builder
all_datasets = []

for data_path in data_paths:
dataset = _try_load_dataset(data_path, builder)
all_datasets.append(dataset)

# Logs warning if datasets have different columns
Expand Down

0 comments on commit 3dc8ef7

Please sign in to comment.