Skip to content

Commit

Permalink
Implemented limited merge Danner et al. 2023
Browse files Browse the repository at this point in the history
  • Loading branch information
makgyver committed Dec 23, 2023
1 parent 0cfe326 commit ca86164
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 0 deletions.
51 changes: 51 additions & 0 deletions gossipy/model/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,3 +687,54 @@ def _merge(self,
# Gets the maximum number of updates from the merged models
self.n_updates = max(self.n_updates, n_up)

class LimitedMergeMixin():

def __init__(self, age_diff_threshold: int=1):
self.L = age_diff_threshold

def _merge(self, other_model_handler: Union[TorchModelHandler, Iterable[TorchModelHandler]]) -> None:
dict_params1 = self.model.state_dict()

if isinstance(other_model_handler, TorchModelHandler):
dict_params2 = other_model_handler.model.state_dict()
n_up = other_model_handler.n_updates
else:
raise ValueError("Invalid type for other_model_handler: %s" %type(other_model_handler))

if self.n_updates > n_up + self.L:
self.model.load_state_dict(dict_params1)
elif n_up > self.n_updates + self.L:
self.model.load_state_dict(dict_params2)
else:
div = self.n_updates + n_up
for key in dict_params1:
dict_params1[key] = (self.n_updates / div) * dict_params1[key] + (n_up / div) * dict_params2[key]

self.model.load_state_dict(dict_params1)

self.n_updates = max(self.n_updates, n_up)


# Danner et al. Improving Gossip Learning via Limited Model Merging (ICCCI 2023)
class LimitedMergeTMH(LimitedMergeMixin, TorchModelHandler):
def __init__(self,
net: TorchModel,
optimizer: torch.optim.Optimizer,
optimizer_params: Dict[str, Any],
criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
local_epochs: int=1,
batch_size: int=32,
create_model_mode: CreateModelMode=CreateModelMode.MERGE_UPDATE,
age_diff_threshold: int=1,
copy_model=True):
LimitedMergeMixin.__init__(self, age_diff_threshold)
TorchModelHandler.__init__(self,
net,
optimizer,
optimizer_params,
criterion,
local_epochs,
batch_size,
create_model_mode,
copy_model)

64 changes: 64 additions & 0 deletions main_danner_2023.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch
from networkx import to_numpy_array
from networkx.generators.random_graphs import random_regular_graph
from gossipy import set_seed
from gossipy.core import UniformDelay, AntiEntropyProtocol, CreateModelMode, StaticP2PNetwork
from gossipy.node import GossipNode
from gossipy.model.handler import LimitedMergeTMH
from gossipy.model.nn import LogisticRegression
from gossipy.data import load_classification_dataset, DataDispatcher
from gossipy.data.handler import ClassificationDataHandler
from gossipy.simul import GossipSimulator, SimulationReport
from gossipy.utils import plot_evaluation

# AUTHORSHIP
__version__ = "0.0.1"
__author__ = "Mirko Polato"
__copyright__ = "Copyright 2022, gossipy"
__license__ = "MIT"
__maintainer__ = "Mirko Polato, PhD"
__email__ = "[email protected]"
__status__ = "Development"
#


set_seed(98765)
X, y = load_classification_dataset("spambase", as_tensor=True)
data_handler = ClassificationDataHandler(X, y, test_size=.1)
dispatcher = DataDispatcher(data_handler, n=100, eval_on_user=False, auto_assign=True)
topology = StaticP2PNetwork(100, to_numpy_array(random_regular_graph(20, 100, seed=42)))
net = LogisticRegression(data_handler.Xtr.shape[1], 2)

nodes = GossipNode.generate(
data_dispatcher=dispatcher,
p2p_net=topology,
model_proto=LimitedMergeTMH(
net=net,
optimizer=torch.optim.SGD,
optimizer_params={
"lr": 1,
"weight_decay": .001
},
criterion=torch.nn.CrossEntropyLoss(),
),
round_len=100,
sync=True
)

simulator = GossipSimulator(
nodes=nodes,
data_dispatcher=dispatcher,
delta=100,
protocol=AntiEntropyProtocol.PUSH,
delay=UniformDelay(0,10),
online_prob=.2, #Approximates the average online rate of the STUNner's smartphone traces
drop_prob=.1, #Simulate the possibility of message dropping,
sampling_eval=.1
)

report = SimulationReport()
simulator.add_receiver(report)
simulator.init_nodes(seed=42)
simulator.start(n_rounds=1000)

plot_evaluation([[ev for _, ev in report.get_evaluation(False)]], "Overall test results")

0 comments on commit ca86164

Please sign in to comment.