Skip to content

Commit

Permalink
Make dataset_processes configurable (#651)
Browse files Browse the repository at this point in the history
I'm using the Axolotl script to train models on https://modal.com serverless GPUs. Unfortunately, their environment seems to have some kind of bug where if I try to run `datasets.filter` with too high a `num_proc`, it throws an error and dies.

This PR adds a new configuration option `dataset_processes`, which lets you explicitly set the number of processes used to map/filter the dataset. If not included, this defaults to the current behavior of setting that to `os.cpu_count()`.
  • Loading branch information
corbt authored Sep 29, 2023
1 parent 590d603 commit 9ec2077
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,9 @@ datasets:
dataset_prepared_path: data/last_run_prepared
# push prepared dataset to hub
push_dataset_to_hub: # repo path
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# if not set.
dataset_processes: # defaults to os.cpu_count() if not set
# push checkpoints to hub
hub_model_id: # repo path to push finetuned model
# how to push checkpoints to hub
Expand Down
2 changes: 2 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def normalize_config(cfg):
else:
cfg.torch_dtype = torch.float32

cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()

model_config = load_model_config(cfg)
cfg.model_config_type = model_config.model_type

Expand Down
16 changes: 11 additions & 5 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,19 +400,25 @@ def disable_datasets_caching():
def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
with zero_first(is_main_process()):
train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count())
train_dataset = train_dataset.filter(drop_long, num_proc=cfg.dataset_processes)
if eval_dataset:
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
eval_dataset = eval_dataset.filter(
drop_long, num_proc=cfg.dataset_processes
)

if cfg.group_by_length:
train_dataset = train_dataset.map(add_length, num_proc=os.cpu_count())
train_dataset = train_dataset.map(
add_length, num_proc=cfg.dataset_processes
)

if cfg.sample_packing:
train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
train_dataset = train_dataset.map(
add_position_ids, num_proc=cfg.dataset_processes
)
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
add_position_ids, num_proc=os.cpu_count()
add_position_ids, num_proc=cfg.dataset_processes
)

# Phi doesn't want the attention_mask feature when training
Expand Down

0 comments on commit 9ec2077

Please sign in to comment.