Skip to content

Commit

Permalink
try
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Feb 7, 2024
1 parent c870159 commit b398fc6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 71 deletions.
27 changes: 7 additions & 20 deletions tests/models/llama/test_train_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,14 @@

class LLamaModelTrainingTest(ModelTrainingTestMixin, TestCasePlus):

def prepare_command(self, **kwargs):
# WIP
def get_training_script(self):
raise NotImplementedError

@unittest.skip("WIP")
def prepare_training_command(self, **kwargs):
output_dir = kwargs.get("output_dir", "my_dir")

testargs = f"""
run_image_classification.py
--output_dir {output_dir}
--model_name_or_path google/vit-base-patch16-224-in21k
--dataset_name hf-internal-testing/cats_vs_dogs_sample
--do_train
--do_eval
--learning_rate 1e-4
--per_device_train_batch_size 2
--per_device_eval_batch_size 1
--remove_unused_columns False
--overwrite_output_dir True
--dataloader_num_workers 16
--metric_for_best_model accuracy
--max_steps 10
--train_val_split 0.1
--seed 42
--label_column_name labels
"""
testargs = f""""""

return testargs
61 changes: 10 additions & 51 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,53 +14,13 @@
from accelerate.state import AcceleratorState, PartialState


SRC_DIRS = [
os.path.join("/transformers/examples/pytorch", dirname)
for dirname in [
"text-generation",
"text-classification",
"token-classification",
"language-modeling",
"multiple-choice",
"question-answering",
"summarization",
"translation",
"image-classification",
"speech-recognition",
"audio-classification",
"speech-pretraining",
"image-pretraining",
"semantic-segmentation",
]
]

print(SRC_DIRS)

sys.path.extend(SRC_DIRS)


if SRC_DIRS is not None:
import run_audio_classification
import run_clm
import run_generation
import run_glue
import run_image_classification
import run_mae
import run_mlm
import run_ner
import run_qa as run_squad
import run_semantic_segmentation
import run_seq2seq_qa as run_squad_seq2seq
import run_speech_recognition_ctc
import run_speech_recognition_ctc_adapter
import run_speech_recognition_seq2seq
import run_summarization
import run_swag
import run_translation
import run_wav2vec2_pretraining_no_trainer
class ModelTrainingTestMixin:

def get_training_script(self):
raise NotImplementedError

class ModelTrainingTestMixin:
def prepare_training_command(self, **kwargs):
raise NotImplementedError

def get_results(self, output_dir):
path = os.path.join(output_dir, "all_results.json")
Expand All @@ -71,16 +31,13 @@ def get_results(self, output_dir):
raise ValueError(f"can't find {path}")
return results

def prepare_command(self, **kwargs):
raise NotImplementedError

def test_model_training(self):
def test_training(self):

tmp_dir = self.get_auto_remove_tmp_dir()
testargs = self.prepare_command(output_dir=tmp_dir).split()
testargs = self.prepare_training_command(output_dir=tmp_dir).split()

with patch.object(sys, "argv", testargs):
run_image_classification.main()
self.get_training_script().main()
result = self.get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.8)

Expand All @@ -90,6 +47,8 @@ def test_inference(self):
raise NotImplementedError


# Copied from https://github.com/huggingface/transformers/blob/308d2b90049b4979a949a069aa4f43b2788254d6/src/transformers/testing_utils.py#L1335 # noqa
# (with minimal set of methods and their contents)
class TestCasePlus(unittest.TestCase):

def setUp(self):
Expand Down

0 comments on commit b398fc6

Please sign in to comment.