From e73b8dff8d5fcfb02371916cbebc1350a3a1a9c9 Mon Sep 17 00:00:00 2001 From: Thomas Cleberg <84520378+thomascleberg@users.noreply.github.com> Date: Fri, 11 Oct 2024 12:32:50 -0500 Subject: [PATCH] Add Support for `revision` Dataset Parameter to specify reading from Huggingface Dataset Revision (#1912) * Add support for `revision` dataset parameter * only use revision on hf hub backed datasets * use revision tied to head * set download to use revision * feat: add config to model validator class * feat: add revision config to RL and tests for it --------- Co-authored-by: Wing Lian Co-authored-by: NanoCode012 --- docs/config.qmd | 1 + .../config/models/input/v0_4_1/__init__.py | 3 + src/axolotl/utils/data/rl.py | 1 + src/axolotl/utils/data/sft.py | 6 +- tests/test_datasets.py | 138 ++++++++++++++++++ 5 files changed, 148 insertions(+), 1 deletion(-) diff --git a/docs/config.qmd b/docs/config.qmd index 99a69a0973..8329f35535 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -90,6 +90,7 @@ datasets: shards: # Optional[int] number of shards to split data into name: # Optional[str] name of dataset configuration to load train_on_split: train # Optional[str] name of dataset split to load from + revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets. # Optional[str] fastchat conversation type, only used with type: sharegpt conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 47796add6b..1c33b59078 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -125,6 +125,7 @@ class SFTDataset(BaseModel): drop_system_message: Optional[bool] = None trust_remote_code: Optional[bool] = False + revision: Optional[str] = None class UserDefinedDPOType(BaseModel): @@ -146,6 +147,7 @@ class DPODataset(BaseModel): split: Optional[str] = None type: Optional[Union[UserDefinedDPOType, str]] = None data_files: Optional[List[str]] = None + revision: Optional[str] = None class UserDefinedKTOType(BaseModel): @@ -167,6 +169,7 @@ class KTODataset(BaseModel): type: Optional[Union[UserDefinedKTOType, str]] = None data_files: Optional[List[str]] = None trust_remote_code: Optional[bool] = False + revision: Optional[str] = None class RLType(str, Enum): diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index d0324e1ebd..35bd5fcbb7 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -90,6 +90,7 @@ def load_split(dataset_cfgs, _cfg): ds = load_dataset( # pylint: disable=invalid-name ds_cfg["path"], split=ds_cfg["split"], + revision=ds_cfg.get("revision", None), ) split_datasets.insert(i, ds) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 7d6922cbf2..39eb2c4e04 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -242,6 +242,7 @@ def for_d_in_datasets(dataset_configs): name=config_dataset.name, streaming=True, token=use_auth_token, + revision=config_dataset.revision, ) ds_from_hub = True except (FileNotFoundError, ConnectionError, HFValidationError, ValueError): @@ -346,6 +347,7 @@ def for_d_in_datasets(dataset_configs): streaming=False, data_files=config_dataset.data_files, token=use_auth_token, + revision=config_dataset.revision, **load_ds_kwargs, ) elif ds_from_cloud and remote_file_system: @@ -380,6 +382,7 @@ def for_d_in_datasets(dataset_configs): repo_id=config_dataset.path, repo_type="dataset", filename=config_dataset.data_files, + revision=config_dataset.revision, ) elif isinstance(config_dataset.data_files, list): fp = [] @@ -389,6 +392,7 @@ def for_d_in_datasets(dataset_configs): repo_id=config_dataset.path, repo_type="dataset", filename=file, + revision=config_dataset.revision, ) ) else: @@ -433,8 +437,8 @@ def for_d_in_datasets(dataset_configs): config_dataset=config_dataset, tokenizer=tokenizer, cfg=cfg, - dataset=ds, d_base_type=d_base_type, + dataset=ds, d_prompt_style=d_prompt_style, processor=processor, ) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index a274b7b894..f8b463a03e 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -12,6 +12,7 @@ from transformers import AutoTokenizer from axolotl.utils.data import load_tokenized_prepared_datasets +from axolotl.utils.data.rl import load_prepare_dpo_datasets from axolotl.utils.dict import DictDefault @@ -267,6 +268,143 @@ def test_load_from_single_json(self): assert "attention_mask" in dataset.features assert "labels" in dataset.features + def test_load_hub_with_dpo(self): + """Verify that processing dpo data from the hub works""" + + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 1024, + "rl": "dpo", + "chat_template": "llama3", + "datasets": [ + { + "path": "fozziethebeat/alpaca_messages_2k_dpo_test", + "type": "chat_template.default", + "chat_template": "llama3", + "field_messages": "conversation", + "field_chosen": "chosen", + "field_rejected": "rejected", + "message_field_role": "role", + "message_field_content": "content", + "roles": { + "system": ["system"], + "user": ["user"], + "assistant": ["assistant"], + }, + } + ], + } + ) + + train_dataset, _ = load_prepare_dpo_datasets(cfg) + + assert len(train_dataset) == 1800 + assert "conversation" in train_dataset.features + + def test_load_hub_with_revision(self): + """Verify that processing data from the hub works with a specific revision""" + with tempfile.TemporaryDirectory() as tmp_dir: + prepared_path = Path(tmp_dir) / "prepared" + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 1024, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + "revision": "d05c1cb", + }, + ], + } + ) + + dataset, _ = load_tokenized_prepared_datasets( + self.tokenizer, cfg, prepared_path + ) + + assert len(dataset) == 2000 + assert "input_ids" in dataset.features + assert "attention_mask" in dataset.features + assert "labels" in dataset.features + + def test_load_hub_with_revision_with_dpo(self): + """Verify that processing dpo data from the hub works with a specific revision""" + + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 1024, + "rl": "dpo", + "chat_template": "llama3", + "datasets": [ + { + "path": "fozziethebeat/alpaca_messages_2k_dpo_test", + "type": "chat_template.default", + "chat_template": "llama3", + "revision": "ea82cff", + "field_messages": "conversation", + "field_chosen": "chosen", + "field_rejected": "rejected", + "message_field_role": "role", + "message_field_content": "content", + "roles": { + "system": ["system"], + "user": ["user"], + "assistant": ["assistant"], + }, + } + ], + } + ) + + train_dataset, _ = load_prepare_dpo_datasets(cfg) + + assert len(train_dataset) == 1800 + assert "conversation" in train_dataset.features + + def test_load_local_hub_with_revision(self): + """Verify that a local copy of a hub dataset can be loaded with a specific revision""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_ds_path = Path("mhenrichsen/alpaca_2k_test") + tmp_ds_path.mkdir(parents=True, exist_ok=True) + snapshot_download( + repo_id="mhenrichsen/alpaca_2k_test", + repo_type="dataset", + local_dir=tmp_ds_path, + revision="d05c1cb", + ) + + prepared_path = Path(tmp_dir) / "prepared" + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 1024, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "ds_type": "parquet", + "type": "alpaca", + "data_files": [ + "mhenrichsen/alpaca_2k_test/alpaca_2000.parquet", + ], + "revision": "d05c1cb", + }, + ], + } + ) + + dataset, _ = load_tokenized_prepared_datasets( + self.tokenizer, cfg, prepared_path + ) + + assert len(dataset) == 2000 + assert "input_ids" in dataset.features + assert "attention_mask" in dataset.features + assert "labels" in dataset.features + shutil.rmtree(tmp_ds_path) + if __name__ == "__main__": unittest.main()