Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run HF dataset processing on local rank 0 first #716

Merged
merged 17 commits into from
Nov 6, 2023
32 changes: 29 additions & 3 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
from typing import Any, Callable, Dict, List, Optional, Union

import datasets as hf_datasets
from composer.utils import dist
from omegaconf import DictConfig
from streaming import StreamingDataset
from transformers import PreTrainedTokenizerBase
Expand Down Expand Up @@ -332,6 +333,13 @@ def build_from_hf(
preprocessing_fn = self.get_preprocessing_fn_from_str(
proto_preprocessing_fn, dataset_name)

signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_data_prep_completed'

if dist.get_local_rank() != 0:
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
log.debug('Waiting for local_rank 0 to finish data prep')
with dist.local_rank_zero_download_and_wait(signal_file_path):
dist.barrier()
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
dakinggg marked this conversation as resolved.
Show resolved Hide resolved

dataset = hf_datasets.load_dataset(dataset_name, split=split, **kwargs)

def dataset_mapper(example: Dict):
Expand All @@ -340,18 +348,21 @@ def dataset_mapper(example: Dict):
return _tokenize_formatted_example(example, tokenizer)

detected_cpu_count = os.cpu_count() or 1
num_cpus_to_use = max(1, detected_cpu_count - 4)
detected_cpus_with_margin = detected_cpu_count - 8
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
num_cpus_to_use = max(1, detected_cpus_with_margin)

columns_to_remove = list(dataset[0].keys())
tokenized_dataset = dataset.map(
dataset_mapper,
batched=False,
remove_columns=columns_to_remove,
num_proc=num_cpus_to_use,
desc='Tokenizing dataset',
)
prompt_length_filtered_dataset = tokenized_dataset.filter(
lambda example: len(example['input_ids']) < max_seq_len,
num_proc=num_cpus_to_use,
desc='Filtering out long prompts',
)

examples_removed = len(tokenized_dataset) - len(
Expand All @@ -361,17 +372,32 @@ def dataset_mapper(example: Dict):
f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}.'
)

pad_token_id = tokenizer.pad_token_id
empty_examples_dropped_dataset = prompt_length_filtered_dataset.filter(
lambda example: len(example['input_ids']) > 0 and len(example[
'labels']) > 0 and any(token_id != tokenizer.pad_token_id
for token_id in example['labels']))
'labels']) > 0 and any(token_id != pad_token_id
for token_id in example['labels']),
num_proc=num_cpus_to_use,
desc='Filtering out empty examples')

log.debug('Done tokenizing and filtering examples.')

empty_examples_removed = len(prompt_length_filtered_dataset) - len(
empty_examples_dropped_dataset)
if empty_examples_removed > 0:
warnings.warn(
f'Dropped {empty_examples_removed} examples where the prompt or response was empty, '
+ 'or the response was only padding tokens.')

if dist.get_local_rank() == 0:
log.debug('Local rank 0 finished data prep')
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_data_prep')

dist.barrier()
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
os.remove(signal_file_path)
dakinggg marked this conversation as resolved.
Show resolved Hide resolved

log.debug('All ranks finished data prep')
return empty_examples_dropped_dataset

def build_from_streaming(self, *args: Any,
Expand Down
Loading