Skip to content

Commit

Permalink
Latest
Browse files Browse the repository at this point in the history
  • Loading branch information
Andreas Hellander committed Dec 16, 2024
1 parent ce258e6 commit e772691
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 32 deletions.
2 changes: 1 addition & 1 deletion fedn/network/combiner/modelservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fedn.common.log_config import logger
from fedn.network.storage.models.tempmodelstorage import TempModelStorage

CHUNK_SIZE = 2 * 1024 * 1024
CHUNK_SIZE = 1 * 1024 * 1024


def upload_request_generator(mdl, id):
Expand Down
84 changes: 56 additions & 28 deletions fedn/network/combiner/roundhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def _training_round(self, config: dict, clients: list, provided_functions: dict)
:return: an aggregated model and associated metadata
:rtype: model, dict
"""
logger.info("ROUNDHANDLER: Initiating training round, participating clients: {}".format(clients))
logger.info(
"ROUNDHANDLER: Initiating training round, participating clients: {}".format(clients))

meta = {}
meta["nr_expected_updates"] = len(clients)
Expand All @@ -142,11 +143,14 @@ def _training_round(self, config: dict, clients: list, provided_functions: dict)
model_id = config["model_id"]

if provided_functions.get("client_settings", False):
global_model_bytes = self.modelservice.temp_model_storage.get(model_id)
client_settings = self.hook_interface.client_settings(global_model_bytes)
global_model_bytes = self.modelservice.temp_model_storage.get(
model_id)
client_settings = self.hook_interface.client_settings(
global_model_bytes)
config["client_settings"] = client_settings
# Request model updates from all active clients.
self.server.request_model_update(session_id=session_id, model_id=model_id, config=config, clients=clients)
self.server.request_model_update(
session_id=session_id, model_id=model_id, config=config, clients=clients)

# If buffer_size is -1 (default), the round terminates when/if all clients have completed.
if int(config["buffer_size"]) == -1:
Expand All @@ -161,7 +165,8 @@ def _training_round(self, config: dict, clients: list, provided_functions: dict)
data = None
try:
helper = get_helper(config["helper_type"])
logger.info("Config delete_models_storage: {}".format(config["delete_models_storage"]))
logger.info("Config delete_models_storage: {}".format(
config["delete_models_storage"]))
if config["delete_models_storage"] == "True":
delete_models = True
else:
Expand All @@ -173,10 +178,13 @@ def _training_round(self, config: dict, clients: list, provided_functions: dict)
else:
parameters = None
if provided_functions.get("aggregate", False):
previous_model_bytes = self.modelservice.temp_model_storage.get(model_id)
model, data = self.hook_interface.aggregate(previous_model_bytes, self.update_handler, helper, delete_models=delete_models)
previous_model_bytes = self.modelservice.temp_model_storage.get(
model_id)
model, data = self.hook_interface.aggregate(
previous_model_bytes, self.update_handler, helper, delete_models=delete_models)
else:
model, data = self.aggregator.combine_models(helper=helper, delete_models=delete_models, parameters=parameters)
model, data = self.aggregator.combine_models(
helper=helper, delete_models=delete_models, parameters=parameters)
except Exception as e:
logger.warning("AGGREGATION FAILED AT COMBINER! {}".format(e))
raise
Expand All @@ -195,7 +203,8 @@ def _validation_round(self, session_id, model_id, clients):
:param model_id: The ID of the model to validate
:type model_id: str
"""
self.server.request_model_validation(session_id, model_id, clients=clients)
self.server.request_model_validation(
session_id, model_id, clients=clients)

def _prediction_round(self, prediction_id: str, model_id: str, clients: list):
"""Send model prediction requests to clients.
Expand All @@ -207,7 +216,8 @@ def _prediction_round(self, prediction_id: str, model_id: str, clients: list):
:param model_id: The ID of the model to use for prediction
:type model_id: str
"""
self.server.request_model_prediction(prediction_id, model_id, clients=clients)
self.server.request_model_prediction(
prediction_id, model_id, clients=clients)

def stage_model(self, model_id, timeout_retry=3, retry=2):
"""Download a model from persistent storage and set in modelservice.
Expand All @@ -221,7 +231,8 @@ def stage_model(self, model_id, timeout_retry=3, retry=2):
"""
# If the model is already in memory at the server we do not need to do anything.
if self.modelservice.temp_model_storage.exist(model_id):
logger.info("Model already exists in memory, skipping model staging.")
logger.info(
"Model already exists in memory, skipping model staging.")
return
logger.info("Model Staging, fetching model from storage...")
# If not, download it and stage it in memory at the combiner.
Expand All @@ -232,11 +243,13 @@ def stage_model(self, model_id, timeout_retry=3, retry=2):
if model:
break
except Exception:
logger.warning("Could not fetch model from storage backend, retrying.")
logger.warning(
"Could not fetch model from storage backend, retrying.")
time.sleep(timeout_retry)
tries += 1
if tries > retry:
logger.error("Failed to stage model {} from storage backend!".format(model_id))
logger.error(
"Failed to stage model {} from storage backend!".format(model_id))
raise

self.modelservice.set_model(model, model_id)
Expand All @@ -256,7 +269,8 @@ def _assign_round_clients(self, n, type="trainers"):
elif type == "trainers":
clients = self.server.get_active_trainers()
else:
logger.error("(ERROR): {} is not a supported type of client".format(type))
logger.error(
"(ERROR): {} is not a supported type of client".format(type))

# If the number of requested trainers exceeds the number of available, use all available.
n = min(n, len(clients))
Expand All @@ -278,7 +292,8 @@ def _check_nr_round_clients(self, config):
"""
active = self.server.nr_active_trainers()
if active >= int(config["clients_required"]):
logger.info("Number of clients required ({0}) to start round met {1}.".format(config["clients_required"], active))
logger.info("Number of clients required ({0}) to start round met {1}.".format(
config["clients_required"], active))
return True
else:
logger.info("Too few clients to start round.")
Expand All @@ -290,9 +305,11 @@ def execute_validation_round(self, session_id, model_id):
:param round_config: The round config object.
:type round_config: dict
"""
logger.info("COMBINER orchestrating validation of model {}".format(model_id))
logger.info(
"COMBINER orchestrating validation of model {}".format(model_id))
self.stage_model(model_id)
validators = self._assign_round_clients(self.server.max_clients, type="validators")
validators = self._assign_round_clients(
self.server.max_clients, type="validators")
self._validation_round(session_id, model_id, validators)

def execute_prediction_round(self, prediction_id: str, model_id: str) -> None:
Expand All @@ -301,10 +318,12 @@ def execute_prediction_round(self, prediction_id: str, model_id: str) -> None:
:param round_config: The round config object.
:type round_config: dict
"""
logger.info("COMBINER orchestrating prediction using model {}".format(model_id))
logger.info(
"COMBINER orchestrating prediction using model {}".format(model_id))
self.stage_model(model_id)
# TODO: Implement prediction client type
clients = self._assign_round_clients(self.server.max_clients, type="validators")
clients = self._assign_round_clients(
self.server.max_clients, type="validators")
self._prediction_round(prediction_id, model_id, clients)

def execute_training_round(self, config):
Expand All @@ -315,7 +334,8 @@ def execute_training_round(self, config):
:return: metadata about the training round.
:rtype: dict
"""
logger.info("Processing training round, job_id {}".format(config["_job_id"]))
logger.info("Processing training round, job_id {}".format(
config["_job_id"]))

data = {}
data["config"] = config
Expand All @@ -324,17 +344,20 @@ def execute_training_round(self, config):
# Download model to update and set in temp storage.
self.stage_model(config["model_id"])

provided_functions = self.hook_interface.provided_functions(self.server_functions)
provided_functions = self.hook_interface.provided_functions(
self.server_functions)

if provided_functions.get("client_selection", False):
clients = self.hook_interface.client_selection(clients=self.server.get_active_trainers())
clients = self.hook_interface.client_selection(
clients=self.server.get_active_trainers())
else:
clients = self._assign_round_clients(self.server.max_clients)
model, meta = self._training_round(config, clients, provided_functions)
data["data"] = meta

if model is None:
logger.warning("\t Failed to update global model in round {0}!".format(config["round_id"]))
logger.warning(
"\t Failed to update global model in round {0}!".format(config["round_id"]))

if model is not None:
helper = get_helper(config["helper_type"])
Expand All @@ -343,7 +366,8 @@ def execute_training_round(self, config):
a.close()
data["model_id"] = model_id

logger.info("TRAINING ROUND COMPLETED. Aggregated model id: {}, Job id: {}".format(model_id, config["_job_id"]))
logger.info("TRAINING ROUND COMPLETED. Aggregated model id: {}, Job id: {}".format(
model_id, config["_job_id"]))

# Delete temp model
self.modelservice.temp_model_storage.delete(config["model_id"])
Expand All @@ -369,19 +393,23 @@ def run(self, polling_interval=1.0):
session_id = round_config["session_id"]
model_id = round_config["model_id"]
tic = time.time()
round_meta = self.execute_training_round(round_config)
round_meta["time_exec_training"] = time.time() - tic
round_meta = self.execute_training_round(
round_config)
round_meta["time_exec_training"] = time.time() - \
tic
round_meta["status"] = "Success"
round_meta["name"] = self.server.id
self.server.statestore.set_round_combiner_data(round_meta)
self.server.statestore.set_round_combiner_data(
round_meta)
elif round_config["task"] == "validation":
session_id = round_config["session_id"]
model_id = round_config["model_id"]
self.execute_validation_round(session_id, model_id)
elif round_config["task"] == "prediction":
prediction_id = round_config["prediction_id"]
model_id = round_config["model_id"]
self.execute_prediction_round(prediction_id, model_id)
self.execute_prediction_round(
prediction_id, model_id)
else:
logger.warning("config contains unkown task type.")
else:
Expand Down
3 changes: 0 additions & 3 deletions fedn/network/grpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def __init__(self, servicer, config: ServerConfig):
KEEPALIVE_TIMEOUT_MS = 20 * 1000
# max idle time before server terminates the connection (5 minutes)
MAX_CONNECTION_IDLE_MS = 5 * 60 * 1000
MAX_MESSAGE_LENGTH = 100 * 1024 * 1024

self.server = grpc.server(
futures.ThreadPoolExecutor(max_workers=350),
Expand All @@ -42,8 +41,6 @@ def __init__(self, servicer, config: ServerConfig):
("grpc.keepalive_time_ms", KEEPALIVE_TIME_MS),
("grpc.keepalive_timeout_ms", KEEPALIVE_TIMEOUT_MS),
("grpc.max_connection_idle_ms", MAX_CONNECTION_IDLE_MS),
('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
],
)
self.certificate = None
Expand Down

0 comments on commit e772691

Please sign in to comment.