diff --git a/gossipy/model/handler.py b/gossipy/model/handler.py index 413f648..863fd77 100644 --- a/gossipy/model/handler.py +++ b/gossipy/model/handler.py @@ -687,3 +687,54 @@ def _merge(self, # Gets the maximum number of updates from the merged models self.n_updates = max(self.n_updates, n_up) +class LimitedMergeMixin(): + + def __init__(self, age_diff_threshold: int=1): + self.L = age_diff_threshold + + def _merge(self, other_model_handler: Union[TorchModelHandler, Iterable[TorchModelHandler]]) -> None: + dict_params1 = self.model.state_dict() + + if isinstance(other_model_handler, TorchModelHandler): + dict_params2 = other_model_handler.model.state_dict() + n_up = other_model_handler.n_updates + else: + raise ValueError("Invalid type for other_model_handler: %s" %type(other_model_handler)) + + if self.n_updates > n_up + self.L: + self.model.load_state_dict(dict_params1) + elif n_up > self.n_updates + self.L: + self.model.load_state_dict(dict_params2) + else: + div = self.n_updates + n_up + for key in dict_params1: + dict_params1[key] = (self.n_updates / div) * dict_params1[key] + (n_up / div) * dict_params2[key] + + self.model.load_state_dict(dict_params1) + + self.n_updates = max(self.n_updates, n_up) + + +# Danner et al. Improving Gossip Learning via Limited Model Merging (ICCCI 2023) +class LimitedMergeTMH(LimitedMergeMixin, TorchModelHandler): + def __init__(self, + net: TorchModel, + optimizer: torch.optim.Optimizer, + optimizer_params: Dict[str, Any], + criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + local_epochs: int=1, + batch_size: int=32, + create_model_mode: CreateModelMode=CreateModelMode.MERGE_UPDATE, + age_diff_threshold: int=1, + copy_model=True): + LimitedMergeMixin.__init__(self, age_diff_threshold) + TorchModelHandler.__init__(self, + net, + optimizer, + optimizer_params, + criterion, + local_epochs, + batch_size, + create_model_mode, + copy_model) + \ No newline at end of file diff --git a/main_danner_2023.py b/main_danner_2023.py new file mode 100644 index 0000000..ce2afc2 --- /dev/null +++ b/main_danner_2023.py @@ -0,0 +1,64 @@ +import torch +from networkx import to_numpy_array +from networkx.generators.random_graphs import random_regular_graph +from gossipy import set_seed +from gossipy.core import UniformDelay, AntiEntropyProtocol, CreateModelMode, StaticP2PNetwork +from gossipy.node import GossipNode +from gossipy.model.handler import LimitedMergeTMH +from gossipy.model.nn import LogisticRegression +from gossipy.data import load_classification_dataset, DataDispatcher +from gossipy.data.handler import ClassificationDataHandler +from gossipy.simul import GossipSimulator, SimulationReport +from gossipy.utils import plot_evaluation + +# AUTHORSHIP +__version__ = "0.0.1" +__author__ = "Mirko Polato" +__copyright__ = "Copyright 2022, gossipy" +__license__ = "MIT" +__maintainer__ = "Mirko Polato, PhD" +__email__ = "mak1788@gmail.com" +__status__ = "Development" +# + + +set_seed(98765) +X, y = load_classification_dataset("spambase", as_tensor=True) +data_handler = ClassificationDataHandler(X, y, test_size=.1) +dispatcher = DataDispatcher(data_handler, n=100, eval_on_user=False, auto_assign=True) +topology = StaticP2PNetwork(100, to_numpy_array(random_regular_graph(20, 100, seed=42))) +net = LogisticRegression(data_handler.Xtr.shape[1], 2) + +nodes = GossipNode.generate( + data_dispatcher=dispatcher, + p2p_net=topology, + model_proto=LimitedMergeTMH( + net=net, + optimizer=torch.optim.SGD, + optimizer_params={ + "lr": 1, + "weight_decay": .001 + }, + criterion=torch.nn.CrossEntropyLoss(), + ), + round_len=100, + sync=True +) + +simulator = GossipSimulator( + nodes=nodes, + data_dispatcher=dispatcher, + delta=100, + protocol=AntiEntropyProtocol.PUSH, + delay=UniformDelay(0,10), + online_prob=.2, #Approximates the average online rate of the STUNner's smartphone traces + drop_prob=.1, #Simulate the possibility of message dropping, + sampling_eval=.1 +) + +report = SimulationReport() +simulator.add_receiver(report) +simulator.init_nodes(seed=42) +simulator.start(n_rounds=1000) + +plot_evaluation([[ev for _, ev in report.get_evaluation(False)]], "Overall test results") \ No newline at end of file