diff --git a/examples/split_learning/llm_split_learning/split_learning_trainer.py b/examples/split_learning/llm_split_learning/split_learning_trainer.py index 88a73417f..f73a8ac44 100644 --- a/examples/split_learning/llm_split_learning/split_learning_trainer.py +++ b/examples/split_learning/llm_split_learning/split_learning_trainer.py @@ -102,8 +102,9 @@ def __init__(self, model=None, callbacks=None): self.model.resize_token_embeddings(len(self.tokenizer)) # self.training args for huggingface training parser = HfArgumentParser(TrainingArguments) + (self.training_args,) = parser.parse_args_into_dataclasses( - args=["--output_dir=/tmp", "--report_to=none"] + args=["--report_to=none"] ) # Redesign the evaluation stage. @@ -132,8 +133,6 @@ def test_model_split_learning(self, batch_size, testset, sampler=None): # save other metric information such as accuracy tester.log_metrics("eval", metrics) - tester.save_metrics("eval", metrics) - return metrics["eval_accuracy"] # Redesign the training stage specific to Split Learning. diff --git a/examples/split_learning/llm_split_learning/split_learning_wikitext2_gpt2.yml b/examples/split_learning/llm_split_learning/split_learning_wikitext2_gpt2.yml index 1c12d146b..e10c55496 100644 --- a/examples/split_learning/llm_split_learning/split_learning_wikitext2_gpt2.yml +++ b/examples/split_learning/llm_split_learning/split_learning_wikitext2_gpt2.yml @@ -12,7 +12,7 @@ clients: do_test: false # Split learning iterations for each client - iteration: 20 + iteration: 1 server: type: split_learning