Skip to content

Commit

Permalink
Check spelling in comments and names
Browse files Browse the repository at this point in the history
  • Loading branch information
silviafeiwang committed Nov 10, 2023
1 parent e5092df commit 04d35d5
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 16 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""
Obtain LLM models from Huggingface, specifically designed for split learning
Obtain LLM models from HuggingFace, specifically designed for split learning
"""

import torch
from transformers import AutoModelForCausalLM, AutoConfig
from peft import get_peft_model, LoraConfig
Expand All @@ -18,8 +17,8 @@ def get_lora_model(model):

def get_module(start_module: torch.nn.Module, module_names):
"""
Recursively get a pytorch module starting from the start module with
a given list of module names.
Recursively get a PyTorch module starting from the start module with
a given list of module names.
"""
module = start_module
for module_name in module_names:
Expand All @@ -29,7 +28,7 @@ def get_module(start_module: torch.nn.Module, module_names):

class BaseModel(torch.nn.Module):
"""
The basic model loading hugginface model used for the server model and the client model
The basic model loading HuggingFace model used for the server model and the client model
"""

def __init__(self, *args, **kwargs) -> None:
Expand Down Expand Up @@ -122,7 +121,7 @@ def __init__(self, *args, **kwargs) -> None:
# In this design, we have two copies of the model
# The first copy of the model is the whole model which is used for test.
# The second copy of the model only contains the layers on the server
# used for training.
# used for training.
self.server_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
config=self.config,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Use Split learning to finetune Huggingface large language model.
Finetune HuggingFace large language models using split learning.
"""
import split_learning_trainer
from split_learning_llm_model import ServerModel, ClientModel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(self, model=None, callbacks=None):
embedding_size = self.model.get_input_embeddings().weight.shape[0]
if len(self.tokenizer) > embedding_size:
self.model.resize_token_embeddings(len(self.tokenizer))
# self.training args for huggingface training
# self.training args for HuggingFace training
parser = HfArgumentParser(TrainingArguments)

(self.training_args,) = parser.parse_args_into_dataclasses(
Expand Down Expand Up @@ -139,7 +139,7 @@ def test_model_split_learning(self, batch_size, testset, sampler=None):
return metrics["eval_accuracy"]

# Redesign the training stage specific to Split Learning.
def process_training_samples_before_retreiving(self, training_samples):
def process_training_samples_before_retrieving(self, training_samples):
inputs = training_samples["input_ids"]
labels = training_samples["labels"]
for index, input_item in enumerate(inputs):
Expand Down
13 changes: 6 additions & 7 deletions plato/trainers/split_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
https://arxiv.org/pdf/2112.01637.pdf
"""

import logging
import os

Expand Down Expand Up @@ -99,7 +98,7 @@ def get_train_samples(self, batch_size, trainset, sampler):
data_loader = iter(data_loader)
self.training_samples = next(data_loader)
# Wrap the training samples with datasource and sampler to be fed into Plato trainer
self.training_samples = self.process_training_samples_before_retreiving(
self.training_samples = self.process_training_samples_before_retrieving(
self.training_samples
)
return self.training_samples
Expand Down Expand Up @@ -190,7 +189,7 @@ def test_model(self, config, testset, sampler=None, **kwargs):
"""
Evaluates the model with the provided test dataset and test sampler.
Auguments:
Arguments:
testset: the test dataset.
sampler: the test sampler. The default is None.
kwargs (optional): Additional keyword arguments.
Expand All @@ -200,19 +199,19 @@ def test_model(self, config, testset, sampler=None, **kwargs):
return accuracy

# API functions for split learning
def process_training_samples_before_retreiving(self, training_samples) -> ...:
"""Process training samples before completing retreiving samples."""
def process_training_samples_before_retrieving(self, training_samples) -> ...:
"""Process training samples before completing retrieving samples."""
return training_samples

def process_samples_before_client_forwarding(self, examples) -> ...:
"""Process the examples befor client conducting forwarding."""
"""Process the examples before client conducting forwarding."""
return examples

# pylint:disable=unused-argument
def server_forward_from(self, batch, config) -> (..., ..., int):
"""
The event for server completing training by forwarding from intermediate features.
Uses may override this function for training different models with split learning.
Users may override this function for training different models with split learning.
Inputs:
batch: the batch of inputs for forwarding.
Expand Down

0 comments on commit 04d35d5

Please sign in to comment.