From 76883851d233d3734c19b1979ede7020059ea37d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 11 Oct 2024 13:33:20 -0400 Subject: [PATCH] add warning that sharegpt will be deprecated (#1957) * add warning that sharegpt will be deprecated * add helper script for chat_templates and document deprecation * Update src/axolotl/prompt_strategies/sharegpt.py Co-authored-by: NanoCode012 --------- Co-authored-by: NanoCode012 --- README.md | 2 +- scripts/chat_datasets.py | 60 +++++++++++++++++++++++ src/axolotl/prompt_strategies/sharegpt.py | 3 ++ 3 files changed, 64 insertions(+), 1 deletion(-) create mode 100644 scripts/chat_datasets.py diff --git a/README.md b/README.md index f6f4e4e806..4ce7a351bb 100644 --- a/README.md +++ b/README.md @@ -383,7 +383,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod - typescript type: ... # unimplemented custom format - # fastchat conversation + # fastchat conversation (deprecation soon, use chat_template) # See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py - path: ... type: sharegpt diff --git a/scripts/chat_datasets.py b/scripts/chat_datasets.py new file mode 100644 index 0000000000..5eb5bde1e2 --- /dev/null +++ b/scripts/chat_datasets.py @@ -0,0 +1,60 @@ +""" +helper script to parse chat datasets into a usable yaml +""" +import click +import yaml +from datasets import load_dataset + + +@click.command() +@click.argument("dataset", type=str) +@click.option("--split", type=str, default="train") +def parse_dataset(dataset=None, split="train"): + ds_cfg = {} + ds_cfg["path"] = dataset + ds_cfg["split"] = split + ds_cfg["type"] = "chat_template" + ds_cfg["chat_template"] = "<<>>" + + dataset = load_dataset(dataset, split=split) + features = dataset.features + feature_keys = features.keys() + field_messages = None + for key in ["conversation", "conversations", "messages"]: + if key in feature_keys: + field_messages = key + break + if not field_messages: + raise ValueError( + f'No conversation field found in dataset: {", ".join(feature_keys)}' + ) + ds_cfg["field_messages"] = field_messages + + message_fields = features["conversations"][0].keys() + message_field_role = None + for key in ["from", "role"]: + if key in message_fields: + message_field_role = key + break + if not message_field_role: + raise ValueError( + f'No role field found in messages: {", ".join(message_fields)}' + ) + ds_cfg["message_field_role"] = message_field_role + + message_field_content = None + for key in ["content", "text", "value"]: + if key in message_fields: + message_field_content = key + break + if not message_field_content: + raise ValueError( + f'No content field found in messages: {", ".join(message_fields)}' + ) + ds_cfg["message_field_content"] = message_field_content + + print(yaml.dump({"datasets": [ds_cfg]})) + + +if __name__ == "__main__": + parse_dataset() diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index 321f19554b..4565c35d5d 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -61,6 +61,9 @@ def build_loader( default_conversation: Optional[str] = None, ): def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): + LOG.warning( + "sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead.", + ) conversation = ( ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg