Skip to content

Commit

Permalink
working
Browse files Browse the repository at this point in the history
  • Loading branch information
viktorvaladi committed Jul 26, 2024
1 parent b8fb56a commit 7449c66
Show file tree
Hide file tree
Showing 20 changed files with 537 additions and 91 deletions.
4 changes: 1 addition & 3 deletions config/settings-hooks.yaml.template
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,4 @@ discover_port: 8092
name: hooks
host: hooks
port: 12081
max_clients: 30


max_clients: 30
12 changes: 6 additions & 6 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ services:
environment:
- PYTHONUNBUFFERED=0
- GET_HOSTS_FROM=dns
- HOOK_SERVICE_HOST=hook:12081
build:
context: .
args:
Expand All @@ -112,10 +113,11 @@ services:
retries: 5
depends_on:
- api-server

- hooks
# Hooks
hooks:
container_name: hook
environment:
- PYTHONUNBUFFERED=0
- GET_HOSTS_FROM=dns
build:
context: .
Expand All @@ -124,10 +126,10 @@ services:
GRPC_HEALTH_PROBE_VERSION: v0.4.24
working_dir: /app
volumes:
- ${HOST_REPO_DIR:-.}/user-hooks:/app/user-hooks
- ${HOST_REPO_DIR:-.}/fedn:/app/fedn
entrypoint: [ "sh", "-c" ]
command:
- "/venv/bin/pip install --no-cache-dir -e . && /venv/bin/fedn hooks start --init config/settings-user-hooks.yaml"
- "/venv/bin/pip install --no-cache-dir -e . && /venv/bin/fedn hooks start"
ports:
- 12081:12081
healthcheck:
Expand All @@ -140,8 +142,6 @@ services:
interval: 20s
timeout: 10s
retries: 5
depends_on:
- combiner

# Client
client:
Expand Down
21 changes: 21 additions & 0 deletions examples/mnist-pytorch/client/aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import numpy as np


class FunctionProvider:
def __init__(self) -> None:
pass

def aggregate(self, parameters):
if len(parameters) == 0:
return []
num_clients = len(parameters)

summed_parameters = [np.zeros_like(param) for param in parameters[0]]

for client_params in parameters:
for i, param in enumerate(client_params):
summed_parameters[i] += param

averaged_parameters = [param / num_clients for param in summed_parameters]

return averaged_parameters
53 changes: 5 additions & 48 deletions fedn/cli/hooks_cmd.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import uuid

import click
import requests

from fedn.network.combiner.combiner import Combiner
from fedn.network.hooks.hooks import serve

from .main import main
from .shared import CONTROLLER_DEFAULTS, apply_config, get_api_url, get_token, print_response


@main.group("hooks")
Expand All @@ -17,47 +13,8 @@ def hooks_cmd(ctx):


@hooks_cmd.command("start")
@click.option("-d", "--discoverhost", required=False, help="Hostname for discovery services (reducer).")
@click.option("-p", "--discoverport", required=False, help="Port for discovery services (reducer).")
@click.option("-t", "--token", required=False, help="Set token provided by reducer if enabled")
@click.option("-n", "--name", required=False, default="combiner" + str(uuid.uuid4())[:8], help="Set name for combiner.")
@click.option("-h", "--host", required=False, default="combiner", help="Set hostname.")
@click.option("-i", "--port", required=False, default=12080, help="Set port.")
@click.option("-f", "--fqdn", required=False, default=None, help="Set fully qualified domain name")
@click.option("-s", "--secure", is_flag=True, help="Enable SSL/TLS encrypted gRPC channels.")
@click.option("-v", "--verify", is_flag=True, help="Verify SSL/TLS for REST discovery service (reducer)")
@click.option("-c", "--max_clients", required=False, default=30, help="The maximal number of client connections allowed.")
@click.option("-in", "--init", required=False, default=None, help="Path to configuration file to (re)init combiner.")
@click.pass_context
def start_cmd(ctx, discoverhost, discoverport, token, name, host, port, fqdn, secure, verify, max_clients, init):
""":param ctx:
:param discoverhost:
:param discoverport:
:param token:
:param name:
:param hostname:
:param port:
:param secure:
:param max_clients:
:param init:
"""
config = {
"discover_host": discoverhost,
"discover_port": discoverport,
"token": token,
"host": host,
"port": port,
"fqdn": fqdn,
"name": name,
"secure": secure,
"verify": verify,
"max_clients": max_clients,
}
click.echo("started hooks container")
if init:
apply_config(init, config)
click.echo(f"\nCombiner configuration loaded from file: {init}")
click.echo("Values set in file override defaults and command line arguments...\n")

# combiner = Combiner(config)
# combiner.run()
def start_cmd(ctx):
""":param ctx:"""
click.echo("Started hooks container")
serve()
10 changes: 10 additions & 0 deletions fedn/network/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,7 @@ def start_session(
helper: str = "",
min_clients: int = 1,
requested_clients: int = 8,
function_provider_path: str = None,
):
"""Start a new session.
Expand Down Expand Up @@ -617,6 +618,7 @@ def start_session(
"helper": helper,
"min_clients": min_clients,
"requested_clients": requested_clients,
"function_provider": function_provider_path if function_provider_path is None else self._read_function_provider(function_provider_path),
},
verify=self.verify,
headers=self.headers,
Expand Down Expand Up @@ -787,3 +789,11 @@ def get_validations_count(self):
_json = response.json()

return _json

def _read_function_provider(self, path):
# Open the file in read mode
with open(path, "r") as file:
file_contents = file.read()

# Print the file contents
return file_contents
19 changes: 13 additions & 6 deletions fedn/network/api/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from fedn.common.config import get_controller_config, get_network_config
from fedn.common.log_config import logger
from fedn.network.combiner.interfaces import CombinerInterface, CombinerUnavailableError
from fedn.network.combiner.modelservice import load_model_from_BytesIO
from fedn.network.state import ReducerState, ReducerStateToString
from fedn.utils.checksum import sha
from fedn.utils.plots import Plot
Expand Down Expand Up @@ -649,13 +650,17 @@ def set_initial_model(self, file):
:rtype: :class:`flask.Response`
"""
try:
object = BytesIO()
object.seek(0, 0)
file.seek(0)
object.write(file.read())
helper = self.control.get_helper()
object.seek(0)
model = helper.load(object)

# Read file data into a BytesIO object
file_bytes = BytesIO()
file.seek(0)
file_bytes.write(file.read())
file_bytes.seek(0)

# Load the model using the load_model_from_BytesIO function
model_bytes = file_bytes.read()
model = load_model_from_BytesIO(model_bytes, helper)
self.control.commit(file.filename, model)
except Exception as e:
logger.debug(e)
Expand Down Expand Up @@ -966,6 +971,7 @@ def start_session(
helper="",
min_clients=1,
requested_clients=8,
function_provider=None,
):
"""Start a session.
Expand Down Expand Up @@ -1068,6 +1074,7 @@ def start_session(
"task": (""),
"validate": validate,
"helper_type": helper,
"function_provider": function_provider,
}

# Start session
Expand Down
3 changes: 1 addition & 2 deletions fedn/network/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
from fedn.common.config import get_controller_config
from fedn.network.api.auth import jwt_auth_required
from fedn.network.api.interface import API
from fedn.network.api.shared import control, statestore
from fedn.network.api.v1 import _routes
from fedn.network.api.shared import statestore, control


custom_url_prefix = os.environ.get("FEDN_CUSTOM_URL_PREFIX", False)
api = API(statestore, control)
Expand Down
93 changes: 93 additions & 0 deletions fedn/network/combiner/aggregators/custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import traceback

import numpy as np

from fedn.common.log_config import logger
from fedn.network.combiner.aggregators.aggregatorbase import AggregatorBase


class Aggregator(AggregatorBase):
"""Custom aggregator provided from user defined code.
:param id: A reference to id of :class: `fedn.network.combiner.Combiner`
:type id: str
:param storage: Model repository for :class: `fedn.network.combiner.Combiner`
:type storage: class: `fedn.common.storage.s3.s3repo.S3ModelRepository`
:param server: A handle to the Combiner class :class: `fedn.network.combiner.Combiner`
:type server: class: `fedn.network.combiner.Combiner`
:param modelservice: A handle to the model service :class: `fedn.network.combiner.modelservice.ModelService`
:type modelservice: class: `fedn.network.combiner.modelservice.ModelService`
:param control: A handle to the :class: `fedn.network.combiner.roundhandler.RoundHandler`
:type control: class: `fedn.network.combiner.roundhandler.RoundHandler`
"""

def __init__(self, storage, server, modelservice, round_handler):
"""Constructor method"""
super().__init__(storage, server, modelservice, round_handler)

self.name = "custom"
self.code_set = False

def combine_models(self, helper=None, delete_models=True, parameters=None):
"""Aggregate all model updates with custom aggregator.
:param helper: An instance of :class: `fedn.utils.helpers.helpers.HelperBase`, ML framework specific helper, defaults to None
:type helper: class: `fedn.utils.helpers.helpers.HelperBase`, optional
:param time_window: The time window for model aggregation, defaults to 180
:type time_window: int, optional
:param max_nr_models: The maximum number of updates aggregated, defaults to 100
:type max_nr_models: int, optional
:param delete_models: Delete models from storage after aggregation, defaults to True
:type delete_models: bool, optional
:return: The global model and metadata
:rtype: tuple
"""
data = {}
data["time_model_load"] = 0.0
data["time_model_aggregation"] = 0.0

model = None
nr_aggregated_models = 0
total_examples = 0

logger.info("AGGREGATOR({}): Aggregating model updates... ".format(self.name))
if not self.code_set:
self.round_handler.set_function_provider()
self.code_set = True
while not self.model_updates.empty():
try:
# Get next model from queue
logger.info("AGGREGATOR({}): Getting next model update from queue.".format(self.name))
model_update = self.next_model_update()

# Load model parameters and metadata
logger.info("AGGREGATOR({}): Loading model metadata {}.".format(self.name, model_update.model_update_id))
model_next, metadata = self.load_model_update(model_update, helper)

logger.info("AGGREGATOR({}): Processing model update {}, metadata: {} ".format(self.name, model_update.model_update_id, metadata))

# Increment total number of examples
total_examples += metadata["num_examples"]

nr_aggregated_models += 1
# Delete model from storage
if delete_models:
self.modelservice.temp_model_storage.delete(model_update.model_update_id)
logger.info("AGGREGATOR({}): Deleted model update {} from storage.".format(self.name, model_update.model_update_id))

self.model_updates.task_done()
if not self.model_updates.empty():
self.round_handler.combiner_hook_client.call_function("store_parameters", model_next, helper)
else:
model = self.round_handler.combiner_hook_client.call_function("aggregate", model_next, helper)
except Exception as e:
tb = traceback.format_exc()
logger.error(f"AGGREGATOR({self.name}): Error encoutered while processing model update: {e}")
logger.error(tb)
self.model_updates.task_done()

data["nr_aggregated_models"] = nr_aggregated_models

logger.info("AGGREGATOR({}): Aggregation completed, aggregated {} models.".format(self.name, nr_aggregated_models))
return model, data
21 changes: 20 additions & 1 deletion fedn/network/combiner/combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,15 +466,34 @@ def SetAggregator(self, control: fedn.ControlRequest, context):
logger.debug("grpc.Combiner.SetAggregator: Called")
for parameter in control.parameter:
aggregator = parameter.value

status = self.round_handler.set_aggregator(aggregator)

response = fedn.ControlResponse()
if status:
response.message = "Success"
else:
response.message = "Failed"
return response

def SetFunctionProvider(self, control: fedn.ControlRequest, context):
"""Set a function provider.
:param control: the control request
:type control: :class:`fedn.network.grpc.fedn_pb2.ControlRequest`
:param context: the context (unused)
:type context: :class:`grpc._server._Context`
:return: the control response
:rtype: :class:`fedn.network.grpc.fedn_pb2.ControlResponse`
"""
logger.debug("grpc.Combiner.SetFunctionProvider: Called")
for parameter in control.parameter:
function_provider = parameter.value

self.round_handler.function_provider_code = function_provider

response = fedn.ControlResponse()
response.message = "Success"
logger.info(f"set function provider response {response}")
return response

def FlushAggregationQueue(self, control: fedn.ControlRequest, context):
Expand Down
40 changes: 40 additions & 0 deletions fedn/network/combiner/hook_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os

import grpc

import fedn.network.grpc.fedn_pb2 as fedn
import fedn.network.grpc.fedn_pb2_grpc as rpc
from fedn.common.log_config import logger
from fedn.network.combiner.modelservice import load_model_from_BytesIO, serialize_model_to_BytesIO


class CombinerHookClient:
def __init__(self):
logger.info("Starting hook client")
self.hook_service_host = os.getenv("HOOK_SERVICE_HOST", "hook:12081")
self.channel = grpc.insecure_channel(self.hook_service_host)
self.stub = rpc.FunctionServiceStub(self.channel)

def call_function_service(self, task, payload):
request = fedn.FunctionRequest(task=task, payload_string=payload) if task == "setup" else fedn.FunctionRequest(task=task, payload_bytes=payload)
try:
response = self.stub.ExecuteFunction(request)
return response
except grpc.RpcError as e:
logger.info(f"RPC failed: {e}")
return None

# Example method to trigger function execution
def set_function_provider(self, class_code):
if not isinstance(class_code, str):
raise TypeError("class_code must be of type string")
self.call_function_service("setup", class_code)

def call_function(self, task, payload, helper):
if task == "aggregate":
payload = serialize_model_to_BytesIO(payload, helper).getvalue()
response = self.call_function_service(task, payload)
return load_model_from_BytesIO(response.result_bytes, helper)
if task == "store_parameters":
payload = serialize_model_to_BytesIO(payload, helper).getvalue()
response = self.call_function_service(task, payload)
Loading

0 comments on commit 7449c66

Please sign in to comment.