From f000351d87769c93ab2fdfb7a3a0a6fe1cb1ae13 Mon Sep 17 00:00:00 2001 From: Andreas Hellander Date: Wed, 24 Jan 2024 11:26:02 +0100 Subject: [PATCH] Improve aggrgator interface --- .../combiner/aggregators/aggregatorbase.py | 25 +++++++++++++++---- .../network/combiner/aggregators/fedavg.py | 10 +++++--- .../network/combiner/aggregators/fedopt.py | 10 +++++--- 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/fedn/fedn/network/combiner/aggregators/aggregatorbase.py b/fedn/fedn/network/combiner/aggregators/aggregatorbase.py index 27a885e18..51a31f1ce 100644 --- a/fedn/fedn/network/combiner/aggregators/aggregatorbase.py +++ b/fedn/fedn/network/combiner/aggregators/aggregatorbase.py @@ -89,23 +89,38 @@ def _validate_model_update(self, model_update): return False return True - def next_model_update(self, helper): + def next_model_update(self): """ Get the next model update from the queue. :param helper: A helper object. :type helper: object - :return: A tuple containing the model update, metadata and model id. - :rtype: tuple + :return: The model update. + :rtype: protobuf """ model_update = self.model_updates.get(block=False) + return model_update + + def load_model_update(self, model_update, helper): + """ Load the memory representation of the model update. + + Load the model update paramters and the + associate metadata into memory. + + :param model_update: The model update. + :type model_update: protobuf + :param helper: A helper object. + :type helper: fedn.utils.helpers.helperbase.Helper + :return: A tuple of (parameters, metadata) + :rtype: tuple + """ model_id = model_update.model_update_id - model_next = self.control.load_model_update(helper, model_id) + model = self.control.load_model_update(helper, model_id) # Get relevant metadata data = json.loads(model_update.meta)['training_metadata'] config = json.loads(json.loads(model_update.meta)['config']) data['round_id'] = config['round_id'] - return model_next, data, model_id, model_update + return model, data def get_state(self): """ Get the state of the aggregator's queue, including the number of model updates.""" diff --git a/fedn/fedn/network/combiner/aggregators/fedavg.py b/fedn/fedn/network/combiner/aggregators/fedavg.py index 9505d545a..0335af761 100644 --- a/fedn/fedn/network/combiner/aggregators/fedavg.py +++ b/fedn/fedn/network/combiner/aggregators/fedavg.py @@ -56,9 +56,13 @@ def combine_models(self, helper=None, delete_models=True): while not self.model_updates.empty(): try: # Get next model from queue - model_next, metadata, model_id, model_update = self.next_model_update(helper) + model_update = self.next_model_update() + + # Load model parameters and metadata + model_next, metadata = self.load_model_update(model_update, helper) + logger.info( - "AGGREGATOR({}): Processing model update {}, metadata: {} ".format(self.name, model_id, metadata)) + "AGGREGATOR({}): Processing model update {}, metadata: {} ".format(self.name, model_update.model_update_id, metadata)) # Increment total number of examples total_examples += metadata['num_examples'] @@ -72,7 +76,7 @@ def combine_models(self, helper=None, delete_models=True): nr_aggregated_models += 1 # Delete model from storage if delete_models: - self.modelservice.models.delete(model_id) + self.modelservice.models.delete(model_update.model_update_id) logger.info( "AGGREGATOR({}): Deleted model update {} from storage.".format(self.name, model_id)) self.model_updates.task_done() diff --git a/fedn/fedn/network/combiner/aggregators/fedopt.py b/fedn/fedn/network/combiner/aggregators/fedopt.py index d94910b53..8f7ccc2df 100644 --- a/fedn/fedn/network/combiner/aggregators/fedopt.py +++ b/fedn/fedn/network/combiner/aggregators/fedopt.py @@ -68,9 +68,13 @@ def combine_models(self, helper=None, delete_models=True): while not self.model_updates.empty(): try: # Get next model from queue - model_next, metadata, model_id, model_update = self.next_model_update(helper) + model_update = self.next_model_update() + + # Load model paratmeters and metadata + model_next, metadata = self.load_model_update(model_update, helper) + logger.info( - "AGGREGATOR({}): Processing model update {}, metadata: {} ".format(self.name, model_id, metadata)) + "AGGREGATOR({}): Processing model update {}, metadata: {} ".format(self.name, model_update.model_update_id, metadata)) print("***** ", model_update, flush=True) # Increment total number of examples @@ -89,7 +93,7 @@ def combine_models(self, helper=None, delete_models=True): nr_aggregated_models += 1 # Delete model from storage if delete_models: - self.modelservice.models.delete(model_id) + self.modelservice.models.delete(model_update.model_id) logger.info( "AGGREGATOR({}): Deleted model update {} from storage.".format(self.name, model_id)) self.model_updates.task_done()