Skip to content

Commit

Permalink
Created trainer, client, and datasource for personalized federated le…
Browse files Browse the repository at this point in the history
…arning based on self-supervised learning and provided examples. (#365)

Co-authored-by: Sijia Chen <[email protected]>
Co-authored-by: Ningxin Su <[email protected]>
Co-authored-by: Baochun Li <[email protected]>
Co-authored-by: Fei Wang <[email protected]>
Co-authored-by: Yufei-Kang <[email protected]>
  • Loading branch information
6 people authored Nov 4, 2023
1 parent 9769476 commit a2c5e87
Show file tree
Hide file tree
Showing 140 changed files with 783 additions and 943 deletions.
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
path = examples/pfedrlnas/VIT/nasvit_wrapper/NASViT
url = https://github.com/facebookresearch/NASViT
[submodule "examples/controlnet_split_learning/ControlNet"]
path = examples/controlnet_split_learning/ControlNet
path = examples/split_learning/controlnet_split_learning/ControlNet
url = https://github.com/lllyasviel/ControlNet
3 changes: 3 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ The type of the server.
- `split_learning` a client following the Split Learning algorithm. When this client is used, `clients.do_test` in configuration should be set as `False` because in split learning, we conduct the test on the server.
- `fedavg_personalized` a client saves its local layers before sending the shared global model to the server after local training.
- `self_supervised_learning` a client to prepare the datasource for personalized learning based on self-supervised learning.
```

```{admonition} **total_clients**
Expand Down Expand Up @@ -392,6 +394,7 @@ The type of the trainer. The following types are available:
- `basic`: a basic trainer with a standard training loop.
- `diff_privacy`: a trainer that supports local differential privacy in its training loop by adding noise to the gradients during each step of training.
- `split_learning`: a trainer that supports the split learning framework.
- `self_supervised_learning`: a trainer that supports personalized federated learning based on self supervised learning.
```{admonition} max_physical_batch_size
The limit on the physical batch size when using the `diff_privacy` trainer. The default value is 128. The GPU memory usage of one process training the ResNet-18 model is around 2817 MB.
Expand Down
169 changes: 63 additions & 106 deletions docs/examples.md

Large diffs are not rendered by default.

6 changes: 1 addition & 5 deletions docs/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,13 @@ In general, the following is the recommended starting point for `.vscode/setting
```
{
"python.linting.enabled": true,
"python.formatting.provider": "black",
"editor.formatOnSave": true,
"workbench.editor.enablePreview": false
}
```
It goes without saying that `/absolute/path/to/project/home/directory` should be replaced with the actual path in the specific development environment.
When working in Visual Studio Code as your development environment, two of our colour theme favourites are called `Bluloco` (both of its light and dark variants) and `City Lights` (dark). They are both excellent and very thoughtfully designed.
It goes without saying that the `Python` extension is required to be installed in Visual Studio Code, which represents Microsoft's modern language server for Python.
The `Black Formatter`, `PyLint`, and `Python` extensions are required to be installed in Visual Studio Code.
````
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,9 @@ trainer:
pers_epoch_log_interval: 2
pers_epoch_model_log_interval: 10

# The machine learning model,
# it behaves as the encoder for the ssl method
# the final fc layer will be removed
# however, in the central test, we do not use this
# but use the custom model
# The machine learning model, it behaves as the encoder for the SSL method
# the final fc layer will be removed however, in the central test, we do not
# use this but use the custom model
model_type: vit
model_name: T2t_vit_14
personalized_model_name: T2t_vit_14
Expand Down
116 changes: 0 additions & 116 deletions examples/personalized_fl/README.md

This file was deleted.

14 changes: 6 additions & 8 deletions examples/personalized_fl/apfl/apfl.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
"""
The implementation of APFL method.
An implementation of Adaptive Personalized Federated Learning (APFL).
Yuyang Deng, et al., Adaptive Personalized Federated Learning
Y. Deng, et al., "Adaptive Personalized Federated Learning"
paper address: https://arxiv.org/pdf/2003.13461.pdf
URL: https://arxiv.org/pdf/2003.13461.pdf
Official code: None
Third-party code:
- https://github.com/MLOPTPSU/FedTorch/blob/main/main.py
- https://github.com/MLOPTPSU/FedTorch/blob/main/fedtorch/comms/trainings/federated/apfl.py
https://github.com/MLOPTPSU/FedTorch/blob/main/main.py
https://github.com/MLOPTPSU/FedTorch/blob/main/fedtorch/comms/trainings/federated/apfl.py
"""

import apfl_trainer
Expand All @@ -20,7 +18,7 @@

def main():
"""
A personalized federated learning session for APFL approach.
A personalized federated learning session using APFL.
"""
trainer = apfl_trainer.Trainer
client = personalized_client.Client(trainer=trainer)
Expand Down
16 changes: 8 additions & 8 deletions examples/personalized_fl/apfl/apfl_trainer.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
"""
A personalized federated learning trainer For APFL.
A personalized federated learning trainer using APFL.
"""
import os
import logging
import os

import numpy as np
import torch

from plato.trainers import basic
from plato.models import registry as models_registry
from plato.config import Config
from plato.models import registry as models_registry
from plato.trainers import basic


class Trainer(basic.Trainer):
"""
A trainer using the APFL algorithm to jointly train the global and
personalized models.
A trainer using the APFL algorithm to train both global and personalized models.
"""

def __init__(self, model=None, callbacks=None):
Expand Down Expand Up @@ -77,7 +77,7 @@ def train_run_start(self, config):
self.personalized_model.train()

def train_run_end(self, config):
"""Saving the alpha."""
"""Saves alpha to a file identified by the client id, and saves the personalized model."""
super().train_run_end(config)

# Save the alpha to the file
Expand All @@ -94,7 +94,7 @@ def train_run_end(self, config):
torch.save(self.personalized_model.state_dict(), model_path)

def train_step_end(self, config, batch=None, loss=None):
"""Updating the alpha of APFL before each iteration."""
"""Updates alpha in APFL before each iteration."""
super().train_step_end(config, batch, loss)

# Update alpha based on Eq. 10 in the paper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,10 @@ algorithm:
# Aggregation algorithm
type: fedavg_personalized

# Important hyper-parameters
# False for performing Per-FedAvg(FO), others for Per-FedAvg(HF)
hessian_free: False
alpha: 0.01 # 1e-2
beta: 0.001

personalization:

# the ratio of clients participanting in training
participating_client_ratio: 1.0

Expand Down
14 changes: 7 additions & 7 deletions examples/personalized_fl/ditto/ditto.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
"""
The implementation of Ditto method based on the pFL framework of Plato.
An implementation of the Ditto personalized federated learning algorithm.
Reference:
Tian Li, et al., Ditto: Fair and robust federated learning through personalization, 2021:
https://proceedings.mlr.press/v139/li21h.html
T. Li, et al., "Ditto: Fair and robust federated learning through personalization,"
in the Proceedings of ICML 2021.
Official code: https://github.com/litian96/ditto
Third-party code: https://github.com/lgcollins/FedRep
https://proceedings.mlr.press/v139/li21h.html
Source code: https://github.com/litian96/ditto
"""

import ditto_trainer

from plato.servers import fedavg_personalized as personalized_server
from plato.clients import fedavg_personalized as personalized_client


def main():
"""
A personalized federated learning session for Ditto approach.
A personalized federated learning session with Ditto.
"""
trainer = ditto_trainer.Trainer
client = personalized_client.Client(trainer=trainer)
Expand Down
23 changes: 15 additions & 8 deletions examples/personalized_fl/ditto/ditto_trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
A personalized federated learning trainer using Ditto.
A personalized federated learning trainer with Ditto.
"""
import os
import copy
Expand All @@ -14,31 +14,36 @@


class Trainer(basic.Trainer):
"""A personalized federated learning trainer using the Ditto algorithm."""
"""
A trainer with Ditto, which first trains the global model and then trains
the personalized model.
"""

def __init__(self, model=None, callbacks=None):
super().__init__(model, callbacks)

# The lambda (used in the paper)
# The lambda adjusts the gradients
self.ditto_lambda = Config().algorithm.ditto_lambda

# The personalized model
# Get the personalized model
if model is None:
self.personalized_model = models_registry.get()
else:
self.personalized_model = model()

# The global model weights received from the server, which is the w^t in
# the paper
# The global model weights, which is w^t in the paper
self.initial_wnet_params = None

def train_run_start(self, config):
super().train_run_start(config)

# Make a copy of the model before local training starts, which will be used when optimizing
# the personalized model
self.initial_wnet_params = copy.deepcopy(self.model.cpu().state_dict())

def train_run_end(self, config):
"""Perform personalized training, proposed in Ditto."""
"""
Optimize the personalized model for epochs following Algorithm 1.
"""
super().train_run_end(config)

logging.info(
Expand All @@ -49,6 +54,7 @@ def train_run_end(self, config):
self.client_id,
)

# Load personalized model
model_path = Config().params["model_path"]
model_name = Config().trainer.model_name
filename = f"{model_path}/{model_name}_{self.client_id}_v_net.pth"
Expand All @@ -64,6 +70,7 @@ def train_run_end(self, config):

self.personalized_model.to(self.device)
self.personalized_model.train()

for epoch in range(1, config["epochs"] + 1):
epoch_loss_meter.reset()
for __, (examples, labels) in enumerate(self.train_loader):
Expand Down
16 changes: 7 additions & 9 deletions examples/personalized_fl/fedavg_finetune/fedavg_finetune.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
"""
An implementation of the personalized learning variant of FedAvg.
This implementation first trains a global model using conventional FedAvg until
a target number of rounds has been reached. In the final `personalization`
round, each client will use its local data samples to further fine-tune the
shared global model for a number of epochs, and then the server will compute the
average client test accuracy.
The core idea is to achieve personalized FL in two stages:
First, it trains a global model using conventional FedAvg until convergence.
Second, each client freezes the trained global model and optimizes the other
parts.
Due to its simplicity, no work has been proposed that specifically discusses
this algorithm but only utilizes it as the baseline for personalized federated
learning.
Due to its simplicity, no papers specifically discussed or proposed this
algorithm; they only utilized it as their baseline for comparisons.
"""

from plato.clients import fedavg_personalized as personalized_client
Expand Down
Loading

0 comments on commit a2c5e87

Please sign in to comment.