Skip to content

Commit

Permalink
support custom field for completion from yml (axolotl-ai-cloud#580)
Browse files Browse the repository at this point in the history
* support custom field for completion from yml

* remove legacy completion check and add doc

* update README docs
  • Loading branch information
winglian authored Sep 15, 2023
1 parent 30da388 commit 8d6fe07
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 12 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- path: EleutherAI/pile
name: enron_emails
type: completion # format from earlier
field: text # Optional[str] default: text, field to use for completion data
# huggingface repo with multiple named configurations/subsets
datasets:
Expand Down Expand Up @@ -444,6 +445,9 @@ datasets:
# 'no_input_format' cannot include {input}
no_input_format: "{instruction} "
# for completions datsets, uses the provided field if not `text`
field:
# axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared
Expand Down
5 changes: 5 additions & 0 deletions src/axolotl/prompt_strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module to load prompt strategies."""

import importlib
import inspect

from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig

Expand All @@ -16,6 +17,10 @@ def load(strategy, tokenizer, cfg, ds_cfg):
load_kwargs = {}
if strategy == "user_defined":
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
else:
sig = inspect.signature(func)
if "ds_cfg" in sig.parameters:
load_kwargs["ds_cfg"] = ds_cfg
return func(tokenizer, cfg, **load_kwargs)
except Exception: # pylint: disable=broad-exception-caught
return None
20 changes: 20 additions & 0 deletions src/axolotl/prompt_strategies/completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
Basic completion text
"""
from typing import Any, Dict, Optional

from axolotl.prompt_tokenizers import CompletionPromptTokenizingStrategy
from axolotl.prompters import CompletionPrompter


def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
strat = CompletionPromptTokenizingStrategy(
CompletionPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "field" in ds_cfg:
strat.field = ds_cfg["field"]

return strat
25 changes: 24 additions & 1 deletion src/axolotl/prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,31 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
Tokenizing strategy for Completion prompts.
"""

_field: str = "text"

@property
def field(self) -> str:
return self._field

@field.setter
def field(self, new_field: str):
self._field = new_field

def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt[self.field],
"",
"",
)

def tokenize_prompt(self, prompt):
full_prompt = self._build_full_prompt(prompt["text"], None, None)
(
instruction,
_,
_,
) = self.parse_instruction_fields(prompt)

full_prompt = self._build_full_prompt(instruction, None, None)
tokenized_full_prompt = self._tokenize(full_prompt)

return tokenized_full_prompt
Expand Down
11 changes: 0 additions & 11 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
AlpacaMultipleChoicePromptTokenizingStrategy,
AlpacaPromptTokenizingStrategy,
AlpacaReflectionPTStrategy,
CompletionPromptTokenizingStrategy,
GPTeacherPromptTokenizingStrategy,
JeopardyPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy,
Expand All @@ -31,7 +30,6 @@
)
from axolotl.prompters import (
AlpacaPrompter,
CompletionPrompter,
GPTeacherPrompter,
JeopardyPrompter,
MultipleChoiceConcisePrompter,
Expand Down Expand Up @@ -327,15 +325,6 @@ def for_d_in_datasets(dataset_configs):
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "completion":
ds_strategy = CompletionPromptTokenizingStrategy(
CompletionPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
else:
suffix = ""
if ":load_" in d.type:
Expand Down

0 comments on commit 8d6fe07

Please sign in to comment.