diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 915267786f..3d9ed056ef 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -515,6 +515,22 @@ def is_valid_ift_example( return True +def _get_num_processes() -> int: + """Get the number of processes to use for dataset processing.""" + detected_cpu_count = os.cpu_count() or 1 + detected_cpus_with_margin = detected_cpu_count - 8 + num_proc = max(1, detected_cpus_with_margin) + + # Check if the user has set the MAX_NUM_PROC environment variable + # which caps the number of processes used for dataset processing. + if 'MAX_NUM_PROC' in os.environ: + max_num_proc_env = int(os.environ['MAX_NUM_PROC']) + if max_num_proc_env < num_proc: + num_proc = max_num_proc_env + + return num_proc + + class StreamingFinetuningDataset(StreamingDataset): """Finetuning dataset with flexible tokenization using StreamingDataset. @@ -960,18 +976,16 @@ def dataset_mapper(example: dict): ) return mapping_fn(example, tokenizer) - detected_cpu_count = os.cpu_count() or 1 - detected_cpus_with_margin = detected_cpu_count - 8 - num_cpus_to_use = max(1, detected_cpus_with_margin) - if len(dataset) < num_cpus_to_use: - num_cpus_to_use = 1 + num_proc = _get_num_processes() + if len(dataset) < num_proc: + num_proc = 1 columns_to_remove = list(dataset[0].keys()) tokenized_dataset = dataset.map( dataset_mapper, batched=False, remove_columns=columns_to_remove, - num_proc=num_cpus_to_use, + num_proc=num_proc, desc='Tokenizing dataset', ) @@ -983,7 +997,7 @@ def dataset_mapper(example: dict): target_responses, decoder_only_format, ), - num_proc=num_cpus_to_use, + num_proc=num_proc, desc='Filtering out long prompts', ) diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 071c189b68..b89fcc4b37 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -1,15 +1,33 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import os from contextlib import nullcontext from typing import Optional from unittest import mock import pytest -from llmfoundry.data.finetuning.tasks import dataset_constructor +from llmfoundry.data.finetuning.tasks import ( + _get_num_processes, + dataset_constructor, +) from llmfoundry.utils.exceptions import DatasetTooSmallError +def test_get_num_processes(): + with mock.patch.dict(os.environ, {'MAX_NUM_PROC': '4'}): + with mock.patch('os.cpu_count', return_value=16): + assert _get_num_processes() == 4 + + with mock.patch.dict(os.environ, {'MAX_NUM_PROC': '32'}): + with mock.patch('os.cpu_count', return_value=16): + assert _get_num_processes() == 8 + + with mock.patch.dict(os.environ, {}): + with mock.patch('os.cpu_count', return_value=16): + assert _get_num_processes() == 8 + + @pytest.mark.parametrize('num_canonical_nodes', [None, 8, 2]) def test_finetuning_streaming_dataset_too_small( num_canonical_nodes: Optional[int],