diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/models/llama/__init__.py b/tests/models/llama/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/models/llama/test_train_llama.py b/tests/models/llama/test_train_llama.py index 91dce94c..1138fdd8 100644 --- a/tests/models/llama/test_train_llama.py +++ b/tests/models/llama/test_train_llama.py @@ -1,16 +1,31 @@ -import os import unittest -import torch -from transformers import AutoModelForCausalLM +from ..test_model import TestModelTraining -device = "cuda" +class LLamaModelTrainingTest(TestModelTraining, unittest.TestCase): -class LLaMaTrainingTest(unittest.TestCase): + def prepare_command(self): + tmp_dir = "my_dir" - def test_foo(self): - ckpt = "meta-llama/Llama-2-7b-hf" - model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, token=os.getenv("HF_HUB_READ_TOKEN", None)) - model.train() - model.to(device) \ No newline at end of file + testargs = f""" + run_image_classification.py + --output_dir {tmp_dir} + --model_name_or_path google/vit-base-patch16-224-in21k + --dataset_name hf-internal-testing/cats_vs_dogs_sample + --do_train + --do_eval + --learning_rate 1e-4 + --per_device_train_batch_size 2 + --per_device_eval_batch_size 1 + --remove_unused_columns False + --overwrite_output_dir True + --dataloader_num_workers 16 + --metric_for_best_model accuracy + --max_steps 10 + --train_val_split 0.1 + --seed 42 + --label_column_name labels + """ + + return testargs diff --git a/tests/models/test_model.py b/tests/models/test_model.py new file mode 100644 index 00000000..ad9d42aa --- /dev/null +++ b/tests/models/test_model.py @@ -0,0 +1,68 @@ +import os +import sys +from unittest.mock import patch + + +SRC_DIRS = [ + os.path.join(os.path.dirname(__file__), "..", "transformers/examples/pytorch", dirname) + for dirname in [ + "text-generation", + "text-classification", + "token-classification", + "language-modeling", + "multiple-choice", + "question-answering", + "summarization", + "translation", + "image-classification", + "speech-recognition", + "audio-classification", + "speech-pretraining", + "image-pretraining", + "semantic-segmentation", + ] +] + +print(SRC_DIRS) + +sys.path.extend(SRC_DIRS) + + +if SRC_DIRS is not None: + import run_audio_classification + import run_clm + import run_generation + import run_glue + import run_image_classification + import run_mae + import run_mlm + import run_ner + import run_qa as run_squad + import run_semantic_segmentation + import run_seq2seq_qa as run_squad_seq2seq + import run_speech_recognition_ctc + import run_speech_recognition_ctc_adapter + import run_speech_recognition_seq2seq + import run_summarization + import run_swag + import run_translation + import run_wav2vec2_pretraining_no_trainer + + +class TestModelTraining: + + def prepare_command(self): + raise NotImplementedError + + def test_model_training(self): + + testargs = self.prepare_command().split() + + with patch.object(sys, "argv", testargs): + run_image_classification.main() + # result = get_results(tmp_dir) + # self.assertGreaterEqual(result["eval_accuracy"], 0.8) + + +class TestModelInference: + pass