diff --git a/tests/models/llama/test_inference_llama.py b/tests/models/llama/test_inference_llama.py index 3a1c9dc9..f41fcc8b 100644 --- a/tests/models/llama/test_inference_llama.py +++ b/tests/models/llama/test_inference_llama.py @@ -3,6 +3,7 @@ import torch from transformers import AutoModelForCausalLM, AutoTokenizer +from ..test_model import ModelInferenceTestMixin torch.backends.cudnn.deterministic = True torch.backends.cuda.matmul.allow_tf32 = False @@ -11,9 +12,9 @@ device = "cuda" -class LLaMaInferenceTest(unittest.TestCase): +class LLaMaInferenceTest(ModelInferenceTestMixin, unittest.TestCase): - def test_foo(self): + def test_inference(self): ckpt = "meta-llama/Llama-2-7b-hf" hf_token = token=os.getenv("HF_HUB_READ_TOKEN", None) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 8e188d98..9759b781 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -1,6 +1,8 @@ +import json import os import sys from unittest.mock import patch +from transformers.testing_utils import TestCasePlus SRC_DIRS = [ @@ -49,7 +51,16 @@ import run_wav2vec2_pretraining_no_trainer -class ModelTrainingTestMixin: +class ModelTrainingTestMixin(TestCasePlus): + + def get_results(self, output_dir): + path = os.path.join(output_dir, "all_results.json") + if os.path.exists(path): + with open(path, "r") as f: + results = json.load(f) + else: + raise ValueError(f"can't find {path}") + return results def prepare_command(self): raise NotImplementedError @@ -58,11 +69,13 @@ def test_model_training(self): testargs = self.prepare_command().split() + tmp_dir = self.get_auto_remove_tmp_dir() with patch.object(sys, "argv", testargs): run_image_classification.main() - # result = get_results(tmp_dir) + result = self.get_results(tmp_dir) # self.assertGreaterEqual(result["eval_accuracy"], 0.8) -class TestModelInference: - pass +class ModelInferenceTestMixin: + def test_inference(self): + raise NotImplementedError