diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 75d8432da8..1dccb6d526 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -38,10 +38,15 @@ def __init__( # pylint: disable=super-init-not-called def process(self, dataset): features = dataset.features.keys() num_proc = min(64, os.cpu_count()) + map_kwargs = {} + if self.prompt_tokenizer.supports_batched: + map_kwargs["batched"] = True + map_kwargs["batch_size"] = 100 return dataset.map( self.prompt_tokenizer.tokenize_prompt, num_proc=num_proc, remove_columns=features, + **map_kwargs, ) diff --git a/src/axolotl/prompt_strategies/completion.py b/src/axolotl/prompt_strategies/completion.py index ee5b4cb3e8..3285e667cb 100644 --- a/src/axolotl/prompt_strategies/completion.py +++ b/src/axolotl/prompt_strategies/completion.py @@ -1,10 +1,81 @@ """ Basic completion text """ -from typing import Any, Dict, Optional +from collections import defaultdict +from typing import Any, Dict, Generator, Optional, Tuple -from axolotl.prompt_tokenizers import CompletionPromptTokenizingStrategy -from axolotl.prompters import CompletionPrompter +from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy + + +class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): + """ + Tokenizing strategy for Completion prompts. + """ + + _field: str = "text" + + def __init__(self, *args, max_length=None, **kwargs): + super().__init__(*args, **kwargs) + if max_length is not None: + self.max_length = max_length + + @property + def supports_batched(self): + return True + + @property + def field(self) -> str: + return self._field + + @field.setter + def field(self, new_field: str): + self._field = new_field + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: + return ( + prompt[self.field], + "", + "", + ) + + def tokenize_prompt(self, prompt): + res = defaultdict(lambda: []) + feature_names = list(prompt.keys()) + for row in zip(*prompt.values()): + prompt_row = dict(zip(feature_names, row)) + ( + instruction, + _, + _, + ) = self.parse_instruction_fields(prompt_row) + + full_prompt = self._build_full_prompt(instruction, None, None) + tokenized_full_prompt = self._tokenize(full_prompt) + + for key, val in tokenized_full_prompt.items(): + for i in range(0, len(val), self.sequence_len): + res[key].append(val[i : i + self.sequence_len]) + + return dict(res) + + def _build_full_prompt( + self, instruction, input, response + ): # pylint: disable=redefined-builtin + return next(iter(self.prompter.build_prompt(instruction, input, response))) + + +class CompletionPrompter: + """ + Prompter for completion + """ + + def build_prompt( + self, + instruction: str, + input=None, # pylint: disable=redefined-builtin, unused-argument + output=None, # pylint: disable=unused-argument + ) -> Generator[str, None, None]: + yield instruction def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): @@ -13,6 +84,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): tokenizer, cfg.train_on_inputs, cfg.sequence_len, + max_length=cfg.sequence_len * 64, ) if ds_cfg and "field" in ds_cfg: strat.field = ds_cfg["field"] diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index f30d0e3832..9553e3e8fe 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -41,11 +41,16 @@ def __init__( self.tokenizer: PreTrainedTokenizer = tokenizer self.train_on_inputs = train_on_inputs self.sequence_len = sequence_len + self.max_length = sequence_len @abc.abstractmethod def tokenize_prompt(self, prompt): pass + @property + def supports_batched(self): + return False + @functools.lru_cache(maxsize=128) def _get_user_token(self): try: @@ -77,7 +82,7 @@ def _tokenize( result = self.tokenizer( prompt, truncation=True, - max_length=self.sequence_len, + max_length=self.max_length, padding=False, return_tensors=None, ) @@ -86,7 +91,7 @@ def _tokenize( if ( len(result["input_ids"]) > 0 and result["input_ids"][-1] != self.tokenizer.eos_token_id - and len(result["input_ids"]) < self.sequence_len + and len(result["input_ids"]) < self.max_length and add_eos_token ): result["input_ids"].append(self.tokenizer.eos_token_id) @@ -247,46 +252,6 @@ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: ) -class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - """ - Tokenizing strategy for Completion prompts. - """ - - _field: str = "text" - - @property - def field(self) -> str: - return self._field - - @field.setter - def field(self, new_field: str): - self._field = new_field - - def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: - return ( - prompt[self.field], - "", - "", - ) - - def tokenize_prompt(self, prompt): - ( - instruction, - _, - _, - ) = self.parse_instruction_fields(prompt) - - full_prompt = self._build_full_prompt(instruction, None, None) - tokenized_full_prompt = self._tokenize(full_prompt) - - return tokenized_full_prompt - - def _build_full_prompt( - self, instruction, input, response - ): # pylint: disable=redefined-builtin - return next(iter(self.prompter.build_prompt(instruction, input, response))) - - class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): """ Tokenizing strategy for Reflection prompts. diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 5322a10182..d6d14c3694 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -135,20 +135,6 @@ def match_prompt_style(self): self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:" -class CompletionPrompter: - """ - Prompter for completion - """ - - def build_prompt( - self, - instruction: str, - input=None, # pylint: disable=redefined-builtin, unused-argument - output=None, # pylint: disable=unused-argument - ) -> Generator[str, None, None]: - yield instruction - - class GPTeacherPrompter(AlpacaPrompter): """ Prompter for GPTeacher