Skip to content

Commit

Permalink
Enable tests to run on third-party devcies (#25327)
Browse files Browse the repository at this point in the history
* enable unit tests to run on third-party devcies other than CUDA and CPU.

* remove the modification that enabled ut on MPS

* control test on third-party device by env variable

* update

---------

Co-authored-by: statelesshz <[email protected]>
  • Loading branch information
statelesshz and statelesshz authored Aug 8, 2023
1 parent 5744482 commit 26ce4dd
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def parse_int_from_env(key, default=None):
_tf_gpu_memory_limit = parse_int_from_env("TF_GPU_MEMORY_LIMIT", default=None)
_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True)
_run_tool_tests = parse_flag_from_env("RUN_TOOL_TESTS", default=False)
_run_third_party_device_tests = parse_flag_from_env("RUN_THIRD_PARTY_DEVICE_TESTS", default=False)


def is_pt_tf_cross_test(test_case):
Expand Down Expand Up @@ -612,7 +613,12 @@ def require_torch_multi_npu(test_case):
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
import torch

torch_device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
torch_device = "cuda"
elif _run_third_party_device_tests and is_torch_npu_available():
torch_device = "npu"
else:
torch_device = "cpu"
else:
torch_device = None

Expand Down

0 comments on commit 26ce4dd

Please sign in to comment.