Skip to content

Commit

Permalink
Added the support of FesMos (#373)
Browse files Browse the repository at this point in the history
  • Loading branch information
NingxinSu authored Apr 25, 2024
1 parent 379f362 commit 659f081
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 0 deletions.
10 changes: 10 additions & 0 deletions docs/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,16 @@ python examples/customized_client_training/fedti/fedti.py -c examples/customized
Gal et al., “[POLARIS: An Image is Worth One Word: Personalizing Text-to-Image Generation using Textual Inversion](https://arxiv.org/pdf/2208.01618.pdf), ” Arxiv, 2022.
````

````{admonition} **FedMos**
FedMoS is a communication-efficient FL framework with coupled double momentum-based update and adaptive client selection, to jointly mitigate the intrinsic variance.
```shell
python examples/customized_client_training/fedmos/fedmos.py -c examples/customized_client_training/fedmos/fedmos_MNIST_lenet5.yml
```
```{note}
X. Wang, Y. Chen, Y. Li, X. Liao, H. Jin and B. Li, "FedMoS: Taming Client Drift in Federated Learning with Double Momentum and Adaptive Selection," IEEE INFOCOM 2023.
````


#### Client Selection Algorithms

Expand Down
26 changes: 26 additions & 0 deletions examples/customized_client_training/fedmos/fedmos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
An implementation of the FedMos algorithm.
X. Wang, Y. Chen, Y. Li, X. Liao, H. Jin and B. Li, "FedMoS: Taming Client Drift in Federated Learning with Double Momentum and Adaptive Selection," IEEE INFOCOM 2023
Paper: https://ieeexplore.ieee.org/document/10228957
Source code: https://github.com/Distributed-Learning-Networking-Group/FedMoS
"""

from plato.servers import fedavg
from plato.clients import simple

import fedmos_trainer


def main():
"""A Plato federated learning training session using FedDyn."""
trainer = fedmos_trainer.Trainer
client = simple.Client(trainer=trainer)
server = fedavg.Server()
server.run(client)


if __name__ == "__main__":
main()
71 changes: 71 additions & 0 deletions examples/customized_client_training/fedmos/fedmos_MNIST_lenet5.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
clients:
# Type
type: simple

# The total number of clients
total_clients: 1000

# The number of clients selected in each round
per_round: 10

# Should the clients compute test accuracy locally?
do_test: false

random_seed: 1

server:
address: 127.0.0.1
port: 8000
synchronous: true

checkpoint_path: models/fedmos/mnist
model_path: models/fedmos/mnist

data:
# The training and testing dataset
datasource: MNIST

# Number of samples in each partition
partition_size: 600

# IID or non-IID?
sampler: noniid

# The concentration parameter for the Dirichlet distribution
concentration: 5

# The random seed for sampling data
random_seed: 1

trainer:
# The type of the trainer
type: basic

# The maximum number of training rounds
rounds: 10

# The maximum number of clients running concurrently
max_concurrency: 3

# The target accuracy
target_accuracy: 0.94

# Number of epochs for local training in each communication round
epochs: 20
batch_size: 10
optimizer: SGD

# The machine learning model
model_name: lenet5

algorithm:
# Aggregation algorithm
type: fedavg
a: 0.1
mu: 0.001

parameters:
optimizer:
lr: 0.03
momentum: 0.0 # learning rate is fixed as in Appendix C.2
weight_decay: 0.0
58 changes: 58 additions & 0 deletions examples/customized_client_training/fedmos/fedmos_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
An implementation of the FedMos algorithm.
X. Wang, Y. Chen, Y. Li, X. Liao, H. Jin and B. Li, "FedMoS: Taming Client Drift in Federated Learning with Double Momentum and Adaptive Selection," IEEE INFOCOM 2023
Paper: https://ieeexplore.ieee.org/document/10228957
Source code: https://github.com/Distributed-Learning-Networking-Group/FedMoS
"""
import copy

from plato.config import Config
from plato.trainers import basic

from optimizers import FedMosOptimizer

# pylint:disable=no-member
class Trainer(basic.Trainer):
"""
FedMos's Trainer.
"""

def __init__(self, model=None, callbacks=None):
super().__init__(model, callbacks)
self.local_param_tmpl = None

def get_optimizer(self, model):
""" Get the optimizer of the Fedmos."""
a = Config().algorithm.a if hasattr(Config().algorithm, "a") else 0.9
mu = Config().algorithm.mu if hasattr(Config().algorithm, "mu") else 0.9
lr = Config().parameters.optimizer.lr if hasattr(Config().parameters.optimizer, "lr") else 0.01

return FedMosOptimizer(model.parameters(), lr=lr, a=a, mu=mu)


def perform_forward_and_backward_passes(self, config, examples, labels):
"""Perform forward and backward passes in the training loop."""
self.optimizer.zero_grad()

outputs = self.model(examples)

loss = self._loss_criterion(outputs, labels)
self._loss_tracker.update(loss, labels.size(0))

if "create_graph" in config:
loss.backward(create_graph=config["create_graph"])
else:
loss.backward()

self.optimizer.update_momentum()
self.optimizer.step(copy.deepcopy(self.local_param_tmpl))

return loss

def train_run_start(self, config):
super().train_run_start(config)
# At the beginning of each round, the client records the local model
self.local_param_tmpl = copy.deepcopy(self.model)
68 changes: 68 additions & 0 deletions examples/customized_client_training/fedmos/optimizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import torch

class FedMosOptimizer(torch.optim.Optimizer):
def __init__(self, params, lr, a=1., mu=0.):
defaults = dict(lr=lr, a=a, mu=mu)
super(FedMosOptimizer, self).__init__(params, defaults)

def clone_grad(self):
for group in self.param_groups:
for p in group['params']:
gt = p.grad.data
if gt is None:
continue
self.state[p]['gt_prev'] = gt.clone().detach()

def get_grad(self):
grad = []
for group in self.param_groups:
for p in group['params']:
gt = p.grad.data
if gt is None:
continue
grad += [gt.clone().detach().cpu().numpy()]
return grad

def update_momentum(self):
for group in self.param_groups:
for p in group['params']:
gt = p.grad.data # grad
if gt is None:
continue
a = group['a']
state = self.state[p]
if len(state) == 0:
# State initialization
state['gt_prev'] = torch.zeros_like(p.data)
state['dt'] = gt.clone()
continue

# state['gt_prev'] = torch.zeros_like(p.data)
# state['dt'] = gt.clone()

gt_prev = state['gt_prev']
# assert not torch.allclose(gt, gt_prev), 'Please call clone_grad() in the preious step.'
dt = state['dt']
# print(torch.equal(dt-gt_prev, torch.zeros_like(dt-gt_prev)))
# print(dt-gt_prev)
state['dt'] = gt + (1-a)*(dt - gt_prev)
state['gt_prev'] = gt.clone().detach()
# state['gt_prev'] = None

def step(self, local_net):
for group in self.param_groups:
# For different groups, we might want to use different lr, regularizer, ...
for p, local_p in zip(group['params'], local_net.parameters()):
state = self.state[p]
if len(state) == 0:
raise Exception('Please call update_momentum() first.')

lr, mu = group['lr'], group['mu']
dt = state['dt']
prox = p.data - local_p.data
p.data.add_(dt, alpha=-lr)
p.data.add_(prox, alpha=-mu)




0 comments on commit 659f081

Please sign in to comment.