Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the support of FesMos #373

Merged
merged 2 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,16 @@ python examples/customized_client_training/fedti/fedti.py -c examples/customized
Gal et al., “[POLARIS: An Image is Worth One Word: Personalizing Text-to-Image Generation using Textual Inversion](https://arxiv.org/pdf/2208.01618.pdf), ” Arxiv, 2022.
````

````{admonition} **FedMos**
FedMoS is a communication-efficient FL framework with coupled double momentum-based update and adaptive client selection, to jointly mitigate the intrinsic variance.

```shell
python examples/customized_client_training/fedmos/fedmos.py -c examples/customized_client_training/fedmos/fedmos_MNIST_lenet5.yml
```
```{note}
X. Wang, Y. Chen, Y. Li, X. Liao, H. Jin and B. Li, "FedMoS: Taming Client Drift in Federated Learning with Double Momentum and Adaptive Selection," IEEE INFOCOM 2023.
````


#### Client Selection Algorithms

Expand Down
26 changes: 26 additions & 0 deletions examples/customized_client_training/fedmos/fedmos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
An implementation of the FedMos algorithm.

X. Wang, Y. Chen, Y. Li, X. Liao, H. Jin and B. Li, "FedMoS: Taming Client Drift in Federated Learning with Double Momentum and Adaptive Selection," IEEE INFOCOM 2023

Paper: https://ieeexplore.ieee.org/document/10228957

Source code: https://github.com/Distributed-Learning-Networking-Group/FedMoS
"""

from plato.servers import fedavg
from plato.clients import simple

import fedmos_trainer


def main():
"""A Plato federated learning training session using FedDyn."""
trainer = fedmos_trainer.Trainer
client = simple.Client(trainer=trainer)
server = fedavg.Server()
server.run(client)


if __name__ == "__main__":
main()
71 changes: 71 additions & 0 deletions examples/customized_client_training/fedmos/fedmos_MNIST_lenet5.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
clients:
# Type
type: simple

# The total number of clients
total_clients: 1000

# The number of clients selected in each round
per_round: 10

# Should the clients compute test accuracy locally?
do_test: false

random_seed: 1

server:
address: 127.0.0.1
port: 8000
synchronous: true

checkpoint_path: models/fedmos/mnist
model_path: models/fedmos/mnist

data:
# The training and testing dataset
datasource: MNIST

# Number of samples in each partition
partition_size: 600

# IID or non-IID?
sampler: noniid

# The concentration parameter for the Dirichlet distribution
concentration: 5

# The random seed for sampling data
random_seed: 1

trainer:
# The type of the trainer
type: basic

# The maximum number of training rounds
rounds: 10

# The maximum number of clients running concurrently
max_concurrency: 3

# The target accuracy
target_accuracy: 0.94

# Number of epochs for local training in each communication round
epochs: 20
batch_size: 10
optimizer: SGD

# The machine learning model
model_name: lenet5

algorithm:
# Aggregation algorithm
type: fedavg
a: 0.1
mu: 0.001

parameters:
optimizer:
lr: 0.03
momentum: 0.0 # learning rate is fixed as in Appendix C.2
weight_decay: 0.0
58 changes: 58 additions & 0 deletions examples/customized_client_training/fedmos/fedmos_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
An implementation of the FedMos algorithm.

X. Wang, Y. Chen, Y. Li, X. Liao, H. Jin and B. Li, "FedMoS: Taming Client Drift in Federated Learning with Double Momentum and Adaptive Selection," IEEE INFOCOM 2023

Paper: https://ieeexplore.ieee.org/document/10228957

Source code: https://github.com/Distributed-Learning-Networking-Group/FedMoS
"""
import copy

from plato.config import Config
from plato.trainers import basic

from optimizers import FedMosOptimizer

# pylint:disable=no-member
class Trainer(basic.Trainer):
"""
FedMos's Trainer.
"""

def __init__(self, model=None, callbacks=None):
super().__init__(model, callbacks)
self.local_param_tmpl = None

def get_optimizer(self, model):
""" Get the optimizer of the Fedmos."""
a = Config().algorithm.a if hasattr(Config().algorithm, "a") else 0.9
mu = Config().algorithm.mu if hasattr(Config().algorithm, "mu") else 0.9
lr = Config().parameters.optimizer.lr if hasattr(Config().parameters.optimizer, "lr") else 0.01

return FedMosOptimizer(model.parameters(), lr=lr, a=a, mu=mu)


def perform_forward_and_backward_passes(self, config, examples, labels):
"""Perform forward and backward passes in the training loop."""
self.optimizer.zero_grad()

outputs = self.model(examples)

loss = self._loss_criterion(outputs, labels)
self._loss_tracker.update(loss, labels.size(0))

if "create_graph" in config:
loss.backward(create_graph=config["create_graph"])
else:
loss.backward()

self.optimizer.update_momentum()
self.optimizer.step(copy.deepcopy(self.local_param_tmpl))

return loss

def train_run_start(self, config):
super().train_run_start(config)
# At the beginning of each round, the client records the local model
self.local_param_tmpl = copy.deepcopy(self.model)
68 changes: 68 additions & 0 deletions examples/customized_client_training/fedmos/optimizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import torch

class FedMosOptimizer(torch.optim.Optimizer):
def __init__(self, params, lr, a=1., mu=0.):
defaults = dict(lr=lr, a=a, mu=mu)
super(FedMosOptimizer, self).__init__(params, defaults)

def clone_grad(self):
for group in self.param_groups:
for p in group['params']:
gt = p.grad.data
if gt is None:
continue
self.state[p]['gt_prev'] = gt.clone().detach()

def get_grad(self):
grad = []
for group in self.param_groups:
for p in group['params']:
gt = p.grad.data
if gt is None:
continue
grad += [gt.clone().detach().cpu().numpy()]
return grad

def update_momentum(self):
for group in self.param_groups:
for p in group['params']:
gt = p.grad.data # grad
if gt is None:
continue
a = group['a']
state = self.state[p]
if len(state) == 0:
# State initialization
state['gt_prev'] = torch.zeros_like(p.data)
state['dt'] = gt.clone()
continue

# state['gt_prev'] = torch.zeros_like(p.data)
# state['dt'] = gt.clone()

gt_prev = state['gt_prev']
# assert not torch.allclose(gt, gt_prev), 'Please call clone_grad() in the preious step.'
dt = state['dt']
# print(torch.equal(dt-gt_prev, torch.zeros_like(dt-gt_prev)))
# print(dt-gt_prev)
state['dt'] = gt + (1-a)*(dt - gt_prev)
state['gt_prev'] = gt.clone().detach()
# state['gt_prev'] = None

def step(self, local_net):
for group in self.param_groups:
# For different groups, we might want to use different lr, regularizer, ...
for p, local_p in zip(group['params'], local_net.parameters()):
state = self.state[p]
if len(state) == 0:
raise Exception('Please call update_momentum() first.')

lr, mu = group['lr'], group['mu']
dt = state['dt']
prox = p.data - local_p.data
p.data.add_(dt, alpha=-lr)
p.data.add_(prox, alpha=-mu)




Loading