-
Notifications
You must be signed in to change notification settings - Fork 0
/
dnc.py
56 lines (45 loc) · 1.66 KB
/
dnc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import numpy as np
from typing import List, Tuple
from flwr.common import (
FitRes,
parameters_to_ndarrays,
)
from flwr.server.client_proxy import ClientProxy
import flwr as fl
class Dnc():
r"""A robust aggregator from paper `Manipulating the Byzantine: Optimizing
Model Poisoning Attacks and Defenses for Federated Learning.
<https://par.nsf.gov/servlets/purl/10286354>`_.
"""
def __init__(
self, num_byzantine, *, sub_dim=10000, num_iters=1, filter_frac=1.0
) -> None:
super(Dnc, self).__init__()
self.num_byzantine = num_byzantine
self.sub_dim = sub_dim
self.num_iters = num_iters
self.filter_frac = filter_frac
def dnc_aggregate(
self,
results: List[Tuple[ClientProxy, FitRes]]
):
updates = np.stack([parameters_to_ndarrays(result[1].parameters) for result in results])
d = len(results)
benign_ids = []
for _ in range(self.num_iters):
indices = np.random.permutation(d)[: self.sub_dim]
sub_updates = updates[:, indices]
mu = np.mean(sub_updates, axis=0)
centered_update = sub_updates - mu
_, _, v = np.linalg.svd(centered_update, full_matrices=False)
v = v[0, :]
s = np.array(
[np.dot(update - mu, v) ** 2 for update in sub_updates]
)
good = np.argsort(s)[
: len(updates) - int(self.filter_frac * self.num_byzantine)
]
benign_ids.extend(good)
benign_ids = list(set(benign_ids))
benign_updates = np.mean(updates[benign_ids, :], axis=0)
return benign_updates