diff --git a/src/axolotl/prompt_strategies/completion.py b/src/axolotl/prompt_strategies/completion.py index 3285e667cb..2f53988ddf 100644 --- a/src/axolotl/prompt_strategies/completion.py +++ b/src/axolotl/prompt_strategies/completion.py @@ -51,9 +51,11 @@ def tokenize_prompt(self, prompt): full_prompt = self._build_full_prompt(instruction, None, None) tokenized_full_prompt = self._tokenize(full_prompt) + steps = self.sequence_len - self.overlap_len + if steps < 1: raise ValueError("Sequence length must be greater than overlap length") for key, val in tokenized_full_prompt.items(): - for i in range(0, len(val), self.sequence_len): + for i in range(0, len(val), steps): res[key].append(val[i : i + self.sequence_len]) return dict(res) @@ -84,6 +86,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): tokenizer, cfg.train_on_inputs, cfg.sequence_len, + cfg.overlap_len, max_length=cfg.sequence_len * 64, ) if ds_cfg and "field" in ds_cfg: diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 1b39336642..d0fcab07e7 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -42,12 +42,14 @@ def __init__( tokenizer, train_on_inputs: bool = False, sequence_len: int = 2048, + overlap_len: int = 0, ): self.prompter = prompter self.tokenizer: PreTrainedTokenizer = tokenizer self.train_on_inputs = train_on_inputs self.sequence_len = sequence_len self.max_length = sequence_len + self.overlap_len = overlap_len @abc.abstractmethod def tokenize_prompt(self, prompt): diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index bdad21fb11..75d80f8e27 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -94,6 +94,8 @@ def load_tokenized_prepared_datasets( md5( ( str(cfg.sequence_len) + + "," + + str(cfg.overlap_len) + "@" + "|".join( sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]) @@ -277,6 +279,7 @@ def for_d_in_datasets(dataset_configs): tokenizer, cfg.train_on_inputs, cfg.sequence_len, + cfg.overlap_len, ) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) datasets.append(ds_wrapper) @@ -286,6 +289,7 @@ def for_d_in_datasets(dataset_configs): tokenizer, cfg.train_on_inputs, cfg.sequence_len, + cfg.overlap_len, ) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) datasets.append(ds_wrapper) @@ -295,6 +299,7 @@ def for_d_in_datasets(dataset_configs): tokenizer, cfg.train_on_inputs, cfg.sequence_len, + cfg.overlap_len, ) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) datasets.append(ds_wrapper) @@ -304,6 +309,7 @@ def for_d_in_datasets(dataset_configs): tokenizer, cfg.train_on_inputs, cfg.sequence_len, + cfg.overlap_len, ) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) datasets.append(ds_wrapper) @@ -313,6 +319,7 @@ def for_d_in_datasets(dataset_configs): tokenizer, cfg.train_on_inputs, cfg.sequence_len, + cfg.overlap_len, ) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) datasets.append(ds_wrapper) @@ -322,6 +329,7 @@ def for_d_in_datasets(dataset_configs): tokenizer, cfg.train_on_inputs, cfg.sequence_len, + cfg.overlap_len, ) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) datasets.append(ds_wrapper) @@ -331,6 +339,7 @@ def for_d_in_datasets(dataset_configs): tokenizer, cfg.train_on_inputs, cfg.sequence_len, + cfg.overlap_len, ) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) datasets.append(ds_wrapper) @@ -340,6 +349,7 @@ def for_d_in_datasets(dataset_configs): tokenizer, cfg.train_on_inputs, cfg.sequence_len, + cfg.overlap_len, ) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) datasets.append(ds_wrapper) @@ -391,6 +401,8 @@ def load_prepare_datasets( md5( ( str(cfg.sequence_len) + + "," + + str(cfg.overlap_len) + "@" + str(max_packed_sequence_len) + seed