diff --git a/src/axolotl/prompt_strategies/__init__.py b/src/axolotl/prompt_strategies/__init__.py index e9e5679534..e62a5c20ce 100644 --- a/src/axolotl/prompt_strategies/__init__.py +++ b/src/axolotl/prompt_strategies/__init__.py @@ -1,6 +1,7 @@ """Module to load prompt strategies.""" import importlib +import inspect from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig @@ -16,6 +17,10 @@ def load(strategy, tokenizer, cfg, ds_cfg): load_kwargs = {} if strategy == "user_defined": load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg) + else: + sig = inspect.signature(func) + if "ds_cfg" in sig.parameters: + load_kwargs["ds_cfg"] = ds_cfg return func(tokenizer, cfg, **load_kwargs) except Exception: # pylint: disable=broad-exception-caught return None diff --git a/src/axolotl/prompt_strategies/completion.py b/src/axolotl/prompt_strategies/completion.py new file mode 100644 index 0000000000..ee5b4cb3e8 --- /dev/null +++ b/src/axolotl/prompt_strategies/completion.py @@ -0,0 +1,20 @@ +""" +Basic completion text +""" +from typing import Any, Dict, Optional + +from axolotl.prompt_tokenizers import CompletionPromptTokenizingStrategy +from axolotl.prompters import CompletionPrompter + + +def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): + strat = CompletionPromptTokenizingStrategy( + CompletionPrompter(), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + if ds_cfg and "field" in ds_cfg: + strat.field = ds_cfg["field"] + + return strat diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index ed32ab24a2..b1aaeb3503 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -245,8 +245,31 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): Tokenizing strategy for Completion prompts. """ + _field: str = "text" + + @property + def field(self) -> str: + return self._field + + @field.setter + def field(self, new_field: str): + self._field = new_field + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: + return ( + prompt[self.field], + "", + "", + ) + def tokenize_prompt(self, prompt): - full_prompt = self._build_full_prompt(prompt["text"], None, None) + ( + instruction, + _, + _, + ) = self.parse_instruction_fields(prompt) + + full_prompt = self._build_full_prompt(instruction, None, None) tokenized_full_prompt = self._tokenize(full_prompt) return tokenized_full_prompt