From 48d46603c884b93528ea5a39d777e3ce48553955 Mon Sep 17 00:00:00 2001 From: Yufei Kang <41565676+Yufei-Kang@users.noreply.github.com> Date: Mon, 30 Sep 2024 21:39:36 -0400 Subject: [PATCH] Cleaned up the source code of AsyncFilter (#378) * Cleaned up the code. * Added configuration files and a readme file. * Cleaned up the comments. * Updated the readme. * Update examples.md --- docs/examples.md | 13 + examples/detector/README.md | 83 ++++ examples/detector/aggregations.py | 12 +- examples/detector/asyncfilter_cifar_2.yml | 88 ++++ examples/detector/asyncfilter_cinic_2.yml | 96 ++++ ...t_lenet5.yml => asyncfilter_fashion_6.yml} | 69 ++- examples/detector/attacks.py | 293 +++++++----- examples/detector/defences.py | 90 +++- examples/detector/detector.py | 2 +- examples/detector/detector_client.py | 41 -- examples/detector/detector_server.py | 120 ++++- examples/detector/detectors.py | 423 ++++++++++++++++++ 12 files changed, 1106 insertions(+), 224 deletions(-) create mode 100644 examples/detector/README.md create mode 100644 examples/detector/asyncfilter_cifar_2.yml create mode 100644 examples/detector/asyncfilter_cinic_2.yml rename examples/detector/{femnist_lenet5.yml => asyncfilter_fashion_6.yml} (59%) delete mode 100644 examples/detector/detector_client.py create mode 100644 examples/detector/detectors.py diff --git a/docs/examples.md b/docs/examples.md index 39eb8f90f..520089089 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -538,6 +538,19 @@ FedSaw is proposed to improve training performance in three-layer federated lear python examples/three_layer_fl/fedsaw/fedsaw.py -c examples/three_layer_fl/fedsaw/fedsaw_MNIST_lenet5.yml ``` ```` +#### Poisoning Detection Algorithms +````{admonition} **AsyncFilter** +AsyncFilter is proposed to defend against untargeted poisoning attacks in asynchronous federated learning with a server filter. With statistical analysis, AsyncFilter identifies potential poisoned model updates and filters them out before the server aggregation stage. + +```shell +python examples/detector/detector.py -c examples/detector/asyncfilter_fashion_6.yml +``` + +```{note} +Kang et al., “[AsyncFilter: Detecting Poisoning Attacks in Asynchronous Federated Learning](http://iqua.ece.toronto.edu/papers/ykang-middleware25.pdf) +&rdquo: in the Proceedings of the 25th ACM/IFIP International Middleware Conference (Middleware), 2024. +``` +```` #### Model Pruning Algorithms diff --git a/examples/detector/README.md b/examples/detector/README.md new file mode 100644 index 000000000..319c63cc9 --- /dev/null +++ b/examples/detector/README.md @@ -0,0 +1,83 @@ +# Reproducing AsyncFilter + +## Setting up your Python environment + +It is recommended that [Miniforge](https://github.com/conda-forge/miniforge) is used to manage Python packages. Before using *Plato*, first install Miniforge, update your `conda` environment, and then create a new `conda` environment with Python 3.9 using the command: + +```shell +conda update conda -y +conda create -n plato -y python=3.9 +conda activate plato +``` + +where `plato` is the preferred name of your new environment. + +The next step is to install the required Python packages. PyTorch should be installed following the advice of its [getting started website](https://pytorch.org/get-started/locally/). The typical command in Linux with CUDA GPU support, for example, would be: + +```shell +pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117 +``` + +In macOS (without GPU support), the recommended command would be: + +```shell +pip install torch==1.13.1 torchvision==0.14.1 +``` +Additionally, install scikit-learn package: + +```shell +pip install scikit-learn +``` +## Installing Plato + +Navigate to the Plato directory and install the latest version from GitHub as a local pip package: + +```shell +cd ../.. +pip install . +``` + +# Running experiments in plato/examples/detector folder +Navigate to the examples/detector folder to start running experiments: +```shell +cd examples/detector +``` + +## Set up the configuration file +A variety of configuration files are provided for different experiments. Below are examples for reproducing key experiments from the paper: + +### Example 1: Section 5.2 - Running AsyncFilter on CIFAR-10 +#### Download the dataset + +```shell +python detector.py -c asyncfilter_cifar_2.yml -d +``` + +#### Run the experiments +```shell +python detector.py -c asyncfilter_cifar_2.yml +``` +### Example 2: Section 5.3 - Running AsyncFilter Under LIE Attack on CINIC-10 (Concentration Factor: 0.01) +#### Download the dataset + +```shell +python detector.py -c asyncfilter_cinic_3.yml -d +``` +#### Run the experiments +```shell +python detector.py -c asyncfilter_cinic_3.yml +``` +### Example 3: Section 5.6 - Running AsyncFilter Under LIE Attack on FashionMNIST (Server Staleness Limit: 10) + +#### Download the dataset + +```shell +python detector.py -c asyncfilter_fashionmnist_6.yml -d +``` +#### Run the experiments +```shell +python detector.py -c asyncfilter_fashionmnist_6.yml +``` + +### Customizing Experiments +For further experimentation, you can modify the configuration files to suit your requirements and reproduce the results. \ No newline at end of file diff --git a/examples/detector/aggregations.py b/examples/detector/aggregations.py index 7f1b43c53..0372fa388 100644 --- a/examples/detector/aggregations.py +++ b/examples/detector/aggregations.py @@ -77,12 +77,12 @@ def bulyan(updates, baseline_weights, weights_attacked): """Aggregate weight updates from the clients using bulyan.""" total_clients = Config().clients.total_clients - num_attackers = len(Config().clients.attacker_ids) # ? + num_attackers = len(Config().clients.attacker_ids) remaining_weights = flatten_weights(weights_attacked) bulyan_cluster = [] - # Search for bulyan cluster based on distance + # Search for bulyan cluster based on distances while (len(bulyan_cluster) < (total_clients - 2 * num_attackers)) and ( len(bulyan_cluster) < (total_clients - 2 - num_attackers) ): @@ -104,7 +104,7 @@ def bulyan(updates, baseline_weights, weights_attacked): : len(remaining_weights) - 2 - num_attackers ] - # Add candidate into bulyan cluster + # Add candidates into bulyan cluster bulyan_cluster = ( remaining_weights[indices[0]][None, :] if not len(bulyan_cluster) @@ -149,7 +149,7 @@ def krum(updates, baseline_weights, weights_attacked): remaining_weights = flatten_weights(weights_attacked) - num_attackers_selected = 2 # ? + num_attackers_selected = 2 distances = [] for weight in remaining_weights: @@ -339,7 +339,7 @@ def afa(updates, baseline_weights, weights_attacked): bad_set.append(remove_id) else: - for counter, weight in enumerate(flattened_weights): # we for loop this + for counter, weight in enumerate(flattened_weights): if cos_sims[counter] > (model_median + epsilon * model_std): remove_set.append(1) remove_id = ( @@ -353,7 +353,7 @@ def afa(updates, baseline_weights, weights_attacked): temp_tensor2 = flattened_weights_copy[delete_id + 1 :] flattened_weights_copy = torch.cat( (temp_tensor1, temp_tensor2), dim=0 - ) # but we changes it in the loop, maybe we should get a copy + ) bad_set.append(remove_id) epsilon += delta_ep diff --git a/examples/detector/asyncfilter_cifar_2.yml b/examples/detector/asyncfilter_cifar_2.yml new file mode 100644 index 000000000..b47f31dba --- /dev/null +++ b/examples/detector/asyncfilter_cifar_2.yml @@ -0,0 +1,88 @@ +clients: + # Type + type: simple + + # The total number of clients + total_clients: 100 + + # The number of clients selected in each round + per_round: 100 + + # Should the clients compute test accuracy locally? + do_test: true + random_seed: 1 + speed_simulation: true + + # The distribution of client speeds + simulation_distribution: + distribution: zipf # zipf is used. + s: 1.2 + sleep_simulation: true + + # If we are simulating client training times, what is the average training time? + avg_training_time: 10 + attack_type: LIE + lambada_value: 2 + attacker_ids: 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20 #,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50 + +server: + address: 127.0.0.1 + port: 5002 + random_seed: 1 + sychronous: false + simulate_wall_time: true + minimum_clients_aggregated: 40 + staleness_bound: 10 + checkpoint_path: results/CIFAR/test/checkpoint + model_path: results/CIFAR/test/model + + +data: + # The training and testing dataset + datasource: CIFAR10 + + # Number of samples in each partition + partition_size: 10000 + + # IID or non-IID? + sampler: noniid + concentration: 0.1 + random_seed: 1 + +trainer: + # The type of the trainer + type: basic + + # The maximum number of training rounds + rounds: 100 + + # The maximum number of clients running concurrently + max_concurrency: 2 + + # The target accuracy + target_accuracy: 0.88 + + # The machine learning model + model_name: vgg_16 + + # Number of epoches for local training in each communication round + epochs: 5 + batch_size: 128 + optimizer: Adam + +algorithm: + # Aggregation algorithm + type: fedavg + +parameters: + model: + num_classes: 10 + + optimizer: + lr: 0.01 + weight_decay: 0.0 +results: + # Write the following parameter(s) into a CSV + types: round, accuracy, elapsed_time, comm_time, round_time + result_path: /data/ykang/plato/results/asyncfilter/cifar + diff --git a/examples/detector/asyncfilter_cinic_2.yml b/examples/detector/asyncfilter_cinic_2.yml new file mode 100644 index 000000000..126957d3e --- /dev/null +++ b/examples/detector/asyncfilter_cinic_2.yml @@ -0,0 +1,96 @@ +clients: + # Type + type: simple + + # The total number of clients + total_clients: 100 + + # The number of clients selected in each round + per_round: 100 + + # Should the clients compute test accuracy locally? + do_test: true + random_seed: 1 + + # The distribution of client speeds + simulation_distribution: + distribution: zipf # zipf is used. + s: 1.2 + sleep_simulation: true + speed_simulation: true + + # If we are simulating client training times, what is the average training time? + avg_training_time: 10 + attack_type: LIE + lambada_value: 2 + attacker_ids: 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20 #,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50 + + +server: + address: 127.0.0.1 + port: 6332 + random_seed: 1 + sychronous: false + simulate_wall_time: true + minimum_clients_aggregated: 40 + detector_type: AsyncFilter + staleness_bound: 20 + checkpoint_path: results/CIFAR/test/checkpoint + model_path: results/CIFAR/test/model + + +data: + # The training and testing dataset + datasource: CINIC10 + + # Where the dataset is located + data_path: data/CINIC-10 + + # + download_url: http://iqua.ece.toronto.edu/baochun/CINIC-10.tar.gz + + # Number of samples in each partition + partition_size: 10000 + + # IID or non-IID? + sampler: noniid + concentration: 0.1 + random_seed: 1 + +trainer: + # The type of the trainer + type: basic + + # The maximum number of training rounds + rounds: 100 + + # The maximum number of clients running concurrently + max_concurrency: 4 + + # The target accuracy + target_accuracy: 0.88 + + # The machine learning model + model_name: vgg_16 + + # Number of epoches for local training in each communication round + epochs: 5 + batch_size: 128 + optimizer: SGD + +algorithm: + # Aggregation algorithm + type: fedavg + +parameters: + model: + num_classes: 10 + + optimizer: + lr: 0.01 + momentum: 0.5 + weight_decay: 0.0 +results: + # Write the following parameter(s) into a CSV + types: round, accuracy, elapsed_time, comm_time, round_time + result_path: /data/ykang/plato/results/asyncfilter/cinic diff --git a/examples/detector/femnist_lenet5.yml b/examples/detector/asyncfilter_fashion_6.yml similarity index 59% rename from examples/detector/femnist_lenet5.yml rename to examples/detector/asyncfilter_fashion_6.yml index 102f0359c..f6c5a3b3e 100644 --- a/examples/detector/femnist_lenet5.yml +++ b/examples/detector/asyncfilter_fashion_6.yml @@ -3,63 +3,51 @@ clients: type: simple # The total number of clients - total_clients: 10 + total_clients: 100 # The number of clients selected in each round - per_round: 2 + per_round: 100 # Should the clients compute test accuracy locally? - do_test: - false - - # Whether client heterogeneity should be simulated - speed_simulation: true + do_test: false + random_seed: 1 # The distribution of client speeds simulation_distribution: distribution: zipf # zipf is used. s: 1.2 - - # The maximum amount of time for clients to sleep after each epoch - #max_sleep_time: 10 - - # Should clients really go to sleep, or should we just simulate the sleep times? sleep_simulation: true + speed_simulation: true # If we are simulating client training times, what is the average training time? avg_training_time: 10 + attack_type: LIE + lambada_value: 2 + attacker_ids: 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20 #,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40 #,41,42,43,44,45,46,47,48,49,50 - random_seed: 1 - attack_type: Min-Sum - attacker_ids: 1,2,3,4,5,6,7,8,9,10 server: address: 127.0.0.1 - port: 2910 + port: 5602 random_seed: 1 - synchronous: true + synchronous: false simulate_wall_time: true - #defence_type: + minimum_clients_aggregated: 40 + staleness_bound: 10 #20 + detector_type: AsyncFilter + checkpoint_path: results/fashion/test/checkpoint + model_path: results/fashion/test/model - checkpoint_path: results/FEMNIST/test/checkpoint - model_path: results/FEMNIST/test/model data: # The training and testing dataset - datasource: FEMNIST - - reload_data: true + datasource: FashionMNIST # Number of samples in each partition - #partition_size: 1000 + partition_size: 2000 # IID or non-IID? - sampler: all_inclusive - - #concentration: 1 - - #testset_size: 1000 - - # The random seed for sampling data + sampler: noniid + concentration: 0.1 random_seed: 1 trainer: @@ -67,22 +55,22 @@ trainer: type: basic # The maximum number of training rounds - rounds: 150 + rounds: 100 # The maximum number of clients running concurrently - max_concurrency: 6 + max_concurrency: 5 # The target accuracy - target_accuracy: 1.0 + target_accuracy: 0.88 + # The machine learning model + model_name: lenet5 + # Number of epoches for local training in each communication round epochs: 5 batch_size: 32 optimizer: SGD - # The machine learning model - model_name: lenet5 - algorithm: # Aggregation algorithm type: fedavg @@ -91,11 +79,8 @@ parameters: optimizer: lr: 0.01 momentum: 0.9 - weight_decay: 0 - model: - num_classes: 62 - + weight_decay: 0.0 results: # Write the following parameter(s) into a CSV types: round, accuracy, elapsed_time, comm_time, round_time - result_path: /data/ykang/plato/results/attackDefence/obs/femnist + result_path: /data/ykang/plato/results/asyncfilter/fashionmnist diff --git a/examples/detector/attacks.py b/examples/detector/attacks.py index 091213df3..400a0dfa5 100644 --- a/examples/detector/attacks.py +++ b/examples/detector/attacks.py @@ -12,7 +12,8 @@ from collections import OrderedDict import numpy as np import os - +import torch.nn.functional as F +import pickle def get(): """Get an attack for malicious clients based on the configuration file.""" @@ -24,7 +25,7 @@ def get(): if attack_type is None: logging.info(f"No attack is applied.") - return lambda x: x + return lambda a,x,b,c: x if attack_type in registered_attacks: registered_attack = registered_attacks[attack_type] @@ -33,23 +34,44 @@ def get(): raise ValueError(f"No such attack: {attack_type}") - -def perform_model_poisoning(weights_received, poison_value): - # Poison the reveiced weights based on calculated poison value. - weights_poisoned = [] - for weight_received in weights_received: +def poisoning_performance_evaluation(clean_weights, poisoned_values, poisoned_weights): + # calculate deviated norm of poisoned weights from clean weights + # poisoned_values is a long tensor, and it's same for all clients + logging.info(f"poisoning performance evaluation (torch.norm(poisoned_value)): %s", torch.norm(poisoned_values)) + logging.info(f"norm of clean weights are: %s", torch.norm(clean_weights, dim=1)) + flatten_poisoned_weights = flatten_weights(poisoned_weights) + logging.info(f"norm of poisoned weights: %s", torch.norm(flatten_poisoned_weights,dim=1)) + +def perform_model_poisoning(weights_received, poison_value): + # Poison the received weights based on calculated poison value. + # The "poison_value" means modifications on original weights + for index, weight_received in enumerate(weights_received): start_index = 0 - weight_poisoned = OrderedDict() for name, weight in weight_received.items(): - weight_poisoned[name] = poison_value[ + if weights_received[index][name].dtype == torch.int64: + weights_received[index][name] += poison_value[ start_index : start_index + len(weight.view(-1)) - ].reshape(weight.shape) + ].reshape(weight.shape).long() + + else: + weights_received[index][name] += poison_value[ + start_index : start_index + len(weight.view(-1)) + ].reshape(weight.shape) start_index += len(weight.view(-1)) + + return weights_received - weights_poisoned.append(weight_poisoned) - return weights_poisoned +def flatten_weight(weight): + flattened_weight = [] + for name in weight.keys(): + flattened_weight = ( + weight[name].view(-1) + if not len(flattened_weight) + else torch.cat((flattened_weight, weight[name].view(-1))) + ) + return flattened_weight def flatten_weights(weights): flattened_weights = [] @@ -70,11 +92,10 @@ def flatten_weights(weights): ) return flattened_weights - def compute_sali_indicator(): # Add importance pruning to the attack sali_map_vector = torch.load(Config().algorithm.map_path) - + sparsity = Config().algorithm.sparsity shrink_level = Config().algorithm.shrink_level inflation_level = Config().algorithm.inflation_level @@ -89,69 +110,107 @@ def compute_sali_indicator(): return sali_indicators_vector - -def smoothing(keywords, value): +def smoothing(keywords,value): total_clients = Config().clients.total_clients num_attackers = len(Config().clients.attacker_ids) clients_per_round = Config().clients.per_round - malicious_expectation = (num_attackers / total_clients) * clients_per_round - if num_attackers < malicious_expectation: # how to know num_attackers? + malicious_expectation = ( + num_attackers /total_clients + ) * clients_per_round + if num_attackers < malicious_expectation: momentum = Config().algorithm.high_momentum else: momentum = Config().algorithm.low_momentum # Smooth poison value - file_path = "./" + keywords + "_model_updates_history.pt" - if os.path.exists(file_path): + file_path = "./"+keywords+"_model_updates_history.pt" + if os.path.exists(file_path): last_model_re = torch.load(file_path) value = (1 - momentum) * value + momentum * last_model_re - torch.save(value, file_path) + torch.save(value,file_path) return value +def gassian_attack(baseline_weights,weights_received,deltas_received,num_attackers): + # calculate poison value based on Gassian distribution + attacker_weights = flatten_weights(weights_received) + baseline_weights = flatten_weight(baseline_weights) -def lie_attack(weights_received): - """ - Attack name: Little is enough - - Reference: - - Baruch et al., "A little is enough: Circumventing defenses for distributed learning," in Proceedings of Advances in Neural Information Processing Systems (NeurIPS) 2019. - - https://proceedings.neurips.cc/paper_files/paper/2019/file/ec1c59141046cd1866bbbcdfb6ae31d4-Paper.pdf - """ - - total_clients = Config().clients.total_clients - num_attackers = len(Config().clients.attacker_ids) + # set standard deviation and mean value for gassian noise + std_dev = 1 + mean = 0 + poison_value = torch.randn_like(attacker_weights[0]) * std_dev + mean + + # Perform model poisoning + weights_poisoned = perform_model_poisoning(weights_received, poison_value) + logging.info(f"Finished Gassian model poisoning attack.") + # poisoning_performance_evaluation(attacker_weights, poison_value, weights_poisoned) + return weights_poisoned +def lambda_attack(baseline_weights, weights_received, deltas_received,num_attackers): attacker_weights = flatten_weights(weights_received) + attacker_deltas = flatten_weights(deltas_received) + baseline_weights = flatten_weight(baseline_weights) + - # Calculate perturbation range - s_value = total_clients / 2 + 1 - num_attackers - possibility = (total_clients - s_value) / total_clients - z_value = norm.cdf(possibility) + lamda = Config().clients.lambada_value + direction = -1 # opposite direction; + perturbation_vector = direction * torch.mean(attacker_deltas, dim=0) + poison_value = lamda * perturbation_vector + #deltas_poisoned = perform_model_poisoning(deltas_received, poison_value) - weights_avg = torch.mean(attacker_weights, dim=0) - weights_std = torch.std(attacker_weights, dim=0) - - # Calculate poisoning model - poison_value = weights_avg + z_value * weights_std # Perform model poisoning weights_poisoned = perform_model_poisoning(weights_received, poison_value) - logging.info(f"Finished LIE model poisoning attack.") - return weights_poisoned + logging.info(f"Finished Lambda model poisoning attack.") + #poisoning_performance_evaluation(attacker_weights, poison_value, weights_poisoned) + return weights_poisoned -def oblivion_lie_attack(weights_received): +def lie_attack(baseline_weights,weights_received,deltas_received,num_attackers): """ - Attack name: Little is enough with Oblivion + Attack name: Little is enough - """ + Reference: - total_clients = Config().clients.total_clients - num_attackers = len(Config().clients.attacker_ids) + Baruch et al., "A little is enough: Circumventing defenses for distributed learning," in Proceedings of Advances in Neural Information Processing Systems (NeurIPS) 2019. + + https://proceedings.neurips.cc/paper_files/paper/2019/file/ec1c59141046cd1866bbbcdfb6ae31d4-Paper.pdf + """ + if num_attackers == 1 : + # LIE attack does not apply to solo attacker + return weights_received + else : + total_clients = Config().clients.per_round + #num_attackers = len(Config().clients.attacker_ids) + + attacker_weights = flatten_weights(weights_received) + attacker_deltas = flatten_weights(deltas_received) + baseline_weights = flatten_weight(baseline_weights) + + # Calculate perturbation range + s_value = total_clients / 2 + 1 - num_attackers + possibility = (total_clients - s_value) / total_clients + z_value = norm.cdf(possibility) + + deltas_avg = torch.mean(attacker_deltas, dim=0) + weights_std = -1 * torch.std(attacker_deltas, dim=0) + + # Calculate poisoning model + poison_value = z_value * weights_std + # Perform model poisoning + weights_poisoned = perform_model_poisoning(weights_received, poison_value) + logging.info(f"Finished LIE model poisoning attack.") + #poisoning_performance_evaluation(attacker_weights, poison_value, weights_poisoned) + return weights_poisoned + +def oblivion_lie_attack(weights_received, deltas_received,num_attackers): + """Little is enough""" + """ LIE importance dynamic momentum attack """ + total_clients = Config().clients.per_round + #num_attackers = len(Config().clients.attacker_ids) attacker_weights = flatten_weights(weights_received) + attacker_deltas = flatten_weights(deltas_received) # Calculate perturbation range s_value = total_clients / 2 + 1 - num_attackers @@ -162,7 +221,7 @@ def oblivion_lie_attack(weights_received): weights_std = torch.std(attacker_weights, dim=0) # Smooth benign model updates - weights_avg = smoothing("benign", weights_avg) + weights_avg=smoothing("benign", weights_avg) # Importance pruning sali_indicators_vector = compute_sali_indicator() @@ -171,39 +230,42 @@ def oblivion_lie_attack(weights_received): poison_value = weights_avg + z_value * weights_std # Smooth poison value - poison_value = smoothing("poisoned", poison_value) + poison_value=smoothing("poisoned", poison_value) # Perform model poisoning weights_poisoned = perform_model_poisoning(weights_received, poison_value) - logging.info(f"Finished LIE model poisoning attack (with Oblivion).") + logging.info(f"Finished LIE Oblivion model poisoning attack.") + #poisoning_performance_evaluation(attacker_weights, poison_value, weights_poisoned) return weights_poisoned +def min_max_attack(baseline_weights, weights_received, deltas_received, num_attackers ,dev_type="unit_vec"): + #attacker_weights = flatten_weights(weights_received) + attacker_weights = flatten_weights(deltas_received) -def min_max_attack(weights_received, dev_type="unit_vec"): - attacker_weights = flatten_weights(weights_received) - - weights_avg = torch.mean(attacker_weights, 0) + #weights_avg = torch.mean(attacker_weights, 0) + weights_avg = torch.mean(attacker_weights,0) # Generate perturbation vectors (Inverse unit vector by default) if dev_type == "unit_vec": # Inverse unit vector - perturbation_vector = weights_avg / torch.norm(weights_avg) + perturbation_vector = weights_avg / torch.norm( + weights_avg + ) elif dev_type == "sign": # Inverse sign perturbation_vector = torch.sign(weights_avg) elif dev_type == "std": # Inverse standard deviation - perturbation_vector = torch.std(attacker_weights, 0) - + perturbation_vector = torch.std(attacker_weights, 0) + # Calculate the maximum distance between any two benign updates (unpoisoned) max_distance = torch.tensor([0]) for attacker_weight in attacker_weights: distance = torch.norm((attacker_weights - attacker_weight), dim=1) ** 2 max_distance = torch.max(max_distance, torch.max(distance)) - - # Search for lambda such that its maximum distance from any other gradient is bounded - lambda_value = torch.Tensor([50.0]).float() + # Search for lambda such that its maximum distance from any other gradient is bounded + lambda_value = torch.Tensor([10000.0]).float() threshold = 1e-5 lambda_step = lambda_value lambda_succ = 0 @@ -221,18 +283,23 @@ def min_max_attack(weights_received, dev_type="unit_vec"): lambda_step = lambda_step / 2 - poison_value = weights_avg - lambda_succ * perturbation_vector - + #poison_value = weights_avg - lambda_succ * perturbation_vector + poison_value = -1 * lambda_succ * perturbation_vector # Perform model poisoning weights_poisoned = perform_model_poisoning(weights_received, poison_value) logging.info(f"Finished Min-Max model poisoning attack.") + #poisoning_performance_evaluation(attacker_weights, poison_value, weights_poisoned) return weights_poisoned - def oblivion_min_max_attack(weights_received, dev_type="unit_vec"): """ - Attack name: Min-max with Oblivion + Attack: Min-Max + + Reference: + + Shejwalkar et al., “Manipulating the Byzantine: Optimizing model poisoning attacks and defenses for federated learning,” in Proceedings of 28th Annual Network and Distributed System Security Symposium (NDSS), 2021 + https://www.ndss-symposium.org/ndss-paper/manipulating-the-byzantine-optimizing-model-poisoning-attacks-and-defenses-for-federated-learning/ """ attacker_weights = flatten_weights(weights_received) @@ -240,32 +307,34 @@ def oblivion_min_max_attack(weights_received, dev_type="unit_vec"): weights_avg = torch.mean(attacker_weights, 0) # Smooth benign model updates - weights_avg = smoothing("benign", weights_avg) + weights_avg=smoothing("benign", weights_avg) # Generate perturbation vectors (Inverse unit vector by default) if dev_type == "unit_vec": # Inverse unit vector - perturbation_vector = weights_avg / torch.norm(weights_avg) + perturbation_vector = weights_avg / torch.norm( + weights_avg + ) elif dev_type == "sign": # Inverse sign perturbation_vector = torch.sign(weights_avg) elif dev_type == "std": # Inverse standard deviation - perturbation_vector = torch.std(attacker_weights, 0) - + perturbation_vector = torch.std(attacker_weights, 0) + # Importance pruning sali_indicators_vector = compute_sali_indicator() perturbation_vector = perturbation_vector * sali_indicators_vector - + # Calculate the maximum distance between any two benign updates (unpoisoned) max_distance = torch.tensor([0]) for attacker_weight in attacker_weights: distance = torch.norm((attacker_weights - attacker_weight), dim=1) ** 2 max_distance = torch.max(max_distance, torch.max(distance)) - - # Search for lambda such that its maximum distance from any other gradient is bounded + + # Search for lambda such that its maximum distance from any other gradient is bounded lambda_value = torch.Tensor([50.0]).float() - threshold = 1e-5 + threshold = 1 #e-3 #1e-5 lambda_step = lambda_value lambda_succ = 0 @@ -285,17 +354,16 @@ def oblivion_min_max_attack(weights_received, dev_type="unit_vec"): poison_value = weights_avg - lambda_succ * perturbation_vector # Smooth poison value - poison_value = smoothing("poisoned", poison_value) + poison_value=smoothing("poisoned", poison_value) # Perform model poisoning weights_poisoned = perform_model_poisoning(weights_received, poison_value) logging.info(f"Finished Min-Max model poisoning attack (with Oblivion).") return weights_poisoned - -def min_sum_attack(weights_received, dev_type="unit_vec"): +def min_sum_attack(baseline_weights,weights_received, deltas_received, num_attackers, dev_type="unit_vec"): """ - Attack: Min-Max + Attack: Min-Sum Reference: @@ -305,69 +373,80 @@ def min_sum_attack(weights_received, dev_type="unit_vec"): """ attacker_weights = flatten_weights(weights_received) + attacker_deltas = flatten_weights(deltas_received) - weights_avg = torch.mean(attacker_weights, 0) + # deltas_avg = torch.mean(attacker_deltas, 0) + weights_avg = torch.mean(attacker_weights,0) # Generate perturbation vectors (Inverse unit vector by default) if dev_type == "unit_vec": # Inverse unit vector - perturbation_vector = weights_avg / torch.norm(weights_avg) + perturbation_vector = weights_avg / torch.norm( + weights_avg + ) elif dev_type == "sign": # Inverse sign perturbation_vector = torch.sign(weights_avg) elif dev_type == "std": # Inverse standard deviation - perturbation_vector = torch.std(attacker_weights, 0) + perturbation_vector = torch.std(attacker_deltas, 0) # Calculate the minimal sum of squared distances of benign update from the other benign updates - min_sum_distance = torch.tensor([0]) - for attacker_weight in attacker_weights: - distance = torch.norm((attacker_weights - attacker_weight), dim=1) ** 2 - min_sum_distance = torch.min(min_sum_distance, torch.sum(distance)) + min_sum_distance = torch.tensor([50000000000]) + for attacker_weight in attacker_deltas: + distance = torch.norm((attacker_deltas - attacker_weight), dim=1) ** 2 + min_sum_distance = torch.min(min_sum_distance,torch.sum(distance)) # Search for lambda - lambda_value = torch.Tensor([50.0]).float() + lambda_value = torch.Tensor([10000.0]).float() threshold = 1e-5 lambda_step = lambda_value lambda_succ = 0 while torch.abs(lambda_succ - lambda_value) > threshold: poison_value = weights_avg - lambda_value * perturbation_vector - distance = torch.norm((attacker_weights - poison_value), dim=1) ** 2 + distance = torch.norm((attacker_deltas - poison_value), dim=1) ** 2 score = torch.sum(distance) - if score <= min_sum_distance: lambda_succ = lambda_value lambda_value = lambda_value + lambda_step / 2 else: + lambda_succ = lambda_value # incase no succ + lambda_value = lambda_value - lambda_step / 2 lambda_step = lambda_step / 2 - - poison_value = weights_avg - lambda_succ * perturbation_vector - + poison_value = lambda_succ * perturbation_vector # perform model poisoning weights_poisoned = perform_model_poisoning(weights_received, poison_value) logging.info(f"Finished Min-Sum model poisoning attack.") + #poisoning_performance_evaluation(attacker_weights, poison_value, weights_poisoned) return weights_poisoned - def oblivion_min_sum_attack(weights_received, dev_type="unit_vec"): """ - Attack name: Min-sum with Oblivion + Attack: Min-Max + + Reference: + Shejwalkar et al., “Manipulating the Byzantine: Optimizing model poisoning attacks and defenses for federated learning,” in Proceedings of 28th Annual Network and Distributed System Security Symposium (NDSS), 2021 + + https://www.ndss-symposium.org/ndss-paper/manipulating-the-byzantine-optimizing-model-poisoning-attacks-and-defenses-for-federated-learning/ """ + attacker_weights = flatten_weights(weights_received) weights_avg = torch.mean(attacker_weights, 0) # Smooth benign model updates - weights_avg = smoothing("benign", weights_avg) + weights_avg=smoothing("benign", weights_avg) # Generate perturbation vectors (Inverse unit vector by default) if dev_type == "unit_vec": # Inverse unit vector - perturbation_vector = weights_avg / torch.norm(weights_avg) + perturbation_vector = weights_avg / torch.norm( + weights_avg + ) elif dev_type == "sign": # Inverse sign perturbation_vector = torch.sign(weights_avg) @@ -383,7 +462,7 @@ def oblivion_min_sum_attack(weights_received, dev_type="unit_vec"): min_sum_distance = torch.tensor([0]) for attacker_weight in attacker_weights: distance = torch.norm((attacker_weights - attacker_weight), dim=1) ** 2 - min_sum_distance = torch.min(min_sum_distance, torch.sum(distance)) + min_sum_distance = torch.min(min_sum_distance,torch.sum(distance)) # Search for lambda lambda_value = torch.Tensor([50.0]).float() @@ -407,7 +486,7 @@ def oblivion_min_sum_attack(weights_received, dev_type="unit_vec"): poison_value = weights_avg - lambda_succ * perturbation_vector # Smooth poison value - poison_value = smoothing("poisoned", poison_value) + poison_value=smoothing("poisoned", poison_value) # perform model poisoning weights_poisoned = perform_model_poisoning(weights_received, poison_value) @@ -418,12 +497,7 @@ def oblivion_min_sum_attack(weights_received, dev_type="unit_vec"): def compute_lambda(attacker_weights, global_model_last_round, num_attackers): """Compute the lambda value for fang's attack.""" distances = [] - ( - num_benign_clients, - d, - ) = ( - attacker_weights.shape - ) # impractical, not sure how many benign clients are included. + num_benign_clients, d = attacker_weights.shape # impractical, not sure how many benign clients are included. for weight in attacker_weights: distance = torch.norm((attacker_weights - weight), dim=1) @@ -504,7 +578,7 @@ def fang_attack(weights_received): attacker_weights = flatten_weights(weights_received) weights_avg = torch.mean(attacker_weights, 0) - global_model_last_round = weights_avg # ? + global_model_last_round = weights_avg #? lambda_value = compute_lambda( attacker_weights, global_model_last_round, num_attackers ) @@ -543,10 +617,13 @@ def fang_attack(weights_received): registered_attacks = { "LIE": lie_attack, - "Oblivison-lie": oblivion_lie_attack, + "Oblivion-lie":oblivion_lie_attack, "Min-Max": min_max_attack, - "Oblivision-minmax": oblivion_min_max_attack, + "Oblivion-minmax":oblivion_min_max_attack, "Min-Sum": min_sum_attack, - "Oblivion-minsum": oblivion_min_sum_attack, + "Oblivion-minsum":oblivion_min_sum_attack, "Fang": fang_attack, + "Lambda_attack": lambda_attack, + "Gassian_attack": gassian_attack, + } diff --git a/examples/detector/defences.py b/examples/detector/defences.py index 2427c0cff..58f179791 100644 --- a/examples/detector/defences.py +++ b/examples/detector/defences.py @@ -1,11 +1,11 @@ import torch import logging from plato.config import Config -from scipy.stats import norm +from collections import OrderedDict +import numpy as np registered_defences = {} - def get(): defence_type = ( @@ -20,6 +20,92 @@ def get(): if defence_type in registered_defences: registered_defence = registered_defences[defence_type] + logging.info(f"Clients perform {defence_type} defence.") return registered_defence raise ValueError(f"No such defence: {defence_type}") + +def flatten_weights(weights): + flattened_weights = [] + + for weight in weights: + flattened_weight = [] + for name in weight.keys(): + flattened_weight = ( + weight[name].view(-1) + if not len(flattened_weight) + else torch.cat((flattened_weight, weight[name].view(-1))) + ) + + flattened_weights = ( + flattened_weight[None, :] + if not len(flattened_weights) + else torch.cat((flattened_weights, flattened_weight[None, :]), 0) + ) + return flattened_weights + +def median(weights_attacked): + """Aggregate weight updates from the clients using median.""" + """ + + deltas_received = self.compute_weight_deltas(updates) + reports = [report for (__, report, __, __) in updates] + clients_id = [client for (client, __, __, __) in updates] + + # The number of malicious clients is known to the server. + # This setting seems unreasonable + n_attackers = 0 + for client_id in clients_id: + if client_id <= self.n_attackers: + n_attackers = n_attackers + 1 + + name_list = [] + for name, delta in deltas_received[0].items(): + name_list.append(name) + + all_deltas_vector = [] + for i, delta_received in enumerate(deltas_received): + delta_vector = [] + for name in name_list: + delta_vector = ( + delta_received[name].view(-1) + if not len(delta_vector) + else torch.cat((delta_vector, delta_received[name].view(-1))) + ) + all_deltas_vector = ( + delta_vector[None, :] + if not len(all_deltas_vector) + else torch.cat((all_deltas_vector, delta_vector[None, :]), 0) + ) + + n_clients = all_deltas_vector.shape[0] + logging.info("[%s] n_clients: %d", self, n_clients) + logging.info("[%s] n_attackers: %d", self, n_attackers) + """ + weights_attacked = flatten_weights(weights_attacked) + + median_delta_vector = torch.median(weights_attacked, dim=0)[0] + # name list? + #median_update = { + # name: self.trainer.zeros(weights.shape) + # for name, weights in deltas_received[0].items() + #} + + start_index = 0 + median_update = OrderedDict() + + for weight in weights_attacked: + for name in weight.keys(): + median_update[name] = median_delta_vector[ + start_index : start_index + len(median_update[name].view(-1)) + ].reshape(median_update[name].shape) + start_index = start_index + len(median_update[name].view(-1)) + + #for name in name_list: + #median_update[name] = median_delta_vector[ + #start_index : start_index + len(median_update[name].view(-1)) + #].reshape(median_update[name].shape) + #start_index = start_index + len(median_update[name].view(-1)) + logging.info(f"Finished Median server aggregation.") + return median_update + diff --git a/examples/detector/detector.py b/examples/detector/detector.py index 1ea324cb4..d6e7d3522 100644 --- a/examples/detector/detector.py +++ b/examples/detector/detector.py @@ -2,7 +2,7 @@ An implementation of the attack-defence scenario. """ -from plato.servers.fedavg import Server +from detector_server import Server from plato.clients.simple import Client from plato.trainers.basic import Trainer diff --git a/examples/detector/detector_client.py b/examples/detector/detector_client.py deleted file mode 100644 index b8d6b8d8f..000000000 --- a/examples/detector/detector_client.py +++ /dev/null @@ -1,41 +0,0 @@ -""" -A federated learning client that is capable to perform model poisoning attacks. -""" - -import logging -import os -import pickle - -from plato.clients import simple -from plato.config import Config -from plato.utils import fonts - -import attacks - - -class Client(simple.Client): - """A client who is able to perform model poisoning attack""" - - def __init__( - self, - model=None, - datasource=None, - algorithm=None, - trainer=None, - callbacks=None, - trainer_callbacks=None, - ): - super().__init__( - model, datasource, algorithm, trainer, callbacks, trainer_callbacks - ) - self.is_attacker = None - self.attack_type = None - - def configure(self) -> None: - """Initialize the attack related parameter""" - super().configure() - - self.is_attacker = self.client_id in Config().clients.attacker_ids - self.attack_type = Config().clients.attack_type - - diff --git a/examples/detector/detector_server.py b/examples/detector/detector_server.py index 52760c757..fae53cbf7 100644 --- a/examples/detector/detector_server.py +++ b/examples/detector/detector_server.py @@ -3,18 +3,18 @@ """ import logging - +import os from plato.config import Config from plato.servers import fedavg from collections import OrderedDict import attacks as attack_registry -import defences as defence_registry +import detectors as defence_registry import aggregations as aggregation_registry import numpy as np import torch import defences - +import csv from typing import Mapping class Server(fedavg.Server): def __init__( @@ -29,6 +29,8 @@ def __init__( ) self.attacker_list = None self.attack_type = None + self.blacklist = [] + self.pre_blacklist = [] def configure(self): """Initialize defence related parameter""" @@ -37,12 +39,21 @@ def configure(self): self.attacker_list = [ int(value) for value in Config().clients.attacker_ids.split(",") ] - self.attack_type = Config().clients.attack_type + self.attack_type = ( + Config().clients.attack_type + if hasattr(Config().clients, "attack_type") + else None + ) logging.info(f"self.attacker_ids: %s", self.attacker_list) logging.info(f"attack_type: %s", self.attack_type) def choose_clients(self, clients_pool, clients_count): + # remove clients in blacklist from available clients pool + #logging.info(f"len of clients pool before removal: %d", len(clients_pool)) + clients_pool = list(filter(lambda x: x not in self.blacklist, clients_pool)) + #logging.info(f"len of cliets pool after removal: %d", len(clients_pool)) + selected_clients = super().choose_clients(clients_pool, clients_count) # recording how many attackers are selected this round to track the defence performance @@ -74,35 +85,95 @@ def weights_attacked(self, weights_received): if update.client_id in self.attacker_list: attacker_weights.append(weight) - # Attacker server perform attack based on attack type - attack = attack_registry.get() - weights_attacked = attack( - attacker_weights - ) # weights and updates are different, think about which is more convenient? - - # Put poisoned model back to weights received for further aggregation - counter = 0 - for i, update in enumerate(self.updates): - if update.client_id in self.attacker_list: - weights_received[i] = weights_attacked[counter] - counter += 1 + # Extract model updates + baseline_weights = self.algorithm.extract_weights() + deltas_received = self.algorithm.compute_weight_deltas( + baseline_weights, attacker_weights + ) + # Get attackers selected at this round + received_ids = [update.client_id for update in self.updates] + num_attackers = len([i for i in received_ids if i in self.attacker_list]) + + if num_attackers > 0: + # Attacker server perform attack based on attack type + attack = attack_registry.get() + weights_attacked = attack( + baseline_weights, attacker_weights, deltas_received, num_attackers + ) # weights and updates are different, think about which is more convenient? + + # Put poisoned model back to weights received for further aggregation + counter = 0 + for i, update in enumerate(self.updates): + if update.client_id in self.attacker_list: + weights_received[i] = weights_attacked[counter] + counter += 1 + return weights_received + + def detect_analysis(self, detected_malicious_ids, received_ids): + "print out detect accuracy, positive rate and negative rate" + logging.info(f"detected ids: %s", detected_malicious_ids) + real_malicious_ids = [i for i in received_ids if i in self.attacker_list] + logging.info(f"real attackers id: %s", real_malicious_ids) + if len(real_malicious_ids) != 0: + correct = 0 + wrong = 0 + for i in detected_malicious_ids: + if i in real_malicious_ids: + correct += 1 + logging.info(f"correctly detectes attacker %d", i) + else: + wrong += 1 + logging.info(f"wrongly classify benign client %i into attacker",i) + detection_accuracy = correct / (len(real_malicious_ids) * 1.0) + with open('detection_accuracy.csv', 'a', newline='') as file: + writer = csv.writer(file) + writer.writerow([detection_accuracy]) + logging.info(f"detection_accuracy is: %.2f",detection_accuracy) + logging.info(f"Missing %d attackers.",len(real_malicious_ids)*1.0 - correct ) + logging.info(f"falsely classified %d clients: ", wrong) def weights_filter(self, weights_attacked): # Identify poisoned updates and remove it from all received updates. defence = defence_registry.get() + if defence is None: + return weights_attacked - weights_approved = defence(weights_attacked) - # get a balck list for attackers_detected this round - - # Remove identified attacker from clients pool. Never select that client again. - # for attacker in attackers_detected: - # self.clients_pool.remove(attacker) + # Extract the current model updates (deltas) + baseline_weights = self.algorithm.extract_weights() + deltas_attacked = self.algorithm.compute_weight_deltas( + baseline_weights, weights_attacked + ) + received_ids = [update.client_id for update in self.updates] + received_staleness = [update.staleness for update in self.updates] + malicious_ids, weights_approved = defence(baseline_weights, weights_attacked, deltas_attacked,received_ids,received_staleness) + + ids = [received_ids[i] for i in malicious_ids] + + cummulative_detect = 0 + for id_temp in self.blacklist: + + if id_temp in self.attacker_list: + cummulative_detect += 1 + #logging.info(f"cummulative detect: %d",cummulative_detect) + + #logging.info(f"Cumulative detection: %.2f", (cummulative_detect) * 1.0 / len(self.attacker_list)) + #logging.info(f"Mistakenly classfied: %d benign clients so far.", (len(self.blacklist)-cummulative_detect)) + #logging.info(f"Blacklist is: %s", self.blacklist) + """ + self.blacklist[name].append() + # Remove identified attacker from client pool. Never select that client again. + for i in ids: + self.clients_pool.remove(i) + logging.info(f"Remove attacker %d from available client pool.", i) + """ + # Analyze detection performance. + # self.detect_analysis(ids, received_ids) return weights_approved - + async def aggregate_weights(self, updates,baseline_weights, weights_received): """Aggregate the reported weight updates from the selected clients.""" @@ -114,10 +185,11 @@ async def aggregate_weights(self, updates,baseline_weights, weights_received): deltas = await self.aggregate_deltas(self.updates, deltas_received) updated_weights = self.algorithm.update_weights(deltas) return updated_weights - + # if secure aggregation is applied. aggregation = aggregation_registry.get() weights_aggregated = aggregation(updates, baseline_weights, weights_received) return weights_aggregated + diff --git a/examples/detector/detectors.py b/examples/detector/detectors.py new file mode 100644 index 000000000..a3ba0eea6 --- /dev/null +++ b/examples/detector/detectors.py @@ -0,0 +1,423 @@ +import torch +import os +import logging +from plato.config import Config +import numpy as np +from sklearn.cluster import KMeans, AgglomerativeClustering +import pickle +import os +import torch.nn.functional as F +from collections import OrderedDict + +# Configure logging +logging.basicConfig(filename='app.log', filemode='w', level=logging.INFO) + +def get(): + + detector_type = ( + Config().server.detector_type + if hasattr(Config().server, "detector_type") + else None + ) + + if detector_type is None: + logging.info("No defence is applied.") + return None + + if detector_type in registered_detectors: + registered_defence = registered_detectors[detector_type] + logging.info(f"Clients perform {detector_type} attack.") + return registered_defence + + raise ValueError(f"No such defence: {detector_type}") + + +def flatten_weights(weights): + flattened_weights = [] + + for weight in weights: + flattened_weight = [] + for name in weight.keys(): + flattened_weight = ( + weight[name].view(-1) + if not len(flattened_weight) + else torch.cat((flattened_weight, weight[name].view(-1))) + ) + + flattened_weights = ( + flattened_weight[None, :] + if not len(flattened_weights) + else torch.cat((flattened_weights, flattened_weight[None, :]), 0) + ) + return flattened_weights + + +def flatten_weight(weight): + + flattened_weight = [] + for name in weight.keys(): + flattened_weight = ( + weight[name].view(-1) + if not len(flattened_weight) + else torch.cat((flattened_weight, weight[name].view(-1))) + ) + return flattened_weight + + +def lbfgs(weights_attacked, global_weights_record, gradients_record, last_weights): + # Approximate integrated Hessian value + # Transfer lists of tensor into tensor matrix + global_weights = torch.stack(global_weights_record) + gradients = torch.stack(gradients_record) + global_times_gradients = torch.matmul(global_weights, gradients.T) + global_times_global = torch.matmul(global_weights, global_weights.T) + + # Get its diagonal matrix and lower triangular submatrix + R_k = np.triu(global_times_gradients.numpy()) + L_k = global_times_gradients - torch.tensor(R_k) + + # Step 3 in Algorithm 1 + sigma_k = torch.matmul( + torch.transpose(global_weights_record[-1], 0, -1), gradients_record[-1] + ) / ( + torch.matmul(torch.transpose(gradients_record[-1], 0, -1), gradients_record[-1]) + ) + D_k_diag = torch.diag(global_times_gradients) + + upper_mat = torch.cat(((sigma_k * global_times_global), L_k), dim=1) + lower_mat = torch.cat((L_k.T, -torch.diag(D_k_diag)), dim=1) + mat = torch.cat((upper_mat, lower_mat), dim=0) + mat_inv = torch.inverse(mat) + + v = ( + weights_attacked - last_weights + ) # deltas_attacked from selected clients + v = torch.mean(v, dim=0) + approx_prod = sigma_k * v + p_mat = torch.cat( + (torch.matmul(global_weights, (sigma_k * v).T), torch.matmul(gradients, v.T)), dim=0 + ) + approx_prod -= torch.matmul( + torch.matmul( + torch.cat((sigma_k * global_weights.T, gradients.T), dim=1), mat_inv + ), + p_mat, + ).T + return approx_prod + + +def gap_statistics(score): + + nrefs = 10 + ks = range(1, 3) + gaps = np.zeros(len(ks)) + gapDiff = np.zeros(len(ks) - 1) + sdk = np.zeros(len(ks)) + min = np.min(score) + max = np.max(score) + 1 #! + score = (score - min) / (max - min) + for i, k in enumerate(ks): + estimator = KMeans(n_clusters=k) + estimator.fit(score.reshape(-1, 1)) + label_pred = estimator.labels_ + center = estimator.cluster_centers_ + Wk = np.sum( + [np.square(score[m] - center[label_pred[m]]) for m in range(len(score))] + ) + + WkRef = np.zeros(nrefs) + for j in range(nrefs): + rand = np.random.uniform(0, 1, len(score)) + estimator = KMeans(n_clusters=k) + estimator.fit(rand.reshape(-1, 1)) + label_pred = estimator.labels_ + center = estimator.cluster_centers_ + WkRef[j] = np.sum( + [np.square(rand[m] - center[label_pred[m]]) for m in range(len(rand))] + ) + gaps[i] = np.log(np.mean(WkRef)) - np.log(Wk) + sdk[i] = np.sqrt((1.0 + nrefs) / nrefs) * np.std(np.log(WkRef)) + + if i > 0: + gapDiff[i - 1] = gaps[i - 1] - gaps[i] + sdk[i] + + select_k = None + for i in range(len(gapDiff)): + if gapDiff[i] >= 0: + select_k = i + 1 + break + if select_k == 1: + print("No attack detected!") + return 0 + else: + print("Attack Detected!") + return 1 + + +def detection(score): + estimator = KMeans(n_clusters=2) + estimator.fit(score.reshape(-1, 1)) + label_pred = estimator.labels_ + # Print the members in each cluster + for cluster in np.unique(label_pred): + cluster_members = score[label_pred == cluster] + logging.info(f"Cluster {cluster + 1} members: %s",cluster_members) + if np.mean(score[label_pred == 0]) < np.mean(score[label_pred == 1]): + # cluster with smaller mean value is clean clients + clean_ids = np.where(label_pred == 0)[0] + malicious_ids = np.where(label_pred == 1)[0] + else: + clean_ids = np.where(label_pred == 1)[0] + malicious_ids = np.where(label_pred == 0)[0] + #logging.info(f"clean_ids: %s", clean_ids) + return malicious_ids, clean_ids + +def detection_cos(score): + estimator = KMeans(n_clusters=3) + estimator.fit(score.reshape(-1, 1)) + label_pred = estimator.labels_ + # Print the members in each cluster + for cluster in np.unique(label_pred): + cluster_members = score[label_pred == cluster] + logging.info(f"Cluster {cluster + 1} members: %s",cluster_members) + if ((np.mean(score[label_pred == 0]) > np.mean(score[label_pred == 1])) and (np.mean(score[label_pred == 0]) > np.mean(score[label_pred == 2]))): + # cluster with larger value is attacker + clean_ids = np.concatenate((np.where(label_pred == 1)[0], np.where(label_pred == 2)[0])) + malicious_ids = np.where(label_pred == 0)[0] + elif ((np.mean(score[label_pred == 1]) > np.mean(score[label_pred == 0])) and (np.mean(score[label_pred == 1]) > np.mean(score[label_pred == 2]))): + clean_ids = np.concatenate((np.where(label_pred == 0)[0], np.where(label_pred == 2)[0])) + malicious_ids = np.where(label_pred == 1)[0] + else: + clean_ids = np.concatenate((np.where(label_pred == 1)[0], np.where(label_pred == 0)[0])) + malicious_ids = np.where(label_pred == 2)[0] + + return malicious_ids, clean_ids + +def pre_data_for_visualization(deltas_attacked, received_staleness): + # saved received local deltas for round x + #logging.info(f"starting preparing data for visualization") + flattened_deltas_attacked = flatten_weights(deltas_attacked) + # list to torch tensor + received_staleness = torch.tensor(received_staleness) + + model_path = Config().params["model_path"] + model_name = Config().trainer.model_name + + try: + if not os.path.exists(model_path): + os.makedirs(model_path) + except FileExistsError: + pass + + try: + # List all files and directories in the given folder + items = os.listdir(model_path) + + # Count the number of files (ignoring directories) + file_count = sum(1 for item in items if os.path.isfile(os.path.join(model_path, item))) + + file_count = str(file_count + 1) # plus one so can be directly used in the following code when create folder for each communication round + + except Exception as e: + pass + + #logging.info(f"saving reveived deltas...") + file_path = f"{model_path}/"+ file_count + ".pkl" + with open(file_path, "wb") as file: + pickle.dump(flattened_deltas_attacked, file) + pickle.dump(received_staleness,file) + + + logging.info("[Server #%d] Model saved to %s at round %s.", os.getpid(), file_path, file_count) + + + +def async_filter(baseline_weights,weights_attacked,deltas_attacked,received_ids,received_staleness): + # first group clients based on their staleness + staleness_bound = Config().server.staleness_bound + flattened_weights_attacked = flatten_weights(weights_attacked) + + # only for visualization + #pre_data_for_visualization(deltas_attacked, received_staleness) + + file_path = "./async_record"+ str(os.getpid()) + ".pkl" + if os.path.exists(file_path): + logging.info(f"loading parameters from file.") + with open(file_path, "rb") as file: + global_weights_record = pickle.load(file) + global_num_record = pickle.load(file) + else: + global_weights_record = [] + global_num_record = [] + + weight_groups = {i: [] for i in range(20)} + id_groups= {i:[] for i in range(20)} + for i, (staleness,weights) in enumerate(zip (received_staleness, flattened_weights_attacked)): + weight_groups[staleness].append(weights) + id_groups[staleness].append(i) + + # calcuate cos_similarity within a group and identify statistical outliers + all_mali_ids = [] + avg_current = torch.zeros_like(torch.mean(weights,dim=0)) + num_current = 0 + for staleness, weights in weight_groups.items(): + if len(weights)>=3: + weights = torch.stack(weights) + # find out avg at the same round + if staleness == 0: + avg = torch.mean(weights,dim=0) + 1e-10 + avg_current = avg + num_current = len(weight_groups[0]) + else: + avg = (global_weights_record[-1*staleness]*global_num_record[-1*staleness] + torch.mean(weights,dim=0)*len(weight_groups[staleness])) / (global_num_record[-1*staleness]+len(weight_groups[staleness])) + # update record + global_weights_record[-1*staleness] = avg + global_num_record[-1*staleness] += len(weight_groups[staleness]) + + similarity = F.cosine_similarity(weights, avg).numpy() + logging.info(f"Group %d cosine similarity: %s", staleness, similarity) + # whether or not to normalization is a question + + distance = torch.norm((avg - weights), dim=1).numpy() + + distance = distance / np.sum(distance) #normalization + + logging.info("applying 3 clustering for comparison") + malicious_ids2, clean_ids = detection_cos(distance) + + malicious_ids=malicious_ids2 + + malicious_ids = [id_groups[staleness][id] for id in malicious_ids.tolist()] + logging.info(f"malicious in this group: %s", malicious_ids) + all_mali_ids.extend(malicious_ids) + + # save into local file + global_weights_record.append(avg_current) + global_num_record.append(num_current) + file_path = "./async_record"+ str(os.getpid()) + ".pkl" + with open(file_path, "wb") as file: + pickle.dump(global_weights_record, file) + pickle.dump(global_num_record,file) + + # remove suspecious weights + clean_weights = [] + for i, weight in enumerate(weights_attacked): + if i not in all_mali_ids: + clean_weights.append(weight) + + return all_mali_ids,clean_weights + +def fl_detector(baseline_weights, weights_attacked, deltas_attacked, received_ids,received_staleness): + #https://arxiv.org/pdf/2207.09209.pdf + # torch.set_printoptions(threshold=10**8) + #malicious_ids_list = [] # for test case only, will be remove after finished + + clean_weights = weights_attacked + + flattened_weights_attacked = flatten_weights(weights_attacked) + flattened_baseline_weights = flatten_weight(baseline_weights) + flattened_deltas_attacked = flatten_weights(deltas_attacked) + id_temp = [x - 1 for x in received_ids] + local_update_current = torch.stack([x[1] for x in sorted(zip(id_temp, flattened_deltas_attacked))]) + + window_size = 0 + malicious_ids = [] + + file_path = "./record"+ str(os.getpid()) + ".pkl" + if os.path.exists(file_path): + logging.info(f"loading parameters from file.") + with open(file_path, "rb") as file: + # download dict from file + # below are records for fldetector + global_weights_record = pickle.load(file) + gradients_record = pickle.load(file) + last_weights = pickle.load(file) + last_gradients = pickle.load(file) + malicious_score_dict = pickle.load(file) + + # get weights for received clients at this round only + malicious_score = [] + #last_move = [] + for received_id in received_ids: + #last_move.append(last_move_dict[received_id]) # for avg + malicious_score.append(list(filter(lambda x: x is not None, malicious_score_dict[received_id]))) # for fldetector + + moving_avg = torch.mean(flattened_deltas_attacked,dim=0) + 1e-10 # for avg + + + if len(global_weights_record) >= window_size + 1: + + """Below are fldetector""" + # Make predication by Cauchy mean value theorem in fldetector + hvp = lbfgs( + flattened_weights_attacked, + global_weights_record, + gradients_record, + last_weights, + ) + + pred_grad = torch.add(last_gradients, hvp) + + # Calculate distance for scoring + distance1 = torch.norm( + (pred_grad - flattened_deltas_attacked), dim=1 + ).numpy() + logging.info(f"distance in fldetectors before normalization: %s", distance1) + # Normalize distance + distance1 = distance1 / np.sum(distance1) + logging.info(f"the fldetector distance after normalization: %s", distance1) + + # add new distance score into malicious score record + # moving averaging + malicious_score_current = [] + for scores, dist in zip(malicious_score, distance1): + scores.append(dist) + score = sum(scores)/len(scores) + malicious_score_current.append(score) + logging.info(f"the malicious score current round: %s", malicious_score_current) + # cluserting and detection (smaller score represents benign clients) + malicious_ids, clean_ids = detection( + np.array(malicious_score_current) + ) # np.sum(malicious_score[-10:], axis=0)) + logging.info(f"fldetector malicious ids: %s", malicious_ids) + + clean_weights = [] + for i, weight in enumerate(weights_attacked): + if i not in malicious_ids: + clean_weights.append(weight) + + else: + logging.info(f"initializing fl parameter record") + global_weights_record = [] + gradients_record = [] + last_gradients = torch.zeros(len(flattened_baseline_weights)) + last_weights = torch.zeros(len(flattened_baseline_weights)) + malicious_score_dict = {client_id + 1: [] for client_id in range(Config().clients.total_clients)}#len(received_ids)*[0] + distance1 = len(received_ids)*[None] # none to keep update consistent + + # update record + for index, received_id in enumerate(received_ids): #self_consistency_pre): + malicious_score_dict[received_id].append(distance1[index]) # distance should be initialized (this one is for fldetector) + + global_weights_record.append(flattened_baseline_weights - last_weights) + gradients_record.append( + torch.mean(flattened_deltas_attacked, dim=0) - last_gradients + ) + last_weights = flattened_baseline_weights + last_gradients= torch.mean(flattened_deltas_attacked, dim=0) + + # save into local file + file_path = "./record"+ str(os.getpid()) + ".pkl" + with open(file_path, "wb") as file: + pickle.dump(global_weights_record, file) + pickle.dump(gradients_record, file) + pickle.dump(last_weights, file) + pickle.dump(last_gradients, file) + pickle.dump(malicious_score_dict, file) + logging.info(f"malicious_ids: %s", malicious_ids) + return malicious_ids, clean_weights + +registered_detectors = { + "FLDetector": fl_detector, + "AsyncFilter":async_filter, +}