diff --git a/tests/models/llama/test_train_llama.py b/tests/models/llama/test_train_llama.py index d28db4c2..3d6005fe 100644 --- a/tests/models/llama/test_train_llama.py +++ b/tests/models/llama/test_train_llama.py @@ -5,27 +5,14 @@ class LLamaModelTrainingTest(ModelTrainingTestMixin, TestCasePlus): - def prepare_command(self, **kwargs): + # WIP + def get_training_script(self): + raise NotImplementedError + + @unittest.skip("WIP") + def prepare_training_command(self, **kwargs): output_dir = kwargs.get("output_dir", "my_dir") - testargs = f""" - run_image_classification.py - --output_dir {output_dir} - --model_name_or_path google/vit-base-patch16-224-in21k - --dataset_name hf-internal-testing/cats_vs_dogs_sample - --do_train - --do_eval - --learning_rate 1e-4 - --per_device_train_batch_size 2 - --per_device_eval_batch_size 1 - --remove_unused_columns False - --overwrite_output_dir True - --dataloader_num_workers 16 - --metric_for_best_model accuracy - --max_steps 10 - --train_val_split 0.1 - --seed 42 - --label_column_name labels - """ + testargs = f"""""" return testargs diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 56e9f75b..6531433e 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -14,53 +14,13 @@ from accelerate.state import AcceleratorState, PartialState -SRC_DIRS = [ - os.path.join("/transformers/examples/pytorch", dirname) - for dirname in [ - "text-generation", - "text-classification", - "token-classification", - "language-modeling", - "multiple-choice", - "question-answering", - "summarization", - "translation", - "image-classification", - "speech-recognition", - "audio-classification", - "speech-pretraining", - "image-pretraining", - "semantic-segmentation", - ] -] - -print(SRC_DIRS) - -sys.path.extend(SRC_DIRS) - - -if SRC_DIRS is not None: - import run_audio_classification - import run_clm - import run_generation - import run_glue - import run_image_classification - import run_mae - import run_mlm - import run_ner - import run_qa as run_squad - import run_semantic_segmentation - import run_seq2seq_qa as run_squad_seq2seq - import run_speech_recognition_ctc - import run_speech_recognition_ctc_adapter - import run_speech_recognition_seq2seq - import run_summarization - import run_swag - import run_translation - import run_wav2vec2_pretraining_no_trainer +class ModelTrainingTestMixin: + def get_training_script(self): + raise NotImplementedError -class ModelTrainingTestMixin: + def prepare_training_command(self, **kwargs): + raise NotImplementedError def get_results(self, output_dir): path = os.path.join(output_dir, "all_results.json") @@ -71,16 +31,13 @@ def get_results(self, output_dir): raise ValueError(f"can't find {path}") return results - def prepare_command(self, **kwargs): - raise NotImplementedError - - def test_model_training(self): + def test_training(self): tmp_dir = self.get_auto_remove_tmp_dir() - testargs = self.prepare_command(output_dir=tmp_dir).split() + testargs = self.prepare_training_command(output_dir=tmp_dir).split() with patch.object(sys, "argv", testargs): - run_image_classification.main() + self.get_training_script().main() result = self.get_results(tmp_dir) self.assertGreaterEqual(result["eval_accuracy"], 0.8) @@ -90,6 +47,8 @@ def test_inference(self): raise NotImplementedError +# Copied from https://github.com/huggingface/transformers/blob/308d2b90049b4979a949a069aa4f43b2788254d6/src/transformers/testing_utils.py#L1335 # noqa +# (with minimal set of methods and their contents) class TestCasePlus(unittest.TestCase): def setUp(self):