Skip to content

Commit

Permalink
add qwen pretrain test
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Oct 30, 2023
1 parent e7b53f6 commit 51b6375
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 5 deletions.
1 change: 1 addition & 0 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class ModelArguments:
Arguments pertaining to which model/config/tokenizer we are going to pre-train from.
"""

model_type: Optional[str] = field(default="llama", metadata={"help": "Use for CI test."})
model_name_or_path: str = field(
default="__internal_testing__/tiny-random-llama",
metadata={
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2510,6 +2510,7 @@ def predict(
test_dataloader,
description="Prediction",
ignore_keys=ignore_keys,
prediction_loss_only=True if self.compute_metrics is None else None,
metric_key_prefix=metric_key_prefix,
max_eval_iters=self.args.max_evaluate_steps,
)
Expand Down
5 changes: 4 additions & 1 deletion tests/fixtures/llm/pretrain.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ pretrain:
chatglm:
model_type: chatglm
model_name_or_path: __internal_testing__/tiny-fused-chatglm
qwen:
model_type: qwen
model_name_or_path: __internal_testing__/tiny-fused-qwen

inference-predict:
default:
Expand All @@ -44,4 +47,4 @@ inference-infer:
dtype: float16
batch_size: 2
decode_strategy: greedy_search
max_length: 20
max_length: 20
24 changes: 20 additions & 4 deletions tests/llm/test_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
["model_dir"],
[
["llama"],
["qwen"],
],
)
class PretrainTest(LLMTest, unittest.TestCase):
Expand All @@ -41,16 +42,30 @@ def setUp(self) -> None:
LLMTest.setUp(self)

self.dataset_dir = tempfile.mkdtemp()
self.model_codes_dir = os.path.join(self.root_path, self.model_dir)
sys.path.insert(0, self.model_codes_dir)
if self.model_dir != "qwen":
self.model_codes_dir = os.path.join(self.root_path, self.model_dir)
sys.path.insert(0, self.model_codes_dir)
else:
self.model_codes_dir = self.root_path

def tearDown(self) -> None:
LLMTest.tearDown(self)

sys.path.remove(self.model_codes_dir)
if self.model_dir != "qwen":
sys.path.remove(self.model_codes_dir)

shutil.rmtree(self.dataset_dir)

def test_pretrain(self):

pretrain_flag = False
for key, value in sys.modules.items():
if "run_pretrain" in key:
pretrain_flag = True
break
if pretrain_flag:
del sys.modules["run_pretrain"]

# Run pretrain
URL = "https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k_ids.npy"
URL2 = "https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k_idx.npz"
Expand All @@ -67,5 +82,6 @@ def test_pretrain(self):

main()

self.run_predictor({"inference_model": True})
if self.model_dir != "qwen":
self.run_predictor({"inference_model": True})
self.run_predictor({"inference_model": False})
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ tool_helpers
fast_tokenizer_python
sacremoses
pydantic==1.10.9
tiktoken

0 comments on commit 51b6375

Please sign in to comment.