From 9477be94773d2789c22ff322176124411fab321d Mon Sep 17 00:00:00 2001 From: Ningxin Su Date: Thu, 25 Apr 2024 10:25:11 -0400 Subject: [PATCH 1/2] Added FedMos support. --- .../fedmos/fedmos.py | 26 +++++ .../fedmos/fedmos_MNIST_lenet5.yml | 71 ++++++++++++++ .../fedmos/fedmos_trainer.py | 58 +++++++++++ .../fedmos/optimizers.py | 95 +++++++++++++++++++ 4 files changed, 250 insertions(+) create mode 100644 examples/customized_client_training/fedmos/fedmos.py create mode 100644 examples/customized_client_training/fedmos/fedmos_MNIST_lenet5.yml create mode 100644 examples/customized_client_training/fedmos/fedmos_trainer.py create mode 100644 examples/customized_client_training/fedmos/optimizers.py diff --git a/examples/customized_client_training/fedmos/fedmos.py b/examples/customized_client_training/fedmos/fedmos.py new file mode 100644 index 000000000..4675be2ad --- /dev/null +++ b/examples/customized_client_training/fedmos/fedmos.py @@ -0,0 +1,26 @@ +""" +An implementation of the FedMos algorithm. + +X. Wang, Y. Chen, Y. Li, X. Liao, H. Jin and B. Li, "FedMoS: Taming Client Drift in Federated Learning with Double Momentum and Adaptive Selection," IEEE INFOCOM 2023 + +Paper: https://ieeexplore.ieee.org/document/10228957 + +Source code: https://github.com/Distributed-Learning-Networking-Group/FedMoS +""" + +from plato.servers import fedavg +from plato.clients import simple + +import fedmos_trainer + + +def main(): + """A Plato federated learning training session using FedDyn.""" + trainer = fedmos_trainer.Trainer + client = simple.Client(trainer=trainer) + server = fedavg.Server() + server.run(client) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/customized_client_training/fedmos/fedmos_MNIST_lenet5.yml b/examples/customized_client_training/fedmos/fedmos_MNIST_lenet5.yml new file mode 100644 index 000000000..594a71e4c --- /dev/null +++ b/examples/customized_client_training/fedmos/fedmos_MNIST_lenet5.yml @@ -0,0 +1,71 @@ +clients: + # Type + type: simple + + # The total number of clients + total_clients: 1000 + + # The number of clients selected in each round + per_round: 10 + + # Should the clients compute test accuracy locally? + do_test: false + + random_seed: 1 + +server: + address: 127.0.0.1 + port: 8000 + synchronous: true + + checkpoint_path: models/fedmos/mnist + model_path: models/fedmos/mnist + +data: + # The training and testing dataset + datasource: MNIST + + # Number of samples in each partition + partition_size: 600 + + # IID or non-IID? + sampler: noniid + + # The concentration parameter for the Dirichlet distribution + concentration: 5 + + # The random seed for sampling data + random_seed: 1 + +trainer: + # The type of the trainer + type: basic + + # The maximum number of training rounds + rounds: 10 + + # The maximum number of clients running concurrently + max_concurrency: 3 + + # The target accuracy + target_accuracy: 0.94 + + # Number of epochs for local training in each communication round + epochs: 20 + batch_size: 10 + optimizer: SGD + + # The machine learning model + model_name: lenet5 + +algorithm: + # Aggregation algorithm + type: fedavg + a: 0.1 + mu: 0.001 + +parameters: + optimizer: + lr: 0.03 + momentum: 0.0 # learning rate is fixed as in Appendix C.2 + weight_decay: 0.0 diff --git a/examples/customized_client_training/fedmos/fedmos_trainer.py b/examples/customized_client_training/fedmos/fedmos_trainer.py new file mode 100644 index 000000000..ffdd054d1 --- /dev/null +++ b/examples/customized_client_training/fedmos/fedmos_trainer.py @@ -0,0 +1,58 @@ +""" +An implementation of the FedMos algorithm. + +X. Wang, Y. Chen, Y. Li, X. Liao, H. Jin and B. Li, "FedMoS: Taming Client Drift in Federated Learning with Double Momentum and Adaptive Selection," IEEE INFOCOM 2023 + +Paper: https://ieeexplore.ieee.org/document/10228957 + +Source code: https://github.com/Distributed-Learning-Networking-Group/FedMoS +""" +import copy + +from plato.config import Config +from plato.trainers import basic + +from optimizers import FedMosOptimizer + +# pylint:disable=no-member +class Trainer(basic.Trainer): + """ + FedMos's Trainer. + """ + + def __init__(self, model=None, callbacks=None): + super().__init__(model, callbacks) + self.local_param_tmpl = None + + def get_optimizer(self, model): + """ Get the optimizer of the Fedmos.""" + a = Config().algorithm.a if hasattr(Config().algorithm, "a") else 0.9 + mu = Config().algorithm.mu if hasattr(Config().algorithm, "mu") else 0.9 + lr = Config().parameters.optimizer.lr if hasattr(Config().parameters.optimizer, "lr") else 0.01 + + return FedMosOptimizer(model.parameters(), lr=lr, a=a, mu=mu) + + + def perform_forward_and_backward_passes(self, config, examples, labels): + """Perform forward and backward passes in the training loop.""" + self.optimizer.zero_grad() + + outputs = self.model(examples) + + loss = self._loss_criterion(outputs, labels) + self._loss_tracker.update(loss, labels.size(0)) + + if "create_graph" in config: + loss.backward(create_graph=config["create_graph"]) + else: + loss.backward() + + self.optimizer.update_momentum() + self.optimizer.step(copy.deepcopy(self.local_param_tmpl)) + + return loss + + def train_run_start(self, config): + super().train_run_start(config) + # At the beginning of each round, the client records the local model + self.local_param_tmpl = copy.deepcopy(self.model) diff --git a/examples/customized_client_training/fedmos/optimizers.py b/examples/customized_client_training/fedmos/optimizers.py new file mode 100644 index 000000000..866f4f996 --- /dev/null +++ b/examples/customized_client_training/fedmos/optimizers.py @@ -0,0 +1,95 @@ +import torch +import time + +class FedMosOptimizer(torch.optim.Optimizer): + def __init__(self, params, lr, a=1., mu=0.): + defaults = dict(lr=lr, a=a, mu=mu) + super(FedMosOptimizer, self).__init__(params, defaults) + + def clone_grad(self): + for group in self.param_groups: + for p in group['params']: + gt = p.grad.data + if gt is None: + continue + self.state[p]['gt_prev'] = gt.clone().detach() + + def get_grad(self): + grad = [] + for group in self.param_groups: + for p in group['params']: + gt = p.grad.data + if gt is None: + continue + grad += [gt.clone().detach().cpu().numpy()] + return grad + + def update_momentum(self): + for group in self.param_groups: + for p in group['params']: + gt = p.grad.data # grad + if gt is None: + continue + a = group['a'] + state = self.state[p] + if len(state) == 0: + # State initialization + state['gt_prev'] = torch.zeros_like(p.data) + state['dt'] = gt.clone() + continue + + # state['gt_prev'] = torch.zeros_like(p.data) + # state['dt'] = gt.clone() + + gt_prev = state['gt_prev'] + # assert not torch.allclose(gt, gt_prev), 'Please call clone_grad() in the preious step.' + dt = state['dt'] + # print(torch.equal(dt-gt_prev, torch.zeros_like(dt-gt_prev))) + # print(dt-gt_prev) + state['dt'] = gt + (1-a)*(dt - gt_prev) + state['gt_prev'] = gt.clone().detach() + # state['gt_prev'] = None + + # def update_momentum(self, net_prev): + # for group in self.param_groups: + # for p, p_prev in zip(group['params'], net_prev.parameters()): + # gt = p.grad.data # grad + # if gt is None: + # continue + # a = group['a'] + # state = self.state[p] + # if len(state) == 0: + # # State initialization + # # state['gt_prev'] = torch.zeros_like(p.data) + # state['dt'] = gt.clone() + # continue + + # # state['gt_prev'] = torch.zeros_like(p.data) + # # state['dt'] = gt.clone() + + # gt_prev = p_prev.grad.data + # # assert not torch.allclose(gt, gt_prev), 'Please call clone_grad() in the preious step.' + # dt = state['dt'] + # # print(torch.equal(dt-gt_prev, torch.zeros_like(dt-gt_prev))) + # # print(dt-gt_prev) + # state['dt'] = gt + (1-a)*(dt - gt_prev) + # # state['gt_prev'] = gt.clone().detach() + # # state['gt_prev'] = None + + def step(self, local_net): + for group in self.param_groups: + # For different groups, we might want to use different lr, regularizer, ... + for p, local_p in zip(group['params'], local_net.parameters()): + state = self.state[p] + if len(state) == 0: + raise Exception('Please call update_momentum() first.') + + lr, mu = group['lr'], group['mu'] + dt = state['dt'] + prox = p.data - local_p.data + p.data.add_(dt, alpha=-lr) + p.data.add_(prox, alpha=-mu) + + + + \ No newline at end of file From 97016854fbc8f62d87e0bf4e5592cd306fbff148 Mon Sep 17 00:00:00 2001 From: Ningxin Su Date: Thu, 25 Apr 2024 10:42:28 -0400 Subject: [PATCH 2/2] Updated the documentation of FedMos. --- docs/examples.md | 10 +++++++ .../fedmos/optimizers.py | 29 +------------------ 2 files changed, 11 insertions(+), 28 deletions(-) diff --git a/docs/examples.md b/docs/examples.md index fac6c83e7..00d4dbc13 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -165,6 +165,16 @@ python examples/customized_client_training/fedti/fedti.py -c examples/customized Gal et al., “[POLARIS: An Image is Worth One Word: Personalizing Text-to-Image Generation using Textual Inversion](https://arxiv.org/pdf/2208.01618.pdf), ” Arxiv, 2022. ```` +````{admonition} **FedMos** +FedMoS is a communication-efficient FL framework with coupled double momentum-based update and adaptive client selection, to jointly mitigate the intrinsic variance. + +```shell +python examples/customized_client_training/fedmos/fedmos.py -c examples/customized_client_training/fedmos/fedmos_MNIST_lenet5.yml +``` +```{note} +X. Wang, Y. Chen, Y. Li, X. Liao, H. Jin and B. Li, "FedMoS: Taming Client Drift in Federated Learning with Double Momentum and Adaptive Selection," IEEE INFOCOM 2023. +```` + #### Client Selection Algorithms diff --git a/examples/customized_client_training/fedmos/optimizers.py b/examples/customized_client_training/fedmos/optimizers.py index 866f4f996..644f8e84a 100644 --- a/examples/customized_client_training/fedmos/optimizers.py +++ b/examples/customized_client_training/fedmos/optimizers.py @@ -1,5 +1,4 @@ import torch -import time class FedMosOptimizer(torch.optim.Optimizer): def __init__(self, params, lr, a=1., mu=0.): @@ -49,33 +48,7 @@ def update_momentum(self): state['dt'] = gt + (1-a)*(dt - gt_prev) state['gt_prev'] = gt.clone().detach() # state['gt_prev'] = None - - # def update_momentum(self, net_prev): - # for group in self.param_groups: - # for p, p_prev in zip(group['params'], net_prev.parameters()): - # gt = p.grad.data # grad - # if gt is None: - # continue - # a = group['a'] - # state = self.state[p] - # if len(state) == 0: - # # State initialization - # # state['gt_prev'] = torch.zeros_like(p.data) - # state['dt'] = gt.clone() - # continue - - # # state['gt_prev'] = torch.zeros_like(p.data) - # # state['dt'] = gt.clone() - - # gt_prev = p_prev.grad.data - # # assert not torch.allclose(gt, gt_prev), 'Please call clone_grad() in the preious step.' - # dt = state['dt'] - # # print(torch.equal(dt-gt_prev, torch.zeros_like(dt-gt_prev))) - # # print(dt-gt_prev) - # state['dt'] = gt + (1-a)*(dt - gt_prev) - # # state['gt_prev'] = gt.clone().detach() - # # state['gt_prev'] = None - + def step(self, local_net): for group in self.param_groups: # For different groups, we might want to use different lr, regularizer, ...