-
Notifications
You must be signed in to change notification settings - Fork 82
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Cleaned up the source code of AsyncFilter (#378)
* Cleaned up the code. * Added configuration files and a readme file. * Cleaned up the comments. * Updated the readme. * Update examples.md
- Loading branch information
1 parent
399cfe2
commit 48d4660
Showing
12 changed files
with
1,106 additions
and
224 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Reproducing AsyncFilter | ||
|
||
## Setting up your Python environment | ||
|
||
It is recommended that [Miniforge](https://github.com/conda-forge/miniforge) is used to manage Python packages. Before using *Plato*, first install Miniforge, update your `conda` environment, and then create a new `conda` environment with Python 3.9 using the command: | ||
|
||
```shell | ||
conda update conda -y | ||
conda create -n plato -y python=3.9 | ||
conda activate plato | ||
``` | ||
|
||
where `plato` is the preferred name of your new environment. | ||
|
||
The next step is to install the required Python packages. PyTorch should be installed following the advice of its [getting started website](https://pytorch.org/get-started/locally/). The typical command in Linux with CUDA GPU support, for example, would be: | ||
|
||
```shell | ||
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117 | ||
``` | ||
|
||
In macOS (without GPU support), the recommended command would be: | ||
|
||
```shell | ||
pip install torch==1.13.1 torchvision==0.14.1 | ||
``` | ||
Additionally, install scikit-learn package: | ||
|
||
```shell | ||
pip install scikit-learn | ||
``` | ||
## Installing Plato | ||
|
||
Navigate to the Plato directory and install the latest version from GitHub as a local pip package: | ||
|
||
```shell | ||
cd ../.. | ||
pip install . | ||
``` | ||
|
||
# Running experiments in plato/examples/detector folder | ||
Navigate to the examples/detector folder to start running experiments: | ||
```shell | ||
cd examples/detector | ||
``` | ||
|
||
## Set up the configuration file | ||
A variety of configuration files are provided for different experiments. Below are examples for reproducing key experiments from the paper: | ||
|
||
### Example 1: Section 5.2 - Running AsyncFilter on CIFAR-10 | ||
#### Download the dataset | ||
|
||
```shell | ||
python detector.py -c asyncfilter_cifar_2.yml -d | ||
``` | ||
|
||
#### Run the experiments | ||
```shell | ||
python detector.py -c asyncfilter_cifar_2.yml | ||
``` | ||
### Example 2: Section 5.3 - Running AsyncFilter Under LIE Attack on CINIC-10 (Concentration Factor: 0.01) | ||
#### Download the dataset | ||
|
||
```shell | ||
python detector.py -c asyncfilter_cinic_3.yml -d | ||
``` | ||
#### Run the experiments | ||
```shell | ||
python detector.py -c asyncfilter_cinic_3.yml | ||
``` | ||
### Example 3: Section 5.6 - Running AsyncFilter Under LIE Attack on FashionMNIST (Server Staleness Limit: 10) | ||
|
||
#### Download the dataset | ||
|
||
```shell | ||
python detector.py -c asyncfilter_fashionmnist_6.yml -d | ||
``` | ||
#### Run the experiments | ||
```shell | ||
python detector.py -c asyncfilter_fashionmnist_6.yml | ||
``` | ||
|
||
### Customizing Experiments | ||
For further experimentation, you can modify the configuration files to suit your requirements and reproduce the results. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
clients: | ||
# Type | ||
type: simple | ||
|
||
# The total number of clients | ||
total_clients: 100 | ||
|
||
# The number of clients selected in each round | ||
per_round: 100 | ||
|
||
# Should the clients compute test accuracy locally? | ||
do_test: true | ||
random_seed: 1 | ||
speed_simulation: true | ||
|
||
# The distribution of client speeds | ||
simulation_distribution: | ||
distribution: zipf # zipf is used. | ||
s: 1.2 | ||
sleep_simulation: true | ||
|
||
# If we are simulating client training times, what is the average training time? | ||
avg_training_time: 10 | ||
attack_type: LIE | ||
lambada_value: 2 | ||
attacker_ids: 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20 #,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50 | ||
|
||
server: | ||
address: 127.0.0.1 | ||
port: 5002 | ||
random_seed: 1 | ||
sychronous: false | ||
simulate_wall_time: true | ||
minimum_clients_aggregated: 40 | ||
staleness_bound: 10 | ||
checkpoint_path: results/CIFAR/test/checkpoint | ||
model_path: results/CIFAR/test/model | ||
|
||
|
||
data: | ||
# The training and testing dataset | ||
datasource: CIFAR10 | ||
|
||
# Number of samples in each partition | ||
partition_size: 10000 | ||
|
||
# IID or non-IID? | ||
sampler: noniid | ||
concentration: 0.1 | ||
random_seed: 1 | ||
|
||
trainer: | ||
# The type of the trainer | ||
type: basic | ||
|
||
# The maximum number of training rounds | ||
rounds: 100 | ||
|
||
# The maximum number of clients running concurrently | ||
max_concurrency: 2 | ||
|
||
# The target accuracy | ||
target_accuracy: 0.88 | ||
|
||
# The machine learning model | ||
model_name: vgg_16 | ||
|
||
# Number of epoches for local training in each communication round | ||
epochs: 5 | ||
batch_size: 128 | ||
optimizer: Adam | ||
|
||
algorithm: | ||
# Aggregation algorithm | ||
type: fedavg | ||
|
||
parameters: | ||
model: | ||
num_classes: 10 | ||
|
||
optimizer: | ||
lr: 0.01 | ||
weight_decay: 0.0 | ||
results: | ||
# Write the following parameter(s) into a CSV | ||
types: round, accuracy, elapsed_time, comm_time, round_time | ||
result_path: /data/ykang/plato/results/asyncfilter/cifar | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
clients: | ||
# Type | ||
type: simple | ||
|
||
# The total number of clients | ||
total_clients: 100 | ||
|
||
# The number of clients selected in each round | ||
per_round: 100 | ||
|
||
# Should the clients compute test accuracy locally? | ||
do_test: true | ||
random_seed: 1 | ||
|
||
# The distribution of client speeds | ||
simulation_distribution: | ||
distribution: zipf # zipf is used. | ||
s: 1.2 | ||
sleep_simulation: true | ||
speed_simulation: true | ||
|
||
# If we are simulating client training times, what is the average training time? | ||
avg_training_time: 10 | ||
attack_type: LIE | ||
lambada_value: 2 | ||
attacker_ids: 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20 #,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50 | ||
|
||
|
||
server: | ||
address: 127.0.0.1 | ||
port: 6332 | ||
random_seed: 1 | ||
sychronous: false | ||
simulate_wall_time: true | ||
minimum_clients_aggregated: 40 | ||
detector_type: AsyncFilter | ||
staleness_bound: 20 | ||
checkpoint_path: results/CIFAR/test/checkpoint | ||
model_path: results/CIFAR/test/model | ||
|
||
|
||
data: | ||
# The training and testing dataset | ||
datasource: CINIC10 | ||
|
||
# Where the dataset is located | ||
data_path: data/CINIC-10 | ||
|
||
# | ||
download_url: http://iqua.ece.toronto.edu/baochun/CINIC-10.tar.gz | ||
|
||
# Number of samples in each partition | ||
partition_size: 10000 | ||
|
||
# IID or non-IID? | ||
sampler: noniid | ||
concentration: 0.1 | ||
random_seed: 1 | ||
|
||
trainer: | ||
# The type of the trainer | ||
type: basic | ||
|
||
# The maximum number of training rounds | ||
rounds: 100 | ||
|
||
# The maximum number of clients running concurrently | ||
max_concurrency: 4 | ||
|
||
# The target accuracy | ||
target_accuracy: 0.88 | ||
|
||
# The machine learning model | ||
model_name: vgg_16 | ||
|
||
# Number of epoches for local training in each communication round | ||
epochs: 5 | ||
batch_size: 128 | ||
optimizer: SGD | ||
|
||
algorithm: | ||
# Aggregation algorithm | ||
type: fedavg | ||
|
||
parameters: | ||
model: | ||
num_classes: 10 | ||
|
||
optimizer: | ||
lr: 0.01 | ||
momentum: 0.5 | ||
weight_decay: 0.0 | ||
results: | ||
# Write the following parameter(s) into a CSV | ||
types: round, accuracy, elapsed_time, comm_time, round_time | ||
result_path: /data/ykang/plato/results/asyncfilter/cinic |
Oops, something went wrong.