Skip to content

Commit

Permalink
support custom field for completion from yml
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 15, 2023
1 parent a5a625f commit c615f11
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 1 deletion.
5 changes: 5 additions & 0 deletions src/axolotl/prompt_strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module to load prompt strategies."""

import importlib
import inspect

from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig

Expand All @@ -16,6 +17,10 @@ def load(strategy, tokenizer, cfg, ds_cfg):
load_kwargs = {}
if strategy == "user_defined":
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
else:
sig = inspect.signature(func)
if "ds_cfg" in sig.parameters:
load_kwargs["ds_cfg"] = ds_cfg
return func(tokenizer, cfg, **load_kwargs)
except Exception: # pylint: disable=broad-exception-caught
return None
20 changes: 20 additions & 0 deletions src/axolotl/prompt_strategies/completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
Basic completion text
"""
from typing import Any, Dict, Optional

from axolotl.prompt_tokenizers import CompletionPromptTokenizingStrategy
from axolotl.prompters import CompletionPrompter


def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
strat = CompletionPromptTokenizingStrategy(
CompletionPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "field" in ds_cfg:
strat.field = ds_cfg["field"]

return strat
25 changes: 24 additions & 1 deletion src/axolotl/prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,31 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
Tokenizing strategy for Completion prompts.
"""

_field: str = "text"

@property
def field(self) -> str:
return self._field

@field.setter
def field(self, new_field: str):
self._field = new_field

def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt[self.field],
"",
"",
)

def tokenize_prompt(self, prompt):
full_prompt = self._build_full_prompt(prompt["text"], None, None)
(
instruction,
_,
_,
) = self.parse_instruction_fields(prompt)

full_prompt = self._build_full_prompt(instruction, None, None)
tokenized_full_prompt = self._tokenize(full_prompt)

return tokenized_full_prompt
Expand Down

0 comments on commit c615f11

Please sign in to comment.