From ac8e023d4534611a9845a2c993bdc91af7b56fbd Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Wed, 1 Nov 2023 10:29:53 -0700 Subject: [PATCH] Add num_proc to map and filter calls (#706) --- llmfoundry/data/finetuning/tasks.py | 8 +++++++- tests/test_hf_conversion_script.py | 3 ++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index f2bd0239c8..edbfcc28c7 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -339,14 +339,20 @@ def dataset_mapper(example: Dict): example = preprocessing_fn(example) return _tokenize_formatted_example(example, tokenizer) + detected_cpu_count = os.cpu_count() or 1 + num_cpus_to_use = max(1, detected_cpu_count - 4) + columns_to_remove = list(dataset[0].keys()) tokenized_dataset = dataset.map( dataset_mapper, batched=False, remove_columns=columns_to_remove, + num_proc=num_cpus_to_use, ) prompt_length_filtered_dataset = tokenized_dataset.filter( - lambda example: len(example['input_ids']) < max_seq_len) + lambda example: len(example['input_ids']) < max_seq_len, + num_proc=num_cpus_to_use, + ) examples_removed = len(tokenized_dataset) - len( prompt_length_filtered_dataset) diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index d2f203d3a0..d2c2a9e1c9 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -5,7 +5,7 @@ import os import pathlib import sys -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch from composer import Trainer from composer.loggers import MLFlowLogger @@ -254,6 +254,7 @@ def test_callback_inits_with_defaults(): @pytest.mark.parametrize( 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', [('3ba', '2ba', '7ba', 3, 4), ('1dur', '2ba', '1ep', 1, 4)]) +@patch('os.cpu_count', MagicMock(return_value=None)) def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, fsdp_state_dict_type: Optional[str], log_to_mlflow: bool,