From fd3b80716aaa5b06ac979fc6a0885d032bcb5d14 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 8 Nov 2024 13:45:49 -0500 Subject: [PATCH] remove fastchat and sharegpt (#2021) * remove fastchat and sharegpt * remove imports * remove more fastchat imports * chore: remove unused functions * feat: remove sharegpt and deprecate from docs * chore: remove unused sharegpt checks * fix: remove sharegpt type from tests * feat: add sharegpt deprecation error * feat: update readme --------- Co-authored-by: NanoCode012 --- README.md | 7 +- devtools/dev_chat_template.yml | 2 +- docs/config.qmd | 11 +- docs/dataset-formats/conversation.qmd | 59 +-- requirements.txt | 1 - src/axolotl/cli/preprocess.py | 21 - src/axolotl/cli/train.py | 19 - .../fastchat_conversation_turns.py | 231 -------- src/axolotl/prompt_strategies/instruct.py | 33 -- src/axolotl/prompt_strategies/llama2_chat.py | 6 +- src/axolotl/prompt_strategies/sharegpt.py | 223 -------- .../prompt_strategies/sharegpt_jokes.py | 28 - src/axolotl/prompt_tokenizers.py | 157 +----- src/axolotl/prompters.py | 159 +----- src/axolotl/utils/config/__init__.py | 26 - .../config/models/input/v0_4_1/__init__.py | 26 +- src/axolotl/utils/tokenization.py | 64 --- tests/prompt_strategies/test_sharegpt.py | 500 ------------------ tests/test_normalize_config.py | 10 +- tests/test_prompt_tokenizers.py | 210 +------- tests/test_validation.py | 33 -- tests/test_validation_dataset.py | 6 +- 22 files changed, 28 insertions(+), 1804 deletions(-) delete mode 100644 src/axolotl/monkeypatch/fastchat_conversation_turns.py delete mode 100644 src/axolotl/prompt_strategies/instruct.py delete mode 100644 src/axolotl/prompt_strategies/sharegpt.py delete mode 100644 src/axolotl/prompt_strategies/sharegpt_jokes.py delete mode 100644 tests/prompt_strategies/test_sharegpt.py diff --git a/README.md b/README.md index b3f292c7dd..077dd6fee5 100644 --- a/README.md +++ b/README.md @@ -383,11 +383,10 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod - typescript type: ... # unimplemented custom format - # fastchat conversation (deprecation soon, use chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template) - # See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py + # chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template - path: ... - type: sharegpt - conversation: chatml # default: vicuna_v1.1 + type: chat_template + chat_template: chatml # defaults to tokenizer's chat_template # local - path: data.jsonl # or json diff --git a/devtools/dev_chat_template.yml b/devtools/dev_chat_template.yml index 9697da4b33..27dc9be1af 100644 --- a/devtools/dev_chat_template.yml +++ b/devtools/dev_chat_template.yml @@ -1,4 +1,4 @@ -# Example config for debugging the sharegpt prompt format +# Example config for debugging the chat_template prompt format base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 model_type: LlamaForCausalLM tokenizer_type: LlamaTokenizer diff --git a/docs/config.qmd b/docs/config.qmd index 238f7201db..09691bc770 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -83,7 +83,7 @@ lora_on_cpu: true datasets: # HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files - path: vicgalle/alpaca-gpt4 - # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection] + # The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection] type: alpaca # format | format: (chat/instruct) | .load_ ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file data_files: # Optional[str] path to source data files @@ -92,15 +92,6 @@ datasets: train_on_split: train # Optional[str] name of dataset split to load from revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets. - # Optional[str] fastchat conversation type, only used with type: sharegpt - conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py - field_human: # Optional[str]. Human key to use for conversation. - field_model: # Optional[str]. Assistant key to use for conversation. - # Add additional keys from your dataset as input or output roles - roles: - input: # Optional[List[str]]. These will be masked based on train_on_input - output: # Optional[List[str]]. - # Custom user instruction prompt - path: repo type: diff --git a/docs/dataset-formats/conversation.qmd b/docs/dataset-formats/conversation.qmd index c7273c5be5..fb9aed3ffa 100644 --- a/docs/dataset-formats/conversation.qmd +++ b/docs/dataset-formats/conversation.qmd @@ -6,33 +6,8 @@ order: 3 ## sharegpt -UPDATE: ShareGPT is being deprecated in the next release. Please see `chat_template` section below. +IMPORTANT: ShareGPT is deprecated!. Please see `chat_template` section below. -conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt) - -```{.json filename="data.jsonl"} -{"conversations": [{"from": "...", "value": "..."}]} -``` - -Note: `type: sharegpt` opens special configs: -- `conversation`: enables conversions to many Conversation types. Refer to the 'name' [here](https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py) for options. -- `roles`: allows you to specify the roles for input and output. This is useful for datasets with custom roles such as `tool` etc to support masking. -- `field_human`: specify the key to use instead of `human` in the conversation. -- `field_model`: specify the key to use instead of `gpt` in the conversation. - -```yaml -datasets: - path: ... - type: sharegpt - - conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py - field_human: # Optional[str]. Human key to use for conversation. - field_model: # Optional[str]. Assistant key to use for conversation. - # Add additional keys from your dataset as input or output roles - roles: - input: # Optional[List[str]]. These will be masked based on train_on_input - output: # Optional[List[str]]. -``` ## pygmalion @@ -40,38 +15,6 @@ datasets: {"conversations": [{"role": "...", "value": "..."}]} ``` -## sharegpt.load_role - -conversations where `role` is used instead of `from` - -```{.json filename="data.jsonl"} -{"conversations": [{"role": "...", "value": "..."}]} -``` - -## sharegpt.load_guanaco - -conversations where `from` is `prompter` `assistant` instead of default sharegpt - -```{.json filename="data.jsonl"} -{"conversations": [{"from": "...", "value": "..."}]} -``` - -## sharegpt.load_ultrachat - -conversations where the turns field is 'messages', human is 'user' and gpt is 'assistant'. - -```{.json filename="data.jsonl"} -{"messages": [{"user": "...", "assistant": "..."}]} -``` - -## sharegpt_jokes - -creates a chat where bot is asked to tell a joke, then explain why the joke is funny - -```{.json filename="data.jsonl"} -{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]} -``` - ## chat_template diff --git a/requirements.txt b/requirements.txt index 735f860a55..d1fdccaf77 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,7 +28,6 @@ scipy scikit-learn==1.4.2 pynvml art -fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe gradio==3.50.2 tensorboard python-dotenv==1.0.1 diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index aab29e2670..a1592aa785 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -23,10 +23,6 @@ ) from axolotl.common.cli import PreprocessCliArgs from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH -from axolotl.prompt_strategies.sharegpt import ( - register_chatml_template, - register_llama3_template, -) from axolotl.utils.trainer import disable_datasets_caching LOG = logging.getLogger("axolotl.cli.preprocess") @@ -44,23 +40,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): return_remaining_strings=True ) - if parsed_cfg.chat_template == "chatml": - if parsed_cfg.default_system_message: - LOG.info( - f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}" - ) - register_chatml_template(parsed_cfg.default_system_message) - else: - register_chatml_template() - elif parsed_cfg.chat_template == "llama3": - if parsed_cfg.default_system_message: - LOG.info( - f"LLaMA-3 set. Adding default system message: {parsed_cfg.default_system_message}" - ) - register_llama3_template(parsed_cfg.default_system_message) - else: - register_llama3_template() - if not parsed_cfg.dataset_prepared_path: msg = ( Fore.RED diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 16d66a82f0..2a40e854ee 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -19,10 +19,6 @@ ) from axolotl.common.cli import TrainerCliArgs from axolotl.integrations.base import PluginManager -from axolotl.prompt_strategies.sharegpt import ( - register_chatml_template, - register_llama3_template, -) from axolotl.train import train LOG = logging.getLogger("axolotl.cli.train") @@ -42,21 +38,6 @@ def do_train(cfg, cli_args) -> None: print_axolotl_text_art() check_accelerate_default_config() check_user_token() - if cfg.chat_template == "chatml" and cfg.default_system_message: - LOG.info( - f"ChatML set. Adding default system message: {cfg.default_system_message}" - ) - register_chatml_template(cfg.default_system_message) - else: - register_chatml_template() - - if cfg.chat_template == "llama3" and cfg.default_system_message: - LOG.info( - f"LLaMA-3 set. Adding default system message: {cfg.default_system_message}" - ) - register_llama3_template(cfg.default_system_message) - else: - register_llama3_template() if cfg.rl: # and cfg.rl != "orpo": dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) diff --git a/src/axolotl/monkeypatch/fastchat_conversation_turns.py b/src/axolotl/monkeypatch/fastchat_conversation_turns.py deleted file mode 100644 index a09bfddb4b..0000000000 --- a/src/axolotl/monkeypatch/fastchat_conversation_turns.py +++ /dev/null @@ -1,231 +0,0 @@ -""" -monkeypatch to add a get_turns method -""" - -import logging -from typing import Generator, Tuple - -from fastchat.conversation import SeparatorStyle - -LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns") - - -def get_prompt(self) -> str: - ret = "" - for role, msg in self.get_turns(): - ret += role + msg - return ret - - -def get_turns( # pylint: disable=too-many-return-statements - self, -) -> Generator[Tuple[str, str], None, None]: - """Get the prompt for generation.""" - system_prompt = self.system_template.format(system_message=self.system_message) - if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: - yield "", system_prompt + self.sep - for role, message in self.messages: - if message: - yield role + ": ", message + self.sep - else: - yield role + ":", "" - return - if self.sep_style == SeparatorStyle.ADD_COLON_TWO: - seps = [self.sep, self.sep2] - yield "", system_prompt + seps[0] - for i, (role, message) in enumerate(self.messages): - if message: - yield role + ": ", message + seps[i % 2] - else: - yield role + ":", "" - return - if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: - yield "", system_prompt + self.sep - for role, message in self.messages: - if message: - yield role + ": ", message + self.sep - else: - yield role + ": ", "" # must be end with a space - return - if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: - yield "", "" if system_prompt == "" else system_prompt + self.sep - for role, message in self.messages: - if message: - yield role + "\n", message + self.sep - else: - yield role + "\n", "" - return - if self.sep_style == SeparatorStyle.NO_COLON_SINGLE: - yield "", system_prompt - for role, message in self.messages: - if message: - yield role, message + self.sep - else: - yield role, "" - return - if self.sep_style == SeparatorStyle.NO_COLON_TWO: - seps = [self.sep, self.sep2] - yield "", system_prompt - for i, (role, message) in enumerate(self.messages): - if message: - yield role, message + seps[i % 2] - else: - yield role, "" - return - if self.sep_style == SeparatorStyle.RWKV: - yield "", system_prompt - for i, (role, message) in enumerate(self.messages): - if message: - yield role + ": ", message.replace("\r\n", "\n").replace( - "\n\n", "\n" - ) + "\n\n" - else: - yield role + ":", "" - return - if self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral": - if self.system_message: - if self.messages: - # For llama, the system message is incorporated into the first human instruction - first_role, first_msg = self.messages[0] - if first_role == self.roles[0]: - system_prompt += first_msg - self.messages.pop(0) - yield "", system_prompt - for i, (role, message) in enumerate(self.messages): - if message: - if (i % 2 == 0 and not self.system_message) or ( - i % 2 != 0 and self.system_message - ): - role = " " + role - yield role + " ", message - else: - yield role, "" - return - if self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral": - contains_sys_msg = False - if self.system_message: - contains_sys_msg = True - if self.messages: - # There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction separated by a newline - first_role, first_msg = self.messages[0] - if first_role == self.roles[0]: - system_prompt = self.system_template.format( - system_message=" " + self.system_message - ) - system_prompt += first_msg - self.messages.pop(0) - yield "", system_prompt - for i, (role, message) in enumerate(self.messages): - if message and i == 0 and not contains_sys_msg: - yield "", system_prompt.strip() + " " + message # if there is no system message, we need to make sure there is the a ` [INST]` at the beginning of the first instruction. - elif message: - yield role + " ", message - else: - yield role, "" - return - if self.sep_style == SeparatorStyle.LLAMA3: - if self.system_message: - # For llama3, the system message is NOT incorporated into the first human instruction - # All messages follow <|start_header_id|>' + role + '<|end_header_id|>\n\n'+ message + '<|eot_id|> - yield "", system_prompt - for i, (role, message) in enumerate(self.messages): - if message: - yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", f"{message.strip()}<|eot_id|>" - else: - yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", "" - return - if self.sep_style == SeparatorStyle.GEMMA: - if self.system_message: - raise ValueError("Gemma chat template does not support system messages") - for i, (role, message) in enumerate(self.messages): - prefix = "" if i == 0 else "" - message_str = message if message else "" - yield prefix + "" + role + "\n", message_str + "\n" - return - if self.sep_style == SeparatorStyle.CHATGLM: - # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 - # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 - round_add_n = 1 if self.name == "chatglm2" else 0 - if system_prompt: - yield "", system_prompt + self.sep - - for i, (role, message) in enumerate(self.messages): - if i % 2 == 0: - yield "", f"[Round {i//2 + round_add_n}]{self.sep}" - - if message: - yield f"{role}:", f"{message}{self.sep}" - else: - yield f"{role}:", "" - return - if self.sep_style == SeparatorStyle.CHATML: - yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n" - for role, message in self.messages: - if message: - yield role + "\n", message + self.sep + "\n" - else: - yield role + "\n", "" - return - if self.sep_style == SeparatorStyle.CHATGLM3: - if self.system_message: - yield "", system_prompt - for role, message in self.messages: - if message: - yield role + "\n", " " + message - else: - yield role - return - if self.sep_style == SeparatorStyle.CHATINTERN: - # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 - seps = [self.sep, self.sep2] - yield "", system_prompt - for i, (role, message) in enumerate(self.messages): - prefix = "" if i % 2 == 0 else "" - if message: - yield prefix + role + ":", message + seps[i % 2] + "\n" - else: - yield role + ":", "" - return - if self.sep_style == SeparatorStyle.DOLLY: - seps = [self.sep, self.sep2] - yield "", system_prompt - for i, (role, message) in enumerate(self.messages): - if message: - suffix = "\n\n" if i % 2 == 1 else "" - yield role + ":\n", message + seps[i % 2] + suffix - else: - yield role + ":\n", "" - return - if self.sep_style == SeparatorStyle.PHOENIX: - yield "", system_prompt - for role, message in self.messages: - if message: - yield role + ": ", "" + message + "" - else: - yield role + ": " + "", "" - return - if self.sep_style == SeparatorStyle.ROBIN: - yield "", system_prompt + self.sep - for role, message in self.messages: - if message: - yield role + ":\n", message + self.sep - else: - yield role + ":\n", "" - return - if self.sep_style == SeparatorStyle.FALCON_CHAT: - if self.system_message: - yield "", system_prompt + self.sep - for role, message in self.messages: - if message: - yield role + ": ", message + self.sep - else: - yield role + ":", "" - else: - raise ValueError(f"Invalid style: {self.sep_style}") - - -def add_get_turns_to_conversation(): - import fastchat.conversation - - fastchat.conversation.Conversation.get_turns = get_turns - fastchat.conversation.Conversation.get_prompt = get_prompt diff --git a/src/axolotl/prompt_strategies/instruct.py b/src/axolotl/prompt_strategies/instruct.py deleted file mode 100644 index 3d63674890..0000000000 --- a/src/axolotl/prompt_strategies/instruct.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Module containing the InstructShareGPTPromptTokenizingStrategy class""" -from typing import Any, Dict, Optional - -from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy -from axolotl.prompters import ShareGPTPrompterV2 - - -def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): - conversation = ( - ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None - ) - strategy = InstructShareGPTPromptTokenizingStrategy( - # pylint: disable=duplicate-code - ShareGPTPrompterV2( - conversation=conversation, - ), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - return strategy - - -class InstructShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): - """ - basic sharegpt strategy to grab conversations from the sample row - """ - - def get_conversation_thread(self, prompt): - return [ - {"from": "human", "value": prompt["instruction"]}, - {"from": "gpt", "value": prompt["output"]}, - ] diff --git a/src/axolotl/prompt_strategies/llama2_chat.py b/src/axolotl/prompt_strategies/llama2_chat.py index a1f5ffefff..29e091bfd0 100644 --- a/src/axolotl/prompt_strategies/llama2_chat.py +++ b/src/axolotl/prompt_strategies/llama2_chat.py @@ -29,7 +29,7 @@ from typing import Generator, List, Sequence from axolotl.prompt_tokenizers import PromptTokenizingStrategy -from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE +from axolotl.prompters import ALTERNATING_ASSERTION_FAILED_ROLE, IGNORE_TOKEN_ID @dataclass @@ -75,7 +75,7 @@ def append_message(self, role: str, message: str): class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy): """ - Tokenizing strategy for ShareGPT prompts. + Tokenizing strategy for Llama2 prompts. adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py """ @@ -191,7 +191,7 @@ def build_prompt(self, source) -> Generator[Llama2ChatConversation, None, None]: conv.messages = [] # pylint: disable=R0801 for j, sentence in enumerate(source): role = roles[sentence["from"]] - assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE + assert role == conv.roles[j % 2], ALTERNATING_ASSERTION_FAILED_ROLE if sentence["value"]: conv.append_message(role, sentence["value"]) yield conv diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py deleted file mode 100644 index 069d243f52..0000000000 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ /dev/null @@ -1,223 +0,0 @@ -"""Module containing the SimpleShareGPTPromptTokenizingStrategy class""" - -import logging -from typing import Any, Dict, Optional, Type - -from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template - -from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy -from axolotl.prompters import ShareGPTPrompterV2 -from axolotl.utils.tokenization import ( - chatml_to_conversation, - merge_consecutive_messages, -) - -LOG = logging.getLogger("axolotl") - - -def register_chatml_template(system_message=None): - system_message = system_message or "You are a helpful assistant." - register_conv_template( - Conversation( - name="chatml", - system_template="<|im_start|>system\n{system_message}", - system_message=system_message, - roles=("<|im_start|>user", "<|im_start|>assistant"), - sep_style=SeparatorStyle.CHATML, - sep="<|im_end|>", - ) - ) - register_conv_template( - Conversation( - name="chatml_glaive", - system_template="<|im_start|>system\n{system_message}", - system_message=system_message, - roles=("<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"), - sep_style=SeparatorStyle.CHATML, - sep="<|im_end|>", - ) - ) - - -def register_llama3_template(system_message=None): - system_message = system_message or "You are a helpful assistant." - register_conv_template( - Conversation( - name="llama3", - system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>", - system_message=system_message, - roles=("user", "assistant"), - sep_style=SeparatorStyle.LLAMA3, - sep="", - stop_str="<|eot_id|>", - stop_token_ids=[128001, 128009], - ) - ) - - -def build_loader( - tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"], - prompter_cls: Type["ShareGPTPrompterV2"], - default_conversation: Optional[str] = None, -): - def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): - LOG.warning( - "sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead. https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template", - ) - conversation = ( - ds_cfg["conversation"] - if ds_cfg and "conversation" in ds_cfg - else default_conversation - ) - field_human = ( - ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None - ) - field_model = ( - ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None - ) - roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None - strategy = tokenization_strategy_cls( - prompter_cls( - conversation=conversation, - role_key_model=field_model, - role_key_human=field_human, - roles=roles, - ), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"): - strategy.strict = ds_cfg["strict"] - if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"): - strategy.messages = ds_cfg["field_messages"] - return strategy - - return _load - - -class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): - """ - basic sharegpt strategy to grab conversations from the sample row - """ - - _strict = False - _messages = "conversations" - - @property - def strict(self): - return self._strict - - @strict.setter - def strict(self, strict): - self._strict = strict - - @property - def messages(self): - return self._messages - - @messages.setter - def messages(self, messages): - self._messages = messages - - def get_conversation_thread(self, prompt): - conversations = prompt[self.messages] - if self.strict: - return conversations - role_key = "from" - if "role" in conversations[0].keys(): - role_key = "role" - value_key = "value" - if "text" in conversations[0].keys(): - value_key = "text" - elif "content" in conversations[0].keys(): - value_key = "content" - # remap roles - allow for assistant turn" - role_map = { - "user": "human", - "human": "human", - "assistant": "gpt", - "gpt": "gpt", - "system": "system", - } - turns = [ - { - "from": ( - role_map[t[role_key]] if t[role_key] in role_map else t[role_key] - ), - "value": t[value_key], - "weight": 1 - if "weight" not in t or t["weight"] is None - else t["weight"], - } - for t in conversations - ] - return turns - - -class SimpleRoleShareGPTPromptTokenizingStrategy( - SimpleShareGPTPromptTokenizingStrategy -): - """ - basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from - """ - - def get_conversation_thread(self, prompt): - conversations = prompt["conversations"] - # remap role: prompter/assistant, text: ... => from: human/gpt, value: ... - turns = [{"from": t["role"], "value": t["value"]} for t in conversations] - return turns - - -class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): - """ - sharegpt strategy that remaps oasst data to sharegpt format - """ - - def get_conversation_thread(self, prompt): - conversations = prompt["conversations"] - # remap role: prompter/assistant, text: ... => from: human/gpt, value: ... - role_map = {"prompter": "human", "assistant": "gpt"} - turns = [ - {"from": role_map[t["role"]], "value": t["text"]} for t in conversations - ] - return turns - - -class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy): - """ - sharegpt strategy that remaps ultrachat data to sharegpt format - """ - - def get_conversation_thread(self, prompt): - conversations = prompt["messages"] - role_map = {"user": "human", "assistant": "gpt"} - turns = [ - {"from": role_map[t["role"]], "value": t["content"]} for t in conversations - ] - return turns - - -class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy): - """ - sharegpt strategy that remaps glaive data to sharegpt format - """ - - def get_conversation_thread(self, prompt): - conversation = chatml_to_conversation(prompt) - conversation = merge_consecutive_messages(conversation) - - return conversation - - -load = build_loader(SimpleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2) -load_role = build_loader(SimpleRoleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2) -load_ultrachat = build_loader( - UltrachatShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2 -) -load_guanaco = build_loader(GuanacoShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2) -load_glaive = build_loader( - GlaiveShareGPTPromptTokenizingStrategy, - ShareGPTPrompterV2, - default_conversation="chatml_glaive", -) diff --git a/src/axolotl/prompt_strategies/sharegpt_jokes.py b/src/axolotl/prompt_strategies/sharegpt_jokes.py deleted file mode 100644 index 404302c81e..0000000000 --- a/src/axolotl/prompt_strategies/sharegpt_jokes.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Module for Jokes prompts using sharegpt style """ -from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy -from axolotl.prompters import ShareGPTPrompterV2 - - -def load(tokenizer, cfg): - return SimpleJokesShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2(), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - - -class SimpleJokesShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): - """ - Tokenization strategy for asking bot to tell a joke and then explain why its funny - """ - - # title, text, explanation - def get_conversation_thread(self, prompt): - title = "" if not prompt["title"] else prompt["title"] + " " - return [ - {"from": "human", "value": "Tell me a joke."}, - {"from": "gpt", "value": title + prompt["text"]}, - {"from": "human", "value": "Why is that joke funny?"}, - {"from": "gpt", "value": prompt["explanation"]}, - ] diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 51d497a23c..bd6e3f9dce 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -1,17 +1,12 @@ """Module containing PromptTokenizingStrategy and Prompter classes""" import abc -import copy import logging from typing import Dict, List, Tuple, Union -from fastchat.conversation import Conversation from transformers import BatchEncoding, PreTrainedTokenizer -from axolotl.monkeypatch.fastchat_conversation_turns import ( - add_get_turns_to_conversation, -) -from axolotl.prompters import IGNORE_TOKEN_ID, Prompter +from axolotl.prompters import Prompter LOG = logging.getLogger("axolotl") @@ -21,8 +16,6 @@ LLAMA_DEFAULT_BOS_TOKEN = "" # nosec LLAMA_DEFAULT_UNK_TOKEN = "" # nosec -add_get_turns_to_conversation() - class InvalidDataException(Exception): """ @@ -331,154 +324,6 @@ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]: ) -class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): - """ - Tokenizing strategy for ShareGPT prompts. - """ - - def get_conversation_thread(self, prompt): - return prompt["conversations"] - - def tokenize_prompt(self, prompt): - # Initial values. We will append to these as we go through the conversation. - result, current_len = tokenize_prompt_default() - conversation: Conversation = ( - self.prompter._conversation.copy() # pylint: disable=protected-access - ) - - input_roles = {conversation.roles[0]} - output_roles = {conversation.roles[1]} - - if len(conversation.roles) == 3: - tool_role_label = conversation.roles[2] - input_roles.add(tool_role_label) - - # Add roles from the config - if self.prompter.roles: - if "input" in self.prompter.roles and self.prompter.roles["input"]: - for role in self.prompter.roles["input"]: - input_roles.add(role) - - if "output" in self.prompter.roles and self.prompter.roles["output"]: - for role in self.prompter.roles["output"]: - output_roles.add(role) - - # support for custom roles from the dataset, only useful for vicuna style prompts/roles - role_remap = [] - if ( - conversation.name == "vicuna_v1.1" - and "roles" in prompt - and len(prompt["roles"]) >= 2 - ): - role_remap = [ - {"from": conversation.roles[0], "to": prompt["roles"][0]}, - {"from": conversation.roles[1], "to": prompt["roles"][1]}, - ] - - try: - for _, part in enumerate( - self.prompter.build_prompt(self.get_conversation_thread(prompt)) - ): - if not isinstance(part, tuple): - LOG.warning(f"expected tuple, got {part}") - continue - - if len(part) <= 2: - role, content = part - weight = 1 - else: - role, content, weight = part - - # Uses "in" because role contains extra characters - input_turn = any(r.lower() in role.lower() for r in input_roles) - output_turn = any(r.lower() in role.lower() for r in output_roles) - empty_role = role.strip() == "" - - if not any([input_turn, output_turn, empty_role]): - LOG.warning(f"unhandled role: {role}") - continue - - if input_turn: - role = ( - role.replace(role_remap[0]["from"], role_remap[0]["to"]) - if role_remap - else role - ) - turn = role + content - # this is still the user query, we should - if not content.strip(): - LOG.warning(f"user turn has empty text: {prompt}") - res = self._tokenize( - turn, - add_eos_token=False, - strip_bos_token=True, - ) - if self.train_on_inputs and weight == 1: - labels = copy.deepcopy(res["input_ids"]) - else: - # everything from this is masked out from the labels - labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - elif output_turn: - role = ( - role.replace(role_remap[1]["from"], role_remap[1]["to"]) - if role_remap - else role - ) - turn = role + content - # this should be the assistant response, should end with an eos token - if not content.strip(): - LOG.warning(f"assistant turn has empty text: {prompt}") - add_eos_token = not ( - conversation.name == "chatml" - and conversation.sep == self.tokenizer.eos_token - ) - res = self._tokenize( - turn, - add_eos_token=add_eos_token, - strip_bos_token=True, - ) - role_res = self._tokenize( - role.rstrip(), - add_eos_token=False, - strip_bos_token=True, - ) - labels = copy.deepcopy(res["input_ids"]) - if not self.train_on_inputs: - # mask out role tokens from the labels - len_role = len(role_res["input_ids"]) - labels[:len_role] = [IGNORE_TOKEN_ID] * min( - len_role, len(labels) - ) - if weight == 0: - # everything from this is masked out from the labels - # (role is masked out too because it makes no sense if contents is masked out) - labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - - elif empty_role: - turn = content - # this is only ever the first part, should include the bos token and the user query - res = self._tokenize( - turn, add_eos_token=False, strip_bos_token=False - ) - if self.train_on_inputs and weight == 1: - labels = copy.deepcopy(res["input_ids"]) - else: - # everything from this is masked out from the labels - labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - - # pylint: disable=duplicate-code - result, current_len = parse_tokenized_to_result( - result, - current_len, - res, - labels, - pad_token_id=self.tokenizer.pad_token_id, - ) - return result - except (KeyError, AssertionError, IndexError) as err: - raise InvalidDataException(str(err)) from err - - def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]: """ Returns the default values for the tokenize prompt function diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 18b73e725e..ec680702dc 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -5,7 +5,6 @@ from typing import Generator, Optional, Union from colorama import Fore -from fastchat.conversation import Conversation, get_conv_template LOG = logging.getLogger("axolotl") IGNORE_TOKEN_ID = -100 @@ -262,166 +261,10 @@ def __repr__(self) -> str: ) -SHAREGPT_ASSERTION_FAILED_ROLE = ( +ALTERNATING_ASSERTION_FAILED_ROLE = ( "Role did not alternate between turns (gpt and human). Please check your data." ) -CONVERSATION_ROLE_FORMAT = { - "chatml": "<|im_start|>{ROLE}", - "zephyr": "<|{ROLE}|>", - "vicuna_v1.1": "{ROLE}", - "llama3": "<|start_header_id|>{ROLE}<|end_header_id|>", -} - - -class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods - """ - A prompter that generates prompts for the ShareGPT - """ - - role_key_human = "human" - role_key_model = "gpt" - # Optional, only used for tool usage datasets. - role_key_tool: Optional[str] = None - # Optional, role input/output mapping - roles: Optional[dict] = None - - def __init__( - self, - prompt_style=None, # pylint: disable=unused-argument - conversation: Optional[Union[str, Conversation]] = None, - role_key_human: Optional[str] = None, - role_key_model: Optional[str] = None, - role_key_tool: Optional[str] = None, - roles: Optional[dict] = None, - ): - if conversation: - if isinstance(conversation, Conversation): - self._conversation = conversation - else: - self._conversation = get_conv_template(conversation) - else: - self._conversation = get_conv_template("vicuna_v1.1") - if role_key_human: - self.role_key_human = role_key_human - if role_key_model: - self.role_key_model = role_key_model - if role_key_tool: - self.role_key_tool = role_key_tool - if roles: - self.roles = roles - - def _build_result(self, source): - if len(source) < 2: - # If there isn't a back and forth conversation, ignore it - # also happens on the data splitting leaving empty conversations - raise IndexError( - f"A conversation entry has less than 2 messages :\n{source}" - ) - - conv = self._conversation.copy() - - original_source = source.copy() - # Add the conversation system prompt if provided, otherwise use the default one - if source[0]["from"] == "system": - conv.set_system_message(source[0]["value"]) - source.pop(0) - - roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]} - if self.role_key_tool: - roles[self.role_key_tool] = conv.roles[2] - - try: - # Apply prompt templates - if source[0]["from"] not in roles: - # Skip the first one if it is not from human - source = source[1:] - except IndexError as err: - # sometimes there is a bing or system chat - raise err - - conv.messages = [] - for _, sentence in enumerate(source): - from_role = sentence["from"] - if from_role in roles: - role = roles[from_role] - else: - if self._conversation.name not in CONVERSATION_ROLE_FORMAT: - raise NotImplementedError( - f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet." - "Please help us by creating an Issue to add support for this conversation type." - ) - - if self._conversation.name in ["llama3"]: - role = from_role - else: - role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format( - ROLE=from_role - ) - - if len(conv.messages) > 0 and ((role == conv.messages[-1][0])): - if ( - role != "assistant" - ): # back to back assistant calls may be okay for tool calls - LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}") - - conv.append_message(role, sentence["value"]) - turns = list(conv.get_turns()) - original_source_length = len(original_source) - assert len(turns) in [ - original_source_length - 1, - original_source_length, - original_source_length + 1, - ] - if len(turns) == original_source_length + 1: - original_source = [{"weight": None}] + original_source - elif len(turns) == original_source_length - 1: - original_source = original_source[1:] - return [ - (*turn, weight) - for turn, weight in zip( - turns, - [ - 1 if "weight" not in e or e["weight"] is None else e["weight"] - for e in original_source - ], - ) - ] - - def build_prompt(self, source) -> Generator[str, None, None]: - turns = self._build_result(source) - - for part in turns: - if part[0] and not part[1]: - LOG.warning(f"role with empty message: {part[0]}") - yield part - - def __repr__(self) -> str: - turns = self._build_result([{"from": "{from}", "value": "{value}"}]) - return "\n".join([REPR_TEMPLATE.format(full_prompt=part) for part in turns]) - - -class ShareGPTPrompterV2(ShareGPTPrompter): - """ - A V2 prompter that generates prompts for the ShareGPT - """ - - def __init__( - self, - conversation: Optional[Union[str, Conversation]] = None, - role_key_human: Optional[str] = None, - role_key_model: Optional[str] = None, - role_key_tool: Optional[str] = None, - roles: Optional[dict] = None, - ): - super().__init__( - conversation=conversation, - role_key_human=role_key_human, - role_key_model=role_key_model, - role_key_tool=role_key_tool, - roles=roles, - ) - class UnsupportedPrompter(Prompter): """ diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index afc8c4fc41..6e5ecda03a 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -215,11 +215,6 @@ def normalize_cfg_datasets(cfg): if cfg.chat_template: if cfg.datasets: for idx, ds_cfg in enumerate(cfg.datasets): - if ds_cfg.type == "sharegpt" and not ds_cfg.conversation: - LOG.info( - f"updating dataset {ds_cfg.path} with `conversation: {cfg.chat_template}` to match your chat_template" - ) - cfg.datasets[idx].conversation = cfg.chat_template if ( ds_cfg.type in ["orpo.chat_template", "chat_template"] and not ds_cfg.chat_template @@ -461,27 +456,6 @@ def legacy_validate_config(cfg): "`early_stopping_patience` requires that eval_steps should evenly divide save_steps." ) - if cfg.datasets: - for idx, ds_cfg in enumerate(cfg.datasets): - if not ds_cfg.type: - continue - if ds_cfg.type == "sharegpt:chat": - LOG.warning( - PendingDeprecationWarning( - "`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead." - ) - ) - cfg.datasets[idx].type = "sharegpt" - if "sharegpt_simple" in ds_cfg.type: - LOG.warning( - PendingDeprecationWarning( - "`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead." - ) - ) - cfg.datasets[idx].type = cfg.datasets[idx].type.replace( - "sharegpt_simple", "sharegpt" - ) - if cfg.saves_per_epoch and cfg.save_steps: raise ValueError( "save_steps and saves_per_epoch are mutually exclusive and cannot be used together." diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 64fec67e0e..8310fd3e57 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -783,26 +783,16 @@ class Config: @field_validator("datasets", mode="before") @classmethod - def fix_sharegpt_datasets(cls, datasets): - for idx, ds_cfg in enumerate(datasets): - if not ds_cfg["type"]: + def deprecate_sharegpt_datasets(cls, datasets): + for _, ds_cfg in enumerate(datasets): + if not ds_cfg.get("type"): continue - if ds_cfg["type"] == "sharegpt:chat": - LOG.warning( - PendingDeprecationWarning( - "`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead." - ) - ) - datasets[idx]["type"] = "sharegpt" - if "sharegpt_simple" in ds_cfg["type"]: - LOG.warning( - PendingDeprecationWarning( - "`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead." - ) - ) - datasets[idx]["type"] = datasets[idx]["type"].replace( - "sharegpt_simple", "sharegpt" + + if ds_cfg["type"].startswith("sharegpt"): + raise ValueError( + "`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead." ) + return datasets @model_validator(mode="before") diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py index f353aebec9..97c3c64650 100644 --- a/src/axolotl/utils/tokenization.py +++ b/src/axolotl/utils/tokenization.py @@ -1,8 +1,6 @@ """Module for tokenization utilities""" import logging -import re -from typing import Dict, List from termcolor import colored @@ -93,65 +91,3 @@ def check_rl_example_labels(example, tokenizer, text_only=False): LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n") return delimiter.join(colored_tokens) - - -GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"] -GLAIVE_TO_SHAREGPT_ROLE = { - "SYSTEM": "system", - "USER": "human", - "ASSISTANT": "gpt", - "FUNCTION RESPONSE": "tool", -} - -GLAIVE_MSG_REGEX = re.compile(rf"({'|'.join(GLAIVE_ROLES)}): ") - - -def chatml_to_conversation(row: Dict[str, str]) -> List[Dict[str, str]]: - """ - Converts a ChatML formatted row to a list of messages in ShareGPT format. - Initially based off https://github.com/lilacai/lilac/blob/main/notebooks/GlaiveToShareGPT.ipynb. - """ - - system_prompt = row.get("system") - if system_prompt: - system_prompt = system_prompt.removeprefix("SYSTEM: ") - - chat_str = row["chat"] - chat_msgs = [s.strip() for s in GLAIVE_MSG_REGEX.split(chat_str) if s] - - chat_msg_dicts = [ - {"from": GLAIVE_TO_SHAREGPT_ROLE[role], "value": value} - for role, value in zip(chat_msgs[::2], chat_msgs[1::2]) - ] - - if system_prompt: - chat_msg_dicts = [ - {"from": GLAIVE_TO_SHAREGPT_ROLE["SYSTEM"], "value": system_prompt} - ] + chat_msg_dicts - - return chat_msg_dicts - - -def merge_consecutive_messages(messages): - """ - Merge consecutive messages from the same sender into a single message. - This can be useful with datasets that contain multiple consecutive tool calls. - """ - - merged_messages = [] - current_from = None - current_message = "" - - for msg in messages: - if current_from == msg["from"]: - current_message += msg["value"] - else: - if current_from is not None: - merged_messages.append({"from": current_from, "value": current_message}) - current_from = msg["from"] - current_message = msg["value"] - - if current_from is not None: - merged_messages.append({"from": current_from, "value": current_message}) - - return merged_messages diff --git a/tests/prompt_strategies/test_sharegpt.py b/tests/prompt_strategies/test_sharegpt.py deleted file mode 100644 index e7a73a0de5..0000000000 --- a/tests/prompt_strategies/test_sharegpt.py +++ /dev/null @@ -1,500 +0,0 @@ -""" -Test module for sharegpt integration w chatml -""" - -import pytest -from datasets import Dataset -from tokenizers import AddedToken -from transformers import AutoTokenizer - -from axolotl.datasets import TokenizedPromptDataset -from axolotl.prompt_strategies.sharegpt import ( - GlaiveShareGPTPromptTokenizingStrategy, - SimpleShareGPTPromptTokenizingStrategy, - register_chatml_template, - register_llama3_template, -) -from axolotl.prompters import ShareGPTPrompterV2 - -register_chatml_template() -register_llama3_template() - - -@pytest.fixture(name="sharegpt_dataset") -def fixture_sharegpt_dataset(): - return Dataset.from_list( - [ - { - "conversations": [ - { - "from": "system", - "value": "repeat", - }, - { - "from": "human", - "value": "hello", - }, - { - "from": "gpt", - "value": "hello", - }, - { - "from": "human", - "value": "goodbye", - }, - { - "from": "gpt", - "value": "goodbye", - }, - ] - } - ] - ) - - -@pytest.fixture(name="sharegpt_dataset_with_weights") -def fixture_sharegpt_dataset_with_weights(): - return Dataset.from_list( - [ - { - "conversations": [ - { - "from": "system", - "value": "repeat", - }, - { - "from": "human", - "value": "hello", - "weight": 1, - }, - { - "from": "gpt", - "value": "hello", - "weight": 0, - }, - { - "from": "human", - "value": "rehello", - "weight": 0, - }, - { - "from": "gpt", - "value": "rehello", - "weight": 1, - }, - { - "from": "human", - "value": "goodbye", - }, - { - "from": "gpt", - "value": "goodbye", - "weight": 0, - }, - ] - } - ] - ) - - -@pytest.fixture(name="glaive_dataset") -def fixture_sharegpt_glaive_dataset(): - return Dataset.from_list( - [ - { - "system": "SYSTEM: This is a system prompt", - "chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>", - } - ] - ) - - -@pytest.fixture(name="multi_role_dataset") -def fixture_multi_role_dataset(): - return Dataset.from_list( - [ - { - "conversations": [ - { - "from": "system", - "value": "use get_weather(city) to get the weather for a city", - }, - { - "from": "human", - "value": "hello, what's the weather in New York?", - }, - { - "from": "gpt", - "value": "let me get that for you", - }, - { - "from": "tool", - "value": "get_weather(New York)", - }, - { - "from": "gpt", - "value": "the weather in New York is 70 degrees and sunny", - }, - ] - } - ] - ) - - -@pytest.fixture(name="tokenizer") -def fixture_tokenizer(): - tokenizer = AutoTokenizer.from_pretrained( - "casperhansen/mistral-7b-instruct-v0.1-awq" - ) - tokenizer.add_special_tokens( - { - "eos_token": AddedToken( - "<|im_end|>", rstrip=False, lstrip=False, normalized=False - ) - } - ) - tokenizer.add_tokens( - [ - AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False), - ] - ) - - return tokenizer - - -@pytest.fixture(name="llama3_tokenizer") -def fixture_llama3_tokenizer(): - tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B") - tokenizer.eos_token = "<|eot_id|>" - - return tokenizer - - -class TestSharegptLlama3: - """Test class for ShareGPT style datasets with llama-3 prompts""" - - def test_tokenization(self, sharegpt_dataset, llama3_tokenizer): - strategy = SimpleShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2( - conversation="llama3", - role_key_model=None, - role_key_human=None, - ), - llama3_tokenizer, - False, # train_on_inputs - 2048, # sequence_len - ) - - dataset_wrapper = TokenizedPromptDataset( - strategy, sharegpt_dataset, process_count=1 - ) - - input_ids = dataset_wrapper[0]["input_ids"] - - # fmt: off - # pylint: disable=duplicate-code - assert input_ids == [ - 128000, # bos - 128006, 9125, 128007, # system header - 271, 31724, 128009, # sys prompt, eot - 128006, 882, 128007, # user header - 271, 15339, 128009, # user prompt eot - 128006, 78191, 128007, # assistant header - 271, 15339, 128009, # assistant response eot - 128006, 882, 128007, - 271, 19045, 29474, 128009, - 128006, 78191, 128007, - 271, 19045, 29474, 128009, - ] - # fmt: on - - def test_tokenization_with_weights( - self, sharegpt_dataset_with_weights, llama3_tokenizer - ): - strategy = SimpleShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2( - conversation="llama3", - role_key_model=None, - role_key_human=None, - ), - llama3_tokenizer, - False, # train_on_inputs - 2048, # sequence_len - ) - - dataset_wrapper = TokenizedPromptDataset( - strategy, sharegpt_dataset_with_weights, process_count=1 - ) - - input_ids = dataset_wrapper[0]["input_ids"] - - # fmt: off - # pylint: disable=duplicate-code - assert input_ids == [ - 128000, # bos - 128006, 9125, 128007, # system header - 271, 31724, 128009, # sys prompt, eot - 128006, 882, 128007, # user header - 271, 15339, 128009, # user prompt eot - 128006, 78191, 128007, # assistant header - 271, 15339, 128009, # assistant response eot - 128006, 882, 128007, - 271, 11310, 4896, 128009, - 128006, 78191, 128007, - 271, 11310, 4896, 128009, - 128006, 882, 128007, - 271, 19045, 29474, 128009, - 128006, 78191, 128007, - 271, 19045, 29474, 128009, - ] - # fmt: on - - -class TestSharegptChatML: - """ - Test class for sharegpt prompter - """ - - def test_no_double_im_end(self, sharegpt_dataset, tokenizer): - strategy = SimpleShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2( - conversation="chatml", - role_key_model=None, - role_key_human=None, - ), - tokenizer, - False, # train_on_inputs - 2048, # sequence_len - ) - - dataset_wrapper = TokenizedPromptDataset( - strategy, sharegpt_dataset, process_count=1 - ) - - input_ids = dataset_wrapper[0]["input_ids"] - # fmt: off - assert input_ids == [ - # 28705, 13, is " \n" - 1, # bos - 32001, 1587, 13, 25997, 32000, 28705, 13, # system - 32001, 2188, 13, 21558, 32000, 28705, 13, # human - 32001, 13892, 13, 21558, 32000, 28705, 13, # gpt - 32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human - 32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt - ] - # fmt: on - - def test_no_double_im_end_with_weights( - self, sharegpt_dataset_with_weights, tokenizer - ): - strategy = SimpleShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2( - conversation="chatml", - role_key_model=None, - role_key_human=None, - ), - tokenizer, - False, # train_on_inputs - 2048, # sequence_len - ) - - dataset_wrapper = TokenizedPromptDataset( - strategy, sharegpt_dataset_with_weights, process_count=1 - ) - - input_ids = dataset_wrapper[0]["input_ids"] - # fmt: off - assert input_ids == [ - # 28705, 13, is " \n" - 1, # bos - 32001, 1587, 13, 25997, 32000, 28705, 13, # system - 32001, 2188, 13, 21558, 32000, 28705, 13, # human - 32001, 13892, 13, 21558, 32000, 28705, 13, # gpt - 32001, 2188, 13, 267, 21558, 32000, 28705, 13, # human - 32001, 13892, 13, 267, 21558, 32000, 28705, 13, # gpt - 32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human - 32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt - ] - # fmt: on - - def test_no_train_on_input(self, sharegpt_dataset, tokenizer): - strategy = SimpleShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2( - conversation="chatml", - role_key_model=None, - role_key_human=None, - ), - tokenizer, - False, # train_on_inputs - 2048, # sequence_len - ) - - dataset_wrapper = TokenizedPromptDataset( - strategy, sharegpt_dataset, process_count=1 - ) - - labels = dataset_wrapper[0]["labels"] - # fmt: off - assert labels == [ - -100, # bos - -100, -100, -100, -100, -100, -100, -100, # system - -100, -100, -100, -100, -100, -100, -100, # human - -100, -100, 13, 21558, 32000, 28705, 13, # gpt - -100, -100, -100, -100, -100, -100, -100, -100, # human - -100, -100, 13, 12684, 17664, 32000, 28705, 13, # gpt - ] - # fmt: on - - def test_no_train_on_input_with_weights( - self, sharegpt_dataset_with_weights, tokenizer - ): - strategy = SimpleShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2( - conversation="chatml", - role_key_model=None, - role_key_human=None, - ), - tokenizer, - False, # train_on_inputs - 2048, # sequence_len - ) - - dataset_wrapper = TokenizedPromptDataset( - strategy, sharegpt_dataset_with_weights, process_count=1 - ) - - labels = dataset_wrapper[0]["labels"] - # fmt: off - assert labels == [ - -100, # bos - -100, -100, -100, -100, -100, -100, -100, # system - -100, -100, -100, -100, -100, -100, -100, # human - -100, -100, -100, -100, -100, -100, -100, # gpt with weight zero - -100, -100, -100, -100, -100, -100, -100, -100, # human - -100, -100, 13, 267, 21558, 32000, 28705, 13, # gpt - -100, -100, -100, -100, -100, -100, -100, -100, # human - -100, -100, -100, -100, -100, -100, -100, -100 # gpt with weight zero - ] - # fmt: on - - def test_w_train_on_input(self, sharegpt_dataset, tokenizer): - strategy = SimpleShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2( - conversation="chatml", - role_key_model=None, - role_key_human=None, - ), - tokenizer, - True, # train_on_inputs - 2048, # sequence_len - ) - - dataset_wrapper = TokenizedPromptDataset( - strategy, sharegpt_dataset, process_count=1 - ) - - labels = dataset_wrapper[0]["labels"] - # fmt: off - assert labels == [ - 1, # bos - 32001, 1587, 13, 25997, 32000, 28705, 13, # system - 32001, 2188, 13, 21558, 32000, 28705, 13, # human - 32001, 13892, 13, 21558, 32000, 28705, 13, # gpt - 32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human - 32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt - ] - # fmt: on - - def test_w_train_on_input_with_weights( - self, sharegpt_dataset_with_weights, tokenizer - ): - strategy = SimpleShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2( - conversation="chatml", - role_key_model=None, - role_key_human=None, - ), - tokenizer, - True, # train_on_inputs - 2048, # sequence_len - ) - - dataset_wrapper = TokenizedPromptDataset( - strategy, sharegpt_dataset_with_weights, process_count=1 - ) - - labels = dataset_wrapper[0]["labels"] - # fmt: off - assert labels == [ - 1, # bos - 32001, 1587, 13, 25997, 32000, 28705, 13, # system - 32001, 2188, 13, 21558, 32000, 28705, 13, # human - -100, -100, -100, -100, -100, -100, -100, # gpt with weight 0 - -100, -100, -100, -100, -100, -100, -100, -100, # human with weight 0 - 32001, 13892, 13, 267, 21558, 32000, 28705, 13, # gpt - 32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human - -100, -100, -100, -100, -100, -100, -100, -100 # gpt with weight 0 - ] - # fmt: on - - def test_chatml_glaive(self, glaive_dataset, tokenizer): - strategy = GlaiveShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2( - conversation="chatml", - role_key_model=None, - role_key_human=None, - ), - tokenizer, - True, # train_on_inputs - 2048, # sequence_len - ) - - dataset_wrapper = TokenizedPromptDataset( - strategy, glaive_dataset, process_count=1 - ) - - labels = dataset_wrapper[0]["labels"] - # fmt: off - assert labels == [ - 1, # bos - 32001, 1587, 13, 3260, 349, 264, 1587, 11510, 32000, 28705, 13, # system - 32001, 2188, 13, 6325, 368, 1820, 264, 9314, 354, 528, 477, 1450, 2726, 298, 4222, 28804, 32000, 28705, 13, # human - 32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt - ] - # fmt: on - - def test_multi_role_dataset(self, multi_role_dataset, tokenizer): - strategy = SimpleShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2(conversation="chatml", roles={"input": ["tool"]}), - tokenizer, - False, # train_on_inputs - 2048, # sequence_len - ) - - dataset_wrapper = TokenizedPromptDataset( - strategy, multi_role_dataset, process_count=1 - ) - - input_ids = dataset_wrapper[0]["input_ids"] - # fmt: off - assert input_ids == [ - 1, # bos - 32001, 1587, 13, 1730, 625, 28730, 769, 1223, 28732, 18373, 28731, 298, 625, 272, 8086, 354, 264, 2990, 32000, 28705, 13, # system - 32001, 2188, 13, 21558, 28725, 767, 28742, 28713, 272, 8086, 297, 1450, 2726, 28804, 32000, 28705, 13, # human - 32001, 13892, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt - 32001, 3921, 13, 527, 28730, 769, 1223, 28732, 2972, 2726, 28731, 32000, 28705, 13, # tool - 32001, 13892, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt - ] - # fmt: on - - labels = dataset_wrapper[0]["labels"] - # fmt: off - assert labels == [ - -100, # bos - -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # system - -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # human - -100, -100, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt - -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool - -100, -100, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt - ] - # fmt: on diff --git a/tests/test_normalize_config.py b/tests/test_normalize_config.py index 2e76ceb45d..0d663183d4 100644 --- a/tests/test_normalize_config.py +++ b/tests/test_normalize_config.py @@ -39,12 +39,12 @@ def test_chat_template_chatml(self): "datasets": [ { "path": "lorem/ipsum", - "type": "sharegpt", - "conversation": "vicuna_v1.1", + "type": "chat_template", + "chat_template": "gemma", }, { "path": "sit/amet", - "type": "sharegpt", + "type": "chat_template", }, ], } @@ -52,8 +52,8 @@ def test_chat_template_chatml(self): normalize_cfg_datasets(cfg) - assert cfg.datasets[0].conversation == "vicuna_v1.1" - assert cfg.datasets[1].conversation == "chatml" + assert cfg.datasets[0].chat_template == "gemma" + assert cfg.datasets[1].chat_template == "chatml" @patch("axolotl.utils.config.is_torch_bf16_gpu_available") def test_bf16_auto_setter_available(self, mock_bf16_avail): diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index 63e9a621bf..4fb72f3e1d 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -3,7 +3,6 @@ import json import logging import unittest -from copy import deepcopy from pathlib import Path from typing import Optional @@ -21,12 +20,8 @@ LLama2ChatTokenizingStrategy, ) from axolotl.prompt_strategies.orpo.chat_template import load -from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy -from axolotl.prompt_tokenizers import ( - AlpacaPromptTokenizingStrategy, - ShareGPTPromptTokenizingStrategy, -) -from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2 +from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy +from axolotl.prompters import AlpacaPrompter, PromptStyle from axolotl.utils.dict import DictDefault LOG = logging.getLogger("axolotl") @@ -65,17 +60,6 @@ } -def prompt_strat(conversation, tokenizer): - "Helper function to create a prompt strategy for testing." - prompter = ShareGPTPrompterV2(conversation=conversation) - return ShareGPTPromptTokenizingStrategy( - prompter, - tokenizer, - False, - 2048, - ) - - class TestPromptTokenizationStrategies(unittest.TestCase): """ Test class for prompt tokenization strategies. @@ -98,196 +82,6 @@ def setUp(self) -> None: } ) - def test_sharegpt_integration(self): - with open( - Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8" - ) as fin: - data = fin.read() - conversation = json.loads(data) - with open( - Path(__file__).parent / "fixtures/conversation.tokenized.json", - encoding="utf-8", - ) as fin: - data = fin.read() - tokenized_conversation = json.loads(data) - prompter = ShareGPTPrompterV2() - strat = ShareGPTPromptTokenizingStrategy( - prompter, - self.tokenizer, - False, - 2048, - ) - example = strat.tokenize_prompt(conversation) - for fields in ["input_ids", "attention_mask", "labels"]: - self.assertEqual(len(example[fields]), len(tokenized_conversation[fields])) - self.assertEqual(example[fields], tokenized_conversation[fields]) - - def test_sharegpt_warnings_integration(self): - with open( - Path(__file__).parent / "fixtures/conversation.missingturns.json", - encoding="utf-8", - ) as fin: - data = fin.read() - conversation = json.loads(data) - prompter = ShareGPTPrompterV2() - strat = ShareGPTPromptTokenizingStrategy( - prompter, - self.tokenizer, - False, - 2048, - ) - with self._caplog.at_level(logging.WARNING): - strat.tokenize_prompt(conversation) - assert "assistant turn has empty text" in self._caplog.records[1].message - - def test_sharegpt_warnings_turns(self): - conversation = { - "conversations": [ - {"from": "system", "value": "lorem"}, - {"from": "gpt", "value": "ipsum"}, - {"from": "human", "value": "dolor"}, - {"from": "human", "value": "dolor"}, - {"from": "gpt", "value": "sit"}, - ] - } - prompter = ShareGPTPrompterV2() - strat = ShareGPTPromptTokenizingStrategy( - prompter, - self.tokenizer, - False, - 2048, - ) - with self._caplog.at_level(logging.WARNING): - strat.tokenize_prompt(conversation) - assert ( - "Role did not alternate between turns (gpt and human)" - in self._caplog.records[0].message - ) - - def test_sharegpt_llama(self): - "Make sure the sharegpt/llama is tokenized and formatted correctly." - strat = prompt_strat("llama-2", self.tokenizer) - - def tokenize(conv): - return strat.tokenize_prompt(deepcopy(conv))["input_ids"] - - def decode(ids): - return strat.tokenizer.decode(ids) - - # fmt: off - # System message, multi-turn conversations - mt_ids = tokenize(test_data['multi_turn_sys']) - assert decode(mt_ids) == ' [INST] <>\nlorem\n<>\n\nabc [/INST] ipsum [INST] 123 [/INST] sit' - assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] - - # System message, single-turn conversations - st_ids = tokenize(test_data['single_turn_sys']) - assert decode(st_ids) == ' [INST] <>\nlorem\n<>\n\nabc [/INST] ipsum' - assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2] - - # No system message, single-turn - ns_ids = tokenize(test_data['single_turn_no_sys']) - assert decode(ns_ids) == ' [INST] abc [/INST] ipsum' - assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2] - - # No system message, multi-turn - ns_mt_ids = tokenize(test_data['multi_turn_no_sys']) - assert decode(ns_mt_ids) == ' [INST] abc [/INST] ipsum [INST] 123 [/INST] sit' - assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] - # fmt: on - - def test_sharegpt_mistral(self): - "Make sure the sharegpt/mistral is tokenized and formatted correctly." - strat = prompt_strat("mistral", self.tokenizer) - - def tokenize(conv): - return strat.tokenize_prompt(deepcopy(conv))["input_ids"] - - def decode(ids): - return strat.tokenizer.decode(ids) - - # fmt: off - # System message, multi-turn conversations - mt_ids = tokenize(test_data['multi_turn_sys']) - assert decode(mt_ids) == ' [INST] lorem\nabc [/INST] ipsum [INST] 123 [/INST] sit' - assert mt_ids == [1, 518, 25580, 29962, 29871, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] - - # System message, single-turn conversations - st_ids = tokenize(test_data['single_turn_sys']) - assert decode(st_ids) == ' [INST] lorem\nabc [/INST] ipsum' - assert st_ids == [1, 518, 25580, 29962, 29871, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2] - - # No system message, single-turn - ns_ids = tokenize(test_data['single_turn_no_sys']) - assert decode(ns_ids) == ' [INST] abc [/INST] ipsum' - assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2] - - # No system message, multi-turn - ns_mt_ids = tokenize(test_data['multi_turn_no_sys']) - assert decode(ns_mt_ids) == ' [INST] abc [/INST] ipsum [INST] 123 [/INST] sit' - assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] - # fmt: on - - def test_sharegpt_changes_roles(self): - conversation = { - "roles": ["USER", "CHARACTER"], - "conversations": [ - {"from": "system", "value": "lorem"}, - {"from": "gpt", "value": "ipsum"}, - {"from": "human", "value": "dolor"}, - {"from": "gpt", "value": "sit"}, - ], - } - prompter = ShareGPTPrompterV2() - strat = ShareGPTPromptTokenizingStrategy( - prompter, - self.tokenizer, - False, - 2048, - ) - with self._caplog.at_level(logging.WARNING): - res = strat.tokenize_prompt(conversation) - assert "CHARACTER" in self.tokenizer.decode(res["input_ids"]) - - def test_sharegpt_assistant_label_ignore(self): - conversation = { - "roles": ["user", "assistant"], - "conversations": [ - {"from": "system", "value": "lorem"}, - {"from": "gpt", "value": "ipsum"}, - {"from": "human", "value": "dolor"}, - {"from": "gpt", "value": "sit"}, - ], - } - prompter = ShareGPTPrompterV2() - strat = ShareGPTPromptTokenizingStrategy( - prompter, - self.tokenizer, - False, - 2048, - ) - with self._caplog.at_level(logging.WARNING): - res = strat.tokenize_prompt(conversation) - idx = res["input_ids"].index(20255) # assistant token - assert res["labels"][idx] == -100 - - def test_glaive_tool_label_ignore(self): - conversation = { - "system": "SYSTEM: This is a system prompt", - "chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>", - } - prompter = ShareGPTPrompterV2() - strat = GlaiveShareGPTPromptTokenizingStrategy( - prompter, - self.tokenizer, - False, - 2048, - ) - with self._caplog.at_level(logging.WARNING): - res = strat.tokenize_prompt(conversation) - idx = res["input_ids"].index(13566) # assistant token - assert res["labels"][idx] == -100 - def test_no_sys_prompt(self): """ tests the interface between the user and assistant parts diff --git a/tests/test_validation.py b/tests/test_validation.py index fb63977f5c..67670b1928 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -646,39 +646,6 @@ def test_merge_lora_no_bf16_fail(self, minimal_cfg): validate_config(cfg) - def test_sharegpt_deprecation(self, minimal_cfg): - cfg = ( - DictDefault( - {"datasets": [{"path": "lorem/ipsum", "type": "sharegpt:chat"}]} - ) - | minimal_cfg - ) - with self._caplog.at_level(logging.WARNING): - new_cfg = validate_config(cfg) - assert any( - "`type: sharegpt:chat` will soon be deprecated." in record.message - for record in self._caplog.records - ) - assert new_cfg.datasets[0].type == "sharegpt" - - cfg = ( - DictDefault( - { - "datasets": [ - {"path": "lorem/ipsum", "type": "sharegpt_simple:load_role"} - ] - } - ) - | minimal_cfg - ) - with self._caplog.at_level(logging.WARNING): - new_cfg = validate_config(cfg) - assert any( - "`type: sharegpt_simple` will soon be deprecated." in record.message - for record in self._caplog.records - ) - assert new_cfg.datasets[0].type == "sharegpt:load_role" - def test_no_conflict_save_strategy(self, minimal_cfg): cfg = ( DictDefault( diff --git a/tests/test_validation_dataset.py b/tests/test_validation_dataset.py index 389424217b..7e288f8165 100644 --- a/tests/test_validation_dataset.py +++ b/tests/test_validation_dataset.py @@ -48,9 +48,8 @@ def test_dataset_config_no_drop_param(self, minimal_cfg): | { "datasets": [ { - "path": "LDJnr/Puffin", - "type": "sharegpt", - "conversation": "chatml", + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", "shards": 10, } ] @@ -62,7 +61,6 @@ def test_dataset_config_no_drop_param(self, minimal_cfg): def _check_config(): assert checked_cfg.datasets[0].path == cfg.datasets[0].path assert checked_cfg.datasets[0].type == cfg.datasets[0].type - assert checked_cfg.datasets[0].conversation == cfg.datasets[0].conversation assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards _check_config()