Skip to content

Commit

Permalink
Add Support for revision Dataset Parameter to specify reading from …
Browse files Browse the repository at this point in the history
…Huggingface Dataset Revision (#1912)

* Add support for `revision` dataset parameter

* only use revision on hf hub backed datasets

* use revision tied to head

* set download to use revision

* feat: add config to model validator class

* feat: add revision config to RL and tests for it

---------

Co-authored-by: Wing Lian <[email protected]>
Co-authored-by: NanoCode012 <[email protected]>
  • Loading branch information
3 people authored Oct 11, 2024
1 parent 2fbc6b0 commit e73b8df
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/config.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ datasets:
shards: # Optional[int] number of shards to split data into
name: # Optional[str] name of dataset configuration to load
train_on_split: train # Optional[str] name of dataset split to load from
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.

# Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
Expand Down
3 changes: 3 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class SFTDataset(BaseModel):
drop_system_message: Optional[bool] = None

trust_remote_code: Optional[bool] = False
revision: Optional[str] = None


class UserDefinedDPOType(BaseModel):
Expand All @@ -146,6 +147,7 @@ class DPODataset(BaseModel):
split: Optional[str] = None
type: Optional[Union[UserDefinedDPOType, str]] = None
data_files: Optional[List[str]] = None
revision: Optional[str] = None


class UserDefinedKTOType(BaseModel):
Expand All @@ -167,6 +169,7 @@ class KTODataset(BaseModel):
type: Optional[Union[UserDefinedKTOType, str]] = None
data_files: Optional[List[str]] = None
trust_remote_code: Optional[bool] = False
revision: Optional[str] = None


class RLType(str, Enum):
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/utils/data/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def load_split(dataset_cfgs, _cfg):
ds = load_dataset( # pylint: disable=invalid-name
ds_cfg["path"],
split=ds_cfg["split"],
revision=ds_cfg.get("revision", None),
)
split_datasets.insert(i, ds)

Expand Down
6 changes: 5 additions & 1 deletion src/axolotl/utils/data/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def for_d_in_datasets(dataset_configs):
name=config_dataset.name,
streaming=True,
token=use_auth_token,
revision=config_dataset.revision,
)
ds_from_hub = True
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
Expand Down Expand Up @@ -346,6 +347,7 @@ def for_d_in_datasets(dataset_configs):
streaming=False,
data_files=config_dataset.data_files,
token=use_auth_token,
revision=config_dataset.revision,
**load_ds_kwargs,
)
elif ds_from_cloud and remote_file_system:
Expand Down Expand Up @@ -380,6 +382,7 @@ def for_d_in_datasets(dataset_configs):
repo_id=config_dataset.path,
repo_type="dataset",
filename=config_dataset.data_files,
revision=config_dataset.revision,
)
elif isinstance(config_dataset.data_files, list):
fp = []
Expand All @@ -389,6 +392,7 @@ def for_d_in_datasets(dataset_configs):
repo_id=config_dataset.path,
repo_type="dataset",
filename=file,
revision=config_dataset.revision,
)
)
else:
Expand Down Expand Up @@ -433,8 +437,8 @@ def for_d_in_datasets(dataset_configs):
config_dataset=config_dataset,
tokenizer=tokenizer,
cfg=cfg,
dataset=ds,
d_base_type=d_base_type,
dataset=ds,
d_prompt_style=d_prompt_style,
processor=processor,
)
Expand Down
138 changes: 138 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from transformers import AutoTokenizer

from axolotl.utils.data import load_tokenized_prepared_datasets
from axolotl.utils.data.rl import load_prepare_dpo_datasets
from axolotl.utils.dict import DictDefault


Expand Down Expand Up @@ -267,6 +268,143 @@ def test_load_from_single_json(self):
assert "attention_mask" in dataset.features
assert "labels" in dataset.features

def test_load_hub_with_dpo(self):
"""Verify that processing dpo data from the hub works"""

cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"rl": "dpo",
"chat_template": "llama3",
"datasets": [
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"chat_template": "llama3",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
}
],
}
)

train_dataset, _ = load_prepare_dpo_datasets(cfg)

assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features

def test_load_hub_with_revision(self):
"""Verify that processing data from the hub works with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"revision": "d05c1cb",
},
],
}
)

dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)

assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features

def test_load_hub_with_revision_with_dpo(self):
"""Verify that processing dpo data from the hub works with a specific revision"""

cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"rl": "dpo",
"chat_template": "llama3",
"datasets": [
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"chat_template": "llama3",
"revision": "ea82cff",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
}
],
}
)

train_dataset, _ = load_prepare_dpo_datasets(cfg)

assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features

def test_load_local_hub_with_revision(self):
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
revision="d05c1cb",
)

prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"ds_type": "parquet",
"type": "alpaca",
"data_files": [
"mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
],
"revision": "d05c1cb",
},
],
}
)

dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)

assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)


if __name__ == "__main__":
unittest.main()

0 comments on commit e73b8df

Please sign in to comment.