Skip to content

Commit

Permalink
try
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Feb 7, 2024
1 parent 90653b0 commit 581a748
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 10 deletions.
Empty file added tests/models/__init__.py
Empty file.
Empty file added tests/models/llama/__init__.py
Empty file.
35 changes: 25 additions & 10 deletions tests/models/llama/test_train_llama.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,31 @@
import os
import unittest
import torch

from transformers import AutoModelForCausalLM
from ..test_model import TestModelTraining

device = "cuda"

class LLamaModelTrainingTest(TestModelTraining, unittest.TestCase):

class LLaMaTrainingTest(unittest.TestCase):
def prepare_command(self):
tmp_dir = "my_dir"

def test_foo(self):
ckpt = "meta-llama/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, token=os.getenv("HF_HUB_READ_TOKEN", None))
model.train()
model.to(device)
testargs = f"""
run_image_classification.py
--output_dir {tmp_dir}
--model_name_or_path google/vit-base-patch16-224-in21k
--dataset_name hf-internal-testing/cats_vs_dogs_sample
--do_train
--do_eval
--learning_rate 1e-4
--per_device_train_batch_size 2
--per_device_eval_batch_size 1
--remove_unused_columns False
--overwrite_output_dir True
--dataloader_num_workers 16
--metric_for_best_model accuracy
--max_steps 10
--train_val_split 0.1
--seed 42
--label_column_name labels
"""

return testargs
68 changes: 68 additions & 0 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import os
import sys
from unittest.mock import patch


SRC_DIRS = [
os.path.join(os.path.dirname(__file__), "..", "transformers/examples/pytorch", dirname)
for dirname in [
"text-generation",
"text-classification",
"token-classification",
"language-modeling",
"multiple-choice",
"question-answering",
"summarization",
"translation",
"image-classification",
"speech-recognition",
"audio-classification",
"speech-pretraining",
"image-pretraining",
"semantic-segmentation",
]
]

print(SRC_DIRS)

sys.path.extend(SRC_DIRS)


if SRC_DIRS is not None:
import run_audio_classification
import run_clm
import run_generation
import run_glue
import run_image_classification
import run_mae
import run_mlm
import run_ner
import run_qa as run_squad
import run_semantic_segmentation
import run_seq2seq_qa as run_squad_seq2seq
import run_speech_recognition_ctc
import run_speech_recognition_ctc_adapter
import run_speech_recognition_seq2seq
import run_summarization
import run_swag
import run_translation
import run_wav2vec2_pretraining_no_trainer


class TestModelTraining:

def prepare_command(self):
raise NotImplementedError

def test_model_training(self):

testargs = self.prepare_command().split()

with patch.object(sys, "argv", testargs):
run_image_classification.main()
# result = get_results(tmp_dir)
# self.assertGreaterEqual(result["eval_accuracy"], 0.8)


class TestModelInference:
pass

0 comments on commit 581a748

Please sign in to comment.