From 2adfa9b3c94ec08015f911698052eec5e6e0d37c Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Tue, 22 Oct 2024 17:14:08 +0000 Subject: [PATCH] Add test --- tests/data/test_dataset.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 071c189b68..b89fcc4b37 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -1,15 +1,33 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import os from contextlib import nullcontext from typing import Optional from unittest import mock import pytest -from llmfoundry.data.finetuning.tasks import dataset_constructor +from llmfoundry.data.finetuning.tasks import ( + _get_num_processes, + dataset_constructor, +) from llmfoundry.utils.exceptions import DatasetTooSmallError +def test_get_num_processes(): + with mock.patch.dict(os.environ, {'MAX_NUM_PROC': '4'}): + with mock.patch('os.cpu_count', return_value=16): + assert _get_num_processes() == 4 + + with mock.patch.dict(os.environ, {'MAX_NUM_PROC': '32'}): + with mock.patch('os.cpu_count', return_value=16): + assert _get_num_processes() == 8 + + with mock.patch.dict(os.environ, {}): + with mock.patch('os.cpu_count', return_value=16): + assert _get_num_processes() == 8 + + @pytest.mark.parametrize('num_canonical_nodes', [None, 8, 2]) def test_finetuning_streaming_dataset_too_small( num_canonical_nodes: Optional[int],