diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 708a3a54e39..9bf35147307 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3061,6 +3061,7 @@ def test_multi_gpu_data_parallel_forward(self): with torch.no_grad(): _ = model(**self._prepare_for_class(inputs_dict, model_class)) + @require_torch_gpu @require_torch_multi_gpu def test_model_parallelization(self): if not self.test_model_parallel: @@ -3123,6 +3124,7 @@ def get_current_gpu_memory_use(): gc.collect() torch.cuda.empty_cache() + @require_torch_gpu @require_torch_multi_gpu def test_model_parallel_equal_results(self): if not self.test_model_parallel: