Skip to content

Commit

Permalink
try
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Feb 7, 2024
1 parent 9e5d347 commit f34e742
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
5 changes: 3 additions & 2 deletions tests/models/llama/test_inference_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer
from ..test_model import ModelInferenceTestMixin

torch.backends.cudnn.deterministic = True
torch.backends.cuda.matmul.allow_tf32 = False
Expand All @@ -11,9 +12,9 @@
device = "cuda"


class LLaMaInferenceTest(unittest.TestCase):
class LLaMaInferenceTest(ModelInferenceTestMixin, unittest.TestCase):

def test_foo(self):
def test_inference(self):
ckpt = "meta-llama/Llama-2-7b-hf"
hf_token = token=os.getenv("HF_HUB_READ_TOKEN", None)

Expand Down
21 changes: 17 additions & 4 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import os
import sys
from unittest.mock import patch
from transformers.testing_utils import TestCasePlus


SRC_DIRS = [
Expand Down Expand Up @@ -49,7 +51,16 @@
import run_wav2vec2_pretraining_no_trainer


class ModelTrainingTestMixin:
class ModelTrainingTestMixin(TestCasePlus):

def get_results(self, output_dir):
path = os.path.join(output_dir, "all_results.json")
if os.path.exists(path):
with open(path, "r") as f:
results = json.load(f)
else:
raise ValueError(f"can't find {path}")
return results

def prepare_command(self):
raise NotImplementedError
Expand All @@ -58,11 +69,13 @@ def test_model_training(self):

testargs = self.prepare_command().split()

tmp_dir = self.get_auto_remove_tmp_dir()
with patch.object(sys, "argv", testargs):
run_image_classification.main()
# result = get_results(tmp_dir)
result = self.get_results(tmp_dir)
# self.assertGreaterEqual(result["eval_accuracy"], 0.8)


class TestModelInference:
pass
class ModelInferenceTestMixin:
def test_inference(self):
raise NotImplementedError

0 comments on commit f34e742

Please sign in to comment.