Skip to content

Commit

Permalink
feature: add overlap_len option to prompt strategies
Browse files Browse the repository at this point in the history
This is useful with smaller datasets, where the default to split the data into context size length chunks (thus only showing each part of the data a single time).
  • Loading branch information
Karl-Johan Alm committed Oct 3, 2023
1 parent 2642cae commit 2c80e9b
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/axolotl/prompt_strategies/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@ def tokenize_prompt(self, prompt):

full_prompt = self._build_full_prompt(instruction, None, None)
tokenized_full_prompt = self._tokenize(full_prompt)
steps = self.sequence_len - self.overlap_len
if steps < 1: raise ValueError("Sequence length must be greater than overlap length")

for key, val in tokenized_full_prompt.items():
for i in range(0, len(val), self.sequence_len):
for i in range(0, len(val), steps):
res[key].append(val[i : i + self.sequence_len])

return dict(res)
Expand Down Expand Up @@ -84,6 +86,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
cfg.overlap_len,
max_length=cfg.sequence_len * 64,
)
if ds_cfg and "field" in ds_cfg:
Expand Down
2 changes: 2 additions & 0 deletions src/axolotl/prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@ def __init__(
tokenizer,
train_on_inputs: bool = False,
sequence_len: int = 2048,
overlap_len: int = 0,
):
self.prompter = prompter
self.tokenizer: PreTrainedTokenizer = tokenizer
self.train_on_inputs = train_on_inputs
self.sequence_len = sequence_len
self.max_length = sequence_len
self.overlap_len = overlap_len

@abc.abstractmethod
def tokenize_prompt(self, prompt):
Expand Down
12 changes: 12 additions & 0 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def load_tokenized_prepared_datasets(
md5(
(
str(cfg.sequence_len)
+ ","
+ str(cfg.overlap_len)
+ "@"
+ "|".join(
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
Expand Down Expand Up @@ -277,6 +279,7 @@ def for_d_in_datasets(dataset_configs):
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
cfg.overlap_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
Expand All @@ -286,6 +289,7 @@ def for_d_in_datasets(dataset_configs):
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
cfg.overlap_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
Expand All @@ -295,6 +299,7 @@ def for_d_in_datasets(dataset_configs):
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
cfg.overlap_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
Expand All @@ -304,6 +309,7 @@ def for_d_in_datasets(dataset_configs):
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
cfg.overlap_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
Expand All @@ -313,6 +319,7 @@ def for_d_in_datasets(dataset_configs):
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
cfg.overlap_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
Expand All @@ -322,6 +329,7 @@ def for_d_in_datasets(dataset_configs):
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
cfg.overlap_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
Expand All @@ -331,6 +339,7 @@ def for_d_in_datasets(dataset_configs):
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
cfg.overlap_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
Expand All @@ -340,6 +349,7 @@ def for_d_in_datasets(dataset_configs):
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
cfg.overlap_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
Expand Down Expand Up @@ -391,6 +401,8 @@ def load_prepare_datasets(
md5(
(
str(cfg.sequence_len)
+ ","
+ str(cfg.overlap_len)
+ "@"
+ str(max_packed_sequence_len)
+ seed
Expand Down

0 comments on commit 2c80e9b

Please sign in to comment.