Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/SK-592 | Refactor gRPC server #490

Merged
merged 8 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/mnist-keras/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
tensorflow==2.9.3
tensorflow==2.13.1
fire==0.3.1
docker==6.1.1
11 changes: 1 addition & 10 deletions fedn/fedn/common/net/grpc/fedn.proto
Original file line number Diff line number Diff line change
Expand Up @@ -195,17 +195,10 @@ message ControlResponse {
repeated Parameter parameter = 2;
}

message ReportResponse {
Client sender = 1;
repeated Parameter parameter = 2;
}

service Control {
rpc Start(ControlRequest) returns (ControlResponse);
rpc Stop(ControlRequest) returns (ControlResponse);
rpc Configure(ControlRequest) returns (ReportResponse);
rpc FlushAggregationQueue(ControlRequest) returns (ControlResponse);
rpc Report(ControlRequest) returns (ReportResponse);
rpc FlushAggregationQueue(ControlRequest) returns (ControlResponse);
}

service Reducer {
Expand Down Expand Up @@ -253,9 +246,7 @@ service Combiner {
rpc ModelValidationRequestStream (ClientAvailableMessage) returns (stream ModelValidationRequest);
rpc ModelValidationStream (ClientAvailableMessage) returns (stream ModelValidation);

rpc SendModelUpdateRequest (ModelUpdateRequest) returns (Response);
rpc SendModelUpdate (ModelUpdate) returns (Response);
rpc SendModelValidationRequest (ModelValidationRequest) returns (Response);
rpc SendModelValidation (ModelValidation) returns (Response);

}
Expand Down
382 changes: 74 additions & 308 deletions fedn/fedn/common/net/grpc/fedn_pb2.py

Large diffs are not rendered by default.

995 changes: 434 additions & 561 deletions fedn/fedn/common/net/grpc/fedn_pb2_grpc.py

Large diffs are not rendered by default.

22 changes: 12 additions & 10 deletions fedn/fedn/common/net/grpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,20 @@
import grpc

import fedn.common.net.grpc.fedn_pb2_grpc as rpc
from fedn.common.log_config import (logger, set_log_level_from_string,
set_log_stream)


class Server:
"""

Server class for gRPC server.
"""

def __init__(self, servicer, modelservicer, config):

set_log_level_from_string(config.get('verbosity', "INFO"))
set_log_stream(config.get('logfile', None))

self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=350))
self.certificate = None

Expand All @@ -27,24 +32,21 @@ def __init__(self, servicer, modelservicer, config):
rpc.add_ControlServicer_to_server(servicer, self.server)

if config['secure']:
print("Creating secure gRPCS server using certificate: {config['certificate']}", flush=True)
logger.info(f'Creating secure gRPCS server using certificate: {config["certificate"]}')
server_credentials = grpc.ssl_server_credentials(
((config['key'], config['certificate'],),))
self.server.add_secure_port(
'[::]:' + str(config['port']), server_credentials)
else:
print("Creating insecure gRPC server", flush=True)
logger.info("Creating insecure gRPC server")
self.server.add_insecure_port('[::]:' + str(config['port']))

def start(self):
"""

"""
print("Server started", flush=True)
""" Start the gRPC server."""
logger.info("gRPC Server started")
self.server.start()

def stop(self):
"""

"""
""" Stop the gRPC server."""
logger.info("gRPC Server stopped")
self.server.stop(0)
6 changes: 3 additions & 3 deletions fedn/fedn/common/tracer/mongotracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def set_round_data(self, round_id, round_data):
self.rounds.update_one({'round_id': round_id}, {
'$set': {'round_data': round_data}}, True)

def update_client_status(self, client_name, status):
def update_client_status(self, clients, status):
""" Update client status in statestore.
:param client_name: The client name
:type client_name: str
Expand All @@ -119,7 +119,7 @@ def update_client_status(self, client_name, status):
:return: None
"""
datetime_now = datetime.now()
filter_query = {"name": client_name}
filter_query = {"name": {"$in": clients}}

update_query = {"$set": {"last_seen": datetime_now, "status": status}}
self.clients.update_one(filter_query, update_query)
self.clients.update_many(filter_query, update_query)
50 changes: 10 additions & 40 deletions fedn/fedn/network/api/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,6 @@ def _to_dict(self):
data = {"name": self.name}
return data

def _get_combiner_report(self, combiner_id):
"""Get report response from combiner.

:param combiner_id: The combiner id to get report response from.
:type combiner_id: str
::return: The report response from combiner.
::rtype: dict
"""
# Get CombinerInterface (fedn.network.combiner.inferface.CombinerInterface) for combiner_id
combiner = self.control.network.get_combiner(combiner_id)
report = combiner.report
return report

def _allowed_file_extension(
self, filename, ALLOWED_EXTENSIONS={"gz", "bz2", "tar", "zip", "tgz"}
):
Expand Down Expand Up @@ -91,30 +78,6 @@ def get_clients(self, limit=None, skip=None, status=False):

return jsonify(result)

def get_active_clients(self, combiner_id):
"""Get all active clients, i.e that are assigned to a combiner.
A report request to the combiner is neccessary to determine if a client is active or not.

:param combiner_id: The combiner id to get active clients for.
:type combiner_id: str
:return: All active clients as a json response.
:rtype: :class:`flask.Response`
"""
# Get combiner interface object
combiner = self.control.network.get_combiner(combiner_id)
if combiner is None:
return (
jsonify(
{
"success": False,
"message": f"Combiner {combiner_id} not found.",
}
),
404,
)
response = combiner.list_active_clients()
return response

def get_all_combiners(self, limit=None, skip=None):
"""Get all combiners from the statestore.

Expand Down Expand Up @@ -154,7 +117,6 @@ def get_combiner(self, combiner_id):
"fqdn": object["fqdn"],
"parent_reducer": object["parent"]["name"],
"port": object["port"],
"report": object["report"],
"updated_at": object["updated_at"],
}
payload[id] = info
Expand Down Expand Up @@ -832,12 +794,20 @@ def start_session(
{"success": False, "message": "A session is already running."}
)

# Check that initial (seed) model is set
if not self.statestore.get_initial_model():
return jsonify(
{
"success": False,
"message": "No initial model set. Set initial model before starting session.",
}
)

# Check available clients per combiner
clients_available = 0
for combiner in self.control.network.get_combiners():
try:
combiner_state = combiner.report()
nr_active_clients = combiner_state["nr_active_clients"]
nr_active_clients = len(combiner.list_active_clients())
clients_available = clients_available + int(nr_active_clients)
except CombinerUnavailableError as e:
# TODO: Handle unavailable combiner, stop session or continue?
Expand Down
18 changes: 1 addition & 17 deletions fedn/fedn/network/api/network.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import base64

from fedn.network.combiner.interfaces import (CombinerInterface,
CombinerUnavailableError)
from fedn.network.combiner.interfaces import CombinerInterface
from fedn.network.loadbalancer.leastpacked import LeastPacked

__all__ = 'Network',
Expand Down Expand Up @@ -154,18 +153,3 @@ def get_client_info(self):
:rtype: list(ObjectId)
"""
return self.statestore.list_clients()

def describe(self):
""" Describe the network.

:return: The network description
:rtype: dict
"""
network = []
for combiner in self.get_combiners():
try:
network.append(combiner.report())
except CombinerUnavailableError:
# TODO, do better here.
pass
return network
60 changes: 39 additions & 21 deletions fedn/fedn/network/clients/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,29 +441,37 @@ def _listen_to_model_update_request_stream(self):
# Add client to metadata
self._add_grpc_metadata('client', self.name)

while True:
while self._attached:
try:
for request in self.combinerStub.ModelUpdateRequestStream(r, metadata=self.metadata):
if request:
logger.debug("Received model update request from combiner: {}.".format(request))
if request.sender.role == fedn.COMBINER:
# Process training request
self._send_status("Received model update request.", log_level=fedn.Status.AUDIT,
type=fedn.StatusType.MODEL_UPDATE_REQUEST, request=request)
logger.info("Received model update request.")

self.inbox.put(('train', request))

if not self._attached:
return
except grpc.RpcError as e:
_ = e.code()
except grpc.RpcError:
# TODO: make configurable
timeout = 5
time.sleep(timeout)
except Exception:
raise
# Handle gRPC errors
status_code = e.code()
if status_code == grpc.StatusCode.UNAVAILABLE:
logger.warning("GRPC server unavailable during model update request stream. Retrying.")
# Retry after a delay
time.sleep(5)
else:
# Log the error and continue
logger.error(f"An error occurred during model update request stream: {e}")

if not self._attached:
return
except Exception as ex:
# Handle other exceptions
logger.error(f"An error occurred during model update request stream: {ex}")

# Detach if not attached
if not self._attached:
return

def _listen_to_model_validation_request_stream(self):
"""Subscribe to the model validation request stream.
Expand All @@ -479,17 +487,27 @@ def _listen_to_model_validation_request_stream(self):
try:
for request in self.combinerStub.ModelValidationRequestStream(r, metadata=self.metadata):
# Process validation request
_ = request.model_id
self._send_status("Recieved model validation request.", log_level=fedn.Status.AUDIT,
type=fedn.StatusType.MODEL_VALIDATION_REQUEST, request=request)
model_id = request.model_id
self._send_status("Received model validation request for model_id {}".format(model_id),
log_level=fedn.Status.AUDIT, type=fedn.StatusType.MODEL_VALIDATION_REQUEST,
request=request)
logger.info("Received model validation request for model_id {}".format(model_id))
self.inbox.put(('validate', request))

except grpc.RpcError:
# TODO: make configurable
timeout = 5
time.sleep(timeout)
except Exception:
raise
except grpc.RpcError as e:
# Handle gRPC errors
status_code = e.code()
if status_code == grpc.StatusCode.UNAVAILABLE:
logger.warning("GRPC server unavailable during model validation request stream. Retrying.")
# Retry after a delay
time.sleep(5)
else:
# Log the error and continue
logger.error(f"An error occurred during model validation request stream: {e}")

except Exception as ex:
# Handle other exceptions
logger.error(f"An error occurred during model validation request stream: {ex}")

if not self._attached:
return
Expand Down
12 changes: 5 additions & 7 deletions fedn/fedn/network/combiner/aggregators/aggregatorbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import queue
from abc import ABC, abstractmethod

import fedn.common.net.grpc.fedn_pb2 as fedn
from fedn.common.log_config import logger

AGGREGATOR_PLUGIN_PATH = "fedn.network.combiner.aggregators.{}"

Expand Down Expand Up @@ -60,19 +60,17 @@ def on_model_update(self, model_update):
:type model_id: str
"""
try:
self.server.report_status("AGGREGATOR({}): callback received model update {}".format(self.name, model_update.model_update_id),
log_level=fedn.Status.INFO)
logger.info("AGGREGATOR({}): callback received model update {}".format(self.name, model_update.model_update_id))

# Validate the update and metadata
valid_update = self._validate_model_update(model_update)
if valid_update:
# Push the model update to the processing queue
self.model_updates.put(model_update)
else:
self.server.report_status("AGGREGATOR({}): Invalid model update, skipping.".format(self.name))
logger.warning("AGGREGATOR({}): Invalid model update, skipping.".format(self.name))
except Exception as e:
self.server.report_status("AGGREGATOR({}): Failed to receive model update! {}".format(self.name, e),
log_level=fedn.Status.WARNING)
logger.error("AGGREGATOR({}): Failed to receive model update! {}".format(self.name, e))
pass

def _validate_model_update(self, model_update):
Expand All @@ -86,7 +84,7 @@ def _validate_model_update(self, model_update):
# TODO: Validate the metadata to check that it contains all variables assumed by the aggregator.
data = json.loads(model_update.meta)['training_metadata']
if 'num_examples' not in data.keys():
self.server.report_status("AGGREGATOR({}): Model validation failed, num_examples missing in metadata.".format(self.name))
logger.error("AGGREGATOR({}): Model validation failed, num_examples missing in metadata.".format(self.name))
return False
return True

Expand Down
13 changes: 6 additions & 7 deletions fedn/fedn/network/combiner/aggregators/fedavg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import fedn.common.net.grpc.fedn_pb2 as fedn
from fedn.common.log_config import logger
from fedn.network.combiner.aggregators.aggregatorbase import AggregatorBase


Expand Down Expand Up @@ -50,14 +50,14 @@ def combine_models(self, helper=None, time_window=180, max_nr_models=100, delete
nr_aggregated_models = 0
total_examples = 0

self.server.report_status(
logger.info(
"AGGREGATOR({}): Aggregating model updates... ".format(self.name))

while not self.model_updates.empty():
try:
# Get next model from queue
model_next, metadata, model_id = self.next_model_update(helper)
self.server.report_status(
logger.info(
"AGGREGATOR({}): Processing model update {}, metadata: {} ".format(self.name, model_id, metadata))

# Increment total number of examples
Expand All @@ -73,16 +73,15 @@ def combine_models(self, helper=None, time_window=180, max_nr_models=100, delete
# Delete model from storage
if delete_models:
self.modelservice.models.delete(model_id)
self.server.report_status(
logger.info(
"AGGREGATOR({}): Deleted model update {} from storage.".format(self.name, model_id))
self.model_updates.task_done()
except Exception as e:
self.server.report_status(
logger.error(
"AGGREGATOR({}): Error encoutered while processing model update {}, skipping this update.".format(self.name, e))
self.model_updates.task_done()

data['nr_aggregated_models'] = nr_aggregated_models

self.server.report_status("AGGREGATOR({}): Aggregation completed, aggregated {} models.".format(self.name, nr_aggregated_models),
log_level=fedn.Status.INFO)
logger.info("AGGREGATOR({}): Aggregation completed, aggregated {} models.".format(self.name, nr_aggregated_models))
return model, data
Loading
Loading