Skip to content

Commit

Permalink
Improve aggrgator interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Andreas Hellander committed Jan 24, 2024
1 parent db1eb74 commit f000351
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 11 deletions.
25 changes: 20 additions & 5 deletions fedn/fedn/network/combiner/aggregators/aggregatorbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,23 +89,38 @@ def _validate_model_update(self, model_update):
return False
return True

def next_model_update(self, helper):
def next_model_update(self):
""" Get the next model update from the queue.
:param helper: A helper object.
:type helper: object
:return: A tuple containing the model update, metadata and model id.
:rtype: tuple
:return: The model update.
:rtype: protobuf
"""
model_update = self.model_updates.get(block=False)
return model_update

def load_model_update(self, model_update, helper):
""" Load the memory representation of the model update.
Load the model update paramters and the
associate metadata into memory.
:param model_update: The model update.
:type model_update: protobuf
:param helper: A helper object.
:type helper: fedn.utils.helpers.helperbase.Helper
:return: A tuple of (parameters, metadata)
:rtype: tuple
"""
model_id = model_update.model_update_id
model_next = self.control.load_model_update(helper, model_id)
model = self.control.load_model_update(helper, model_id)
# Get relevant metadata
data = json.loads(model_update.meta)['training_metadata']
config = json.loads(json.loads(model_update.meta)['config'])
data['round_id'] = config['round_id']

return model_next, data, model_id, model_update
return model, data

def get_state(self):
""" Get the state of the aggregator's queue, including the number of model updates."""
Expand Down
10 changes: 7 additions & 3 deletions fedn/fedn/network/combiner/aggregators/fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,13 @@ def combine_models(self, helper=None, delete_models=True):
while not self.model_updates.empty():
try:
# Get next model from queue
model_next, metadata, model_id, model_update = self.next_model_update(helper)
model_update = self.next_model_update()

# Load model parameters and metadata
model_next, metadata = self.load_model_update(model_update, helper)

logger.info(
"AGGREGATOR({}): Processing model update {}, metadata: {} ".format(self.name, model_id, metadata))
"AGGREGATOR({}): Processing model update {}, metadata: {} ".format(self.name, model_update.model_update_id, metadata))

# Increment total number of examples
total_examples += metadata['num_examples']
Expand All @@ -72,7 +76,7 @@ def combine_models(self, helper=None, delete_models=True):
nr_aggregated_models += 1
# Delete model from storage
if delete_models:
self.modelservice.models.delete(model_id)
self.modelservice.models.delete(model_update.model_update_id)
logger.info(
"AGGREGATOR({}): Deleted model update {} from storage.".format(self.name, model_id))
self.model_updates.task_done()
Expand Down
10 changes: 7 additions & 3 deletions fedn/fedn/network/combiner/aggregators/fedopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,13 @@ def combine_models(self, helper=None, delete_models=True):
while not self.model_updates.empty():
try:
# Get next model from queue
model_next, metadata, model_id, model_update = self.next_model_update(helper)
model_update = self.next_model_update()

# Load model paratmeters and metadata
model_next, metadata = self.load_model_update(model_update, helper)

logger.info(
"AGGREGATOR({}): Processing model update {}, metadata: {} ".format(self.name, model_id, metadata))
"AGGREGATOR({}): Processing model update {}, metadata: {} ".format(self.name, model_update.model_update_id, metadata))
print("***** ", model_update, flush=True)

# Increment total number of examples
Expand All @@ -89,7 +93,7 @@ def combine_models(self, helper=None, delete_models=True):
nr_aggregated_models += 1
# Delete model from storage
if delete_models:
self.modelservice.models.delete(model_id)
self.modelservice.models.delete(model_update.model_id)
logger.info(
"AGGREGATOR({}): Deleted model update {} from storage.".format(self.name, model_id))
self.model_updates.task_done()
Expand Down

0 comments on commit f000351

Please sign in to comment.