Skip to content

Commit

Permalink
Add env var for configuring the maximum number of processes to use fo…
Browse files Browse the repository at this point in the history
…r dataset processing (#1606)
  • Loading branch information
irenedea authored Oct 22, 2024
1 parent 8e78eb5 commit 97d7f6b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 8 deletions.
28 changes: 21 additions & 7 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,22 @@ def is_valid_ift_example(
return True


def _get_num_processes() -> int:
"""Get the number of processes to use for dataset processing."""
detected_cpu_count = os.cpu_count() or 1
detected_cpus_with_margin = detected_cpu_count - 8
num_proc = max(1, detected_cpus_with_margin)

# Check if the user has set the MAX_NUM_PROC environment variable
# which caps the number of processes used for dataset processing.
if 'MAX_NUM_PROC' in os.environ:
max_num_proc_env = int(os.environ['MAX_NUM_PROC'])
if max_num_proc_env < num_proc:
num_proc = max_num_proc_env

return num_proc


class StreamingFinetuningDataset(StreamingDataset):
"""Finetuning dataset with flexible tokenization using StreamingDataset.
Expand Down Expand Up @@ -960,18 +976,16 @@ def dataset_mapper(example: dict):
)
return mapping_fn(example, tokenizer)

detected_cpu_count = os.cpu_count() or 1
detected_cpus_with_margin = detected_cpu_count - 8
num_cpus_to_use = max(1, detected_cpus_with_margin)
if len(dataset) < num_cpus_to_use:
num_cpus_to_use = 1
num_proc = _get_num_processes()
if len(dataset) < num_proc:
num_proc = 1

columns_to_remove = list(dataset[0].keys())
tokenized_dataset = dataset.map(
dataset_mapper,
batched=False,
remove_columns=columns_to_remove,
num_proc=num_cpus_to_use,
num_proc=num_proc,
desc='Tokenizing dataset',
)

Expand All @@ -983,7 +997,7 @@ def dataset_mapper(example: dict):
target_responses,
decoder_only_format,
),
num_proc=num_cpus_to_use,
num_proc=num_proc,
desc='Filtering out long prompts',
)

Expand Down
20 changes: 19 additions & 1 deletion tests/data/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,33 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import os
from contextlib import nullcontext
from typing import Optional
from unittest import mock

import pytest

from llmfoundry.data.finetuning.tasks import dataset_constructor
from llmfoundry.data.finetuning.tasks import (
_get_num_processes,
dataset_constructor,
)
from llmfoundry.utils.exceptions import DatasetTooSmallError


def test_get_num_processes():
with mock.patch.dict(os.environ, {'MAX_NUM_PROC': '4'}):
with mock.patch('os.cpu_count', return_value=16):
assert _get_num_processes() == 4

with mock.patch.dict(os.environ, {'MAX_NUM_PROC': '32'}):
with mock.patch('os.cpu_count', return_value=16):
assert _get_num_processes() == 8

with mock.patch.dict(os.environ, {}):
with mock.patch('os.cpu_count', return_value=16):
assert _get_num_processes() == 8


@pytest.mark.parametrize('num_canonical_nodes', [None, 8, 2])
def test_finetuning_streaming_dataset_too_small(
num_canonical_nodes: Optional[int],
Expand Down

0 comments on commit 97d7f6b

Please sign in to comment.