-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b8fb56a
commit 7449c66
Showing
20 changed files
with
537 additions
and
91 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,4 @@ discover_port: 8092 | |
name: hooks | ||
host: hooks | ||
port: 12081 | ||
max_clients: 30 | ||
|
||
|
||
max_clients: 30 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import numpy as np | ||
|
||
|
||
class FunctionProvider: | ||
def __init__(self) -> None: | ||
pass | ||
|
||
def aggregate(self, parameters): | ||
if len(parameters) == 0: | ||
return [] | ||
num_clients = len(parameters) | ||
|
||
summed_parameters = [np.zeros_like(param) for param in parameters[0]] | ||
|
||
for client_params in parameters: | ||
for i, param in enumerate(client_params): | ||
summed_parameters[i] += param | ||
|
||
averaged_parameters = [param / num_clients for param in summed_parameters] | ||
|
||
return averaged_parameters |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import traceback | ||
|
||
import numpy as np | ||
|
||
from fedn.common.log_config import logger | ||
from fedn.network.combiner.aggregators.aggregatorbase import AggregatorBase | ||
|
||
|
||
class Aggregator(AggregatorBase): | ||
"""Custom aggregator provided from user defined code. | ||
:param id: A reference to id of :class: `fedn.network.combiner.Combiner` | ||
:type id: str | ||
:param storage: Model repository for :class: `fedn.network.combiner.Combiner` | ||
:type storage: class: `fedn.common.storage.s3.s3repo.S3ModelRepository` | ||
:param server: A handle to the Combiner class :class: `fedn.network.combiner.Combiner` | ||
:type server: class: `fedn.network.combiner.Combiner` | ||
:param modelservice: A handle to the model service :class: `fedn.network.combiner.modelservice.ModelService` | ||
:type modelservice: class: `fedn.network.combiner.modelservice.ModelService` | ||
:param control: A handle to the :class: `fedn.network.combiner.roundhandler.RoundHandler` | ||
:type control: class: `fedn.network.combiner.roundhandler.RoundHandler` | ||
""" | ||
|
||
def __init__(self, storage, server, modelservice, round_handler): | ||
"""Constructor method""" | ||
super().__init__(storage, server, modelservice, round_handler) | ||
|
||
self.name = "custom" | ||
self.code_set = False | ||
|
||
def combine_models(self, helper=None, delete_models=True, parameters=None): | ||
"""Aggregate all model updates with custom aggregator. | ||
:param helper: An instance of :class: `fedn.utils.helpers.helpers.HelperBase`, ML framework specific helper, defaults to None | ||
:type helper: class: `fedn.utils.helpers.helpers.HelperBase`, optional | ||
:param time_window: The time window for model aggregation, defaults to 180 | ||
:type time_window: int, optional | ||
:param max_nr_models: The maximum number of updates aggregated, defaults to 100 | ||
:type max_nr_models: int, optional | ||
:param delete_models: Delete models from storage after aggregation, defaults to True | ||
:type delete_models: bool, optional | ||
:return: The global model and metadata | ||
:rtype: tuple | ||
""" | ||
data = {} | ||
data["time_model_load"] = 0.0 | ||
data["time_model_aggregation"] = 0.0 | ||
|
||
model = None | ||
nr_aggregated_models = 0 | ||
total_examples = 0 | ||
|
||
logger.info("AGGREGATOR({}): Aggregating model updates... ".format(self.name)) | ||
if not self.code_set: | ||
self.round_handler.set_function_provider() | ||
self.code_set = True | ||
while not self.model_updates.empty(): | ||
try: | ||
# Get next model from queue | ||
logger.info("AGGREGATOR({}): Getting next model update from queue.".format(self.name)) | ||
model_update = self.next_model_update() | ||
|
||
# Load model parameters and metadata | ||
logger.info("AGGREGATOR({}): Loading model metadata {}.".format(self.name, model_update.model_update_id)) | ||
model_next, metadata = self.load_model_update(model_update, helper) | ||
|
||
logger.info("AGGREGATOR({}): Processing model update {}, metadata: {} ".format(self.name, model_update.model_update_id, metadata)) | ||
|
||
# Increment total number of examples | ||
total_examples += metadata["num_examples"] | ||
|
||
nr_aggregated_models += 1 | ||
# Delete model from storage | ||
if delete_models: | ||
self.modelservice.temp_model_storage.delete(model_update.model_update_id) | ||
logger.info("AGGREGATOR({}): Deleted model update {} from storage.".format(self.name, model_update.model_update_id)) | ||
|
||
self.model_updates.task_done() | ||
if not self.model_updates.empty(): | ||
self.round_handler.combiner_hook_client.call_function("store_parameters", model_next, helper) | ||
else: | ||
model = self.round_handler.combiner_hook_client.call_function("aggregate", model_next, helper) | ||
except Exception as e: | ||
tb = traceback.format_exc() | ||
logger.error(f"AGGREGATOR({self.name}): Error encoutered while processing model update: {e}") | ||
logger.error(tb) | ||
self.model_updates.task_done() | ||
|
||
data["nr_aggregated_models"] = nr_aggregated_models | ||
|
||
logger.info("AGGREGATOR({}): Aggregation completed, aggregated {} models.".format(self.name, nr_aggregated_models)) | ||
return model, data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import os | ||
|
||
import grpc | ||
|
||
import fedn.network.grpc.fedn_pb2 as fedn | ||
import fedn.network.grpc.fedn_pb2_grpc as rpc | ||
from fedn.common.log_config import logger | ||
from fedn.network.combiner.modelservice import load_model_from_BytesIO, serialize_model_to_BytesIO | ||
|
||
|
||
class CombinerHookClient: | ||
def __init__(self): | ||
logger.info("Starting hook client") | ||
self.hook_service_host = os.getenv("HOOK_SERVICE_HOST", "hook:12081") | ||
self.channel = grpc.insecure_channel(self.hook_service_host) | ||
self.stub = rpc.FunctionServiceStub(self.channel) | ||
|
||
def call_function_service(self, task, payload): | ||
request = fedn.FunctionRequest(task=task, payload_string=payload) if task == "setup" else fedn.FunctionRequest(task=task, payload_bytes=payload) | ||
try: | ||
response = self.stub.ExecuteFunction(request) | ||
return response | ||
except grpc.RpcError as e: | ||
logger.info(f"RPC failed: {e}") | ||
return None | ||
|
||
# Example method to trigger function execution | ||
def set_function_provider(self, class_code): | ||
if not isinstance(class_code, str): | ||
raise TypeError("class_code must be of type string") | ||
self.call_function_service("setup", class_code) | ||
|
||
def call_function(self, task, payload, helper): | ||
if task == "aggregate": | ||
payload = serialize_model_to_BytesIO(payload, helper).getvalue() | ||
response = self.call_function_service(task, payload) | ||
return load_model_from_BytesIO(response.result_bytes, helper) | ||
if task == "store_parameters": | ||
payload = serialize_model_to_BytesIO(payload, helper).getvalue() | ||
response = self.call_function_service(task, payload) |
Oops, something went wrong.