diff --git a/tests/models/llama/test_inference_llama.py b/tests/models/llama/test_inference_llama.py index e69f8983..4ce30563 100644 --- a/tests/models/llama/test_inference_llama.py +++ b/tests/models/llama/test_inference_llama.py @@ -1,7 +1,14 @@ import unittest +from transformers import AutoModelForCausalLM + +device = "cuda" + class LLaMaInferenceTest(unittest.TestCase): def test_foo(self): - assert 1 == 1 + ckpt = "meta-llama/Llama-2-7b" + model = AutoModelForCausalLM.from_pretrained(ckpt) + model.train() + model.to(device) diff --git a/tests/models/llama/test_train_llama.py b/tests/models/llama/test_train_llama.py index 935ed3fe..89ba28bd 100644 --- a/tests/models/llama/test_train_llama.py +++ b/tests/models/llama/test_train_llama.py @@ -1,7 +1,14 @@ import unittest +from transformers import AutoModelForCausalLM + +device = "cuda" + class LLaMaTrainingTest(unittest.TestCase): def test_foo(self): - assert 1 == 2 + ckpt = "meta-llama/Llama-2-7b" + model = AutoModelForCausalLM.from_pretrained(ckpt) + model.train() + model.to(device) \ No newline at end of file