-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
93 additions
and
10 deletions.
There are no files selected for viewing
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,31 @@ | ||
import os | ||
import unittest | ||
import torch | ||
|
||
from transformers import AutoModelForCausalLM | ||
from ..test_model import TestModelTraining | ||
|
||
device = "cuda" | ||
|
||
class LLamaModelTrainingTest(TestModelTraining, unittest.TestCase): | ||
|
||
class LLaMaTrainingTest(unittest.TestCase): | ||
def prepare_command(self): | ||
tmp_dir = "my_dir" | ||
|
||
def test_foo(self): | ||
ckpt = "meta-llama/Llama-2-7b-hf" | ||
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, token=os.getenv("HF_HUB_READ_TOKEN", None)) | ||
model.train() | ||
model.to(device) | ||
testargs = f""" | ||
run_image_classification.py | ||
--output_dir {tmp_dir} | ||
--model_name_or_path google/vit-base-patch16-224-in21k | ||
--dataset_name hf-internal-testing/cats_vs_dogs_sample | ||
--do_train | ||
--do_eval | ||
--learning_rate 1e-4 | ||
--per_device_train_batch_size 2 | ||
--per_device_eval_batch_size 1 | ||
--remove_unused_columns False | ||
--overwrite_output_dir True | ||
--dataloader_num_workers 16 | ||
--metric_for_best_model accuracy | ||
--max_steps 10 | ||
--train_val_split 0.1 | ||
--seed 42 | ||
--label_column_name labels | ||
""" | ||
|
||
return testargs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import os | ||
import sys | ||
from unittest.mock import patch | ||
|
||
|
||
SRC_DIRS = [ | ||
os.path.join(os.path.dirname(__file__), "..", "transformers/examples/pytorch", dirname) | ||
for dirname in [ | ||
"text-generation", | ||
"text-classification", | ||
"token-classification", | ||
"language-modeling", | ||
"multiple-choice", | ||
"question-answering", | ||
"summarization", | ||
"translation", | ||
"image-classification", | ||
"speech-recognition", | ||
"audio-classification", | ||
"speech-pretraining", | ||
"image-pretraining", | ||
"semantic-segmentation", | ||
] | ||
] | ||
|
||
print(SRC_DIRS) | ||
|
||
sys.path.extend(SRC_DIRS) | ||
|
||
|
||
if SRC_DIRS is not None: | ||
import run_audio_classification | ||
import run_clm | ||
import run_generation | ||
import run_glue | ||
import run_image_classification | ||
import run_mae | ||
import run_mlm | ||
import run_ner | ||
import run_qa as run_squad | ||
import run_semantic_segmentation | ||
import run_seq2seq_qa as run_squad_seq2seq | ||
import run_speech_recognition_ctc | ||
import run_speech_recognition_ctc_adapter | ||
import run_speech_recognition_seq2seq | ||
import run_summarization | ||
import run_swag | ||
import run_translation | ||
import run_wav2vec2_pretraining_no_trainer | ||
|
||
|
||
class TestModelTraining: | ||
|
||
def prepare_command(self): | ||
raise NotImplementedError | ||
|
||
def test_model_training(self): | ||
|
||
testargs = self.prepare_command().split() | ||
|
||
with patch.object(sys, "argv", testargs): | ||
run_image_classification.main() | ||
# result = get_results(tmp_dir) | ||
# self.assertGreaterEqual(result["eval_accuracy"], 0.8) | ||
|
||
|
||
class TestModelInference: | ||
pass |