Skip to content

Commit

Permalink
Feature/SK-946 | Graceful failing if new container is not present (#733)
Browse files Browse the repository at this point in the history
  • Loading branch information
viktorvaladi authored Nov 4, 2024
1 parent 6462978 commit c03519c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 20 deletions.
39 changes: 23 additions & 16 deletions fedn/network/combiner/hooks/hook_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,22 @@ class CombinerHookInterface:

def __init__(self):
"""Initialize CombinerHookInterface client."""
self.hook_service_host = os.getenv("HOOK_SERVICE_HOST", "hook:12081")
self.channel = grpc.insecure_channel(
self.hook_service_host,
options=[
("grpc.keepalive_time_ms", 30000), # 30 seconds ping interval
("grpc.keepalive_timeout_ms", 5000), # 5 seconds timeout for a response
("grpc.keepalive_permit_without_calls", 1), # allow keepalives even with no active calls
("grpc.enable_retries", 1), # automatic retries
("grpc.initial_reconnect_backoff_ms", 1000), # initial delay before retrying
("grpc.max_reconnect_backoff_ms", 5000), # maximum delay before retrying
],
)
self.stub = rpc.FunctionServiceStub(self.channel)
try:
self.hook_service_host = os.getenv("HOOK_SERVICE_HOST", "hook:12081")
self.channel = grpc.insecure_channel(
self.hook_service_host,
options=[
("grpc.keepalive_time_ms", 30000), # 30 seconds ping interval
("grpc.keepalive_timeout_ms", 5000), # 5 seconds timeout for a response
("grpc.keepalive_permit_without_calls", 1), # allow keepalives even with no active calls
("grpc.enable_retries", 1), # automatic retries
("grpc.initial_reconnect_backoff_ms", 1000), # initial delay before retrying
("grpc.max_reconnect_backoff_ms", 5000), # maximum delay before retrying
],
)
self.stub = rpc.FunctionServiceStub(self.channel)
except Exception as e:
logger.warning(f"Failed to initialize connection to hooks container with error {e}")

def provided_functions(self, server_functions: str):
"""Communicates to hook container and asks which functions are available.
Expand All @@ -39,10 +42,14 @@ def provided_functions(self, server_functions: str):
:return: dictionary specifing which functions are implemented.
:rtype: dict
"""
request = fedn.ProvidedFunctionsRequest(function_code=server_functions)
try:
request = fedn.ProvidedFunctionsRequest(function_code=server_functions)

response = self.stub.HandleProvidedFunctions(request)
return response.available_functions
response = self.stub.HandleProvidedFunctions(request)
return response.available_functions
except Exception as e:
logger.warning(f"Was not able to communicate to hooks container due to: {e}")
return {}

def client_settings(self, global_model) -> dict:
"""Communicates to hook container to get a client config.
Expand Down
8 changes: 4 additions & 4 deletions fedn/network/combiner/roundhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def push_round_config(self, round_config: RoundConfig) -> str:
raise
return round_config["_job_id"]

def _training_round(self, config, clients, provided_functions):
def _training_round(self, config: dict, clients: list, provided_functions: dict):
"""Send model update requests to clients and aggregate results.
:param config: The round config object (passed to the client).
Expand All @@ -141,7 +141,7 @@ def _training_round(self, config, clients, provided_functions):
session_id = config["session_id"]
model_id = config["model_id"]

if provided_functions["client_settings"]:
if provided_functions.get("client_settings", False):
global_model_bytes = self.modelservice.temp_model_storage.get(model_id)
client_settings = self.hook_interface.client_settings(global_model_bytes)
config["client_settings"] = client_settings
Expand Down Expand Up @@ -172,7 +172,7 @@ def _training_round(self, config, clients, provided_functions):
parameters = Parameters(dict_parameters)
else:
parameters = None
if provided_functions["aggregate"]:
if provided_functions.get("aggregate", False):
previous_model_bytes = self.modelservice.temp_model_storage.get(model_id)
model, data = self.hook_interface.aggregate(previous_model_bytes, self.update_handler, helper, delete_models=delete_models)
else:
Expand Down Expand Up @@ -326,7 +326,7 @@ def execute_training_round(self, config):

provided_functions = self.hook_interface.provided_functions(self.server_functions)

if provided_functions["client_selection"]:
if provided_functions.get("client_selection", False):
clients = self.hook_interface.client_selection(clients=self.server.get_active_trainers())
else:
clients = self._assign_round_clients(self.server.max_clients)
Expand Down

0 comments on commit c03519c

Please sign in to comment.