Skip to content

Commit

Permalink
make tests of pytorch_example device agnostic (#27081)
Browse files Browse the repository at this point in the history
  • Loading branch information
statelesshz authored Oct 30, 2023
1 parent 6b46677 commit cd19b19
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 33 deletions.
24 changes: 12 additions & 12 deletions examples/pytorch/test_accelerate_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,16 @@
import unittest
from unittest import mock

import torch
from accelerate.utils import write_basic_config

from transformers.testing_utils import TestCasePlus, get_gpu_count, run_command, slow, torch_device
from transformers.utils import is_apex_available
from transformers.testing_utils import (
TestCasePlus,
backend_device_count,
is_torch_fp16_available_on_device,
run_command,
slow,
torch_device,
)


logging.basicConfig(level=logging.DEBUG)
Expand All @@ -54,11 +59,6 @@ def get_results(output_dir):
return results


def is_cuda_and_apex_available():
is_using_cuda = torch.cuda.is_available() and torch_device == "cuda"
return is_using_cuda and is_apex_available()


stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)

Expand Down Expand Up @@ -93,7 +93,7 @@ def test_run_glue_no_trainer(self):
--with_tracking
""".split()

if is_cuda_and_apex_available():
if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16")

run_command(self._launch_args + testargs)
Expand All @@ -119,7 +119,7 @@ def test_run_clm_no_trainer(self):
--with_tracking
""".split()

if torch.cuda.device_count() > 1:
if backend_device_count(torch_device) > 1:
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
return

Expand Down Expand Up @@ -152,7 +152,7 @@ def test_run_mlm_no_trainer(self):
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline"})
def test_run_ner_no_trainer(self):
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
epochs = 7 if get_gpu_count() > 1 else 2
epochs = 7 if backend_device_count(torch_device) > 1 else 2

tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
Expand Down Expand Up @@ -326,7 +326,7 @@ def test_run_image_classification_no_trainer(self):
--checkpointing_steps 1
""".split()

if is_cuda_and_apex_available():
if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16")

run_command(self._launch_args + testargs)
Expand Down
41 changes: 20 additions & 21 deletions examples/pytorch/test_pytorch_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@
import sys
from unittest.mock import patch

import torch

from transformers import ViTMAEForPreTraining, Wav2Vec2ForPreTraining
from transformers.testing_utils import CaptureLogger, TestCasePlus, get_gpu_count, slow, torch_device
from transformers.utils import is_apex_available
from transformers.testing_utils import (
CaptureLogger,
TestCasePlus,
backend_device_count,
is_torch_fp16_available_on_device,
slow,
torch_device,
)


SRC_DIRS = [
Expand Down Expand Up @@ -86,11 +90,6 @@ def get_results(output_dir):
return results


def is_cuda_and_apex_available():
is_using_cuda = torch.cuda.is_available() and torch_device == "cuda"
return is_using_cuda and is_apex_available()


stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)

Expand All @@ -116,7 +115,7 @@ def test_run_glue(self):
--max_seq_length=128
""".split()

if is_cuda_and_apex_available():
if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16")

with patch.object(sys, "argv", testargs):
Expand All @@ -141,7 +140,7 @@ def test_run_clm(self):
--overwrite_output_dir
""".split()

if torch.cuda.device_count() > 1:
if backend_device_count(torch_device) > 1:
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
return

Expand Down Expand Up @@ -203,7 +202,7 @@ def test_run_mlm(self):

def test_run_ner(self):
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
epochs = 7 if get_gpu_count() > 1 else 2
epochs = 7 if backend_device_count(torch_device) > 1 else 2

tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
Expand Down Expand Up @@ -312,7 +311,7 @@ def test_run_swag(self):
def test_generation(self):
testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"]

if is_cuda_and_apex_available():
if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16")

model_type, model_name = (
Expand Down Expand Up @@ -401,7 +400,7 @@ def test_run_image_classification(self):
--seed 42
""".split()

if is_cuda_and_apex_available():
if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16")

with patch.object(sys, "argv", testargs):
Expand Down Expand Up @@ -431,7 +430,7 @@ def test_run_speech_recognition_ctc(self):
--seed 42
""".split()

if is_cuda_and_apex_available():
if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16")

with patch.object(sys, "argv", testargs):
Expand Down Expand Up @@ -462,7 +461,7 @@ def test_run_speech_recognition_ctc_adapter(self):
--seed 42
""".split()

if is_cuda_and_apex_available():
if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16")

with patch.object(sys, "argv", testargs):
Expand Down Expand Up @@ -493,7 +492,7 @@ def test_run_speech_recognition_seq2seq(self):
--seed 42
""".split()

if is_cuda_and_apex_available():
if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16")

with patch.object(sys, "argv", testargs):
Expand Down Expand Up @@ -525,7 +524,7 @@ def test_run_audio_classification(self):
--seed 42
""".split()

if is_cuda_and_apex_available():
if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16")

with patch.object(sys, "argv", testargs):
Expand All @@ -551,7 +550,7 @@ def test_run_wav2vec2_pretraining(self):
--seed 42
""".split()

if is_cuda_and_apex_available():
if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16")

with patch.object(sys, "argv", testargs):
Expand Down Expand Up @@ -579,7 +578,7 @@ def test_run_vit_mae_pretraining(self):
--seed 42
""".split()

if is_cuda_and_apex_available():
if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16")

with patch.object(sys, "argv", testargs):
Expand All @@ -604,7 +603,7 @@ def test_run_semantic_segmentation(self):
--seed 32
""".split()

if is_cuda_and_apex_available():
if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16")

with patch.object(sys, "argv", testargs):
Expand Down

0 comments on commit cd19b19

Please sign in to comment.