diff --git a/tests/models/llama/test_train_llama.py b/tests/models/llama/test_train_llama.py index 1138fdd8..48357175 100644 --- a/tests/models/llama/test_train_llama.py +++ b/tests/models/llama/test_train_llama.py @@ -1,9 +1,9 @@ import unittest -from ..test_model import TestModelTraining +from ..test_model import ModelTrainingTestMixin -class LLamaModelTrainingTest(TestModelTraining, unittest.TestCase): +class LLamaModelTrainingTest(ModelTrainingTestMixin, unittest.TestCase): def prepare_command(self): tmp_dir = "my_dir" diff --git a/tests/models/test_model.py b/tests/models/test_model.py index bd90d5d0..8e188d98 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -49,7 +49,7 @@ import run_wav2vec2_pretraining_no_trainer -class TestModelTraining: +class ModelTrainingTestMixin: def prepare_command(self): raise NotImplementedError