Skip to content

Commit

Permalink
Resolve conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
Andreas Hellander committed Dec 7, 2023
2 parents 170aee2 + bf471e3 commit d12246d
Show file tree
Hide file tree
Showing 20 changed files with 812 additions and 1,369 deletions.
27 changes: 27 additions & 0 deletions .devcontainer/devcontainer.json.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"name": "devcontainer",
"dockerFile": "Dockerfile",
"context": "..",
"remoteUser": "default",
// "workspaceFolder": "/fedn",
// "workspaceMount": "source=/path/to/fedn,target=/fedn,type=bind,consistency=default",
"extensions": [
"ms-azuretools.vscode-docker",
"ms-python.python",
"exiasr.hadolint",
"yzhang.markdown-all-in-one",
"ms-python.isort"
],
"mounts": [
"source=/var/run/docker.sock,target=/var/run/docker.sock,type=bind,consistency=default",
],
"runArgs": [
"--net=host"
],
"build": {
"args": {
"BASE_IMG": "python:3.9"
}
}
}

2 changes: 1 addition & 1 deletion examples/mnist-keras/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
tensorflow==2.9.3
tensorflow==2.13.1
fire==0.3.1
docker==6.1.1
11 changes: 1 addition & 10 deletions fedn/fedn/common/net/grpc/fedn.proto
Original file line number Diff line number Diff line change
Expand Up @@ -195,17 +195,10 @@ message ControlResponse {
repeated Parameter parameter = 2;
}

message ReportResponse {
Client sender = 1;
repeated Parameter parameter = 2;
}

service Control {
rpc Start(ControlRequest) returns (ControlResponse);
rpc Stop(ControlRequest) returns (ControlResponse);
rpc Configure(ControlRequest) returns (ReportResponse);
rpc FlushAggregationQueue(ControlRequest) returns (ControlResponse);
rpc Report(ControlRequest) returns (ReportResponse);
rpc FlushAggregationQueue(ControlRequest) returns (ControlResponse);
}

service Reducer {
Expand Down Expand Up @@ -253,9 +246,7 @@ service Combiner {
rpc ModelValidationRequestStream (ClientAvailableMessage) returns (stream ModelValidationRequest);
rpc ModelValidationStream (ClientAvailableMessage) returns (stream ModelValidation);

rpc SendModelUpdateRequest (ModelUpdateRequest) returns (Response);
rpc SendModelUpdate (ModelUpdate) returns (Response);
rpc SendModelValidationRequest (ModelValidationRequest) returns (Response);
rpc SendModelValidation (ModelValidation) returns (Response);

}
Expand Down
382 changes: 74 additions & 308 deletions fedn/fedn/common/net/grpc/fedn_pb2.py

Large diffs are not rendered by default.

995 changes: 434 additions & 561 deletions fedn/fedn/common/net/grpc/fedn_pb2_grpc.py

Large diffs are not rendered by default.

18 changes: 11 additions & 7 deletions fedn/fedn/common/net/grpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
import grpc

import fedn.common.net.grpc.fedn_pb2_grpc as rpc
from fedn.common.log_config import (logger, set_log_level_from_string,
set_log_stream)


class Server:
""" Class for configuring and launching the gRPC server."""

def __init__(self, servicer, modelservicer, config):

set_log_level_from_string(config.get('verbosity', "INFO"))
set_log_stream(config.get('logfile', None))

self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=350))
self.certificate = None

Expand All @@ -25,22 +30,21 @@ def __init__(self, servicer, modelservicer, config):
rpc.add_ControlServicer_to_server(servicer, self.server)

if config['secure']:
print("Creating secure gRPCS server using certificate: {config['certificate']}", flush=True)
logger.info(f'Creating secure gRPCS server using certificate: {config["certificate"]}')
server_credentials = grpc.ssl_server_credentials(
((config['key'], config['certificate'],),))
self.server.add_secure_port(
'[::]:' + str(config['port']), server_credentials)
else:
print("Creating insecure gRPC server", flush=True)
logger.info("Creating insecure gRPC server")
self.server.add_insecure_port('[::]:' + str(config['port']))

def start(self):
""" Start gRPC server. """

print("Server started", flush=True)
""" Start the gRPC server."""
logger.info("gRPC Server started")
self.server.start()

def stop(self):
""" Stop gRPC server. """

""" Stop the gRPC server."""
logger.info("gRPC Server stopped")
self.server.stop(0)
50 changes: 10 additions & 40 deletions fedn/fedn/network/api/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,6 @@ def _to_dict(self):
data = {"name": self.name}
return data

def _get_combiner_report(self, combiner_id):
"""Get report response from combiner.
:param combiner_id: The combiner id to get report response from.
:type combiner_id: str
::return: The report response from combiner.
::rtype: dict
"""
# Get CombinerInterface (fedn.network.combiner.inferface.CombinerInterface) for combiner_id
combiner = self.control.network.get_combiner(combiner_id)
report = combiner.report
return report

def _allowed_file_extension(
self, filename, ALLOWED_EXTENSIONS={"gz", "bz2", "tar", "zip", "tgz"}
):
Expand Down Expand Up @@ -91,30 +78,6 @@ def get_clients(self, limit=None, skip=None, status=False):

return jsonify(result)

def get_active_clients(self, combiner_id):
"""Get all active clients, i.e that are assigned to a combiner.
A report request to the combiner is neccessary to determine if a client is active or not.
:param combiner_id: The combiner id to get active clients for.
:type combiner_id: str
:return: All active clients as a json response.
:rtype: :class:`flask.Response`
"""
# Get combiner interface object
combiner = self.control.network.get_combiner(combiner_id)
if combiner is None:
return (
jsonify(
{
"success": False,
"message": f"Combiner {combiner_id} not found.",
}
),
404,
)
response = combiner.list_active_clients()
return response

def get_all_combiners(self, limit=None, skip=None):
"""Get all combiners from the statestore.
Expand Down Expand Up @@ -154,7 +117,6 @@ def get_combiner(self, combiner_id):
"fqdn": object["fqdn"],
"parent_reducer": object["parent"]["name"],
"port": object["port"],
"report": object["report"],
"updated_at": object["updated_at"],
}
payload[id] = info
Expand Down Expand Up @@ -832,12 +794,20 @@ def start_session(
{"success": False, "message": "A session is already running."}
)

# Check that initial (seed) model is set
if not self.statestore.get_initial_model():
return jsonify(
{
"success": False,
"message": "No initial model set. Set initial model before starting session.",
}
)

# Check available clients per combiner
clients_available = 0
for combiner in self.control.network.get_combiners():
try:
combiner_state = combiner.report()
nr_active_clients = combiner_state["nr_active_clients"]
nr_active_clients = len(combiner.list_active_clients())
clients_available = clients_available + int(nr_active_clients)
except CombinerUnavailableError as e:
# TODO: Handle unavailable combiner, stop session or continue?
Expand Down
18 changes: 1 addition & 17 deletions fedn/fedn/network/api/network.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import base64

from fedn.network.combiner.interfaces import (CombinerInterface,
CombinerUnavailableError)
from fedn.network.combiner.interfaces import CombinerInterface
from fedn.network.loadbalancer.leastpacked import LeastPacked

__all__ = 'Network',
Expand Down Expand Up @@ -154,18 +153,3 @@ def get_client_info(self):
:rtype: list(ObjectId)
"""
return self.statestore.list_clients()

def describe(self):
""" Describe the network.
:return: The network description
:rtype: dict
"""
network = []
for combiner in self.get_combiners():
try:
network.append(combiner.report())
except CombinerUnavailableError:
# TODO, do better here.
pass
return network
60 changes: 39 additions & 21 deletions fedn/fedn/network/clients/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,29 +441,37 @@ def _listen_to_model_update_request_stream(self):
# Add client to metadata
self._add_grpc_metadata('client', self.name)

while True:
while self._attached:
try:
for request in self.combinerStub.ModelUpdateRequestStream(r, metadata=self.metadata):
if request:
logger.debug("Received model update request from combiner: {}.".format(request))
if request.sender.role == fedn.COMBINER:
# Process training request
self._send_status("Received model update request.", log_level=fedn.Status.AUDIT,
type=fedn.StatusType.MODEL_UPDATE_REQUEST, request=request)
logger.info("Received model update request.")

self.inbox.put(('train', request))

if not self._attached:
return
except grpc.RpcError as e:
_ = e.code()
except grpc.RpcError:
# TODO: make configurable
timeout = 5
time.sleep(timeout)
except Exception:
raise
# Handle gRPC errors
status_code = e.code()
if status_code == grpc.StatusCode.UNAVAILABLE:
logger.warning("GRPC server unavailable during model update request stream. Retrying.")
# Retry after a delay
time.sleep(5)
else:
# Log the error and continue
logger.error(f"An error occurred during model update request stream: {e}")

if not self._attached:
return
except Exception as ex:
# Handle other exceptions
logger.error(f"An error occurred during model update request stream: {ex}")

# Detach if not attached
if not self._attached:
return

def _listen_to_model_validation_request_stream(self):
"""Subscribe to the model validation request stream.
Expand All @@ -479,17 +487,27 @@ def _listen_to_model_validation_request_stream(self):
try:
for request in self.combinerStub.ModelValidationRequestStream(r, metadata=self.metadata):
# Process validation request
_ = request.model_id
self._send_status("Recieved model validation request.", log_level=fedn.Status.AUDIT,
type=fedn.StatusType.MODEL_VALIDATION_REQUEST, request=request)
model_id = request.model_id
self._send_status("Received model validation request for model_id {}".format(model_id),
log_level=fedn.Status.AUDIT, type=fedn.StatusType.MODEL_VALIDATION_REQUEST,
request=request)
logger.info("Received model validation request for model_id {}".format(model_id))
self.inbox.put(('validate', request))

except grpc.RpcError:
# TODO: make configurable
timeout = 5
time.sleep(timeout)
except Exception:
raise
except grpc.RpcError as e:
# Handle gRPC errors
status_code = e.code()
if status_code == grpc.StatusCode.UNAVAILABLE:
logger.warning("GRPC server unavailable during model validation request stream. Retrying.")
# Retry after a delay
time.sleep(5)
else:
# Log the error and continue
logger.error(f"An error occurred during model validation request stream: {e}")

except Exception as ex:
# Handle other exceptions
logger.error(f"An error occurred during model validation request stream: {ex}")

if not self._attached:
return
Expand Down
12 changes: 5 additions & 7 deletions fedn/fedn/network/combiner/aggregators/aggregatorbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import queue
from abc import ABC, abstractmethod

import fedn.common.net.grpc.fedn_pb2 as fedn
from fedn.common.log_config import logger

AGGREGATOR_PLUGIN_PATH = "fedn.network.combiner.aggregators.{}"

Expand Down Expand Up @@ -60,19 +60,17 @@ def on_model_update(self, model_update):
:type model_id: str
"""
try:
self.server.report_status("AGGREGATOR({}): callback received model update {}".format(self.name, model_update.model_update_id),
log_level=fedn.Status.INFO)
logger.info("AGGREGATOR({}): callback received model update {}".format(self.name, model_update.model_update_id))

# Validate the update and metadata
valid_update = self._validate_model_update(model_update)
if valid_update:
# Push the model update to the processing queue
self.model_updates.put(model_update)
else:
self.server.report_status("AGGREGATOR({}): Invalid model update, skipping.".format(self.name))
logger.warning("AGGREGATOR({}): Invalid model update, skipping.".format(self.name))
except Exception as e:
self.server.report_status("AGGREGATOR({}): Failed to receive model update! {}".format(self.name, e),
log_level=fedn.Status.WARNING)
logger.error("AGGREGATOR({}): Failed to receive model update! {}".format(self.name, e))
pass

def _validate_model_update(self, model_update):
Expand All @@ -86,7 +84,7 @@ def _validate_model_update(self, model_update):
# TODO: Validate the metadata to check that it contains all variables assumed by the aggregator.
data = json.loads(model_update.meta)['training_metadata']
if 'num_examples' not in data.keys():
self.server.report_status("AGGREGATOR({}): Model validation failed, num_examples missing in metadata.".format(self.name))
logger.error("AGGREGATOR({}): Model validation failed, num_examples missing in metadata.".format(self.name))
return False
return True

Expand Down
13 changes: 6 additions & 7 deletions fedn/fedn/network/combiner/aggregators/fedavg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import fedn.common.net.grpc.fedn_pb2 as fedn
from fedn.common.log_config import logger
from fedn.network.combiner.aggregators.aggregatorbase import AggregatorBase


Expand Down Expand Up @@ -50,14 +50,14 @@ def combine_models(self, helper=None, time_window=180, max_nr_models=100, delete
nr_aggregated_models = 0
total_examples = 0

self.server.report_status(
logger.info(
"AGGREGATOR({}): Aggregating model updates... ".format(self.name))

while not self.model_updates.empty():
try:
# Get next model from queue
model_next, metadata, model_id = self.next_model_update(helper)
self.server.report_status(
logger.info(
"AGGREGATOR({}): Processing model update {}, metadata: {} ".format(self.name, model_id, metadata))

# Increment total number of examples
Expand All @@ -73,16 +73,15 @@ def combine_models(self, helper=None, time_window=180, max_nr_models=100, delete
# Delete model from storage
if delete_models:
self.modelservice.models.delete(model_id)
self.server.report_status(
logger.info(
"AGGREGATOR({}): Deleted model update {} from storage.".format(self.name, model_id))
self.model_updates.task_done()
except Exception as e:
self.server.report_status(
logger.error(
"AGGREGATOR({}): Error encoutered while processing model update {}, skipping this update.".format(self.name, e))
self.model_updates.task_done()

data['nr_aggregated_models'] = nr_aggregated_models

self.server.report_status("AGGREGATOR({}): Aggregation completed, aggregated {} models.".format(self.name, nr_aggregated_models),
log_level=fedn.Status.INFO)
logger.info("AGGREGATOR({}): Aggregation completed, aggregated {} models.".format(self.name, nr_aggregated_models))
return model, data
Loading

0 comments on commit d12246d

Please sign in to comment.