Skip to content

Commit

Permalink
Preprocess dataset size fix (#1131)
Browse files Browse the repository at this point in the history
* overwrite cache on preprocess step
* don't cache the TokenizedPromptDataset at all
* load_from_cache_file no longer needed
  • Loading branch information
winglian authored Jan 17, 2024
1 parent ece0211 commit 7570446
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 16 deletions.
1 change: 1 addition & 0 deletions src/axolotl/cli/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.is_preprocess = True
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((PreprocessCliArgs))
Expand Down
6 changes: 5 additions & 1 deletion src/axolotl/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def __init__( # pylint: disable=super-init-not-called
):
self.prompt_tokenizer = prompt_tokenizer
self.process_count = process_count
super().__init__(self.process(dataset).data, **kwargs)
super().__init__(
self.process(dataset).data,
**kwargs,
)

def process(self, dataset):
features = dataset.features.keys()
Expand All @@ -52,6 +55,7 @@ def process(self, dataset):
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,
remove_columns=features,
keep_in_memory=True,
**map_kwargs,
)

Expand Down
40 changes: 30 additions & 10 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,12 +594,16 @@ def get_dataset_wrapper(
)
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
ds_strategy,
dataset,
process_count=cfg.dataset_processes,
)
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
ds_strategy,
dataset,
process_count=cfg.dataset_processes,
)
elif d_base_type == "alpaca":
dataset_prompter = AlpacaPrompter(d_prompt_style)
Expand All @@ -610,7 +614,9 @@ def get_dataset_wrapper(
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
ds_strategy,
dataset,
process_count=cfg.dataset_processes,
)
dataset_wrapper = ds_wrapper
elif d_base_type == "explainchoice":
Expand All @@ -622,7 +628,9 @@ def get_dataset_wrapper(
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
ds_strategy,
dataset,
process_count=cfg.dataset_processes,
)
dataset_wrapper = ds_wrapper
elif d_base_type == "concisechoice":
Expand All @@ -634,7 +642,9 @@ def get_dataset_wrapper(
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
ds_strategy,
dataset,
process_count=cfg.dataset_processes,
)
dataset_wrapper = ds_wrapper
elif d_base_type == "summarizetldr":
Expand All @@ -646,7 +656,9 @@ def get_dataset_wrapper(
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
ds_strategy,
dataset,
process_count=cfg.dataset_processes,
)
dataset_wrapper = ds_wrapper
elif d_base_type == "jeopardy":
Expand All @@ -658,7 +670,9 @@ def get_dataset_wrapper(
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
ds_strategy,
dataset,
process_count=cfg.dataset_processes,
)
dataset_wrapper = ds_wrapper
elif d_base_type == "oasst":
Expand All @@ -670,7 +684,9 @@ def get_dataset_wrapper(
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
ds_strategy,
dataset,
process_count=cfg.dataset_processes,
)
dataset_wrapper = ds_wrapper
elif d_base_type == "gpteacher":
Expand All @@ -682,7 +698,9 @@ def get_dataset_wrapper(
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
ds_strategy,
dataset,
process_count=cfg.dataset_processes,
)
dataset_wrapper = ds_wrapper
elif d_base_type == "reflection":
Expand All @@ -694,7 +712,9 @@ def get_dataset_wrapper(
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
ds_strategy,
dataset,
process_count=cfg.dataset_processes,
)
dataset_wrapper = ds_wrapper
else:
Expand Down
22 changes: 17 additions & 5 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,27 +111,39 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
with zero_first(is_main_process()):
if cfg.group_by_length:
train_dataset = train_dataset.map(
add_length, num_proc=cfg.dataset_processes
add_length,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
)

if cfg.sample_packing:
train_dataset = train_dataset.map(
add_position_ids, num_proc=cfg.dataset_processes
add_position_ids,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
)
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
add_position_ids, num_proc=cfg.dataset_processes
add_position_ids,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
)

if cfg.group_by_length or cfg.sample_packing:
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)

train_dataset = train_dataset.filter(drop_long, num_proc=cfg.dataset_processes)
train_dataset = train_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
)
if eval_dataset:
eval_dataset = eval_dataset.filter(
drop_long, num_proc=cfg.dataset_processes
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
)

# Phi doesn't want the attention_mask feature when training
Expand Down

0 comments on commit 7570446

Please sign in to comment.