Skip to content

Commit

Permalink
enable tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenglongjiepheonix committed Jul 23, 2024
1 parent c9c7571 commit 97e6431
Showing 1 changed file with 36 additions and 36 deletions.
72 changes: 36 additions & 36 deletions tests/fx/parallelization/test_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,17 +198,17 @@ def run_test_tie_word_embeddings(rank: int, world_size: int, model_id: str, mode
tearDown()


# @parameterized.expand(DUMMY_MODELS_TO_TEST)
# @unittest.skipIf(
# not is_gpu_available() or not is_torch_compile_available(), "requires gpu and torch version >= 2.3.0 to run"
# )
# def test_all_rank_results_match(
# model_id,
# model_kwargs,
# ):
# for world_size in [1, 2, 4, 8]:
# if world_size <= NUM_AVAILABLE_DEVICES:
# spawn(world_size, run_test_all_rank_results_match, model_id, model_kwargs, deterministic=True)
@parameterized.expand(DUMMY_MODELS_TO_TEST)
@unittest.skipIf(
not is_gpu_available() or not is_torch_compile_available(), "requires gpu and torch version >= 2.3.0 to run"
)
def test_all_rank_results_match(
model_id,
model_kwargs,
):
for world_size in [1, 2, 4, 8]:
if world_size <= NUM_AVAILABLE_DEVICES:
spawn(world_size, run_test_all_rank_results_match, model_id, model_kwargs, deterministic=True)


@parameterized.expand(DUMMY_MODELS_TO_TEST)
Expand All @@ -226,28 +226,28 @@ def test_parameters_persist_bewteen_recompile(
)


# @parameterized.expand(DUMMY_MODELS_TO_TEST)
# @unittest.skipIf(
# not is_gpu_available() or not is_torch_compile_available() or NUM_AVAILABLE_DEVICES < 2,
# "requires more than one gpu and torch version >= 2.3.0 to run",
# )
# def test_parallel_results_matches_non_parallel(
# model_id,
# model_kwargs,
# ):
# # world_size == 2 is enough
# spawn(2, run_test_parallel_results_matches_non_parallel, model_id, model_kwargs, deterministic=True)


# @parameterized.expand(DUMMY_MODELS_TO_TEST)
# @unittest.skipIf(
# not is_gpu_available() or not is_torch_compile_available(),
# "requires gpu and torch version >= 2.3.0 to run",
# )
# def test_tie_word_embeddings(
# model_id,
# model_kwargs,
# ):
# for world_size in [1, 2]:
# if world_size <= NUM_AVAILABLE_DEVICES:
# spawn(world_size, run_test_tie_word_embeddings, model_id, model_kwargs, deterministic=False)
@parameterized.expand(DUMMY_MODELS_TO_TEST)
@unittest.skipIf(
not is_gpu_available() or not is_torch_compile_available() or NUM_AVAILABLE_DEVICES < 2,
"requires more than one gpu and torch version >= 2.3.0 to run",
)
def test_parallel_results_matches_non_parallel(
model_id,
model_kwargs,
):
# world_size == 2 is enough
spawn(2, run_test_parallel_results_matches_non_parallel, model_id, model_kwargs, deterministic=True)


@parameterized.expand(DUMMY_MODELS_TO_TEST)
@unittest.skipIf(
not is_gpu_available() or not is_torch_compile_available(),
"requires gpu and torch version >= 2.3.0 to run",
)
def test_tie_word_embeddings(
model_id,
model_kwargs,
):
for world_size in [1, 2]:
if world_size <= NUM_AVAILABLE_DEVICES:
spawn(world_size, run_test_tie_word_embeddings, model_id, model_kwargs, deterministic=False)

0 comments on commit 97e6431

Please sign in to comment.