From 97e6431ae9962bc6b1821e798cd70fc0f9b0a499 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Tue, 23 Jul 2024 22:46:33 +0200 Subject: [PATCH] enable tests --- .../parallelization/test_tensor_parallel.py | 72 +++++++++---------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/tests/fx/parallelization/test_tensor_parallel.py b/tests/fx/parallelization/test_tensor_parallel.py index fe09d6fff69..9626fccec3b 100644 --- a/tests/fx/parallelization/test_tensor_parallel.py +++ b/tests/fx/parallelization/test_tensor_parallel.py @@ -198,17 +198,17 @@ def run_test_tie_word_embeddings(rank: int, world_size: int, model_id: str, mode tearDown() -# @parameterized.expand(DUMMY_MODELS_TO_TEST) -# @unittest.skipIf( -# not is_gpu_available() or not is_torch_compile_available(), "requires gpu and torch version >= 2.3.0 to run" -# ) -# def test_all_rank_results_match( -# model_id, -# model_kwargs, -# ): -# for world_size in [1, 2, 4, 8]: -# if world_size <= NUM_AVAILABLE_DEVICES: -# spawn(world_size, run_test_all_rank_results_match, model_id, model_kwargs, deterministic=True) +@parameterized.expand(DUMMY_MODELS_TO_TEST) +@unittest.skipIf( + not is_gpu_available() or not is_torch_compile_available(), "requires gpu and torch version >= 2.3.0 to run" +) +def test_all_rank_results_match( + model_id, + model_kwargs, +): + for world_size in [1, 2, 4, 8]: + if world_size <= NUM_AVAILABLE_DEVICES: + spawn(world_size, run_test_all_rank_results_match, model_id, model_kwargs, deterministic=True) @parameterized.expand(DUMMY_MODELS_TO_TEST) @@ -226,28 +226,28 @@ def test_parameters_persist_bewteen_recompile( ) -# @parameterized.expand(DUMMY_MODELS_TO_TEST) -# @unittest.skipIf( -# not is_gpu_available() or not is_torch_compile_available() or NUM_AVAILABLE_DEVICES < 2, -# "requires more than one gpu and torch version >= 2.3.0 to run", -# ) -# def test_parallel_results_matches_non_parallel( -# model_id, -# model_kwargs, -# ): -# # world_size == 2 is enough -# spawn(2, run_test_parallel_results_matches_non_parallel, model_id, model_kwargs, deterministic=True) - - -# @parameterized.expand(DUMMY_MODELS_TO_TEST) -# @unittest.skipIf( -# not is_gpu_available() or not is_torch_compile_available(), -# "requires gpu and torch version >= 2.3.0 to run", -# ) -# def test_tie_word_embeddings( -# model_id, -# model_kwargs, -# ): -# for world_size in [1, 2]: -# if world_size <= NUM_AVAILABLE_DEVICES: -# spawn(world_size, run_test_tie_word_embeddings, model_id, model_kwargs, deterministic=False) +@parameterized.expand(DUMMY_MODELS_TO_TEST) +@unittest.skipIf( + not is_gpu_available() or not is_torch_compile_available() or NUM_AVAILABLE_DEVICES < 2, + "requires more than one gpu and torch version >= 2.3.0 to run", +) +def test_parallel_results_matches_non_parallel( + model_id, + model_kwargs, +): + # world_size == 2 is enough + spawn(2, run_test_parallel_results_matches_non_parallel, model_id, model_kwargs, deterministic=True) + + +@parameterized.expand(DUMMY_MODELS_TO_TEST) +@unittest.skipIf( + not is_gpu_available() or not is_torch_compile_available(), + "requires gpu and torch version >= 2.3.0 to run", +) +def test_tie_word_embeddings( + model_id, + model_kwargs, +): + for world_size in [1, 2]: + if world_size <= NUM_AVAILABLE_DEVICES: + spawn(world_size, run_test_tie_word_embeddings, model_id, model_kwargs, deterministic=False)