From 04d35d5676db52ce568400c5c7c42679ae3a0c77 Mon Sep 17 00:00:00 2001 From: Fei Wang <917522022@qq.com> Date: Thu, 9 Nov 2023 21:30:30 -0500 Subject: [PATCH] Check spelling in comments and names --- .../llm_split_learning/split_learning_llm_model.py | 11 +++++------ .../llm_split_learning/split_learning_main.py | 2 +- .../llm_split_learning/split_learning_trainer.py | 4 ++-- plato/trainers/split_learning.py | 13 ++++++------- 4 files changed, 14 insertions(+), 16 deletions(-) diff --git a/examples/split_learning/llm_split_learning/split_learning_llm_model.py b/examples/split_learning/llm_split_learning/split_learning_llm_model.py index d48aea64a..aa898e09e 100644 --- a/examples/split_learning/llm_split_learning/split_learning_llm_model.py +++ b/examples/split_learning/llm_split_learning/split_learning_llm_model.py @@ -1,7 +1,6 @@ """ -Obtain LLM models from Huggingface, specifically designed for split learning +Obtain LLM models from HuggingFace, specifically designed for split learning """ - import torch from transformers import AutoModelForCausalLM, AutoConfig from peft import get_peft_model, LoraConfig @@ -18,8 +17,8 @@ def get_lora_model(model): def get_module(start_module: torch.nn.Module, module_names): """ - Recursively get a pytorch module starting from the start module with - a given list of module names. + Recursively get a PyTorch module starting from the start module with + a given list of module names. """ module = start_module for module_name in module_names: @@ -29,7 +28,7 @@ def get_module(start_module: torch.nn.Module, module_names): class BaseModel(torch.nn.Module): """ - The basic model loading hugginface model used for the server model and the client model + The basic model loading HuggingFace model used for the server model and the client model """ def __init__(self, *args, **kwargs) -> None: @@ -122,7 +121,7 @@ def __init__(self, *args, **kwargs) -> None: # In this design, we have two copies of the model # The first copy of the model is the whole model which is used for test. # The second copy of the model only contains the layers on the server - # used for training. + # used for training. self.server_model = AutoModelForCausalLM.from_pretrained( self.model_name, config=self.config, diff --git a/examples/split_learning/llm_split_learning/split_learning_main.py b/examples/split_learning/llm_split_learning/split_learning_main.py index 50fbe14dc..91a10b05d 100644 --- a/examples/split_learning/llm_split_learning/split_learning_main.py +++ b/examples/split_learning/llm_split_learning/split_learning_main.py @@ -1,5 +1,5 @@ """ -Use Split learning to finetune Huggingface large language model. +Finetune HuggingFace large language models using split learning. """ import split_learning_trainer from split_learning_llm_model import ServerModel, ClientModel diff --git a/examples/split_learning/llm_split_learning/split_learning_trainer.py b/examples/split_learning/llm_split_learning/split_learning_trainer.py index 7082394d1..bbaff0446 100644 --- a/examples/split_learning/llm_split_learning/split_learning_trainer.py +++ b/examples/split_learning/llm_split_learning/split_learning_trainer.py @@ -100,7 +100,7 @@ def __init__(self, model=None, callbacks=None): embedding_size = self.model.get_input_embeddings().weight.shape[0] if len(self.tokenizer) > embedding_size: self.model.resize_token_embeddings(len(self.tokenizer)) - # self.training args for huggingface training + # self.training args for HuggingFace training parser = HfArgumentParser(TrainingArguments) (self.training_args,) = parser.parse_args_into_dataclasses( @@ -139,7 +139,7 @@ def test_model_split_learning(self, batch_size, testset, sampler=None): return metrics["eval_accuracy"] # Redesign the training stage specific to Split Learning. - def process_training_samples_before_retreiving(self, training_samples): + def process_training_samples_before_retrieving(self, training_samples): inputs = training_samples["input_ids"] labels = training_samples["labels"] for index, input_item in enumerate(inputs): diff --git a/plato/trainers/split_learning.py b/plato/trainers/split_learning.py index f6cc68fa7..eae7d9dc1 100644 --- a/plato/trainers/split_learning.py +++ b/plato/trainers/split_learning.py @@ -13,7 +13,6 @@ https://arxiv.org/pdf/2112.01637.pdf """ - import logging import os @@ -99,7 +98,7 @@ def get_train_samples(self, batch_size, trainset, sampler): data_loader = iter(data_loader) self.training_samples = next(data_loader) # Wrap the training samples with datasource and sampler to be fed into Plato trainer - self.training_samples = self.process_training_samples_before_retreiving( + self.training_samples = self.process_training_samples_before_retrieving( self.training_samples ) return self.training_samples @@ -190,7 +189,7 @@ def test_model(self, config, testset, sampler=None, **kwargs): """ Evaluates the model with the provided test dataset and test sampler. - Auguments: + Arguments: testset: the test dataset. sampler: the test sampler. The default is None. kwargs (optional): Additional keyword arguments. @@ -200,19 +199,19 @@ def test_model(self, config, testset, sampler=None, **kwargs): return accuracy # API functions for split learning - def process_training_samples_before_retreiving(self, training_samples) -> ...: - """Process training samples before completing retreiving samples.""" + def process_training_samples_before_retrieving(self, training_samples) -> ...: + """Process training samples before completing retrieving samples.""" return training_samples def process_samples_before_client_forwarding(self, examples) -> ...: - """Process the examples befor client conducting forwarding.""" + """Process the examples before client conducting forwarding.""" return examples # pylint:disable=unused-argument def server_forward_from(self, batch, config) -> (..., ..., int): """ The event for server completing training by forwarding from intermediate features. - Uses may override this function for training different models with split learning. + Users may override this function for training different models with split learning. Inputs: batch: the batch of inputs for forwarding.