diff --git a/.gitmodules b/.gitmodules index 2c27271df..062941da5 100644 --- a/.gitmodules +++ b/.gitmodules @@ -8,5 +8,5 @@ path = examples/pfedrlnas/VIT/nasvit_wrapper/NASViT url = https://github.com/facebookresearch/NASViT [submodule "examples/controlnet_split_learning/ControlNet"] - path = examples/controlnet_split_learning/ControlNet + path = examples/split_learning/controlnet_split_learning/ControlNet url = https://github.com/lllyasviel/ControlNet diff --git a/docs/configuration.md b/docs/configuration.md index 2d287af41..e0ac91282 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -34,6 +34,8 @@ The type of the server. - `split_learning` a client following the Split Learning algorithm. When this client is used, `clients.do_test` in configuration should be set as `False` because in split learning, we conduct the test on the server. - `fedavg_personalized` a client saves its local layers before sending the shared global model to the server after local training. + +- `self_supervised_learning` a client to prepare the datasource for personalized learning based on self-supervised learning. ``` ```{admonition} **total_clients** @@ -392,6 +394,7 @@ The type of the trainer. The following types are available: - `basic`: a basic trainer with a standard training loop. - `diff_privacy`: a trainer that supports local differential privacy in its training loop by adding noise to the gradients during each step of training. - `split_learning`: a trainer that supports the split learning framework. +- `self_supervised_learning`: a trainer that supports personalized federated learning based on self supervised learning. ```{admonition} max_physical_batch_size The limit on the physical batch size when using the `diff_privacy` trainer. The default value is 128. The GPU memory usage of one process training the ResNet-18 model is around 2817 MB. diff --git a/docs/examples.md b/docs/examples.md index 54dc0e132..c4ec7d18d 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -38,7 +38,7 @@ H. Wu, P. Wang. “[Fast-Convergent Federated Learning with Adaptive Weighti ``` ```` -#### Secure Aggregation +#### Secure Aggregation with Homomorphic Encryption ````{admonition} **MaskCrypt** MaskCrypt is a secure federated learning system based on homomorphic encryption. Instead of encrypting all the model updates, MaskCrypt encrypts only part of them to balance the tradeoff between security and efficiency. In this example, clients only select 5% of the model updates to encrypt during the learning process. The number of encrypted weights is determined by `encrypt_ratio`, which can be adjusted in the configuration file. A random mask will be adopted if `random_mask` is set to true. @@ -110,10 +110,10 @@ N. Su, B. Li. “[Asynchronous Federated Unlearning](https://iqua.ece.toront #### Gradient Leakage Attacks and Defences ````{admonition} **Gradient leakage attacks and defenses** -Gradient leakage attacks and their defenses have been extensively studied in the research literature on federated learning. In `examples/dlg/`, several attacks, including `DLG`, `iDLG`, and `csDLG`, have been implemented, as well as several defense mechanisms, including `Soteria`, `GradDefense`, `Differential Privacy`, `Gradient Compression`, and `Outpost`. A variety of methods in the trainer API has been used in their implementations. +Gradient leakage attacks and their defenses have been extensively studied in the research literature on federated learning. In `examples/gradient_leakage_attacks/`, several attacks, including `DLG`, `iDLG`, and `csDLG`, have been implemented, as well as several defense mechanisms, including `Soteria`, `GradDefense`, `Differential Privacy`, `Gradient Compression`, and `Outpost`. A variety of methods in the trainer API has been used in their implementations. Refer to `examples/dlg/README.md` for more details. ```shell -python examples/dlg/dlg.py -c examples/dlg/reconstruction_emnist.yml --cpu +python examples/gradient_leakage_attacks/dlg.py -c examples/gradient_leakage_attacks/reconstruction_emnist.yml --cpu ``` ```` @@ -228,33 +228,33 @@ Vepakomma et al., “[Split Learning for Health: Distributed Deep Learning w ``` ```` +````{admonition} **Split Learning for Training ControlNet** +ControlNet is a conditional image generation model that only finetunes the control network without updating parameters in the large diffusion model. It has a more complicated structure than the usual deep learning model. Hence, to train a ControlNet with split learning, the control network and a part of the diffusion model are on the clients and the remaining part of the diffusion model is on the server. The forwarding and backwarding processes are specifically designed according to the inputs and training targets of the image generation based on diffusion models. + +```shell +python examples/split_learning/controlnet_split_learning/split_learning_main.py -c examples/split_learning/controlnet_split_learning/split_learning.yml +``` +```` #### Personalized Federated Learning Algorithms ````{admonition} **FedRep** -FedRep is an algorithm for learning a shared data representation across clients and unique, personalized local ``heads'' for each client. In this implementation, after each round of local training, only the representation on each client is retrieved and uploaded to the server for aggregation. - -FedRep belongs to personalized federated learning. -Please read `examples/personalized_fl/README.md` for more details about how to run the code. +FedRep learns a shared data representation (the global layers) across clients and a unique, personalized local ``head'' (the local layers) for each client. In this implementation, after each round of local training, only the representation on each client is retrieved and uploaded to the server for aggregation. ```shell -python examples/personalized_fl/fedrep/fedrep.py -c examples/personalized_fl/configs/fedrep_CIFAR10_resnet18.yml -b pflExperiments +python examples/personalized_fl/fedrep/fedrep.py -c examples/personalized_fl/configs/fedrep_CIFAR10_resnet18.yml ``` ```{note} -Collins et al., “[Exploiting Shared Representations for Personalized Federated Learning](http://proceedings.mlr.press/v139/collins21a/collins21a.pdf), -” in Proc. International Conference on Machine Learning (ICML), 2021. +Collins et al., “[Exploiting Shared Representations for Personalized Federated Learning](http://proceedings.mlr.press/v139/collins21a/collins21a.pdf), ” in Proc. International Conference on Machine Learning (ICML), 2021. ``` ```` ````{admonition} **FedBABU** -FedBABU argued that a better federated global model performance does not constantly improve personalization. In this algorithm, it only updates the body of the model during FL training. In this implementation, the head is frozen at the beginning of each local training epoch through the API ```train_run_start```. - -FedBABU belongs to personalized federated learning. -Please read `examples/personalized_fl/README.md` for more details about how to run the code. +FedBABU only updates the global layers of the model during FL training. The local layers are frozen at the beginning of each local training epoch. ```shell -python examples/personalized_fl/fedbabu/fedbabu.py -c examples/personalized_fl/configs/fedbabu_CIFAR10_resnet18.yml -b pflExperiments +python examples/personalized_fl/fedbabu/fedbabu.py -c examples/personalized_fl/configs/fedbabu_CIFAR10_resnet18.yml ``` ```{note} @@ -264,13 +264,10 @@ Oh et al., “[FedBABU: Towards Enhanced Representation for Federated Image ```` ````{admonition} **APFL** -APFL is a synchronous personalized federated learning algorithm that jointly optimizes the global model and personalized models by interpolating between local and personalized models. It has been quite widely cited and compared with in the personalized federated learning literature. In this example, once the global model is received, each client will carry out a regular local update, and then conduct a personalized optimization to acquire a trained personalized model. The trained global model and the personalized model will subsequently be combined using the parameter "alpha," which can be dynamically updated. - -APFL belongs to personalized federated learning. -Please read `examples/personalized_fl/README.md` for more details about how to run the code. +APFL jointly optimizes the global model and personalized models by interpolating between local and personalized models. Once the global model is received, each client will carry out a regular local update, and then conduct a personalized optimization to acquire a trained personalized model. The trained global model and the personalized model will subsequently be combined using the parameter "alpha," which can be dynamically updated. ```shell -python examples/personalized_fl/apfl/apfl.py -c examples/personalized_fl/configs/apfl_CIFAR10_resnet18.yml -b pflExperiments +python examples/personalized_fl/apfl/apfl.py -c examples/personalized_fl/configs/apfl_CIFAR10_resnet18.yml ``` ```{note} @@ -280,79 +277,58 @@ Deng et al., “[Adaptive Personalized Federated Learning](https://arxiv.org ```` ````{admonition} **FedPer** -FedPer is a synchronous personalized federated learning algorithm that learns a global representation and personalized heads, but makes simultaneous local updates for both sets of parameters, therefore makes the same number of local updates for the head and the representation on each local round. - -FedPer belongs to personalized federated learning. -Please read `examples/personalized_fl/README.md` for more details about how to run the code. +FedPer learns a global representation and personalized heads, but makes simultaneous local updates for both sets of parameters, therefore makes the same number of local updates for the head and the representation on each local round. ```shell -python examples/personalized_fl/fedper/fedper.py -c examples/personalized_fl/configs/fedper_CIFAR10_resnet18.yml -b pflExperiments +python examples/personalized_fl/fedper/fedper.py -c examples/personalized_fl/configs/fedper_CIFAR10_resnet18.yml ``` ```{note} -Arivazhagan et al., “[Federated learning with personalization layers](https://arxiv.org/abs/1912.00818), -” in Arxiv, 2019. +Arivazhagan et al., “[Federated learning with personalization layers](https://arxiv.org/abs/1912.00818), ” in Arxiv, 2019. ``` ```` ````{admonition} **LG-FedAvg** -LG-FedAvg is a synchronous personalized federated learning algorithm that learns local representations and a global head. Therefore, only the head of one model is exchanged between the server and clients, while each client maintains a body of the model as its personalized encoder. - -LG-FedAvg belongs to personalized federated learning. -Please read `examples/personalized_fl/README.md` for more details about how to run the code. +With LG-FedAvg only the global layers of a model are sent to the server for aggregation, while each client keeps local layers to itself. ```shell -python examples/personalized_fl/lgfedavg/lgfedavg.py -c examples/personalized_fl/configs/lgfedavg_CIFAR10_resnet18.yml -b pflExperiments +python examples/personalized_fl/lgfedavg/lgfedavg.py -c examples/personalized_fl/configs/lgfedavg_CIFAR10_resnet18.yml ``` ```{note} -Liang et al., “[Think Locally, Act Globally: Federated Learning with Local and Global Representations](https://arxiv.org/abs/2001.01523), -” in Proc. NeurIPS, 2019. +Liang et al., “[Think Locally, Act Globally: Federated Learning with Local and Global Representations](https://arxiv.org/abs/2001.01523), ” in Proc. NeurIPS, 2019. ``` ```` ````{admonition} **Ditto** -Ditto is another synchronous personalized federated learning algorithm that jointly optimizes the global model and personalized models by learning local models that are encouraged to be close together by global regularization. In this example, once the global model is received, each client will carry out a regular local update followed by a Ditto solver to optimize the personalized model. - -Ditto belongs to personalized federated learning. -Please read `examples/personalized_fl/README.md` for more details about how to run the code. +Ditto jointly optimizes the global model and personalized models by learning local models that are encouraged to be close together by global regularization. In this example, once the global model is received, each client will carry out a regular local update and then optimizes the personalized model. ```shell -python examples/personalized_fl/ditto/ditto.py -c examples/personalized_fl/configs/ditto_CIFAR10_resnet18.yml -b pflExperiments +python examples/personalized_fl/ditto/ditto.py -c examples/personalized_fl/configs/ditto_CIFAR10_resnet18.yml ``` ```{note} -Li et al., “[Ditto: Fair and robust federated learning through personalization](https://proceedings.mlr.press/v139/li21h.html), -” in Proc ICML, 2021. +Li et al., “[Ditto: Fair and robust federated learning through personalization](https://proceedings.mlr.press/v139/li21h.html), ” in Proc ICML, 2021. ``` ```` -````{admonition} **PerFedAvg** -PerFedAvg focuses the personalized federated learning in which our goal is to find an initial shared model that current or new users can easily adapt to their local dataset by performing one or a few steps of gradient descent with respect to their own data. Specifically, it introduces the Model-Agnostic Meta-Learning (MAML) framework into the local update of federated learning. - -PerFedAvg belongs to personalized federated learning. -Please read `examples/personalized_fl/README.md` for more details about how to run the code. +````{admonition} **Per-FedAvg** +Per-FedAvg uses the Model-Agnostic Meta-Learning (MAML) framework to perform local training during the regular training rounds. It performs two forward and backward passes with fixed learning rates in each iteration. ```shell -python examples/personalized_fl/perfedavg/perfedavg.py -c examples/personalized_fl/configs/perfedavg_CIFAR10_resnet18.yml -b pflExperiments +python examples/personalized_fl/perfedavg/perfedavg.py -c examples/personalized_fl/configs/perfedavg_CIFAR10_resnet18.yml ``` ```{note} -Fallah et al., “[Ditto: Personalized federated learning with theoretical guarantees: -A model-agnostic meta-learning approach](https://proceedings.neurips.cc/paper/2020/hash/24389bfe4fe2eba8bf9aa9203a44cdad-Abstract.html), -” in Proc NeurIPS, 2020. +Fallah et al., “[Personalized Federated Learning with Theoretical Guarantees: A Model-Agnostic Meta-Learning Approach](https://proceedings.neurips.cc/paper/2020/hash/24389bfe4fe2eba8bf9aa9203a44cdad-Abstract.html), ” in Proc NeurIPS, 2020. ``` ```` ````{admonition} **Hermes** Hermes utilizes structured pruning to improve both communication efficiency and inference efficiency of federated learning. It prunes channels with the lowest magnitudes in each local model and adjusts the pruning amount based on each local model’s test accuracy and its previous pruning amount. When the server aggregates pruned updates, it only averages parameters that were not pruned on all clients. - -Hermes belongs to personalized federated learning. -Please read `examples/personalized_fl/README.md` for more details about how to run the code. - ```shell -python examples/personalized_fl/hermes/hermes.py -c examples/personalized_fl/configs/hermes_CIFAR10_resnet18.yml -b pflExperiments +python examples/personalized_fl/hermes/hermes.py -c examples/personalized_fl/configs/hermes_CIFAR10_resnet18.yml ``` ```{note} @@ -364,7 +340,7 @@ Li et al., “[Hermes: An Efficient Federated Learning Framework for Heterog #### Personalized Federated Learning Algorithms based on Self-Supervised Learning ````{admonition} **Self Supervised Learning** -This category aims to achieve personalized federated learning by introducing self-supervised learning (SSL) to the training schema. In the context of self-supervised learning (SSL), the model is trained to learn representations from unlabeled data. Thus, the model is capable of extracting generic representations. A higher performance can be achieved in subsequent tasks with the trained model as the encoder. Such a benefit of SSL is introduced into personalized FL by relying on the learning objective of SSL to train the global model. After reaching convergence, each client can download the trained global model to extract features from local samples. A high-quality personalized model, typically a linear network, is prone to be achieved under those extracted features. The code is available under `examples/ssl/`. And under `algorithms/` of the folder, the following algorithms are implemented: +This category aims to achieve personalized federated learning by introducing self-supervised learning (SSL) to the training process. With SSL, an encoder model is trained to learn representations from unlabeled data. A higher performance can be achieved in subsequent tasks with the trained encoder. Only the encoder model is globally aggregated and shared during the regular training process. After reaching convergence, each client can download the trained global model to extract features from local samples. In this category, the following algorithms have been implemented: - SimCLR [1] - BYOL [2] @@ -373,60 +349,44 @@ This category aims to achieve personalized federated learning by introducing sel - SwAV [5] - SMoG [6] - FedEMA [7] - -Please read `examples/ssl/README.md` for more details about how to run the code. - -```shell -python examples/ssl/algorithms/simclr/simclr.py -c examples/ssl/configs/simclr_MNIST_lenet5.yml -b pflExperiments -``` - -```shell -python examples/ssl/algorithms/simclr/simclr.py -c examples/ssl/configs/simclr_CIFAR10_resnet18.yml -b pflExperiments -``` - -```shell -python examples/ssl/algorithms/byol/byol.py -c examples/ssl/configs/byol_CIFAR10_resnet18.yml -b pflExperiments -``` - -```shell -python examples/ssl/algorithms/simsiam/simsiam.py -c examples/ssl/configs/simsiam_CIFAR10_resnet18.yml -b pflExperiments -``` - -```shell -python examples/ssl/algorithms/moco/mocov2.py -c examples/ssl/configs/mocov2_CIFAR10_resnet18.yml -b pflExperiments -``` - -```shell -python examples/ssl/algorithms/swav/swav.py -c examples/ssl/configs/swav_CIFAR10_resnet18.yml -b pflExperiments -``` +- Calibre ```shell -python examples/ssl/algorithms/smog/smog.py -c examples/ssl/configs/smog_CIFAR10_resnet18.yml -b pflExperiments -``` - -```shell -python examples/ssl/algorithms/fedema/fedema.py -c examples/ssl/configs/fedema_CIFAR10_resnet18.yml -b pflExperiments +python examples/ssl/simclr/simclr.py -c examples/ssl/configs/simclr_MNIST_lenet5.yml +python examples/ssl/simclr/simclr.py -c examples/ssl/configs/simclr_CIFAR10_resnet18.yml +python examples/ssl/byol/byol.py -c examples/ssl/configs/byol_CIFAR10_resnet18.yml +python examples/ssl/simsiam/simsiam.py -c examples/ssl/configs/simsiam_CIFAR10_resnet18.yml +python examples/ssl/moco/mocov2.py -c examples/ssl/configs/mocov2_CIFAR10_resnet18.yml +python examples/ssl/swav/swav.py -c examples/ssl/configs/swav_CIFAR10_resnet18.yml +python examples/ssl/smog/smog.py -c examples/ssl/configs/smog_CIFAR10_resnet18.yml +python examples/ssl/fedema/fedema.py -c examples/ssl/configs/fedema_CIFAR10_resnet18.yml +python examples/ssl/calibre/calibre.py -c examples/ssl/configs/calibre_CIFAR10_resnet18.yml ``` ```{note} -[1]. Chen et al., “[A Simple Framework for Contrastive Learning of Visual Representations](https://arxiv.org/abs/2002.05709),” in Proc. ICML, 2020. +[1] Chen et al., “[A Simple Framework for Contrastive Learning of Visual Representations](https://arxiv.org/abs/2002.05709),” in Proc. ICML, 2020. -[2]. Grill et al., “[Bootstrap Your Own Latent A New Approach to Self-Supervised Learning](https://arxiv.org/pdf/2006.07733.pdf), ” in Proc. NeurIPS, 2020. +[2] Grill et al., “[Bootstrap Your Own Latent A New Approach to Self-Supervised Learning](https://arxiv.org/pdf/2006.07733.pdf), ” in Proc. NeurIPS, 2020. -[3]. Chen et al., “[Exploring Simple Siamese Representation Learning](https://arxiv.org/pdf/2011.10566.pdf), ” in Proc. CVPR, 2021. +[3] Chen et al., “[Exploring Simple Siamese Representation Learning](https://arxiv.org/pdf/2011.10566.pdf), ” in Proc. CVPR, 2021. -[4]. Chen et al., “[Improved Baselines with Momentum Contrastive Learning](https://arxiv.org/abs/2003.04297), ” in ArXiv, 2020. +[4] Chen et al., “[Improved Baselines with Momentum Contrastive Learning](https://arxiv.org/abs/2003.04297), ” in ArXiv, 2020. -[5]. Caron et al., “[Unsupervised Learning of Visual Features by Contrasting Cluster Assignments](https://arxiv.org/abs/2006.09882), ” in Proc. NeurIPS, 2022. +[5] Caron et al., “[Unsupervised Learning of Visual Features by Contrasting Cluster Assignments](https://arxiv.org/abs/2006.09882), ” in Proc. NeurIPS, 2022. -[6]. Pang et al., “[Unsupervised Visual Representation Learning by Synchronous Momentum Grouping](https://arxiv.org/pdf/2006.07733.pdf), ” in Proc. ECCV, 2022. +[6] Pang et al., “[Unsupervised Visual Representation Learning by Synchronous Momentum Grouping](https://arxiv.org/pdf/2006.07733.pdf), ” in Proc. ECCV, 2022. -[7]. Zhuang et al., “[Divergence-aware federated self-supervised learning](https://arxiv.org/pdf/2204.04385.pdf), ” in Proc. ICLR, 2022. +[7] Zhuang et al., “[Divergence-Aware Federated Self-Supervised Learning](https://arxiv.org/pdf/2204.04385.pdf), ” in Proc. ICLR, 2022. +``` ``` +Calibre is currently only supported on GPUs or Apple Silicon Chip. One should run on a GPU device or on MAC OS with adding the argument -m. +``` + ```` -#### Federated Learning Algorithms based on Neural Architecture Search and Model Search +#### Algorithms based on Neural Architecture Search and Model Search + ````{admonition} **FedRLNAS** FedRLNAS is an algorithm designed to conduct Federated Neural Architecture Search without sending the entire supernet to the clients. Instead, clients still perform conventional model training as in Federated Averaging, and the server will search for the best model architecture. In this example, the server overrides ```aggregate_weights()``` to aggregate updates from subnets of different architectures into the supernet, and implements architecture parameter updates in ```weights_aggregated()```. In its implementation, only only DARTS search space is supported. @@ -535,7 +495,7 @@ python3 ./examples/model_search/sysheterofl/sysheterofl.py -c examples/model_sea Tempo is proposed to improve training performance in three-layer federated learning. It adaptively tunes the number of each client's local training epochs based on the difference between its edge server's locally aggregated model and the current global model. ```shell -python examples/tempo/tempo.py -c examples/tempo/tempo_MNIST_lenet5.yml +python examples/three_layer_fl/tempo/tempo.py -c examples/three_layer_fl/tempo/tempo_MNIST_lenet5.yml ``` ```{note} @@ -548,18 +508,17 @@ Ying et al., “[Tempo: Improving Training Performance in Cross-Silo Federat FedSaw is proposed to improve training performance in three-layer federated learning with L1-norm structured pruning. Edge servers and clients pruned their updates before sending them out. FedSaw adaptively tunes the pruning amount of each edge server and its clients based on the difference between the edge server's locally aggregated model and the current global model. ```shell -python examples/fedsaw/fedsaw.py -c examples/fedsaw/fedsaw_MNIST_lenet5.yml +python examples/three_layer_fl/fedsaw/fedsaw.py -c examples/three_layer_fl/fedsaw/fedsaw_MNIST_lenet5.yml ``` ```` -#### Algorithms Not Yet Categorized - +#### Model Pruning Algorithms ````{admonition} **FedSCR** FedSCR uses structured pruning to prune each update’s entire filters and channels if their summed parameter values are below a particular threshold. ```shell -python examples/fedscr/fedscr.py -c examples/fedscr/fedscr_MNIST_lenet5.yml +python examples/model_pruning/fedscr/fedscr.py -c examples/model_pruning/fedscr/fedscr_MNIST_lenet5.yml ``` ```{note} @@ -574,13 +533,13 @@ Sub-FedAvg aims to obtain a personalized model for each client with non-i.i.d. l For two-layer federated learning: ```shell -python examples/sub_fedavg/subfedavg.py -c examples/sub_fedavg/subfedavg_MNIST_lenet5.yml +python examples/model_pruning/sub_fedavg/subfedavg.py -c examples/model_pruning/sub_fedavg/subfedavg_MNIST_lenet5.yml ``` For three-layer federated learning: ```shell -python examples/sub_fedavg/subcs.py -c examples/sub_fedavg/subcs_MNIST_lenet5.yml +python examples/model_pruning/sub_fedavg/subcs.py -c examples/model_pruning/sub_fedavg/subcs_MNIST_lenet5.yml ``` ```{note} @@ -589,12 +548,10 @@ Vahidian et al., “[Personalized Federated Learning by Structured and Unstr ``` ```` - - -With the recent redesign of the Plato API, the following list is outdated and will be updated as they are tested again. +With the redesign of the Plato API, the following list is outdated and will be updated as they are tested again. | Method | Notes | Tested | -| :----------------------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :----: | +|:------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------:| | [Adaptive Freezing](https://henryhxu.github.io/share/chen-icdcs21.pdf) | Change directory to `examples/adaptive_freezing` and run `python adaptive_freezing.py -c `. | Yes | | [Gradient-Instructed Frequency Tuning](https://github.com/TL-System/plato/blob/main/examples/adaptive_sync/papers/adaptive_sync.pdf) | Change directory to `examples/adaptive_sync` and run `python adaptive_sync.py -c `. | Yes | | [Attack Adaptive](https://arxiv.org/pdf/2102.05257.pdf) | Change directory to `examples/attack_adaptive` and run `python attack_adaptive.py -c `. | Yes | diff --git a/docs/installation.md b/docs/installation.md index 89a17108f..811dbed3e 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -101,17 +101,13 @@ In general, the following is the recommended starting point for `.vscode/setting ``` { - "python.linting.enabled": true, - "python.formatting.provider": "black", "editor.formatOnSave": true, "workbench.editor.enablePreview": false } ``` -It goes without saying that `/absolute/path/to/project/home/directory` should be replaced with the actual path in the specific development environment. - When working in Visual Studio Code as your development environment, two of our colour theme favourites are called `Bluloco` (both of its light and dark variants) and `City Lights` (dark). They are both excellent and very thoughtfully designed. -It goes without saying that the `Python` extension is required to be installed in Visual Studio Code, which represents Microsoft's modern language server for Python. +The `Black Formatter`, `PyLint`, and `Python` extensions are required to be installed in Visual Studio Code. ```` diff --git a/examples/dlg/README.md b/examples/gradient_leakage_attacks/README.md similarity index 100% rename from examples/dlg/README.md rename to examples/gradient_leakage_attacks/README.md diff --git a/examples/dlg/convergence_cifar10.yml b/examples/gradient_leakage_attacks/convergence_cifar10.yml similarity index 100% rename from examples/dlg/convergence_cifar10.yml rename to examples/gradient_leakage_attacks/convergence_cifar10.yml diff --git a/examples/dlg/convergence_emnist.yml b/examples/gradient_leakage_attacks/convergence_emnist.yml similarity index 100% rename from examples/dlg/convergence_emnist.yml rename to examples/gradient_leakage_attacks/convergence_emnist.yml diff --git a/examples/dlg/defense/GradDefense/clip.py b/examples/gradient_leakage_attacks/defense/GradDefense/clip.py similarity index 100% rename from examples/dlg/defense/GradDefense/clip.py rename to examples/gradient_leakage_attacks/defense/GradDefense/clip.py diff --git a/examples/dlg/defense/GradDefense/compensate.py b/examples/gradient_leakage_attacks/defense/GradDefense/compensate.py similarity index 100% rename from examples/dlg/defense/GradDefense/compensate.py rename to examples/gradient_leakage_attacks/defense/GradDefense/compensate.py diff --git a/examples/dlg/defense/GradDefense/dataloader.py b/examples/gradient_leakage_attacks/defense/GradDefense/dataloader.py similarity index 100% rename from examples/dlg/defense/GradDefense/dataloader.py rename to examples/gradient_leakage_attacks/defense/GradDefense/dataloader.py diff --git a/examples/dlg/defense/GradDefense/perturb.py b/examples/gradient_leakage_attacks/defense/GradDefense/perturb.py similarity index 100% rename from examples/dlg/defense/GradDefense/perturb.py rename to examples/gradient_leakage_attacks/defense/GradDefense/perturb.py diff --git a/examples/dlg/defense/GradDefense/sensitivity.py b/examples/gradient_leakage_attacks/defense/GradDefense/sensitivity.py similarity index 100% rename from examples/dlg/defense/GradDefense/sensitivity.py rename to examples/gradient_leakage_attacks/defense/GradDefense/sensitivity.py diff --git a/examples/dlg/defense/Outpost/perturb.py b/examples/gradient_leakage_attacks/defense/Outpost/perturb.py similarity index 100% rename from examples/dlg/defense/Outpost/perturb.py rename to examples/gradient_leakage_attacks/defense/Outpost/perturb.py diff --git a/examples/dlg/dlg.py b/examples/gradient_leakage_attacks/dlg.py similarity index 100% rename from examples/dlg/dlg.py rename to examples/gradient_leakage_attacks/dlg.py diff --git a/examples/dlg/dlg_client.py b/examples/gradient_leakage_attacks/dlg_client.py similarity index 100% rename from examples/dlg/dlg_client.py rename to examples/gradient_leakage_attacks/dlg_client.py diff --git a/examples/dlg/dlg_model.py b/examples/gradient_leakage_attacks/dlg_model.py similarity index 100% rename from examples/dlg/dlg_model.py rename to examples/gradient_leakage_attacks/dlg_model.py diff --git a/examples/dlg/dlg_server.py b/examples/gradient_leakage_attacks/dlg_server.py similarity index 100% rename from examples/dlg/dlg_server.py rename to examples/gradient_leakage_attacks/dlg_server.py diff --git a/examples/dlg/dlg_trainer.py b/examples/gradient_leakage_attacks/dlg_trainer.py similarity index 100% rename from examples/dlg/dlg_trainer.py rename to examples/gradient_leakage_attacks/dlg_trainer.py diff --git a/examples/dlg/reconstruction_cifar10.yml b/examples/gradient_leakage_attacks/reconstruction_cifar10.yml similarity index 100% rename from examples/dlg/reconstruction_cifar10.yml rename to examples/gradient_leakage_attacks/reconstruction_cifar10.yml diff --git a/examples/dlg/reconstruction_emnist.yml b/examples/gradient_leakage_attacks/reconstruction_emnist.yml similarity index 100% rename from examples/dlg/reconstruction_emnist.yml rename to examples/gradient_leakage_attacks/reconstruction_emnist.yml diff --git a/examples/dlg/utils/evaluations.py b/examples/gradient_leakage_attacks/utils/evaluations.py similarity index 100% rename from examples/dlg/utils/evaluations.py rename to examples/gradient_leakage_attacks/utils/evaluations.py diff --git a/examples/dlg/utils/modules.py b/examples/gradient_leakage_attacks/utils/modules.py similarity index 100% rename from examples/dlg/utils/modules.py rename to examples/gradient_leakage_attacks/utils/modules.py diff --git a/examples/dlg/utils/plot.py b/examples/gradient_leakage_attacks/utils/plot.py similarity index 100% rename from examples/dlg/utils/plot.py rename to examples/gradient_leakage_attacks/utils/plot.py diff --git a/examples/dlg/utils/pseudorandom.py b/examples/gradient_leakage_attacks/utils/pseudorandom.py similarity index 100% rename from examples/dlg/utils/pseudorandom.py rename to examples/gradient_leakage_attacks/utils/pseudorandom.py diff --git a/examples/dlg/utils/utils.py b/examples/gradient_leakage_attacks/utils/utils.py similarity index 100% rename from examples/dlg/utils/utils.py rename to examples/gradient_leakage_attacks/utils/utils.py diff --git a/examples/fedscr/fedscr.py b/examples/model_pruning/fedscr/fedscr.py similarity index 100% rename from examples/fedscr/fedscr.py rename to examples/model_pruning/fedscr/fedscr.py diff --git a/examples/fedscr/fedscr_MNIST_lenet5.yml b/examples/model_pruning/fedscr/fedscr_MNIST_lenet5.yml similarity index 100% rename from examples/fedscr/fedscr_MNIST_lenet5.yml rename to examples/model_pruning/fedscr/fedscr_MNIST_lenet5.yml diff --git a/examples/fedscr/fedscr_client.py b/examples/model_pruning/fedscr/fedscr_client.py similarity index 100% rename from examples/fedscr/fedscr_client.py rename to examples/model_pruning/fedscr/fedscr_client.py diff --git a/examples/fedscr/fedscr_server.py b/examples/model_pruning/fedscr/fedscr_server.py similarity index 100% rename from examples/fedscr/fedscr_server.py rename to examples/model_pruning/fedscr/fedscr_server.py diff --git a/examples/fedscr/fedscr_trainer.py b/examples/model_pruning/fedscr/fedscr_trainer.py similarity index 100% rename from examples/fedscr/fedscr_trainer.py rename to examples/model_pruning/fedscr/fedscr_trainer.py diff --git a/examples/sub_fedavg/subcs.py b/examples/model_pruning/sub_fedavg/subcs.py similarity index 100% rename from examples/sub_fedavg/subcs.py rename to examples/model_pruning/sub_fedavg/subcs.py diff --git a/examples/sub_fedavg/subcs_MNIST_lenet5.yml b/examples/model_pruning/sub_fedavg/subcs_MNIST_lenet5.yml similarity index 100% rename from examples/sub_fedavg/subcs_MNIST_lenet5.yml rename to examples/model_pruning/sub_fedavg/subcs_MNIST_lenet5.yml diff --git a/examples/sub_fedavg/subfedavg.py b/examples/model_pruning/sub_fedavg/subfedavg.py similarity index 100% rename from examples/sub_fedavg/subfedavg.py rename to examples/model_pruning/sub_fedavg/subfedavg.py diff --git a/examples/sub_fedavg/subfedavg_MNIST_lenet5.yml b/examples/model_pruning/sub_fedavg/subfedavg_MNIST_lenet5.yml similarity index 100% rename from examples/sub_fedavg/subfedavg_MNIST_lenet5.yml rename to examples/model_pruning/sub_fedavg/subfedavg_MNIST_lenet5.yml diff --git a/examples/sub_fedavg/subfedavg_client.py b/examples/model_pruning/sub_fedavg/subfedavg_client.py similarity index 100% rename from examples/sub_fedavg/subfedavg_client.py rename to examples/model_pruning/sub_fedavg/subfedavg_client.py diff --git a/examples/sub_fedavg/subfedavg_pruning.py b/examples/model_pruning/sub_fedavg/subfedavg_pruning.py similarity index 100% rename from examples/sub_fedavg/subfedavg_pruning.py rename to examples/model_pruning/sub_fedavg/subfedavg_pruning.py diff --git a/examples/sub_fedavg/subfedavg_trainer.py b/examples/model_pruning/sub_fedavg/subfedavg_trainer.py similarity index 100% rename from examples/sub_fedavg/subfedavg_trainer.py rename to examples/model_pruning/sub_fedavg/subfedavg_trainer.py diff --git a/examples/model_search/fedtp/FedTP_CIFAR10_T2TVIT14_NonIID03_scratch.yml b/examples/model_search/fedtp/FedTP_CIFAR10_T2TVIT14_NonIID03_scratch.yml index 040957cb1..20d782052 100644 --- a/examples/model_search/fedtp/FedTP_CIFAR10_T2TVIT14_NonIID03_scratch.yml +++ b/examples/model_search/fedtp/FedTP_CIFAR10_T2TVIT14_NonIID03_scratch.yml @@ -139,11 +139,9 @@ trainer: pers_epoch_log_interval: 2 pers_epoch_model_log_interval: 10 - # The machine learning model, - # it behaves as the encoder for the ssl method - # the final fc layer will be removed - # however, in the central test, we do not use this - # but use the custom model + # The machine learning model, it behaves as the encoder for the SSL method + # the final fc layer will be removed however, in the central test, we do not + # use this but use the custom model model_type: vit model_name: T2t_vit_14 personalized_model_name: T2t_vit_14 diff --git a/examples/personalized_fl/README.md b/examples/personalized_fl/README.md deleted file mode 100644 index 11455d80f..000000000 --- a/examples/personalized_fl/README.md +++ /dev/null @@ -1,116 +0,0 @@ -## Implemented Algorithms - -To better compare the performance of different personalized federated learning approaches, we implemented the following algorithms. - -### Baseline Algorithm - -We implemented the FedAvg with finetuning, referred to as FedAvg_finetune, as the baseline algorithm for personalized federated learning (FL). The code is available under `algorithms/fedavg_finetune/`. - -### Classical personalized FL approaches - -- {apfl} [Deng et al., "Adaptive Personalized Federated Learning", Arxiv 2020, citation 374](https://arxiv.org/pdf/2003.13461.pdf) - None - -- {fedper} [Arivazhagan et al., "Federated Learning with Personalization Layers", Arxiv 2019, citation 486](https://browse.arxiv.org/pdf/1912.00818.pdf) - [Third-party code](https://github.com/ki-ljl/FedPer) - -- {ditto} [Li et.al "Ditto: Fair and robust federated learning through personalization", ICML2020, citation 419](https://proceedings.mlr.press/v139/li21h.html) - [Official code](https://github.com/litian96/ditto) - -- {fedbabu} [Oh et.al "FedBABU: Toward Enhanced Representation for Federated Image Classification", ICLR 2022, citation 74](https://openreview.net/pdf?id=HuaYQfggn5u) - [Official code](https://github.com/jhoon-oh/FedBABU) - -- {fedrep} [Collins et al., "Exploiting Shared Representations for Personalized Federated -Learning", ICML21, citation 289](https://arxiv.org/abs/2102.07078) - [Official code](https://github.com/lgcollins/FedRep) - -- {lgfedavg} [Liang et al., "Think Locally, Act Globally: Federated Learning with Local and Global Representations", NeurIPS 2019, citation 359](https://arxiv.org/abs/2001.01523) - [Official code](https://github.com/pliang279/LG-FedAvg) - -- {perfedavg} [Fallah et al., "Personalized federated learning with theoretical guarantees: -A model-agnostic meta-learning approach", NeurIPS 2019, citation 502](https://proceedings.neurips.cc/paper/2020/hash/24389bfe4fe2eba8bf9aa9203a44cdad-Abstract.html) - [Third-party code](https://github.com/jhoon-oh/FedBABU) - -- {hermes} [Li et al., "Hermes: An Efficient Federated Learning Framework for Heterogeneous Mobile Clients", ACM MobiCom 21, citation 75](https://www.ang-li.com/assets/pdf/hermes.pdf) - None - -2. Perform the algorithms by running: - -```bash -python algorithms/fedavg_finetune/fedavg_finetune.py -c algorithms/configs/fedavg_finetune_CIFAR10_resnet18.yml -b pflExperiments -``` - ---- - -```bash -python algorithms/apfl/apfl.py -c algorithms/configs/apfl_CIFAR10_resnet18.yml -b pflExperiments -``` - -```bash -python algorithms/fedrep/fedrep.py -c algorithms/configs/fedrep_CIFAR10_resnet18.yml -b pflExperiments -``` - -```bash -python algorithms/fedbabu/fedbabu.py -c algorithms/configs/fedbabu_CIFAR10_resnet18.yml -b pflExperiments -``` - -```bash -python algorithms/ditto/ditto.py -c algorithms/configs/ditto_CIFAR10_resnet18.yml -b pflExperiments -``` - -```bash -python algorithms/fedper/fedper.py -c algorithms/configs/fedper_CIFAR10_resnet18.yml -b pflExperiments -``` - -```bash -python algorithms/lgfedavg/lgfedavg.py -c algorithms/configs/lgfedavg_CIFAR10_resnet18.yml -b pflExperiments -``` - -```bash -python algorithms/perfedavg/perfedavg.py -c algorithms/configs/perfedavg_CIFAR10_resnet18.yml -b pflExperiments -``` - -```bash -python algorithms/hermes/hermes.py -c algorithms/configs/hermes_CIFAR10_resnet18.yml -b pflExperiments -``` - - - -## Hyper-parameters - -All hyper-parameters should be placed under the `algorithm` block of the configuration file. - -### For `fedavg_personalized` -- global_layer_names: This is a list in which each item is a string presenting the parameter name. When you utilize the `fedavg_personalized.py` as the algorithm, the `global_layer_names` is required to be set under the `algorithm` block of the configuration file. Then, only the parameters contained in the `global_layer_names` will be the global model to be exchanged between the server and clients. Thus, server aggregation will be performed only on these parameters. If this hyper-parameter is not set, all parameters of the defined model will be used by default. For example, - ```yaml - algorithm: - global_layer_names: - - conv1 - - bn1 - - layer1 - - layer2 - - layer3 - - layer4 - ``` - -- local_layer_names: This is a list in which each item is a string presenting the parameter name. Once you set the `local_layer_names`, the client receives a portion of the model from the server. To embrace all parameters (i.e., the whole model) during the local update, you should set `local_layer_names` to indicate which parts parameters will be loaded from the local side to merge with the received ones. This is more like: the whole model is A+B. The client receives A from the server. Then, the client loads B from the local side to merge with A. For example, - ```yaml - algorithm: - local_layer_names: - - linear - ``` - -### For Personalization - -All hyper-parameters related to personalization should be placed under the `personalization` sub-block of the `algorithm` block. - -- model_name: A string to indicate the personalized model name. This is not mandatory as if it is omitted, it will be assumed to be the same as the global model. Default: `model_name` under the `trainer` block. For example, - ```yaml - algorithm: - personalization: - # the personalized model name - # this can be omitted - model_name: resnet_18 - ``` - -- participating_clients_ratio: A float to show the proportion of clients participating in the federated training process. The value ranges from 0.0 to 1.0 while 1.0 means that all clients will participant in training. Default: 1.0. For example, - ```yaml - algorithm: - personalization: - # the ratio of clients participanting in training - participating_clients_ratio: 0.6 - ``` - - diff --git a/examples/personalized_fl/apfl/apfl.py b/examples/personalized_fl/apfl/apfl.py index 114bf18ab..475201c69 100644 --- a/examples/personalized_fl/apfl/apfl.py +++ b/examples/personalized_fl/apfl/apfl.py @@ -1,15 +1,13 @@ """ -The implementation of APFL method. +An implementation of Adaptive Personalized Federated Learning (APFL). -Yuyang Deng, et al., Adaptive Personalized Federated Learning +Y. Deng, et al., "Adaptive Personalized Federated Learning" -paper address: https://arxiv.org/pdf/2003.13461.pdf +URL: https://arxiv.org/pdf/2003.13461.pdf -Official code: None Third-party code: -- https://github.com/MLOPTPSU/FedTorch/blob/main/main.py -- https://github.com/MLOPTPSU/FedTorch/blob/main/fedtorch/comms/trainings/federated/apfl.py - +https://github.com/MLOPTPSU/FedTorch/blob/main/main.py +https://github.com/MLOPTPSU/FedTorch/blob/main/fedtorch/comms/trainings/federated/apfl.py """ import apfl_trainer @@ -20,7 +18,7 @@ def main(): """ - A personalized federated learning session for APFL approach. + A personalized federated learning session using APFL. """ trainer = apfl_trainer.Trainer client = personalized_client.Client(trainer=trainer) diff --git a/examples/personalized_fl/apfl/apfl_trainer.py b/examples/personalized_fl/apfl/apfl_trainer.py index eb8d09768..5cde1ab1b 100644 --- a/examples/personalized_fl/apfl/apfl_trainer.py +++ b/examples/personalized_fl/apfl/apfl_trainer.py @@ -1,20 +1,20 @@ """ -A personalized federated learning trainer For APFL. +A personalized federated learning trainer using APFL. """ -import os import logging +import os + import numpy as np import torch -from plato.trainers import basic -from plato.models import registry as models_registry from plato.config import Config +from plato.models import registry as models_registry +from plato.trainers import basic class Trainer(basic.Trainer): """ - A trainer using the APFL algorithm to jointly train the global and - personalized models. + A trainer using the APFL algorithm to train both global and personalized models. """ def __init__(self, model=None, callbacks=None): @@ -77,7 +77,7 @@ def train_run_start(self, config): self.personalized_model.train() def train_run_end(self, config): - """Saving the alpha.""" + """Saves alpha to a file identified by the client id, and saves the personalized model.""" super().train_run_end(config) # Save the alpha to the file @@ -94,7 +94,7 @@ def train_run_end(self, config): torch.save(self.personalized_model.state_dict(), model_path) def train_step_end(self, config, batch=None, loss=None): - """Updating the alpha of APFL before each iteration.""" + """Updates alpha in APFL before each iteration.""" super().train_step_end(config, batch, loss) # Update alpha based on Eq. 10 in the paper diff --git a/examples/personalized_fl/configs/perfedavg_CIFAR10_resnet18.yml b/examples/personalized_fl/configs/perfedavg_CIFAR10_resnet18.yml index ca1f70983..14e9b0cca 100644 --- a/examples/personalized_fl/configs/perfedavg_CIFAR10_resnet18.yml +++ b/examples/personalized_fl/configs/perfedavg_CIFAR10_resnet18.yml @@ -34,14 +34,10 @@ algorithm: # Aggregation algorithm type: fedavg_personalized - # Important hyper-parameters - # False for performing Per-FedAvg(FO), others for Per-FedAvg(HF) - hessian_free: False alpha: 0.01 # 1e-2 beta: 0.001 personalization: - # the ratio of clients participanting in training participating_client_ratio: 1.0 diff --git a/examples/personalized_fl/ditto/ditto.py b/examples/personalized_fl/ditto/ditto.py index 7fc6c0079..509753811 100644 --- a/examples/personalized_fl/ditto/ditto.py +++ b/examples/personalized_fl/ditto/ditto.py @@ -1,13 +1,12 @@ """ -The implementation of Ditto method based on the pFL framework of Plato. +An implementation of the Ditto personalized federated learning algorithm. -Reference: -Tian Li, et al., Ditto: Fair and robust federated learning through personalization, 2021: - https://proceedings.mlr.press/v139/li21h.html +T. Li, et al., "Ditto: Fair and robust federated learning through personalization," +in the Proceedings of ICML 2021. -Official code: https://github.com/litian96/ditto -Third-party code: https://github.com/lgcollins/FedRep +https://proceedings.mlr.press/v139/li21h.html +Source code: https://github.com/litian96/ditto """ import ditto_trainer @@ -15,9 +14,10 @@ from plato.servers import fedavg_personalized as personalized_server from plato.clients import fedavg_personalized as personalized_client + def main(): """ - A personalized federated learning session for Ditto approach. + A personalized federated learning session with Ditto. """ trainer = ditto_trainer.Trainer client = personalized_client.Client(trainer=trainer) diff --git a/examples/personalized_fl/ditto/ditto_trainer.py b/examples/personalized_fl/ditto/ditto_trainer.py index 5e18f3ac0..de81a8f9b 100644 --- a/examples/personalized_fl/ditto/ditto_trainer.py +++ b/examples/personalized_fl/ditto/ditto_trainer.py @@ -1,5 +1,5 @@ """ -A personalized federated learning trainer using Ditto. +A personalized federated learning trainer with Ditto. """ import os import copy @@ -14,31 +14,36 @@ class Trainer(basic.Trainer): - """A personalized federated learning trainer using the Ditto algorithm.""" + """ + A trainer with Ditto, which first trains the global model and then trains + the personalized model. + """ def __init__(self, model=None, callbacks=None): super().__init__(model, callbacks) - - # The lambda (used in the paper) + # The lambda adjusts the gradients self.ditto_lambda = Config().algorithm.ditto_lambda - # The personalized model + # Get the personalized model if model is None: self.personalized_model = models_registry.get() else: self.personalized_model = model() - # The global model weights received from the server, which is the w^t in - # the paper + # The global model weights, which is w^t in the paper self.initial_wnet_params = None def train_run_start(self, config): super().train_run_start(config) + # Make a copy of the model before local training starts, which will be used when optimizing + # the personalized model self.initial_wnet_params = copy.deepcopy(self.model.cpu().state_dict()) def train_run_end(self, config): - """Perform personalized training, proposed in Ditto.""" + """ + Optimize the personalized model for epochs following Algorithm 1. + """ super().train_run_end(config) logging.info( @@ -49,6 +54,7 @@ def train_run_end(self, config): self.client_id, ) + # Load personalized model model_path = Config().params["model_path"] model_name = Config().trainer.model_name filename = f"{model_path}/{model_name}_{self.client_id}_v_net.pth" @@ -64,6 +70,7 @@ def train_run_end(self, config): self.personalized_model.to(self.device) self.personalized_model.train() + for epoch in range(1, config["epochs"] + 1): epoch_loss_meter.reset() for __, (examples, labels) in enumerate(self.train_loader): diff --git a/examples/personalized_fl/fedavg_finetune/fedavg_finetune.py b/examples/personalized_fl/fedavg_finetune/fedavg_finetune.py index 7840fed1e..e63815ccf 100644 --- a/examples/personalized_fl/fedavg_finetune/fedavg_finetune.py +++ b/examples/personalized_fl/fedavg_finetune/fedavg_finetune.py @@ -1,14 +1,12 @@ """ -An implementation of the personalized learning variant of FedAvg. +This implementation first trains a global model using conventional FedAvg until +a target number of rounds has been reached. In the final `personalization` +round, each client will use its local data samples to further fine-tune the +shared global model for a number of epochs, and then the server will compute the +average client test accuracy. -The core idea is to achieve personalized FL in two stages: -First, it trains a global model using conventional FedAvg until convergence. -Second, each client freezes the trained global model and optimizes the other -parts. - -Due to its simplicity, no work has been proposed that specifically discusses -this algorithm but only utilizes it as the baseline for personalized federated -learning. +Due to its simplicity, no papers specifically discussed or proposed this +algorithm; they only utilized it as their baseline for comparisons. """ from plato.clients import fedavg_personalized as personalized_client diff --git a/examples/personalized_fl/fedbabu/fedbabu.py b/examples/personalized_fl/fedbabu/fedbabu.py index 34f81f420..5919ae27b 100644 --- a/examples/personalized_fl/fedbabu/fedbabu.py +++ b/examples/personalized_fl/fedbabu/fedbabu.py @@ -1,7 +1,7 @@ """ An implementation of the FedBABU algorithm. -J. Oh, et al., "FedBABU: Toward Enhanced Representation for Federated Image Classification," +Oh, et al., "FedBABU: Toward Enhanced Representation for Federated Image Classification," in the Proceedings of ICLR 2022. https://openreview.net/pdf?id=HuaYQfggn5u @@ -17,7 +17,7 @@ def main(): """ - A personalized federated learning session for FedBABU algorithm under the supervised setting. + A personalized federated learning session with FedBABU. """ trainer = fedbabu_trainer.Trainer client = personalized_client.Client(trainer=trainer) diff --git a/examples/personalized_fl/fedbabu/fedbabu_trainer.py b/examples/personalized_fl/fedbabu/fedbabu_trainer.py index 2e8a042a7..94000f7d1 100644 --- a/examples/personalized_fl/fedbabu/fedbabu_trainer.py +++ b/examples/personalized_fl/fedbabu/fedbabu_trainer.py @@ -1,5 +1,5 @@ """ -A personalized federated learning trainer for FedBABU. +A personalized federated learning trainer with FedBABU. """ from plato.config import Config from plato.trainers import basic @@ -7,14 +7,13 @@ class Trainer(basic.Trainer): - """A trainer to freeze and activate layers of one model - for normal and personalized learning processes.""" + """ + A trainer with FedBABU, which freezes the global model layers in the final + personalization round, and freezes the local layers instead in the regular + rounds before the target number of rounds has been reached. + """ def train_run_start(self, config): - """According to FedBABU, - 1. freeze first part of the model during federated training phase. - 2. freeze second part of the personalized model during personalized learning phase. - """ super().train_run_start(config) if self.current_round > Config().trainer.rounds: trainer_utils.freeze_model( @@ -28,7 +27,6 @@ def train_run_start(self, config): ) def train_run_end(self, config): - """Activate the model.""" super().train_run_end(config) if self.current_round > Config().trainer.rounds: diff --git a/examples/personalized_fl/fedper/fedper_trainer.py b/examples/personalized_fl/fedper/fedper_trainer.py index ce687938b..8ed33e128 100644 --- a/examples/personalized_fl/fedper/fedper_trainer.py +++ b/examples/personalized_fl/fedper/fedper_trainer.py @@ -1,5 +1,5 @@ """ -A personalized federated learning trainer for FedPer. +A personalized federated learning trainer with FedPer. """ from plato.config import Config from plato.trainers import basic @@ -7,13 +7,12 @@ class Trainer(basic.Trainer): - """A trainer to freeze and activate layers of one model - for normal and personalized learning processes.""" + """ + A trainer with FedPer, which freezes the global model layers in the final + personalization round. + """ def train_run_start(self, config): - """According to FedPer, - Freeze body of the model during personalization. - """ super().train_run_start(config) if self.current_round > Config().trainer.rounds: trainer_utils.freeze_model( @@ -24,6 +23,7 @@ def train_run_start(self, config): def train_run_end(self, config): """Activate the model.""" super().train_run_end(config) + if self.current_round > Config().trainer.rounds: trainer_utils.activate_model( self.model, Config().algorithm.global_layer_names diff --git a/examples/personalized_fl/fedrep/fedrep_trainer.py b/examples/personalized_fl/fedrep/fedrep_trainer.py index e3679aa9a..7fa7dfc49 100644 --- a/examples/personalized_fl/fedrep/fedrep_trainer.py +++ b/examples/personalized_fl/fedrep/fedrep_trainer.py @@ -1,5 +1,5 @@ """ -A trainer for FedRep approach. +A trainer with FedRep. """ from plato.trainers import basic from plato.config import Config @@ -8,10 +8,10 @@ class Trainer(basic.Trainer): - """A trainer for FedRep.""" + """A trainer with FedRep.""" def train_run_start(self, config): - """Freeze the global layers during the personalization round.""" + """Freeze the global layers during the final personalization round.""" super().train_run_start(config) if self.current_round > Config().trainer.rounds: diff --git a/examples/personalized_fl/hermes/hermes.py b/examples/personalized_fl/hermes/hermes.py index 891ebb5e3..d88f1220f 100644 --- a/examples/personalized_fl/hermes/hermes.py +++ b/examples/personalized_fl/hermes/hermes.py @@ -1,9 +1,10 @@ """ -A federated learning training session using Hermes +A federated learning training session using Hermes. A. Li, J. Sun, P. Li, Y. Pu, H. Li, and Y. Chen, -“Hermes: An Efficient Federated Learning Framework for Heterogeneous Mobile Clients,” -in Proc. 27th Annual International Conference on Mobile Computing and Networking (MobiCom), 2021. +“Hermes: An Efficient Federated Learning Framework for Heterogeneous Mobile +Clients,” in Proc. 27th Annual International Conference on Mobile Computing and +Networking (MobiCom), 2021. """ from hermes_callback import HermesCallback diff --git a/examples/personalized_fl/hermes/hermes_callback.py b/examples/personalized_fl/hermes/hermes_callback.py index e28a78771..95195acc2 100644 --- a/examples/personalized_fl/hermes/hermes_callback.py +++ b/examples/personalized_fl/hermes/hermes_callback.py @@ -1,55 +1,9 @@ """ Callback for attaching a pruning mask to the payload if pruning had been conducted. """ - -import os -import pickle import logging -from typing import OrderedDict as OrderedDictType - - +from hermes_processor import SendMaskProcessor from plato.callbacks.client import ClientCallback -from plato.processors import base -from plato.config import Config - - -class SendMaskProcessor(base.Processor): - """ - Implements a processor for attaching a pruning mask to the payload if pruning - had been conducted - """ - - def __init__(self, client_id, **kwargs) -> None: - super().__init__(**kwargs) - - self.client_id = client_id - - def process(self, data: OrderedDictType): - model_name = ( - Config().trainer.model_name - if hasattr(Config().trainer, "model_name") - else "custom" - ) - model_path = Config().params["model_path"] - - mask_filename = f"{model_path}/{model_name}_client{self.client_id}_mask.pth" - if os.path.exists(mask_filename): - with open(mask_filename, "rb") as payload_file: - client_mask = pickle.load(payload_file) - data = [data, client_mask] - else: - data = [data, None] - - if data[1] is not None: - if self.client_id is None: - logging.info( - "[Server #%d] Pruning mask attached to payload.", self.server_id - ) - else: - logging.info( - "[Client #%d] Pruning mask attached to payload.", self.client_id - ) - return data class HermesCallback(ClientCallback): diff --git a/examples/personalized_fl/hermes/hermes_processor.py b/examples/personalized_fl/hermes/hermes_processor.py new file mode 100644 index 000000000..77c02904f --- /dev/null +++ b/examples/personalized_fl/hermes/hermes_processor.py @@ -0,0 +1,51 @@ +""" +An outbound processor for Hermes to load a mask from the local file system on the client, +and attach it to the payload. +""" + +import os +import pickle +import logging +from typing import OrderedDict + +from plato.processors import base +from plato.config import Config + + +class SendMaskProcessor(base.Processor): + """ + Implements a processor for attaching a pruning mask to the payload if pruning + had been conducted. + """ + + def __init__(self, client_id, **kwargs) -> None: + super().__init__(**kwargs) + + self.client_id = client_id + + def process(self, data: OrderedDict): + model_name = ( + Config().trainer.model_name + if hasattr(Config().trainer, "model_name") + else "custom" + ) + model_path = Config().params["model_path"] + + mask_filename = f"{model_path}/{model_name}_client{self.client_id}_mask.pth" + if os.path.exists(mask_filename): + with open(mask_filename, "rb") as payload_file: + client_mask = pickle.load(payload_file) + data = [data, client_mask] + else: + data = [data, None] + + if data[1] is not None: + if self.client_id is None: + logging.info( + "[Server #%d] Pruning mask attached to payload.", self.server_id + ) + else: + logging.info( + "[Client #%d] Pruning mask attached to payload.", self.client_id + ) + return data diff --git a/examples/personalized_fl/hermes/hermes_pruning.py b/examples/personalized_fl/hermes/hermes_pruning.py index c57a60bbe..c41fdbc30 100644 --- a/examples/personalized_fl/hermes/hermes_pruning.py +++ b/examples/personalized_fl/hermes/hermes_pruning.py @@ -69,7 +69,7 @@ def structured_pruning(model, pruning_rate, adjust_rate=0.0): mask = [] if adjust_rate == 0: - for __, layer in model.named_parameters(): + for layer in model.modules(): if isinstance(layer, (torch.nn.Conv2d, torch.nn.Linear)): pruning_rates.append(pruning_rate) else: @@ -93,18 +93,16 @@ def structured_pruning(model, pruning_rate, adjust_rate=0.0): ) / weight_nums[step] pruning_rates[step] = (pruning_rates[step] * 100) / (100 - adjust_rate) - step = 0 - step = 0 - for __, layer in model.named_parameters(): + for layer in model.modules(): if isinstance(layer, (torch.nn.Conv2d, torch.nn.Linear)): amount = pruning_rates[step] prune.ln_structured(layer, "weight", amount, norm, dim) + for name, buffer in layer.named_buffers(): + if "mask" in name: + mask.append(buffer.cpu().numpy()) step += 1 - - for name, buffer in model.named_buffers(): - if "mask" in name: - mask.append(buffer.cpu().numpy()) + prune.remove(layer, "weight") return mask @@ -122,5 +120,4 @@ def apply_mask(model, mask, device): device = layer.weight.device prune.custom_from_mask(layer, "weight", mask[step].to(device)) step += 1 - return model diff --git a/examples/personalized_fl/hermes/hermes_server.py b/examples/personalized_fl/hermes/hermes_server.py index dfed0b959..2c0d39d76 100644 --- a/examples/personalized_fl/hermes/hermes_server.py +++ b/examples/personalized_fl/hermes/hermes_server.py @@ -96,7 +96,6 @@ async def aggregate_weights(self, updates, baseline_weights, weights_received): def update_client_model(self, aggregated_clients_models, updates): """Update clients' models.""" - for client_model, update in zip(aggregated_clients_models, updates): received_client_id = update.client_id if received_client_id in self.aggregated_clients_model: @@ -105,11 +104,10 @@ def update_client_model(self, aggregated_clients_models, updates): def customize_server_payload(self, payload): """Customizes the server payload before sending to the client.""" - # If the client has already begun the learning of a personalized model + # If the client has already begun training a personalized model # in a previous communication round, the personalized file is loaded and # sent to the client for continued training. Otherwise, if the client is # selected for the first time, it receives the pre-initialized model. - if self.selected_client_id in self.aggregated_clients_model: # replace the payload for the current client with the personalized model payload = self.aggregated_clients_model[self.selected_client_id] diff --git a/examples/personalized_fl/hermes/hermes_trainer.py b/examples/personalized_fl/hermes/hermes_trainer.py index cd4648a23..09d9d86dd 100644 --- a/examples/personalized_fl/hermes/hermes_trainer.py +++ b/examples/personalized_fl/hermes/hermes_trainer.py @@ -1,10 +1,5 @@ """ -The training loop that takes place on clients. - -As each client of Hermes do not hold the personalized model but receives one -from the server, there is no need to do any operations on the personalized_model. -But the received model will be personalized model directly. - +The trainer used by clients using Hermes. """ import logging @@ -42,8 +37,9 @@ def __init__(self, model=None, callbacks=None): def train_run_start(self, config): """Conducts pruning if needed before training.""" - # Evaluate if structured pruning should be conducted super().train_run_start(config) + + # Evaluate if structured pruning should be conducted self.datasource = datasources_registry.get(client_id=self.client_id) self.testset = self.datasource.get_test_set() logging.info( diff --git a/examples/personalized_fl/lgfedavg/lgfedavg.py b/examples/personalized_fl/lgfedavg/lgfedavg.py index 74fe7680a..fee7ad327 100644 --- a/examples/personalized_fl/lgfedavg/lgfedavg.py +++ b/examples/personalized_fl/lgfedavg/lgfedavg.py @@ -1,9 +1,12 @@ """ -The implementation of LG-FedAvg method based on the plato's pFL code. +An implementation of LG-FedAvg. -Paul Pu Liang, et al., Think Locally, Act Globally: Federated Learning -with Local and Global Representations. https://arxiv.org/abs/2001.01523 +P. Liang, et al., "Think Locally, Act Globally: Federated Learning +with Local and Global Representations," Arxiv 2020. +https://arxiv.org/abs/2001.01523 + +Source code: https://github.com/pliang279/LG-FedAvg """ import lgfedavg_trainer diff --git a/examples/personalized_fl/lgfedavg/lgfedavg_trainer.py b/examples/personalized_fl/lgfedavg/lgfedavg_trainer.py index dc3da07bc..472a212c5 100644 --- a/examples/personalized_fl/lgfedavg/lgfedavg_trainer.py +++ b/examples/personalized_fl/lgfedavg/lgfedavg_trainer.py @@ -1,5 +1,5 @@ """ -A personalized federated learning trainer using LG-FedAvg. +A personalized federated learning trainer with LG-FedAvg. """ from plato.trainers import basic from plato.config import Config @@ -7,23 +7,25 @@ class Trainer(basic.Trainer): - """A personalized federated learning trainer using the LG-FedAvg algorithm.""" + """ + The training loop in LG-FedAvg performs two forward and backward passes in + one iteration. It first freezes the global model layers and trains the local + layers, and then freezes the local layers and trains global layers to finish + one training loop. + """ def perform_forward_and_backward_passes(self, config, examples, labels): """Performing one iteration of LG-FedAvg.""" - # LG-FedAvg will first only train local layers + # LG-FedAvg first only trains local layers trainer_utils.freeze_model(self.model, Config().algorithm.global_layer_names) trainer_utils.activate_model(self.model, Config().algorithm.local_layer_names) super().perform_forward_and_backward_passes(config, examples, labels) - # Then, LG-FedAvg will only train non-local layers + # Secondly, LG-FedAvg only trains non-local layers trainer_utils.activate_model(self.model, Config().algorithm.global_layer_names) - trainer_utils.freeze_model( - self.model, - Config().algorithm.local_layer_names, - ) + trainer_utils.freeze_model(self.model, Config().algorithm.local_layer_names) loss = super().perform_forward_and_backward_passes(config, examples, labels) diff --git a/examples/personalized_fl/perfedavg/perfedavg.py b/examples/personalized_fl/perfedavg/perfedavg.py index 1f6b0daaf..8ffe122cb 100644 --- a/examples/personalized_fl/perfedavg/perfedavg.py +++ b/examples/personalized_fl/perfedavg/perfedavg.py @@ -1,14 +1,13 @@ """ -The implementation of Per-FedAvg method based on the plato's -pFL code. +A federated learning training session using Per-FedAvg. -Reference -Alireza Fallah, et al., Personalized federated learning with theoretical guarantees: -A model-agnostic meta-learning approach, NeurIPS 2020. -https://proceedings.neurips.cc/paper/2020/hash/24389bfe4fe2eba8bf9aa9203a44cdad-Abstract.html +A. Fallah, et al., “Personalized Federated Learning with Theoretical Guarantees: +A Model-Agnostic Meta-Learning Approach,” in Proc. Advances in Neural +Information Processing Systems (NeurIPS), 2020. -Third-party code: https://github.com/jhoon-oh/FedBABU +https://dl.acm.org/doi/abs/10.5555/3495724.3496024 +Third-party code: https://github.com/jhoon-oh/FedBABU """ import perfedavg_trainer @@ -16,9 +15,10 @@ from plato.servers import fedavg_personalized as personalized_server from plato.clients import fedavg_personalized as personalized_client + def main(): """ - A personalized federated learning session for PerFedAvg approach. + A personalized federated learning session using the Per-FedAvg algorithm. """ trainer = perfedavg_trainer.Trainer client = personalized_client.Client(trainer=trainer) diff --git a/examples/personalized_fl/perfedavg/perfedavg_trainer.py b/examples/personalized_fl/perfedavg/perfedavg_trainer.py index 1c60968bf..9a6c48587 100644 --- a/examples/personalized_fl/perfedavg/perfedavg_trainer.py +++ b/examples/personalized_fl/perfedavg/perfedavg_trainer.py @@ -1,5 +1,5 @@ """ -A personalized federated learning trainer using Per-FedAvg +A personalized federated learning trainer using Per-FedAvg. """ import copy @@ -14,31 +14,27 @@ class Trainer(basic.Trainer): def __init__(self, model=None, callbacks=None): super().__init__(model, callbacks) - # the iterator for the dataloader self.iter_trainloader = None def train_epoch_start(self, config): - """Defining the iterator for the train dataloader.""" + """Runs at the start of each epoch.""" super().train_epoch_start(config) + self.iter_trainloader = iter(self.train_loader) def perform_forward_and_backward_passes(self, config, examples, labels): - """Perform forward and backward passes in the training loop.""" + """Performs forward and backward passes in the training loop.""" if self.current_round > Config().trainer.rounds: - # No meta learning in the fine-tuning in the final personalization round + # During the final personalization round, the normal training loop is used return super().perform_forward_and_backward_passes(config, examples, labels) else: - alpha = Config().algorithm.alpha - beta = Config().algorithm.beta - - # Put the current model weights into the other meta model + # Save a copy of the current model weights past_model_params = copy.deepcopy(list(self.model.parameters())) - # Step 1 - # Update model with learning rate alpha. + # Step 1: Update the model with a fixed learning rate, alpha, in Algorithm 1 for g in self.optimizer.param_groups: - g["lr"] = alpha + g["lr"] = Config().algorithm.alpha self.optimizer.zero_grad() logits = self.model(examples) @@ -46,30 +42,25 @@ def perform_forward_and_backward_passes(self, config, examples, labels): loss.backward() self.optimizer.step() - # Step 2 - # Calculate the meta gradients + # Step 2: Compute the meta gradients with a fixed learning rate, beta, in Algorithm 1 for g in self.optimizer.param_groups: - g["lr"] = beta + g["lr"] = Config().algorithm.beta self.optimizer.zero_grad() examples, labels = next(self.iter_trainloader) examples, labels = examples.to(self.device), labels.to(self.device) - logits = self.model(examples) - loss = self._loss_criterion(logits, labels) self._loss_tracker.update(loss, labels.size(0)) loss.backward() - # Step 3 - # Update model weights with meta model's gradients - # The model parameter is only updated here, in each iteration. - # Use the gradients by step 2 to update the weights before step 1. + # Step 3: Restore the model weights saved before step 1 for model_param, past_model_param in zip( self.model.parameters(), past_model_params ): model_param.data = past_model_param.data.clone() + # Update the model with the meta gradients from step 2 self.optimizer.step() return loss diff --git a/examples/controlnet_split_learning/ControlNet b/examples/split_learning/controlnet_split_learning/ControlNet similarity index 100% rename from examples/controlnet_split_learning/ControlNet rename to examples/split_learning/controlnet_split_learning/ControlNet diff --git a/examples/controlnet_split_learning/OrgModel/cldm_client.py b/examples/split_learning/controlnet_split_learning/OrgModel/cldm_client.py similarity index 100% rename from examples/controlnet_split_learning/OrgModel/cldm_client.py rename to examples/split_learning/controlnet_split_learning/OrgModel/cldm_client.py diff --git a/examples/controlnet_split_learning/OrgModel/cldm_client_safe.py b/examples/split_learning/controlnet_split_learning/OrgModel/cldm_client_safe.py similarity index 100% rename from examples/controlnet_split_learning/OrgModel/cldm_client_safe.py rename to examples/split_learning/controlnet_split_learning/OrgModel/cldm_client_safe.py diff --git a/examples/controlnet_split_learning/OrgModel/cldm_server.py b/examples/split_learning/controlnet_split_learning/OrgModel/cldm_server.py similarity index 100% rename from examples/controlnet_split_learning/OrgModel/cldm_server.py rename to examples/split_learning/controlnet_split_learning/OrgModel/cldm_server.py diff --git a/examples/controlnet_split_learning/OrgModel/cldm_server_safe.py b/examples/split_learning/controlnet_split_learning/OrgModel/cldm_server_safe.py similarity index 100% rename from examples/controlnet_split_learning/OrgModel/cldm_server_safe.py rename to examples/split_learning/controlnet_split_learning/OrgModel/cldm_server_safe.py diff --git a/examples/controlnet_split_learning/OrgModel/cldm_v15_client.yaml b/examples/split_learning/controlnet_split_learning/OrgModel/cldm_v15_client.yaml similarity index 100% rename from examples/controlnet_split_learning/OrgModel/cldm_v15_client.yaml rename to examples/split_learning/controlnet_split_learning/OrgModel/cldm_v15_client.yaml diff --git a/examples/controlnet_split_learning/OrgModel/cldm_v15_client_safe.yaml b/examples/split_learning/controlnet_split_learning/OrgModel/cldm_v15_client_safe.yaml similarity index 100% rename from examples/controlnet_split_learning/OrgModel/cldm_v15_client_safe.yaml rename to examples/split_learning/controlnet_split_learning/OrgModel/cldm_v15_client_safe.yaml diff --git a/examples/controlnet_split_learning/OrgModel/cldm_v15_server.yaml b/examples/split_learning/controlnet_split_learning/OrgModel/cldm_v15_server.yaml similarity index 100% rename from examples/controlnet_split_learning/OrgModel/cldm_v15_server.yaml rename to examples/split_learning/controlnet_split_learning/OrgModel/cldm_v15_server.yaml diff --git a/examples/controlnet_split_learning/OrgModel/cldm_v15_server_safe.yaml b/examples/split_learning/controlnet_split_learning/OrgModel/cldm_v15_server_safe.yaml similarity index 100% rename from examples/controlnet_split_learning/OrgModel/cldm_v15_server_safe.yaml rename to examples/split_learning/controlnet_split_learning/OrgModel/cldm_v15_server_safe.yaml diff --git a/examples/controlnet_split_learning/OrgModel/model.py b/examples/split_learning/controlnet_split_learning/OrgModel/model.py similarity index 100% rename from examples/controlnet_split_learning/OrgModel/model.py rename to examples/split_learning/controlnet_split_learning/OrgModel/model.py diff --git a/examples/controlnet_split_learning/controlnet_datasource.py b/examples/split_learning/controlnet_split_learning/controlnet_datasource.py similarity index 100% rename from examples/controlnet_split_learning/controlnet_datasource.py rename to examples/split_learning/controlnet_split_learning/controlnet_datasource.py diff --git a/examples/controlnet_split_learning/dataset/dataset_basic.py b/examples/split_learning/controlnet_split_learning/dataset/dataset_basic.py similarity index 100% rename from examples/controlnet_split_learning/dataset/dataset_basic.py rename to examples/split_learning/controlnet_split_learning/dataset/dataset_basic.py diff --git a/examples/controlnet_split_learning/dataset/dataset_celeba.py b/examples/split_learning/controlnet_split_learning/dataset/dataset_celeba.py similarity index 100% rename from examples/controlnet_split_learning/dataset/dataset_celeba.py rename to examples/split_learning/controlnet_split_learning/dataset/dataset_celeba.py diff --git a/examples/controlnet_split_learning/dataset/dataset_coco.py b/examples/split_learning/controlnet_split_learning/dataset/dataset_coco.py similarity index 100% rename from examples/controlnet_split_learning/dataset/dataset_coco.py rename to examples/split_learning/controlnet_split_learning/dataset/dataset_coco.py diff --git a/examples/controlnet_split_learning/dataset/dataset_fill50k.py b/examples/split_learning/controlnet_split_learning/dataset/dataset_fill50k.py similarity index 100% rename from examples/controlnet_split_learning/dataset/dataset_fill50k.py rename to examples/split_learning/controlnet_split_learning/dataset/dataset_fill50k.py diff --git a/examples/controlnet_split_learning/dataset/dataset_omniglot.py b/examples/split_learning/controlnet_split_learning/dataset/dataset_omniglot.py similarity index 100% rename from examples/controlnet_split_learning/dataset/dataset_omniglot.py rename to examples/split_learning/controlnet_split_learning/dataset/dataset_omniglot.py diff --git a/examples/controlnet_split_learning/dataset/download_coco.py b/examples/split_learning/controlnet_split_learning/dataset/download_coco.py similarity index 100% rename from examples/controlnet_split_learning/dataset/download_coco.py rename to examples/split_learning/controlnet_split_learning/dataset/download_coco.py diff --git a/examples/controlnet_split_learning/split_learning.yml b/examples/split_learning/controlnet_split_learning/split_learning.yml similarity index 100% rename from examples/controlnet_split_learning/split_learning.yml rename to examples/split_learning/controlnet_split_learning/split_learning.yml diff --git a/examples/controlnet_split_learning/split_learning_algorithm.py b/examples/split_learning/controlnet_split_learning/split_learning_algorithm.py similarity index 100% rename from examples/controlnet_split_learning/split_learning_algorithm.py rename to examples/split_learning/controlnet_split_learning/split_learning_algorithm.py diff --git a/examples/controlnet_split_learning/split_learning_client.py b/examples/split_learning/controlnet_split_learning/split_learning_client.py similarity index 100% rename from examples/controlnet_split_learning/split_learning_client.py rename to examples/split_learning/controlnet_split_learning/split_learning_client.py diff --git a/examples/controlnet_split_learning/split_learning_main.py b/examples/split_learning/controlnet_split_learning/split_learning_main.py similarity index 100% rename from examples/controlnet_split_learning/split_learning_main.py rename to examples/split_learning/controlnet_split_learning/split_learning_main.py diff --git a/examples/controlnet_split_learning/split_learning_safe.yml b/examples/split_learning/controlnet_split_learning/split_learning_safe.yml similarity index 100% rename from examples/controlnet_split_learning/split_learning_safe.yml rename to examples/split_learning/controlnet_split_learning/split_learning_safe.yml diff --git a/examples/controlnet_split_learning/split_learning_server.py b/examples/split_learning/controlnet_split_learning/split_learning_server.py similarity index 100% rename from examples/controlnet_split_learning/split_learning_server.py rename to examples/split_learning/controlnet_split_learning/split_learning_server.py diff --git a/examples/controlnet_split_learning/split_learning_trainer.py b/examples/split_learning/controlnet_split_learning/split_learning_trainer.py similarity index 100% rename from examples/controlnet_split_learning/split_learning_trainer.py rename to examples/split_learning/controlnet_split_learning/split_learning_trainer.py diff --git a/examples/ssl/algorithms/byol/byol.py b/examples/ssl/algorithms/byol/byol.py deleted file mode 100644 index 1e43be96e..000000000 --- a/examples/ssl/algorithms/byol/byol.py +++ /dev/null @@ -1,37 +0,0 @@ -""" -The implementation for the BYOL [1] method. - -Reference: -[1]. Jean-Bastien Grill, et al., Bootstrap Your Own Latent A New Approach to Self-Supervised Learning. -https://arxiv.org/pdf/2006.07733.pdf. - -Source code: https://github.com/lucidrains/byol-pytorch -The third-party code: https://github.com/sthalles/PyTorch-BYOL -""" - -from plato.servers import fedavg_personalized as personalized_server - -from ssl import ssl_client -from ssl import ssl_datasources - -from byol_model import BYOLModel -import byol_trainer - - -def main(): - """ - A personalized federated learning session for BYOL approach. - """ - trainer = byol_trainer.Trainer - client = ssl_client.Client( - model=BYOLModel, - datasource=ssl_datasources.SSLDataSource, - trainer=trainer, - ) - server = personalized_server.Server(model=BYOLModel, trainer=trainer) - - server.run(client) - - -if __name__ == "__main__": - main() diff --git a/examples/ssl/algorithms/calibre/clustering.py b/examples/ssl/algorithms/calibre/clustering.py deleted file mode 100644 index 3742f5720..000000000 --- a/examples/ssl/algorithms/calibre/clustering.py +++ /dev/null @@ -1,29 +0,0 @@ -""" -Clustering based on the encodinds. -""" - -import torch - - -def kmeans_clustering(features, n_clusters, max_iters=100): - """Computing the keams""" - # Initialize centroids randomly - centroids = features[torch.randperm(features.size(0))[:n_clusters]] - - for _ in range(max_iters): - # Assign each data point to the nearest centroid - distances = torch.cdist(features, centroids) # Compute distances - cluster_ids = torch.argmin(distances, dim=1) # Assign labels - - # Update centroids as the mean of the assigned data points - new_centroids = torch.stack( - [features[cluster_ids == i].mean(0) for i in range(n_clusters)] - ) - - # Check for convergence - if torch.all(new_centroids == centroids): - break - - centroids = new_centroids - - return cluster_ids, centroids diff --git a/examples/ssl/algorithms/moco/mocov2.py b/examples/ssl/algorithms/moco/mocov2.py deleted file mode 100644 index ec3f73216..000000000 --- a/examples/ssl/algorithms/moco/mocov2.py +++ /dev/null @@ -1,43 +0,0 @@ -""" -The implementation for the MoCoV2 [2] method, which is the enhanced version of MoCoV1 [1], -for personalized federated learning. - -Reference: -[1]. Kaiming He, et al., Momentum Contrast for Unsupervised Visual Representation Learning, -CVPR 2020. https://arxiv.org/abs/1911.05722. - -[2]. Xinlei Chen, et al., Improved Baselines with Momentum Contrastive Learning, ArXiv, 2020. -https://arxiv.org/abs/2003.04297. - -The official code: https://github.com/facebookresearch/moco. - -""" - -from plato.servers import fedavg_personalized as personalized_server - -from ssl import ssl_client -from ssl import ssl_datasources - -import mocov2_model -import mocov2_trainer - - -def main(): - """ - A personalized federated learning session for BYOL approach. - """ - client = ssl_client.Client( - model=mocov2_model.MoCoV2, - datasource=ssl_datasources.SSLDataSource, - trainer=mocov2_trainer.Trainer, - ) - server = personalized_server.Server( - model=mocov2_model.MoCoV2, - trainer=mocov2_trainer.Trainer, - ) - - server.run(client) - - -if __name__ == "__main__": - main() diff --git a/examples/ssl/algorithms/simclr/simclr.py b/examples/ssl/algorithms/simclr/simclr.py deleted file mode 100644 index f8bcb0e46..000000000 --- a/examples/ssl/algorithms/simclr/simclr.py +++ /dev/null @@ -1,37 +0,0 @@ -""" -The implementation for the SimCLR [1] method for personalized federated learning. - -[1]. Ting Chen, et al., A Simple Framework for Contrastive Learning of Visual Representations, -ICML 2020. https://arxiv.org/abs/2002.05709 - -The official code: https://github.com/google-research/simclr - -The structure of our SimCLR and the classifier is the same as the ones used in -the work https://github.com/spijkervet/SimCLR.git. - -""" -from plato.servers import fedavg_personalized as personalized_server -from ssl import ssl_datasources -from ssl import ssl_client -from ssl import ssl_trainer - -from simclr_model import SimCLRModel - - -def main(): - """ - A personalized federated learning session for SimCLR approach. - """ - trainer = ssl_trainer.Trainer - client = ssl_client.Client( - model=SimCLRModel, - datasource=ssl_datasources.SSLDataSource, - trainer=trainer, - ) - server = personalized_server.Server(model=SimCLRModel, trainer=trainer) - - server.run(client) - - -if __name__ == "__main__": - main() diff --git a/examples/ssl/algorithms/swav/swav.py b/examples/ssl/algorithms/swav/swav.py deleted file mode 100644 index 42334b702..000000000 --- a/examples/ssl/algorithms/swav/swav.py +++ /dev/null @@ -1,38 +0,0 @@ -""" -The implementation for the SwAV [1] method. - -Reference: -[1]. Mathilde Caron, et al., Unsupervised Learning of Visual Features by Contrasting Cluster Assignments. -https://arxiv.org/abs/2006.09882, NeurIPS 2020. - -Source code: https://github.com/facebookresearch/swav -""" - -from plato.servers import fedavg_personalized as personalized_server - -from ssl import ssl_client -from ssl import ssl_trainer -from ssl import ssl_datasources - - -import swav_model - - -def main(): - """ - A pFL session for SwaV approach. - """ - client = ssl_client.Client( - model=swav_model.SwaV, - datasource=ssl_datasources.SSLDataSource, - trainer=ssl_trainer.Trainer, - ) - server = personalized_server.Server( - model=swav_model.SwaV, trainer=ssl_trainer.Trainer - ) - - server.run(client) - - -if __name__ == "__main__": - main() diff --git a/examples/ssl/byol/byol.py b/examples/ssl/byol/byol.py new file mode 100644 index 000000000..7ada5e7e0 --- /dev/null +++ b/examples/ssl/byol/byol.py @@ -0,0 +1,34 @@ +""" +An implementation of the BYOL algorithm. + +Jean-Bastien Grill, et al., Bootstrap Your Own Latent A New Approach to Self-Supervised Learning. +https://arxiv.org/pdf/2006.07733.pdf. + +Source code: https://github.com/lucidrains/byol-pytorch or https://github.com/sthalles/PyTorch-BYOL. +""" + +from plato.servers import fedavg_personalized as personalized_server +from plato.clients import self_supervised_learning as ssl_client +from plato.datasources import self_supervised_learning as ssl_datasource + +from byol_model import BYOLModel +import byol_trainer + + +def main(): + """ + A self-supervised federated learning session with BYOL. + """ + trainer = byol_trainer.Trainer + client = ssl_client.Client( + model=BYOLModel, + datasource=ssl_datasource.SSLDataSource, + trainer=trainer, + ) + server = personalized_server.Server(model=BYOLModel, trainer=trainer) + + server.run(client) + + +if __name__ == "__main__": + main() diff --git a/examples/ssl/algorithms/byol/byol_model.py b/examples/ssl/byol/byol_model.py similarity index 67% rename from examples/ssl/algorithms/byol/byol_model.py rename to examples/ssl/byol/byol_model.py index 67dc71365..3858e7484 100644 --- a/examples/ssl/algorithms/byol/byol_model.py +++ b/examples/ssl/byol/byol_model.py @@ -30,39 +30,41 @@ def __init__(self, encoder=None): model_name=encoder_name, **encoder_params ) - self.encoding_dim = self.encoder.encoding_dim - - # A projector project higher dimension features to output dimensions. + # A projector projects higher dimension features to output dimensions self.projector = BYOLProjectionHead( - self.encoding_dim, + self.encoder.encoding_dim, Config().trainer.projection_hidden_dim, Config().trainer.projection_out_dim, ) + # A predictor predicts the output dimensions to the prediction dimensions self.predictor = BYOLPredictionHead( Config().trainer.projection_out_dim, Config().trainer.prediction_hidden_dim, Config().trainer.prediction_out_dim, ) + # The momentum encoder and projector, which are work in + # a momentum manner self.momentum_encoder = copy.deepcopy(self.encoder) self.momentum_projector = copy.deepcopy(self.projector) + # Deactivate the requires_grad flag for all parameters deactivate_requires_grad(self.momentum_encoder) deactivate_requires_grad(self.momentum_projector) - def forward_view(self, sample): - """Foward one view sample to get the output.""" - encoded_sample = self.encoder(sample).flatten(start_dim=1) - projected_sample = self.projector(encoded_sample) - output = self.predictor(projected_sample) + def forward_view(self, view_sample): + """Foward one view to get the output.""" + encoded_view = self.encoder(view_sample).flatten(start_dim=1) + projected_view = self.projector(encoded_view) + output = self.predictor(projected_view) return output - def forward_momentum(self, sample): - """Foward one view sample to get the output in a momentum manner.""" - encoded_example = self.momentum_encoder(sample).flatten(start_dim=1) - projected_example = self.momentum_projector(encoded_example) - projected_example = projected_example.detach() - return projected_example + def forward_momentum(self, view_sample): + """Foward one view to get the output in a momentum manner.""" + encoded_view = self.momentum_encoder(view_sample).flatten(start_dim=1) + projected_view = self.momentum_projector(encoded_view) + projected_view = projected_view.detach() + return projected_view def forward(self, multiview_samples): """Main forward function of the model.""" diff --git a/examples/ssl/algorithms/fedema/fedema_trainer.py b/examples/ssl/byol/byol_trainer.py similarity index 61% rename from examples/ssl/algorithms/fedema/fedema_trainer.py rename to examples/ssl/byol/byol_trainer.py index 3a4146103..d20fc73ae 100644 --- a/examples/ssl/algorithms/fedema/fedema_trainer.py +++ b/examples/ssl/byol/byol_trainer.py @@ -1,28 +1,34 @@ """ -Implementation of the trainer for FedEMA. +A self-supervised federated learning trainer with BYOL. """ - from lightly.utils.scheduler import cosine_schedule from lightly.models.utils import update_momentum -from ssl import ssl_trainer + +from plato.trainers import self_supervised_learning as ssl_trainer from plato.trainers import loss_criterion from plato.config import Config class Trainer(ssl_trainer.Trainer): - """A trainer for FedEMA.""" + """ + A trainer with BYOL, which generates the BYOL's loss and computes the + momentum value at the start of each epoch; thus the model will be updated + step-wise based on this value in a momentum manner. + """ def __init__(self, model=None, callbacks=None): super().__init__(model, callbacks) + # The momentum value used to update the model + # with Exponential Moving Average self.momentum_val = 0 def get_ssl_criterion(self): - """A wrapper to connect ssl loss with plato.""" + """Compute the loss proposed by BYOL.""" defined_ssl_loss = loss_criterion.get() - def compute_plato_loss(outputs, labels): + def compute_loss(outputs, labels): if isinstance(outputs, (list, tuple)): loss = 0.5 * ( defined_ssl_loss(*outputs[0]) + defined_ssl_loss(*outputs[1]) @@ -31,10 +37,12 @@ def compute_plato_loss(outputs, labels): else: return defined_ssl_loss(outputs) - return compute_plato_loss + return compute_loss def train_epoch_start(self, config): - """Operations before starting one epoch.""" + """ + At the start of one epoch, the momentum value should be computed. + """ super().train_epoch_start(config) epoch = self.current_epoch total_epochs = config["epochs"] * config["rounds"] @@ -44,12 +52,14 @@ def train_epoch_start(self, config): def train_step_start(self, config, batch=None): """ - At the start of every iteration, - update the models for generating momentum - with new momemtum parameter: momentum value. + At the start of every iteration, the model should be updated based on the + momentum value in a momentum manner. """ super().train_step_start(config) if not self.current_round > Config().trainer.rounds: + # Update the model based on the momentum value + # Specifically, it updates parameters of `encoder` with + # Exponential Moving Average of `encoder_momentum` update_momentum( self.model.encoder, self.model.momentum_encoder, m=self.momentum_val ) diff --git a/examples/ssl/algorithms/calibre/calibre.py b/examples/ssl/calibre/calibre.py similarity index 63% rename from examples/ssl/algorithms/calibre/calibre.py rename to examples/ssl/calibre/calibre.py index 515b3ab80..3f848e672 100644 --- a/examples/ssl/algorithms/calibre/calibre.py +++ b/examples/ssl/calibre/calibre.py @@ -1,9 +1,8 @@ """ -Implementation of our Calibre algorithm. +An implementation of the Calibre algorithm. """ -from ssl import ssl_datasources -from ssl import ssl_client - +from plato.clients import self_supervised_learning as ssl_client +from plato.datasources import self_supervised_learning as ssl_datasource import calibre_model import calibre_trainer @@ -13,11 +12,11 @@ def main(): """ - A personalized federated learning session for SimCLR approach. + A self-supervised federated learning session with Calibre. """ client = ssl_client.Client( model=calibre_model.CalibreNet, - datasource=ssl_datasources.SSLDataSource, + datasource=ssl_datasource.SSLDataSource, trainer=calibre_trainer.Trainer, callbacks=[ calibre_callback.CalibreCallback, diff --git a/examples/ssl/calibre/calibre_callback.py b/examples/ssl/calibre/calibre_callback.py new file mode 100644 index 000000000..dd22623d5 --- /dev/null +++ b/examples/ssl/calibre/calibre_callback.py @@ -0,0 +1,27 @@ +""" +Callback for adding the divergence rate to the payload. +""" + + +import calibre_processor + +from plato.callbacks.client import ClientCallback + + +class CalibreCallback(ClientCallback): + """ + A client callback that adds the divergence rate computed locally to the + payload sent to the server. + """ + + def on_outbound_ready(self, client, report, outbound_processor): + """ + Insert a AddDivergenceRateProcessor to the list of outbound processors. + """ + send_payload_processor = calibre_processor.AddDivergenceRateProcessor( + client_id=client.client_id, + trainer=client.trainer, + name="AddDivergenceRateProcessor", + ) + + outbound_processor.processors.insert(0, send_payload_processor) diff --git a/examples/ssl/algorithms/calibre/calibre_loss.py b/examples/ssl/calibre/calibre_loss.py similarity index 95% rename from examples/ssl/algorithms/calibre/calibre_loss.py rename to examples/ssl/calibre/calibre_loss.py index d9beba284..f08fa31bf 100644 --- a/examples/ssl/algorithms/calibre/calibre_loss.py +++ b/examples/ssl/calibre/calibre_loss.py @@ -22,7 +22,7 @@ from collections import OrderedDict import torch -import torch.nn as nn +from torch import nn from lightly import loss as lightly_loss from clustering import kmeans_clustering @@ -41,7 +41,7 @@ def __init__( main_loss: str, main_loss_params: dict, auxiliary_losses: List[str] = None, - auxiliary_losses_params: List[dict] = None, + auxiliary_loss_params: List[dict] = None, losses_weight: List[float] = None, device: str = "cpu", ): @@ -56,9 +56,9 @@ def __init__( # The auxiliary losses and the corresponding parameters if auxiliary_losses is None: auxiliary_losses = [] - if auxiliary_losses_params is None: - auxiliary_losses_params = [] - assert len(auxiliary_losses) == len(auxiliary_losses_params) + if auxiliary_loss_params is None: + auxiliary_loss_params = [] + assert len(auxiliary_losses) == len(auxiliary_loss_params) # The weights of these losses set in the config file losses_weight = losses_weight._asdict() @@ -77,7 +77,7 @@ def __init__( for loss in auxiliary_losses: if loss in losses_weight: self.loss_weights_params[loss] = { - "params": auxiliary_losses_params[loss]._asdict(), + "params": auxiliary_loss_params[loss]._asdict(), "weight": losses_weight[loss], } diff --git a/examples/ssl/algorithms/calibre/calibre_model.py b/examples/ssl/calibre/calibre_model.py similarity index 100% rename from examples/ssl/algorithms/calibre/calibre_model.py rename to examples/ssl/calibre/calibre_model.py diff --git a/examples/ssl/algorithms/calibre/calibre_callback.py b/examples/ssl/calibre/calibre_processor.py similarity index 50% rename from examples/ssl/algorithms/calibre/calibre_callback.py rename to examples/ssl/calibre/calibre_processor.py index 56f8f8392..a08eb8357 100644 --- a/examples/ssl/algorithms/calibre/calibre_callback.py +++ b/examples/ssl/calibre/calibre_processor.py @@ -1,5 +1,5 @@ """ -Callback for adding the divergence rate to the payload. +An outbound prossor for Calibre algorithm to save the divergence on the client locally. """ import os import logging @@ -7,7 +7,6 @@ import torch -from plato.callbacks.client import ClientCallback from plato.processors import base from plato.config import Config @@ -24,6 +23,7 @@ def __init__(self, client_id, trainer, **kwargs) -> None: self.trainer = trainer def process(self, data: OrderedDict): + """Process the payload by adding the computed divergence rate to the payload.""" model_path = Config().params["model_path"] filename = f"client_{self.client_id}_divergence_rate.pth" save_path = os.path.join(model_path, filename) @@ -36,28 +36,3 @@ def process(self, data: OrderedDict): "[Client #%d] Divergence Rate attached to payload.", self.client_id ) return data - - -class CalibreCallback(ClientCallback): - """ - A client callback that adds the divergence rate computed locally to the - payload sent to the server. - """ - - def on_outbound_ready(self, client, report, outbound_processor): - """ - Insert a AddDivergenceRateProcessor to the list of outbound processors. - """ - send_payload_processor = AddDivergenceRateProcessor( - client_id=client.client_id, - trainer=client.trainer, - name="AddDivergenceRateProcessor", - ) - - outbound_processor.processors.insert(0, send_payload_processor) - - logging.info( - "[%s] List of outbound processors: %s.", - client, - outbound_processor.processors, - ) diff --git a/examples/ssl/algorithms/calibre/calibre_server.py b/examples/ssl/calibre/calibre_server.py similarity index 93% rename from examples/ssl/algorithms/calibre/calibre_server.py rename to examples/ssl/calibre/calibre_server.py index 08e265f63..e575940ff 100644 --- a/examples/ssl/algorithms/calibre/calibre_server.py +++ b/examples/ssl/calibre/calibre_server.py @@ -1,5 +1,5 @@ """ -A base server for Calibre to perform divergence-aware global aggregation. +A self supervised learning server for Calibre to perform divergence-aware global aggregation. After each client clusters local samples based on their encodings, there will be local clusters where each cluster's divergence is computed as the normalized distance @@ -37,7 +37,7 @@ def __init__( self.divergence_rates_received = [] async def aggregate_deltas(self, updates, deltas_received): - """Add the divergence rates to deltas.""" + """Apply the divergence rate as the weight to deltas.""" total_divergence = torch.sum(self.divergence_rates_received) # Normalize the delta with the divergence rates for i, update in enumerate(deltas_received): diff --git a/examples/ssl/algorithms/calibre/calibre_trainer.py b/examples/ssl/calibre/calibre_trainer.py similarity index 70% rename from examples/ssl/algorithms/calibre/calibre_trainer.py rename to examples/ssl/calibre/calibre_trainer.py index 3391075b0..14cab1021 100644 --- a/examples/ssl/algorithms/calibre/calibre_trainer.py +++ b/examples/ssl/calibre/calibre_trainer.py @@ -1,5 +1,5 @@ """ -Implementation of the trainer for Calibre algorithm. +A self-supervised federated learning trainer with Calibre. """ import os @@ -7,20 +7,24 @@ import torch -from ssl import ssl_trainer +from plato.trainers import self_supervised_learning as ssl_trainer +from plato.config import Config from calibre_loss import CalibreLoss from clustering import kmeans_clustering -from plato.config import Config - class Trainer(ssl_trainer.Trainer): - """A trainer for the Calibre method.""" + """ + A trainer with Calibre, which computes Calibre's loss and computes the + divergence of clusters, showing the normalized distance between the points + and the centroid. + """ def get_ssl_criterion(self): - """A wrapper to connect ssl loss with plato.""" + """Get the loss of Calibre.""" + # Get the main loss criterion loss_criterion_name = ( Config().trainer.loss_criterion if hasattr(Config.trainer, "loss_criterion") @@ -32,17 +36,20 @@ def get_ssl_criterion(self): else {} ) + # Get the auxiliary losses which are regularizers in the + # objective funct auxiliary_losses = ( Config().algorithm.auxiliary_loss_criterions if hasattr(Config.algorithm, "auxiliary_loss_criterions") else [] ) - auxiliary_losses_params = ( + auxiliary_loss_params = ( Config().algorithm.auxiliary_loss_criterions_param._asdict() if hasattr(Config.algorithm, "auxiliary_loss_criterions_param") else {} ) + # Get the weight for these losses losses_weight = ( Config().algorithm.losses_weight if hasattr(Config.algorithm, "losses_weight") @@ -53,35 +60,38 @@ def get_ssl_criterion(self): main_loss=loss_criterion_name, main_loss_params=loss_criterion_params, auxiliary_losses=auxiliary_losses, - auxiliary_losses_params=auxiliary_losses_params, + auxiliary_loss_params=auxiliary_loss_params, losses_weight=losses_weight, device=self.device, ) - def compute_plato_loss(outputs, labels): + def compute_loss(outputs, labels): if isinstance(outputs, (list, tuple)): return defined_ssl_loss(*outputs, labels=labels) else: return defined_ssl_loss(outputs, labels=labels) - return compute_plato_loss + return compute_loss def compute_divergence_rate(self, encodings): - """Compute the divergence rate of the local model""" + """ + Compute the divergence rate, which is the normalized distance between the points + and the corresponding centroid. + """ cluster_ids_x, cluster_centers = kmeans_clustering(encodings, n_clusters=10) - clusters_id = torch.unique(cluster_ids_x, return_counts=False) - clusters_divergence = torch.zeros(size=(len(clusters_id),), device=self.device) - for cluster_id in clusters_id: + cluster_ids = torch.unique(cluster_ids_x, return_counts=False) + cluster_divergence = torch.zeros(size=(len(cluster_ids),), device=self.device) + for cluster_id in cluster_ids: cluster_center = cluster_centers[cluster_id] cluster_elems = encodings[cluster_ids_x == cluster_id] distance = torch.norm(cluster_elems - cluster_center, dim=1) divergence = torch.mean(distance) - clusters_divergence[cluster_id] = divergence + cluster_divergence[cluster_id] = divergence - return torch.mean(clusters_divergence) + return torch.mean(cluster_divergence) def get_optimizer(self, model): - """Getting the optimizer""" + """Get the optimizer""" optimizer = super().get_optimizer(model) if self.current_round > Config().trainer.rounds: # Add another self.model's parameters to the existing optimizer @@ -91,7 +101,11 @@ def get_optimizer(self, model): return optimizer def train_run_end(self, config): - """Get the features of local samples after training.""" + """ + Compute divergence rate based on the learned features of local samples + after training. The, the computed value will be saved to disk to be loaded + when the client sends it to the server. + """ super().train_run_end(config) personalized_train_loader = torch.utils.data.DataLoader( @@ -106,7 +120,7 @@ def train_run_end(self, config): sample_encodings = None with torch.no_grad(): - for _, (examples, _) in enumerate(personalized_train_loader): + for examples, _ in personalized_train_loader: examples = examples.to(self.device) features = self.model.encoder(examples) diff --git a/examples/ssl/calibre/clustering.py b/examples/ssl/calibre/clustering.py new file mode 100644 index 000000000..b362745f1 --- /dev/null +++ b/examples/ssl/calibre/clustering.py @@ -0,0 +1,24 @@ +""" +Clustering based on encodings. +""" + +import torch + +from sklearn.cluster import KMeans + + +def kmeans_clustering(features, n_clusters, max_iter=200): + """Cluster features using the K-means algorithm.""" + device = features.device + + features = features.detach().cpu().numpy() + kmeans = KMeans(n_init="auto", n_clusters=n_clusters, max_iter=max_iter).fit( + features + ) + cluster_ids = torch.from_numpy(kmeans.labels_).int() + centroids = torch.from_numpy(kmeans.cluster_centers_).float() + centroids = torch.nn.functional.normalize(centroids, dim=1) + + cluster_ids = cluster_ids.to(device) + centroids = centroids.to(device) + return cluster_ids, centroids diff --git a/examples/ssl/algorithms/calibre/prototype_loss.py b/examples/ssl/calibre/prototype_loss.py similarity index 100% rename from examples/ssl/algorithms/calibre/prototype_loss.py rename to examples/ssl/calibre/prototype_loss.py diff --git a/examples/ssl/configs/byol_CIFAR10_resnet18.yml b/examples/ssl/configs/byol_CIFAR10_resnet18.yml index afc1cd1e7..34573221f 100644 --- a/examples/ssl/configs/byol_CIFAR10_resnet18.yml +++ b/examples/ssl/configs/byol_CIFAR10_resnet18.yml @@ -1,5 +1,4 @@ clients: - # The total number of clients total_clients: 10 @@ -25,6 +24,8 @@ data: !include cifar10_ssl_noniid.yml trainer: + type: self_supervised_learning + # The maximum number of training rounds rounds: 2 diff --git a/examples/ssl/configs/calibre_CIFAR10_resnet18.yml b/examples/ssl/configs/calibre_CIFAR10_resnet18.yml index f3f7f7a10..895a336d2 100644 --- a/examples/ssl/configs/calibre_CIFAR10_resnet18.yml +++ b/examples/ssl/configs/calibre_CIFAR10_resnet18.yml @@ -25,6 +25,8 @@ data: !include cifar10_ssl_noniid.yml trainer: + type: self_supervised_learning + # LeNet5 model with the basic trainer # The maximum number of training rounds rounds: 10 diff --git a/examples/ssl/configs/fedema_CIFAR10_resnet18.yml b/examples/ssl/configs/fedema_CIFAR10_resnet18.yml index a78e7dbe8..cb68b418e 100644 --- a/examples/ssl/configs/fedema_CIFAR10_resnet18.yml +++ b/examples/ssl/configs/fedema_CIFAR10_resnet18.yml @@ -25,6 +25,8 @@ data: !include cifar10_ssl_noniid.yml trainer: + type: self_supervised_learning + # The maximum number of training rounds rounds: 2 diff --git a/examples/ssl/configs/mocov2_CIFAR10_resnet18.yml b/examples/ssl/configs/mocov2_CIFAR10_resnet18.yml index dfcfc5634..d299f4831 100644 --- a/examples/ssl/configs/mocov2_CIFAR10_resnet18.yml +++ b/examples/ssl/configs/mocov2_CIFAR10_resnet18.yml @@ -25,6 +25,8 @@ data: !include cifar10_ssl_noniid.yml trainer: + type: self_supervised_learning + # The maximum number of training rounds rounds: 2 diff --git a/examples/ssl/configs/simclr_CIFAR10_resnet18.yml b/examples/ssl/configs/simclr_CIFAR10_resnet18.yml index 9b04d2a2f..4e6786f6f 100644 --- a/examples/ssl/configs/simclr_CIFAR10_resnet18.yml +++ b/examples/ssl/configs/simclr_CIFAR10_resnet18.yml @@ -25,6 +25,8 @@ data: !include cifar10_ssl_noniid.yml trainer: + type: self_supervised_learning + # The maximum number of training rounds rounds: 2 diff --git a/examples/ssl/configs/simsiam_CIFAR10_resnet18.yml b/examples/ssl/configs/simsiam_CIFAR10_resnet18.yml index 20efdc91d..aba25f732 100644 --- a/examples/ssl/configs/simsiam_CIFAR10_resnet18.yml +++ b/examples/ssl/configs/simsiam_CIFAR10_resnet18.yml @@ -25,6 +25,8 @@ data: !include cifar10_ssl_noniid.yml trainer: + type: self_supervised_learning + # The maximum number of training rounds rounds: 2 diff --git a/examples/ssl/configs/smog_CIFAR10_resnet18.yml b/examples/ssl/configs/smog_CIFAR10_resnet18.yml index ca2a022d3..d9527b989 100644 --- a/examples/ssl/configs/smog_CIFAR10_resnet18.yml +++ b/examples/ssl/configs/smog_CIFAR10_resnet18.yml @@ -25,6 +25,8 @@ data: !include cifar10_ssl_noniid.yml trainer: + type: self_supervised_learning + # The maximum number of training rounds rounds: 2 diff --git a/examples/ssl/configs/swav_CIFAR10_resnet18.yml b/examples/ssl/configs/swav_CIFAR10_resnet18.yml index c2dcba716..ad7250269 100644 --- a/examples/ssl/configs/swav_CIFAR10_resnet18.yml +++ b/examples/ssl/configs/swav_CIFAR10_resnet18.yml @@ -25,6 +25,8 @@ data: !include cifar10_ssl_noniid.yml trainer: + type: self_supervised_learning + # The maximum number of training rounds rounds: 2 diff --git a/examples/ssl/description.txt b/examples/ssl/description.txt deleted file mode 100644 index 028eed4af..000000000 --- a/examples/ssl/description.txt +++ /dev/null @@ -1 +0,0 @@ -ssl: A library of convenient classes for personalized federated learning algorithms based on self-supervised learning (SSL) \ No newline at end of file diff --git a/examples/ssl/algorithms/fedema/fedema.py b/examples/ssl/fedema/fedema.py similarity index 66% rename from examples/ssl/algorithms/fedema/fedema.py rename to examples/ssl/fedema/fedema.py index d7042bf2f..edbb92262 100644 --- a/examples/ssl/algorithms/fedema/fedema.py +++ b/examples/ssl/fedema/fedema.py @@ -1,12 +1,12 @@ """ -The implementation for the FedEMA proposed by the work [1]. +An implementation of the FedEMA algorithm. Zhuang, et.al, "Divergence-aware Federated Self-Supervised Learning", ICLR22. https://arxiv.org/pdf/2204.04385.pdf. """ -from ssl import ssl_datasources -from ssl import ssl_client +from plato.clients import self_supervised_learning as ssl_client +from plato.datasources import self_supervised_learning as ssl_datasource import fedema_server import fedema_trainer @@ -15,10 +15,12 @@ def main(): - """Running the FedEMA approach.""" + """ + A self-supervised federated learning session with FedEMA. + """ client = ssl_client.Client( model=fedema_model.BYOLModel, - datasource=ssl_datasources.SSLDataSource, + datasource=ssl_datasource.SSLDataSource, trainer=fedema_trainer.Trainer, callbacks=[ fedema_callback.FedEMACallback, diff --git a/examples/ssl/fedema/fedema_callback.py b/examples/ssl/fedema/fedema_callback.py new file mode 100644 index 000000000..5f601b78c --- /dev/null +++ b/examples/ssl/fedema/fedema_callback.py @@ -0,0 +1,24 @@ +""" +The ClientCallback used by FedEMA to add on the FedEMA processor. +""" + +import fedema_processor + +from plato.callbacks.client import ClientCallback + + +class FedEMACallback(ClientCallback): + """ + A client callback that dynamically compute the divergence between the received model + and the local model. + """ + + def on_inbound_received(self, client, inbound_processor): + """ + Insert an GlobalLocalDivergenceProcessor to the list of inbound processors. + """ + extract_payload_processor = fedema_processor.GlobalLocalDivergenceProcessor( + trainer=client.trainer, + name="GlobalLocalDivergenceProcessor", + ) + inbound_processor.processors.insert(0, extract_payload_processor) diff --git a/examples/ssl/algorithms/fedema/fedema_model.py b/examples/ssl/fedema/fedema_model.py similarity index 54% rename from examples/ssl/algorithms/fedema/fedema_model.py rename to examples/ssl/fedema/fedema_model.py index 1d833f307..e9a23ceb3 100644 --- a/examples/ssl/algorithms/fedema/fedema_model.py +++ b/examples/ssl/fedema/fedema_model.py @@ -13,7 +13,7 @@ class BYOLModel(nn.Module): - """The model structure of BYOL.""" + """The model structure of FedEMA.""" def __init__(self, encoder=None): super().__init__() @@ -31,45 +31,48 @@ def __init__(self, encoder=None): model_name=encoder_name, **encoder_params ) - self.encoding_dim = self.encoder.encoding_dim - - # A projector project higher dimension features to output dimensions. + # A projector projects higher dimension features to + # output dimensions self.projector = BYOLProjectionHead( - self.encoding_dim, + self.encoder.encoding_dim, Config().trainer.projection_hidden_dim, Config().trainer.projection_out_dim, ) + # A predictor predicts the output from the projected features self.predictor = BYOLPredictionHead( Config().trainer.projection_out_dim, Config().trainer.prediction_hidden_dim, Config().trainer.prediction_out_dim, ) + # The momentum encoder and projector, which are work in + # a momentum manner self.momentum_encoder = copy.deepcopy(self.encoder) self.momentum_projector = copy.deepcopy(self.projector) + # Deactivate the requires_grad flag for all parameters deactivate_requires_grad(self.momentum_encoder) deactivate_requires_grad(self.momentum_projector) - def forward_direct(self, sample): - """Foward one sample to get the output.""" - encoded_sample = self.encoder(sample).flatten(start_dim=1) - projected_sample = self.projector(encoded_sample) - output = self.predictor(projected_sample) + def forward_view(self, view_sample): + """Foward one view to get the output.""" + encoded_view = self.encoder(view_sample).flatten(start_dim=1) + projected_view = self.projector(encoded_view) + output = self.predictor(projected_view) return output - def forward_momentum(self, sample): - """Foward one sample to get the output in a momentum manner.""" - encoded_example = self.momentum_encoder(sample).flatten(start_dim=1) - projected_example = self.momentum_projector(encoded_example) - projected_example = projected_example.detach() - return projected_example + def forward_momentum(self, view_sample): + """Foward one view to get the output in a momentum manner.""" + encoded_view = self.momentum_encoder(view_sample).flatten(start_dim=1) + projected_view = self.momentum_projector(encoded_view) + projected_view = projected_view.detach() + return projected_view def forward(self, multiview_samples): """Main forward function of the model.""" - sample1, sample2 = multiview_samples - output1 = self.forward_direct(sample1) - projected_sample1 = self.forward_momentum(sample1) - output2 = self.forward_direct(sample2) - projected_sample2 = self.forward_momentum(sample2) - return (output1, projected_sample2), (output2, projected_sample1) + view_sample1, view_sample2 = multiview_samples + output1 = self.forward_view(view_sample1) + momentum1 = self.forward_momentum(view_sample1) + output2 = self.forward_view(view_sample2) + momentum2 = self.forward_momentum(view_sample2) + return (output1, momentum2), (output2, momentum1) diff --git a/examples/ssl/algorithms/fedema/fedema_callback.py b/examples/ssl/fedema/fedema_processor.py similarity index 70% rename from examples/ssl/algorithms/fedema/fedema_callback.py rename to examples/ssl/fedema/fedema_processor.py index e132cad78..a657ef9c4 100644 --- a/examples/ssl/algorithms/fedema/fedema_callback.py +++ b/examples/ssl/fedema/fedema_processor.py @@ -1,13 +1,14 @@ """ -Customized processor for FedEMA. +An inbound processor for FedEMA to calculate the divergence between received payload +and local saved model weights. And then add on such divergence to the payload. """ + import logging from typing import Any import utils from plato.config import Config -from plato.callbacks.client import ClientCallback from plato.processors import base @@ -18,8 +19,8 @@ class GlobalLocalDivergenceProcessor(base.Processor): """ def process(self, data: Any) -> Any: - """Processing the received payload by assigning layers of local model of - each client.""" + """Process the received payload by updating the layers using + the model divergence.""" divergence_scale = data[1] @@ -66,20 +67,3 @@ def process(self, data: Any) -> Any: ) return data[0] - - -class FedEMACallback(ClientCallback): - """ - A client callback that dynamically compute the divergence between the received model - and the local model. - """ - - def on_inbound_received(self, client, inbound_processor): - """ - Insert an GlobalLocalDivergenceProcessor to the list of inbound processors. - """ - extract_payload_processor = GlobalLocalDivergenceProcessor( - trainer=client.trainer, - name="GlobalLocalDivergenceProcessor", - ) - inbound_processor.processors.insert(0, extract_payload_processor) diff --git a/examples/ssl/algorithms/fedema/fedema_server.py b/examples/ssl/fedema/fedema_server.py similarity index 89% rename from examples/ssl/algorithms/fedema/fedema_server.py rename to examples/ssl/fedema/fedema_server.py index b359b4d07..e055eed20 100644 --- a/examples/ssl/algorithms/fedema/fedema_server.py +++ b/examples/ssl/fedema/fedema_server.py @@ -1,8 +1,5 @@ """ -Implementation of the server for the FedEMA . - -Note: - Divergence is abbreviated as divg +Implementation of the server for the FedEMA. """ import os import logging @@ -14,7 +11,7 @@ class Server(personalized_server.Server): - """A personalized federated learning server using the pFL-CMA's EMA method.""" + """A server for FedEMA method to compute the model divergence.""" def __init__( self, model=None, datasource=None, algorithm=None, trainer=None, callbacks=None @@ -27,7 +24,7 @@ def __init__( callbacks=callbacks, ) - # The lambda used in the paper + # Set the lambda used in the paper self.clients_divg_scale = { client_id: 0.0 for client_id in range(1, self.total_clients + 1) } @@ -60,7 +57,7 @@ def weights_aggregated(self, updates): # To compute the divergence scale adaptively # and within the computing rounds - if self.adaptive_divg_scale and not (self.current_round > divg_before_round): + if self.adaptive_divg_scale and self.current_round <= divg_before_round: clients_id = [update.report.client_id for update in updates] # Compute the divergence scale based on the distance between @@ -85,7 +82,8 @@ def weights_aggregated(self, updates): encoder_layer_names=encoder_layer_names, ) - # the global L2 norm over a list of tensors. + # Compute L2 norm between the aggregated encoder + # and client encoder l2_distance = utils.get_parameters_diff( parameter_a=aggregated_encoder, parameter_b=client_encoder, @@ -97,6 +95,7 @@ def weights_aggregated(self, updates): tau = Config().algorithm.divergence_scale_tau client_divg_scale = tau / l2_distance + # Assign the divergence scale to the client self.clients_divg_scale[client_id] = client_divg_scale def customize_server_payload(self, payload): diff --git a/examples/ssl/algorithms/byol/byol_trainer.py b/examples/ssl/fedema/fedema_trainer.py similarity index 64% rename from examples/ssl/algorithms/byol/byol_trainer.py rename to examples/ssl/fedema/fedema_trainer.py index c372b71aa..0c529256c 100644 --- a/examples/ssl/algorithms/byol/byol_trainer.py +++ b/examples/ssl/fedema/fedema_trainer.py @@ -1,28 +1,34 @@ """ -A trainer for BYOL to rewrite the loss wrappe. +A self-supervised federated learning trainer with FedEMA. """ + from lightly.utils.scheduler import cosine_schedule from lightly.models.utils import update_momentum - -from ssl import ssl_trainer +from plato.trainers import self_supervised_learning as ssl_trainer from plato.trainers import loss_criterion from plato.config import Config class Trainer(ssl_trainer.Trainer): - """A trainer for BYOL to rewrite the loss wrappe.""" + """ + A trainer with FedEMA, which computes FedEMA's loss and computes the + momentum value at the start of each epoch; thus the model will be updated + step-wise based on this value in a momentum manner. + """ def __init__(self, model=None, callbacks=None): super().__init__(model, callbacks) + # The momentum value used to update the model + # with Exponential Moving Average self.momentum_val = 0 def get_ssl_criterion(self): """A wrapper to connect ssl loss with plato.""" defined_ssl_loss = loss_criterion.get() - def compute_plato_loss(outputs, labels): + def compute_loss(outputs, labels): if isinstance(outputs, (list, tuple)): loss = 0.5 * ( defined_ssl_loss(*outputs[0]) + defined_ssl_loss(*outputs[1]) @@ -31,12 +37,11 @@ def compute_plato_loss(outputs, labels): else: return defined_ssl_loss(outputs) - return compute_plato_loss + return compute_loss def train_epoch_start(self, config): """ - Before the start of one epoch, - prepare the momentum value for updating momentum outputs. + At the start of one epoch, the momentum value should be computed. """ super().train_epoch_start(config) epoch = self.current_epoch @@ -47,12 +52,14 @@ def train_epoch_start(self, config): def train_step_start(self, config, batch=None): """ - At the start of every iteration, - update the models for generating momentum - with new momemtum parameter: momentum value. + At the start of every iteration, the model should be updated based on the + momentum value in a momentum manner. """ super().train_step_start(config) if not self.current_round > Config().trainer.rounds: + # Update the model based on the momentum value + # Specifically, it updates parameters of `encoder` with + # Exponential Moving Average of `encoder_momentum` update_momentum( self.model.encoder, self.model.momentum_encoder, m=self.momentum_val ) diff --git a/examples/ssl/algorithms/fedema/utils.py b/examples/ssl/fedema/utils.py similarity index 95% rename from examples/ssl/algorithms/fedema/utils.py rename to examples/ssl/fedema/utils.py index 6f7bfcb96..a5fc703b7 100644 --- a/examples/ssl/algorithms/fedema/utils.py +++ b/examples/ssl/fedema/utils.py @@ -1,5 +1,5 @@ """ -Tools used in algorithm FedEMA +Tools used by the FedEMA algorithm. """ from collections import OrderedDict import torch @@ -20,7 +20,7 @@ def extract_encoder(model_layers, encoder_layer_names): def get_parameters_diff(parameter_a: OrderedDict, parameter_b: OrderedDict): - """Get the difference between two sets of parameters""" + """Get the difference between two sets of parameters.""" # Compute the divergence between encoders of local and global models l2_distance = 0.0 for paraml, paramg in zip(parameter_a.items(), parameter_b.items()): diff --git a/examples/ssl/moco/mocov2.py b/examples/ssl/moco/mocov2.py new file mode 100644 index 000000000..5f0917c0b --- /dev/null +++ b/examples/ssl/moco/mocov2.py @@ -0,0 +1,39 @@ +""" +An implementation of the MoCoV2 algorithm. + +K. He, et al., "Momentum Contrast for Unsupervised Visual Representation Learning," +CVPR 2020. https://arxiv.org/abs/1911.05722. + +X. Chen, et al., "Improved Baselines with Momentum Contrastive Learning," ArXiv, 2020. +https://arxiv.org/abs/2003.04297. + +Source code: https://github.com/facebookresearch/moco + +""" +from plato.servers import fedavg_personalized as personalized_server +from plato.clients import self_supervised_learning as ssl_client +from plato.datasources import self_supervised_learning as ssl_datasource + +import mocov2_model +import mocov2_trainer + + +def main(): + """ + A self-supervised federated learning session with MoCoV2. + """ + client = ssl_client.Client( + model=mocov2_model.MoCoV2, + datasource=ssl_datasource.SSLDataSource, + trainer=mocov2_trainer.Trainer, + ) + server = personalized_server.Server( + model=mocov2_model.MoCoV2, + trainer=mocov2_trainer.Trainer, + ) + + server.run(client) + + +if __name__ == "__main__": + main() diff --git a/examples/ssl/algorithms/moco/mocov2_model.py b/examples/ssl/moco/mocov2_model.py similarity index 85% rename from examples/ssl/algorithms/moco/mocov2_model.py rename to examples/ssl/moco/mocov2_model.py index af6118cb6..18f114c0c 100644 --- a/examples/ssl/algorithms/moco/mocov2_model.py +++ b/examples/ssl/moco/mocov2_model.py @@ -1,6 +1,5 @@ """ A model for the MoCoV2 method. - """ import copy from torch import nn @@ -12,6 +11,8 @@ class MoCoV2(nn.Module): + """A model structure for the MoCoV2 method.""" + def __init__(self, encoder=None): super().__init__() @@ -19,7 +20,7 @@ def __init__(self, encoder=None): encoder_params = ( Config().params.encoder if hasattr(Config().params, "encoder") else {} ) - # Define the encoder. + # Define the encoder if encoder is not None: self.encoder = encoder else: @@ -27,7 +28,7 @@ def __init__(self, encoder=None): model_name=encoder_name, **encoder_params ) - # Define heads. + # Define the projector self.projector = MoCoProjectionHead( self.encoder.encoding_dim, Config().trainer.projection_hidden_dim, @@ -37,21 +38,18 @@ def __init__(self, encoder=None): self.encoder_momentum = copy.deepcopy(self.encoder) self.projector_momentum = copy.deepcopy(self.projector) + # Deactivate the requires_grad flag for all parameters deactivate_requires_grad(self.encoder_momentum) deactivate_requires_grad(self.projector_momentum) def forward_view(self, view_sample): - """ - Foward one view sample to get the output. - """ + """Foward one view sample to get the output.""" query = self.encoder(view_sample).flatten(start_dim=1) query = self.projector(query) return query def forward_momentum(self, view_sample): - """ - Foward one view sample to get the output in a momentum manner. - """ + """Foward one view sample to get the output in a momentum manner.""" key = self.encoder_momentum(view_sample).flatten(start_dim=1) key = self.projector_momentum(key).detach() return key diff --git a/examples/ssl/algorithms/moco/mocov2_trainer.py b/examples/ssl/moco/mocov2_trainer.py similarity index 56% rename from examples/ssl/algorithms/moco/mocov2_trainer.py rename to examples/ssl/moco/mocov2_trainer.py index c16640980..62ffa7993 100644 --- a/examples/ssl/algorithms/moco/mocov2_trainer.py +++ b/examples/ssl/moco/mocov2_trainer.py @@ -1,40 +1,48 @@ """ -A base trainer for MoCoV2 algorithm. +A self-supervised federated learning trainer with MoCoV2. """ from lightly.models.utils import update_momentum from lightly.utils.scheduler import cosine_schedule -from ssl import ssl_trainer +from plato.trainers import self_supervised_learning as ssl_trainer from plato.config import Config class Trainer(ssl_trainer.Trainer): - """A personalized federated learning trainer with self-supervised learning.""" + """ + A trainer with MoCoV2, which updates the momentum value at the start + of each training epoch and updates the model based on this value in a + momentum manner in each training step. + """ def __init__(self, model=None, callbacks=None): super().__init__(model, callbacks) + # The momentum value used to update the model + # with Exponential Moving Average self.momentum_val = 0 def train_epoch_start(self, config): - """Operations before starting one epoch.""" + """Update the momentum value.""" super().train_epoch_start(config) - epoch = self.current_epoch total_epochs = config["epochs"] * config["rounds"] - global_epoch = (self.current_round - 1) * config["epochs"] + epoch + global_epoch = (self.current_round - 1) * config["epochs"] + self.current_epoch + # Compute the momentum value during the regular federated training process if not self.current_round > Config().trainer.rounds: self.momentum_val = cosine_schedule(global_epoch, total_epochs, 0.996, 1) def train_step_start(self, config, batch=None): """ - At the start of every iteration, - update the models for generating momentum - with new momemtum parameter: momentum value. + At the start of every iteration, the model should be updated based on the + momentum value. """ super().train_step_start(config) if not self.current_round > Config().trainer.rounds: + # Update the model based on the momentum value + # Specifically, it updates parameters of `encoder` with + # Exponential Moving Average of `encoder_momentum` update_momentum( self.model.encoder, self.model.encoder_momentum, m=self.momentum_val ) diff --git a/examples/ssl/requirements.txt b/examples/ssl/requirements.txt deleted file mode 100644 index 13f6a92cd..000000000 --- a/examples/ssl/requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ -numpy -matplotlib -scipy -scikit-learn -lightly \ No newline at end of file diff --git a/examples/ssl/setup.py b/examples/ssl/setup.py deleted file mode 100644 index 077b283aa..000000000 --- a/examples/ssl/setup.py +++ /dev/null @@ -1,58 +0,0 @@ -import io -import os -import re - -import setuptools - - -def get_long_description(): - base_dir = os.path.abspath(os.path.dirname(__file__)) - with io.open(os.path.join(base_dir, "description.txt"), encoding="utf-8") as f: - return f.read() - - -def get_requirements(): - with open("requirements.txt") as f: - return f.read().splitlines() - - -def get_version(): - current_dir = os.path.abspath(os.path.dirname(__file__)) - version_file = os.path.join(current_dir, "ssl", "__init__.py") - with io.open(version_file, encoding="utf-8") as f: - return re.search(r'^__version__ = [\'"]([^\'"]*)[\'"]', f.read(), re.M).group(1) - - -setuptools.setup( - name="pfl-bases", - version=get_version(), - author="", - license="Apache-2.0", - description="Packaged version of the Plato-related framework for personalized federated learning research", - long_description=get_long_description(), - long_description_content_type="text/markdown", - url="https://github.com/TL-System/plato/tree/main/examples/pfl", - packages=setuptools.find_packages(exclude=["tests"]), - python_requires=">=3.7", - install_requires=get_requirements(), - extras_require={"tests": ["pytest"]}, - include_package_data=True, - options={"bdist_wheel": {"python_tag": "py36.py37.py38.py39"}}, - classifiers=[ - "Development Status :: 3 - Alpha", - "License :: OSI Approved :: Apache Software License", - "Operating System :: OS Independent", - "Intended Audience :: Developers", - "Intended Audience :: Science/Research", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Topic :: Software Development :: Libraries", - "Topic :: Software Development :: Libraries :: Python layers", - "Topic :: Education", - "Topic :: Scientific/Engineering", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - ], - keywords="fit, plot, experiments, visualization", -) diff --git a/examples/ssl/simclr/simclr.py b/examples/ssl/simclr/simclr.py new file mode 100644 index 000000000..2e7b5db7f --- /dev/null +++ b/examples/ssl/simclr/simclr.py @@ -0,0 +1,31 @@ +""" +An implementation of the SimCLR algorithm. + +T. Chen, et al., "A Simple Framework for Contrastive Learning of Visual Representations," ICML 2020. + +https://arxiv.org/abs/2002.05709 + +Source code: https://github.com/google-research/simclr or https://github.com/spijkervet/SimCLR.git. + +""" +from plato.servers import fedavg_personalized as personalized_server +from plato.clients import self_supervised_learning as ssl_client +from plato.datasources import self_supervised_learning as ssl_datasource + +from simclr_model import SimCLRModel + + +def main(): + """ + A self-supervised federated learning session with SimCLR. + """ + client = ssl_client.Client( + model=SimCLRModel, datasource=ssl_datasource.SSLDataSource + ) + server = personalized_server.Server(model=SimCLRModel) + + server.run(client) + + +if __name__ == "__main__": + main() diff --git a/examples/ssl/algorithms/simclr/simclr_model.py b/examples/ssl/simclr/simclr_model.py similarity index 93% rename from examples/ssl/algorithms/simclr/simclr_model.py rename to examples/ssl/simclr/simclr_model.py index 0872df2eb..bd9a585ff 100644 --- a/examples/ssl/algorithms/simclr/simclr_model.py +++ b/examples/ssl/simclr/simclr_model.py @@ -1,5 +1,5 @@ """ -Model speicifcally used in SimCLR algorithm. +Model for the SimCLR algorithm. """ from torch import nn from lightly.models.modules.heads import SimCLRProjectionHead @@ -34,7 +34,7 @@ def __init__(self, encoder=None): ) def forward(self, multiview_samples): - """Forward two batch of contrastive samples.""" + """Forward the contrastive samples.""" view_sample1, view_sample2 = multiview_samples encoded_sample1 = self.encoder(view_sample1) encoded_sample2 = self.encoder(view_sample2) diff --git a/examples/ssl/algorithms/simsiam/simsiam.py b/examples/ssl/simsiam/simsiam.py similarity index 51% rename from examples/ssl/algorithms/simsiam/simsiam.py rename to examples/ssl/simsiam/simsiam.py index 153225a1b..1887a534c 100644 --- a/examples/ssl/algorithms/simsiam/simsiam.py +++ b/examples/ssl/simsiam/simsiam.py @@ -1,18 +1,15 @@ """ -The implementation for the SimSiam [1] method. +An implementation of the SimSiam algorithm. -[1]. Xinlei Chen, et al., Exploring Simple Siamese Representation Learning. +X. Chen, et al., "Exploring Simple Siamese Representation Learning." https://arxiv.org/pdf/2011.10566.pdf -Reference: -Source code: https://github.com/facebookresearch/simsiam -Third-party code: https://github.com/PatrickHua/SimSiam +Source code: https://github.com/facebookresearch/simsiam or https://github.com/PatrickHua/SimSiam """ from plato.servers import fedavg_personalized as personalized_server - -from ssl import ssl_client -from ssl import ssl_datasources +from plato.clients import self_supervised_learning as ssl_client +from plato.datasources import self_supervised_learning as ssl_datasource import simsiam_trainer import simsiam_model @@ -20,11 +17,11 @@ def main(): """ - A personalized federated learning session for SimSiam approach. + A self-supervised federated learning session with SimSiam. """ client = ssl_client.Client( model=simsiam_model.SimSiam, - datasource=ssl_datasources.SSLDataSource, + datasource=ssl_datasource.SSLDataSource, trainer=simsiam_trainer.Trainer, ) server = personalized_server.Server( diff --git a/examples/ssl/algorithms/simsiam/simsiam_model.py b/examples/ssl/simsiam/simsiam_model.py similarity index 68% rename from examples/ssl/algorithms/simsiam/simsiam_model.py rename to examples/ssl/simsiam/simsiam_model.py index c9f7ccbae..f34507613 100644 --- a/examples/ssl/algorithms/simsiam/simsiam_model.py +++ b/examples/ssl/simsiam/simsiam_model.py @@ -27,29 +27,30 @@ def __init__(self, encoder=None): self.encoder = encoder_registry.get( model_name=encoder_name, **encoder_params ) - + # A projector projects higher dimension encodings to projections self.projector = SimSiamProjectionHead( self.encoder.encoding_dim, Config().trainer.projection_hidden_dim, Config().trainer.projection_out_dim, ) + # A predictor predicts the output from the projected features self.predictor = SimSiamPredictionHead( Config().trainer.projection_out_dim, Config().trainer.prediction_hidden_dim, Config().trainer.prediction_out_dim, ) - def forward_view(self, sample): + def forward_view(self, view_sample): """Foward one view sample to get the output.""" - encoded_sample = self.encoder(sample).flatten(start_dim=1) - projected_sample = self.projector(encoded_sample) - output = self.predictor(projected_sample) - projected_sample = projected_sample.detach() - return projected_sample, output + encoded_view = self.encoder(view_sample).flatten(start_dim=1) + projected_view = self.projector(encoded_view) + output = self.predictor(projected_view) + projected_view = projected_view.detach() + return projected_view, output def forward(self, multiview_samples): """Main forward function of the model.""" view_sample1, view_sample2 = multiview_samples - projected_sample1, output1 = self.forward_view(view_sample1) - projected_sample2, output2 = self.forward_view(view_sample2) - return (projected_sample1, output2), (projected_sample2, output1) + projected_view1, output1 = self.forward_view(view_sample1) + projected_view2, output2 = self.forward_view(view_sample2) + return (projected_view1, output2), (projected_view2, output1) diff --git a/examples/ssl/algorithms/simsiam/simsiam_trainer.py b/examples/ssl/simsiam/simsiam_trainer.py similarity index 59% rename from examples/ssl/algorithms/simsiam/simsiam_trainer.py rename to examples/ssl/simsiam/simsiam_trainer.py index 7ec0f242d..70e2b4233 100644 --- a/examples/ssl/algorithms/simsiam/simsiam_trainer.py +++ b/examples/ssl/simsiam/simsiam_trainer.py @@ -1,20 +1,19 @@ """ -A base trainer for simsiam algorithm. +A self-supervised federated learning trainer with SimSiam. """ from plato.trainers import loss_criterion - -from ssl import ssl_trainer +from plato.trainers import self_supervised_learning as ssl_trainer class Trainer(ssl_trainer.Trainer): - """A trainer for SimSiam to rewrite the loss wrapper.""" + """A trainer with SimSiam to compute the loss.""" def get_ssl_criterion(self): - """A wrapper to connect ssl loss with plato.""" + """Get the loss proposed by the SimSiam.""" defined_ssl_loss = loss_criterion.get() - def compute_plato_loss(outputs, labels): + def compute_loss(outputs, labels): if isinstance(outputs, (list, tuple)): loss = 0.5 * ( defined_ssl_loss(*outputs[0]) + defined_ssl_loss(*outputs[1]) @@ -23,4 +22,4 @@ def compute_plato_loss(outputs, labels): else: return defined_ssl_loss(outputs) - return compute_plato_loss + return compute_loss diff --git a/examples/ssl/algorithms/smog/smog.py b/examples/ssl/smog/smog.py similarity index 50% rename from examples/ssl/algorithms/smog/smog.py rename to examples/ssl/smog/smog.py index cbf98c48a..ed2fc29a7 100644 --- a/examples/ssl/algorithms/smog/smog.py +++ b/examples/ssl/smog/smog.py @@ -1,16 +1,14 @@ """ -The implementation for the SMoG [1] method. +An implementation of the SMoG algorithm. -[1]. Bo Pang, et al., Unsupervised Visual Representation Learning by Synchronous Momentum Grouping. -ECCV, 2022. https://arxiv.org/pdf/2006.07733.pdf. -""" +B. Pang, et al., "Unsupervised Visual Representation Learning by Synchronous Momentum Grouping," ECCV, 2022. +https://arxiv.org/pdf/2006.07733.pdf +""" from plato.servers import fedavg_personalized as personalized_server - -from ssl import ssl_client -from ssl import ssl_datasources - +from plato.clients import self_supervised_learning as ssl_client +from plato.datasources import self_supervised_learning as ssl_datasource import smog_trainer import smog_model @@ -18,11 +16,11 @@ def main(): """ - A personalized federated learning session for SMoG approach. + A self-supervised federated learning session with SMoG. """ client = ssl_client.Client( model=smog_model.SMoG, - datasource=ssl_datasources.SSLDataSource, + datasource=ssl_datasource.SSLDataSource, trainer=smog_trainer.Trainer, ) server = personalized_server.Server( diff --git a/examples/ssl/algorithms/smog/smog_model.py b/examples/ssl/smog/smog_model.py similarity index 89% rename from examples/ssl/algorithms/smog/smog_model.py rename to examples/ssl/smog/smog_model.py index a0dd40328..abd34201b 100644 --- a/examples/ssl/algorithms/smog/smog_model.py +++ b/examples/ssl/smog/smog_model.py @@ -1,5 +1,5 @@ """ -A model for the SMoG method. +The model for the SMoG algorithm. """ import copy @@ -19,7 +19,7 @@ class SMoG(nn.Module): - """Core structure of the SMoG model.""" + """The structure of the SMoG model.""" def __init__(self, encoder=None): super().__init__() @@ -58,12 +58,15 @@ def __init__(self, encoder=None): Config().trainer.prediction_out_dim, ) + # Deepcopy the encoder and projector to create the momentum self.encoder_momentum = copy.deepcopy(self.encoder) self.projector_momentum = copy.deepcopy(self.projector) + # Deactivate the requires_grad flag for all parameters deactivate_requires_grad(self.encoder_momentum) deactivate_requires_grad(self.projector_momentum) + # Set the necessary hyper-parameter for SMoG self.n_groups = Config().trainer.n_groups n_prototypes = Config().trainer.n_prototypes beta = Config().trainer.smog_beta @@ -98,16 +101,16 @@ def reset_momentum_weights(self): deactivate_requires_grad(self.encoder_momentum) deactivate_requires_grad(self.projector_momentum) - def forward_view(self, views): + def forward_view(self, view_sample): """Foward one view sample to get the output.""" - encoded_features = self.encoder(views).flatten(start_dim=1) + encoded_features = self.encoder(view_sample).flatten(start_dim=1) projected_features = self.projector(encoded_features) prediction = self.predictor(projected_features) return projected_features, prediction - def forward_momentum(self, samples): + def forward_momentum(self, view_sample): """Foward one view sample to get the output in a momentum manner.""" - features = self.encoder_momentum(samples).flatten(start_dim=1) + features = self.encoder_momentum(view_sample).flatten(start_dim=1) encoded = self.projector_momentum(features) return encoded diff --git a/examples/ssl/algorithms/smog/smog_trainer.py b/examples/ssl/smog/smog_trainer.py similarity index 65% rename from examples/ssl/algorithms/smog/smog_trainer.py rename to examples/ssl/smog/smog_trainer.py index a050f2ddc..810faf622 100644 --- a/examples/ssl/algorithms/smog/smog_trainer.py +++ b/examples/ssl/smog/smog_trainer.py @@ -1,31 +1,39 @@ """ -The implemetation of the trainer for SMoG approach. +A self-supervised federated learning trainer with SMoG. """ import os -from ssl import ssl_trainer - import torch from lightly.loss.memory_bank import MemoryBankModule from lightly.models.utils import update_momentum from lightly.utils.scheduler import cosine_schedule +from plato.trainers import self_supervised_learning as ssl_trainer from plato.config import Config class Trainer(ssl_trainer.Trainer): - """A base trainer for SMoG approach.""" + """ + A trainer with SMoG, which computes the momentum value to update the model + in each training step and loads the 'memory_bank' to facilitate the model forward. + After training, the 'memory_bank' from the trained model will be saved to disk for + subsequent learning. + """ def __init__(self, model=None, callbacks=None): super().__init__(model, callbacks) + # The momentum value used to update the model + # with Exponential Moving Average self.momentum_val = 0 # Set training steps self.global_step = 0 - # Set the memory bank because we - # reset the group features every 300 iterations + # Set the memory bank and its size + # The reset_interval used here is the common term to show + # how many iterations we reset this memory bank. + # The number used by the authors is 300 self.reset_interval = ( Config().trainer.reset_interval if hasattr(Config().trainer, "reset_interval") @@ -36,7 +44,7 @@ def __init__(self, model=None, callbacks=None): ) def train_epoch_start(self, config): - """Operation before starting one epoch.""" + """Compute the momentum value before starting one epoch of training.""" super().train_epoch_start(config) epoch = self.current_epoch total_epochs = config["epochs"] * config["rounds"] @@ -46,7 +54,11 @@ def train_epoch_start(self, config): self.momentum_val = cosine_schedule(epoch, total_epochs, 0.996, 1) def train_step_start(self, config, batch=None): - """Operation before starting one iteration.""" + """ + Update the model based on the computed momentum value in each training step. + And reset the 'memory bank' along with all momentum values when the number of + collected features in this bank reaches its full size. + """ super().train_step_start(config) if not self.current_round > Config().trainer.rounds: @@ -54,10 +66,15 @@ def train_step_start(self, config, batch=None): self.global_step += 1 if self.global_step > 0 and self.global_step % self.reset_interval == 0: - # Reset group features and weights every some iterations + # Reset group features and momentum weights when the memory bank is + # full, i.e., the number of features added to the bank + # in each iteration step, reaches its full size. self.model.reset_group_features(memory_bank=self.memory_bank) self.model.reset_momentum_weights() else: + # Update the model based on the momentum value + # Specifically, it updates parameters of `encoder` with + # Exponential Moving Average of `encoder_momentum` update_momentum( self.model.encoder, self.model.encoder_momentum, m=self.momentum_val ) @@ -67,12 +84,14 @@ def train_step_start(self, config, batch=None): m=self.momentum_val, ) - # update the local iteration + # Update the local iteration for the model self.model.n_iteration = batch def train_run_start(self, config): """Load the memory bank from file system.""" super().train_run_start(config) + # Load the memory bank from the file system during + # regular federated training if not self.current_round > Config().trainer.rounds: model_path = Config().params["model_path"] filename_bank = f"client_{self.client_id}_bank.pth" @@ -87,6 +106,8 @@ def train_run_start(self, config): def train_run_end(self, config): """Save the memory bank to the file system.""" super().train_run_end(config) + # Load the memory bank from the file system during + # regular federated training if not self.current_round > Config().trainer.rounds: model_path = Config().params["model_path"] filename_bank = f"client_{self.client_id}_bank.pth" diff --git a/examples/ssl/ssl/__init__.py b/examples/ssl/ssl/__init__.py deleted file mode 100644 index f102a9cad..000000000 --- a/examples/ssl/ssl/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = "0.0.1" diff --git a/examples/ssl/swav/swav.py b/examples/ssl/swav/swav.py new file mode 100644 index 000000000..18b7d31ec --- /dev/null +++ b/examples/ssl/swav/swav.py @@ -0,0 +1,31 @@ +""" +An implementation of the SwAV algorithm. + +M. Caron, et al., "Unsupervised Learning of Visual Features by Contrasting Cluster Assignments," NeurIPS 2020. + +https://arxiv.org/abs/2006.09882 + +Source code: https://github.com/facebookresearch/swav +""" +from plato.servers import fedavg_personalized as personalized_server +from plato.clients import self_supervised_learning as ssl_client +from plato.datasources import self_supervised_learning as ssl_datasource + +import swav_model + + +def main(): + """ + A self-supervised federated learning session with SwAV. + """ + client = ssl_client.Client( + model=swav_model.SwaV, + datasource=ssl_datasource.SSLDataSource, + ) + server = personalized_server.Server(model=swav_model.SwaV) + + server.run(client) + + +if __name__ == "__main__": + main() diff --git a/examples/ssl/algorithms/swav/swav_model.py b/examples/ssl/swav/swav_model.py similarity index 72% rename from examples/ssl/algorithms/swav/swav_model.py rename to examples/ssl/swav/swav_model.py index dc3de4ce2..5201f5308 100644 --- a/examples/ssl/algorithms/swav/swav_model.py +++ b/examples/ssl/swav/swav_model.py @@ -1,29 +1,28 @@ """ -A model for the SwAV method. +A model for the SwAV algorithm. """ from torch import nn from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes - from plato.models.cnn_encoder import Model as encoder_registry from plato.config import Config class SwaV(nn.Module): - """The model structure for the SwaV.""" + """The structure of the SwAV Model.""" def __init__(self, encoder=None): super().__init__() - # Define the encoder. + # Define the encoder encoder_name = Config().trainer.encoder_name encoder_params = ( Config().params.encoder if hasattr(Config().params, "encoder") else {} ) - # Define the encoder. + # Define the encoder if encoder is not None: self.encoder = encoder else: @@ -31,12 +30,14 @@ def __init__(self, encoder=None): model_name=encoder_name, **encoder_params ) - # Define the projector. + # Define the projector self.projector = SwaVProjectionHead( self.encoder.encoding_dim, Config().trainer.projection_hidden_dim, Config().trainer.projection_out_dim, ) + # Define the prototypes which behave as the core + # part of the SwAV algorithm self.prototypes = SwaVPrototypes( Config().trainer.projection_out_dim, n_prototypes=Config().trainer.n_prototypes, @@ -44,14 +45,14 @@ def __init__(self, encoder=None): def forward_view(self, view_sample): """Foward views of the samples""" - encoded_sample = self.encoder(view_sample).flatten(start_dim=1) - projected_sample = self.projector(encoded_sample) - normalized_sample = nn.functional.normalize(projected_sample, dim=1, p=2) - outputs = self.prototypes(normalized_sample) + encoded_view = self.encoder(view_sample).flatten(start_dim=1) + projected_view = self.projector(encoded_view) + normalized_view = nn.functional.normalize(projected_view, dim=1, p=2) + outputs = self.prototypes(normalized_view) return outputs def forward(self, multiview_samples): - """Forward multiview samples.""" + """Forward multiview samples""" self.prototypes.normalize() multi_crop_features = [self.forward_view(views) for views in multiview_samples] high_resolution = multi_crop_features[:2] diff --git a/plato/clients/registry.py b/plato/clients/registry.py index a32ad1574..519debd89 100644 --- a/plato/clients/registry.py +++ b/plato/clients/registry.py @@ -7,12 +7,19 @@ import logging from plato.config import Config -from plato.clients import simple, mistnet, fedavg_personalized, split_learning +from plato.clients import ( + self_supervised_learning, + simple, + mistnet, + fedavg_personalized, + split_learning, +) registered_clients = { "simple": simple.Client, "mistnet": mistnet.Client, "fedavg_personalized": fedavg_personalized.Client, + "self_supervised_learning": self_supervised_learning.Client, "split_learning": split_learning.Client, } diff --git a/examples/ssl/ssl/ssl_client.py b/plato/clients/self_supervised_learning.py similarity index 65% rename from examples/ssl/ssl/ssl_client.py rename to plato/clients/self_supervised_learning.py index c08f20451..888356d3a 100644 --- a/examples/ssl/ssl/ssl_client.py +++ b/plato/clients/self_supervised_learning.py @@ -1,7 +1,12 @@ """ -A self-supervised learning (SSL) setting the personalized datasource for -the client. The datasets used in personalization are different from the ones used in -the regular federated learning with SSL. +A self-supervised learning (SSL) client prepares a personalized datasource for +the personalization process, which will be performed after finishing the FL +training process with SSL. + +Specifically, the conventional FL training process with SSL will train the model +with the datasource and objective function of SSL. Yet, the datasource used in +personalization should be one of supervised learning. Therefore, a client needs +to prepare the personalized datasource. """ from plato.datasources import registry as datasources_registry @@ -9,7 +14,7 @@ class Client(simple.Client): - """An SSL client to prepare the datasource for the personalization.""" + """An SSL client to prepare the datasource for personalization.""" def __init__( self, @@ -28,6 +33,7 @@ def __init__( callbacks=callbacks, trainer_callbacks=trainer_callbacks, ) + # The datasource used in personalization self.personalized_datasource = None def configure(self) -> None: diff --git a/plato/datasources/registry.py b/plato/datasources/registry.py index f995a1bfe..c7559e72d 100644 --- a/plato/datasources/registry.py +++ b/plato/datasources/registry.py @@ -68,11 +68,13 @@ def get(client_id: int = 0, **kwargs): """Get the data source with the provided name.""" datasource_name = ( - kwargs["datasource_name"] if "datasource_name" in kwargs else Config().data.datasource + kwargs["datasource_name"] + if "datasource_name" in kwargs + else Config().data.datasource ) logging.info("Data source: %s", datasource_name) - + if datasource_name == "kinetics700": from plato.datasources import kinetics diff --git a/examples/ssl/ssl/ssl_datasources.py b/plato/datasources/self_supervised_learning.py similarity index 93% rename from examples/ssl/ssl/ssl_datasources.py rename to plato/datasources/self_supervised_learning.py index 5080ad99a..51176dad4 100644 --- a/examples/ssl/ssl/ssl_datasources.py +++ b/plato/datasources/self_supervised_learning.py @@ -12,6 +12,7 @@ from plato.config import Config +# The normalizations for different datasets MNIST_NORMALIZE = {"mean": [0.1307], "std": [0.3081]} FashionMNIST_NORMALIZE = {"mean": [0.1307], "std": [0.3081]} CIFAR10_NORMALIZE = {"mean": [0.491, 0.482, 0.447], "std": [0.247, 0.243, 0.262]} @@ -19,7 +20,7 @@ IMAGENET_NORMALIZE = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]} STL10_NORMALIZE = {"mean": [0.4914, 0.4823, 0.4466], "std": [0.247, 0.243, 0.261]} -datasets_normalization = { +dataset_normalizations = { "MNIST": MNIST_NORMALIZE, "FashionMNIST": FashionMNIST_NORMALIZE, "CIFAR10": CIFAR10_NORMALIZE, @@ -63,7 +64,7 @@ def get_transforms(): # Get the data normalization for the datasource datasource_name = Config().data.datasource - transform_params["normalize"] = datasets_normalization[datasource_name] + transform_params["normalize"] = dataset_normalizations[datasource_name] # Get the SSL transform if transform_name in registered_transforms: dataset_transform = registered_transforms[transform_name]( @@ -94,9 +95,3 @@ def __init__(self): self.datasource = datasources_registry.get(**data_transforms) self.trainset = self.datasource.trainset self.testset = self.datasource.testset - - def num_train_examples(self): - return len(self.trainset) - - def num_test_examples(self): - return len(self.testset) diff --git a/plato/trainers/registry.py b/plato/trainers/registry.py index a9e876270..280b1b778 100644 --- a/plato/trainers/registry.py +++ b/plato/trainers/registry.py @@ -41,6 +41,11 @@ def get(model=None, callbacks=None): from plato.trainers import huggingface return huggingface.Trainer(model=model, callbacks=callbacks) + + elif Config().trainer.type == "self_supervised_learning": + from plato.trainers import self_supervised_learning + + return self_supervised_learning.Trainer(model=model, callbacks=callbacks) elif trainer_name in registered_trainers: return registered_trainers[trainer_name](model=model, callbacks=callbacks) else: diff --git a/examples/ssl/ssl/ssl_trainer.py b/plato/trainers/self_supervised_learning.py similarity index 85% rename from examples/ssl/ssl/ssl_trainer.py rename to plato/trainers/self_supervised_learning.py index eaa072cc7..2277867c8 100644 --- a/examples/ssl/ssl/ssl_trainer.py +++ b/plato/trainers/self_supervised_learning.py @@ -1,13 +1,17 @@ """ A self-supervised learning (SSL) trainer for SSL training and testing. -Federated SSL trains the global model based on the data loader and objective function -of SSL algorithms. For this unsupervised learning process, we cannot test the model directly -as the model only extracts features from the data. Therefore, we use the KNN as a classifier -to get the accuracy of the global model during the regular federated training process. +Federated learning with SSL trains the global model based on the data loader and +objective function of SSL algorithms. For this unsupervised learning process, we +cannot test the model directly as the model only extracts features from the +data. Therefore, we use KNN as a classifier to get the accuracy of the global +model during the regular federated training process. -In the personalization phase, each client trains a linear layer locally, -based on the features extracted by the trained global model. +In the personalization process, each client trains a linear layer locally, based +on the features extracted by the trained global model. + +The accuracy obtained by KNN during the regular federated training rounds may +not be used to compare with the accuracy in supervised learning methods. """ import logging @@ -27,18 +31,19 @@ class SSLSamples(UserList): def to(self, device): """Assign a list of views into the specific device.""" - for example_idx, example in enumerate(self.data): - if isinstance(example, torch.Tensor): - example = example.to(device) + for view_idx, view in enumerate(self.data): + if isinstance(view, torch.Tensor): + view = view.to(device) - self[example_idx] = example + self[view_idx] = view return self.data class MultiViewCollateWrapper(MultiViewCollate): - """An interface to connect the collate from lightly with the data loading schema of - Plato.""" + """ + An interface to connect collate from lightly with the data loading schema in Plato. + """ def __call__(self, batch): """Turn a batch of tuples into a single tuple.""" @@ -81,7 +86,7 @@ def set_personalized_datasets(self, trainset, testset): def get_train_loader(self, batch_size, trainset, sampler, **kwargs): """Obtain the training loader based on the learning mode.""" - # Get the training loader for the personalization + # Get the trainloader for personalization if self.current_round > Config().trainer.rounds: return torch.utils.data.DataLoader( dataset=self.personalized_trainset, @@ -114,9 +119,8 @@ def get_optimizer(self, model): def get_ssl_criterion(self): """ - Get the loss criterion for the SSL. - Some SSL algorithms, for example, BYOL, will overwrite this function for - specific loss functions. + Get the loss criterion for SSL. Some SSL algorithms, for example, + BYOL, will overwrite this function for specific loss functions. """ # Get loss criterion for the SSL @@ -125,16 +129,16 @@ def get_ssl_criterion(self): # We need to wrap the loss function to make it compatible # with different types of outputs # The types of the outputs can vary from Tensor to a list of Tensors - def compute_plato_loss(outputs, labels): + def compute_loss(outputs, __): if isinstance(outputs, (list, tuple)): return ssl_loss_function(*outputs) - else: - return ssl_loss_function(outputs) - return compute_plato_loss + return ssl_loss_function(outputs) + + return compute_loss def get_loss_criterion(self): - """Return the loss criterion for the SSL.""" + """Return the loss criterion for SSL.""" # Get loss criterion for the subsequent training process if self.current_round > Config().trainer.rounds: loss_criterion_type = Config().algorithm.personalization.loss_criterion @@ -151,7 +155,7 @@ def get_loss_criterion(self): return self.get_ssl_criterion() def get_lr_scheduler(self, config, optimizer): - # Get the lr scheduler for the personalization + # Get the lr scheduler for personalization if self.current_round > Config().trainer.rounds: lr_scheduler = Config().algorithm.personalization.lr_scheduler lr_params = Config().parameters.personalization.learning_rate._asdict() @@ -183,11 +187,11 @@ def perform_forward_and_backward_passes(self, config, examples, labels): extract features into the local layers. """ - # Perform the SSL training in the first Config().trainer.rounds rounds + # Perform SSL training in the first `Config().trainer.rounds`` rounds if not self.current_round > Config().trainer.rounds: return super().perform_forward_and_backward_passes(config, examples, labels) - # Perform the personalization after the final round + # Perform personalization after the final round # Perform the local update on self.local_layers self.optimizer.zero_grad() @@ -262,10 +266,10 @@ def test_model(self, config, testset, sampler=None, **kwargs): return accuracy else: - # Test the personalization in each round. + # Test the personalized model in each round. # For SSL, the way to test the trained model before personalization is - # to use the KNN as a classifier to evaluate the extracted features. + # to use the KNN as a classifier to evaluate the extracted features. logging.info("[Client #%d] Testing the model with KNN.", self.client_id) @@ -280,11 +284,11 @@ def test_model(self, config, testset, sampler=None, **kwargs): testset, batch_size=batch_size, shuffle=False, sampler=sampler ) # For evaluating self-supervised performance, we need to calculate - # distance between training samples and testing samples. + # distance between training samples and testing samples. train_encodings, train_labels = self.collect_encodings(train_loader) test_encodings, test_labels = self.collect_encodings(test_loader) - # Build the KNN and Perform the prediction + # Build KNN and perform the prediction distances = torch.cdist(test_encodings, train_encodings, p=2) knn = distances.topk(1, largest=False) nearest_idx = knn.indices diff --git a/plato/utils/trainer_utils.py b/plato/utils/trainer_utils.py index 1d9a48d5a..2bcfb94eb 100644 --- a/plato/utils/trainer_utils.py +++ b/plato/utils/trainer_utils.py @@ -14,7 +14,7 @@ def freeze_model(model, layer_names=None): def activate_model(model, layer_names=None): - """Defreeze a part of the model.""" + """Activate a part of the model.""" if layer_names is not None: for name, param in model.named_parameters(): if any(param_name in name for param_name in layer_names): diff --git a/requirements.txt b/requirements.txt index dfc72a69d..d7b3c2ac7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,8 +11,4 @@ gym zstd torch-optimizer timm -lightly -scikit-learn -pytest -transformers -peft \ No newline at end of file +lightly \ No newline at end of file