diff --git a/fedn/network/combiner/modelservice.py b/fedn/network/combiner/modelservice.py index ef5f9a75..b2a11861 100644 --- a/fedn/network/combiner/modelservice.py +++ b/fedn/network/combiner/modelservice.py @@ -9,7 +9,7 @@ from fedn.common.log_config import logger from fedn.network.storage.models.tempmodelstorage import TempModelStorage -CHUNK_SIZE = 2 * 1024 * 1024 +CHUNK_SIZE = 1 * 1024 * 1024 def upload_request_generator(mdl, id): diff --git a/fedn/network/combiner/roundhandler.py b/fedn/network/combiner/roundhandler.py index fa3d83e8..604a7724 100644 --- a/fedn/network/combiner/roundhandler.py +++ b/fedn/network/combiner/roundhandler.py @@ -131,7 +131,8 @@ def _training_round(self, config: dict, clients: list, provided_functions: dict) :return: an aggregated model and associated metadata :rtype: model, dict """ - logger.info("ROUNDHANDLER: Initiating training round, participating clients: {}".format(clients)) + logger.info( + "ROUNDHANDLER: Initiating training round, participating clients: {}".format(clients)) meta = {} meta["nr_expected_updates"] = len(clients) @@ -142,11 +143,14 @@ def _training_round(self, config: dict, clients: list, provided_functions: dict) model_id = config["model_id"] if provided_functions.get("client_settings", False): - global_model_bytes = self.modelservice.temp_model_storage.get(model_id) - client_settings = self.hook_interface.client_settings(global_model_bytes) + global_model_bytes = self.modelservice.temp_model_storage.get( + model_id) + client_settings = self.hook_interface.client_settings( + global_model_bytes) config["client_settings"] = client_settings # Request model updates from all active clients. - self.server.request_model_update(session_id=session_id, model_id=model_id, config=config, clients=clients) + self.server.request_model_update( + session_id=session_id, model_id=model_id, config=config, clients=clients) # If buffer_size is -1 (default), the round terminates when/if all clients have completed. if int(config["buffer_size"]) == -1: @@ -161,7 +165,8 @@ def _training_round(self, config: dict, clients: list, provided_functions: dict) data = None try: helper = get_helper(config["helper_type"]) - logger.info("Config delete_models_storage: {}".format(config["delete_models_storage"])) + logger.info("Config delete_models_storage: {}".format( + config["delete_models_storage"])) if config["delete_models_storage"] == "True": delete_models = True else: @@ -173,10 +178,13 @@ def _training_round(self, config: dict, clients: list, provided_functions: dict) else: parameters = None if provided_functions.get("aggregate", False): - previous_model_bytes = self.modelservice.temp_model_storage.get(model_id) - model, data = self.hook_interface.aggregate(previous_model_bytes, self.update_handler, helper, delete_models=delete_models) + previous_model_bytes = self.modelservice.temp_model_storage.get( + model_id) + model, data = self.hook_interface.aggregate( + previous_model_bytes, self.update_handler, helper, delete_models=delete_models) else: - model, data = self.aggregator.combine_models(helper=helper, delete_models=delete_models, parameters=parameters) + model, data = self.aggregator.combine_models( + helper=helper, delete_models=delete_models, parameters=parameters) except Exception as e: logger.warning("AGGREGATION FAILED AT COMBINER! {}".format(e)) raise @@ -195,7 +203,8 @@ def _validation_round(self, session_id, model_id, clients): :param model_id: The ID of the model to validate :type model_id: str """ - self.server.request_model_validation(session_id, model_id, clients=clients) + self.server.request_model_validation( + session_id, model_id, clients=clients) def _prediction_round(self, prediction_id: str, model_id: str, clients: list): """Send model prediction requests to clients. @@ -207,7 +216,8 @@ def _prediction_round(self, prediction_id: str, model_id: str, clients: list): :param model_id: The ID of the model to use for prediction :type model_id: str """ - self.server.request_model_prediction(prediction_id, model_id, clients=clients) + self.server.request_model_prediction( + prediction_id, model_id, clients=clients) def stage_model(self, model_id, timeout_retry=3, retry=2): """Download a model from persistent storage and set in modelservice. @@ -221,7 +231,8 @@ def stage_model(self, model_id, timeout_retry=3, retry=2): """ # If the model is already in memory at the server we do not need to do anything. if self.modelservice.temp_model_storage.exist(model_id): - logger.info("Model already exists in memory, skipping model staging.") + logger.info( + "Model already exists in memory, skipping model staging.") return logger.info("Model Staging, fetching model from storage...") # If not, download it and stage it in memory at the combiner. @@ -232,11 +243,13 @@ def stage_model(self, model_id, timeout_retry=3, retry=2): if model: break except Exception: - logger.warning("Could not fetch model from storage backend, retrying.") + logger.warning( + "Could not fetch model from storage backend, retrying.") time.sleep(timeout_retry) tries += 1 if tries > retry: - logger.error("Failed to stage model {} from storage backend!".format(model_id)) + logger.error( + "Failed to stage model {} from storage backend!".format(model_id)) raise self.modelservice.set_model(model, model_id) @@ -256,7 +269,8 @@ def _assign_round_clients(self, n, type="trainers"): elif type == "trainers": clients = self.server.get_active_trainers() else: - logger.error("(ERROR): {} is not a supported type of client".format(type)) + logger.error( + "(ERROR): {} is not a supported type of client".format(type)) # If the number of requested trainers exceeds the number of available, use all available. n = min(n, len(clients)) @@ -278,7 +292,8 @@ def _check_nr_round_clients(self, config): """ active = self.server.nr_active_trainers() if active >= int(config["clients_required"]): - logger.info("Number of clients required ({0}) to start round met {1}.".format(config["clients_required"], active)) + logger.info("Number of clients required ({0}) to start round met {1}.".format( + config["clients_required"], active)) return True else: logger.info("Too few clients to start round.") @@ -290,9 +305,11 @@ def execute_validation_round(self, session_id, model_id): :param round_config: The round config object. :type round_config: dict """ - logger.info("COMBINER orchestrating validation of model {}".format(model_id)) + logger.info( + "COMBINER orchestrating validation of model {}".format(model_id)) self.stage_model(model_id) - validators = self._assign_round_clients(self.server.max_clients, type="validators") + validators = self._assign_round_clients( + self.server.max_clients, type="validators") self._validation_round(session_id, model_id, validators) def execute_prediction_round(self, prediction_id: str, model_id: str) -> None: @@ -301,10 +318,12 @@ def execute_prediction_round(self, prediction_id: str, model_id: str) -> None: :param round_config: The round config object. :type round_config: dict """ - logger.info("COMBINER orchestrating prediction using model {}".format(model_id)) + logger.info( + "COMBINER orchestrating prediction using model {}".format(model_id)) self.stage_model(model_id) # TODO: Implement prediction client type - clients = self._assign_round_clients(self.server.max_clients, type="validators") + clients = self._assign_round_clients( + self.server.max_clients, type="validators") self._prediction_round(prediction_id, model_id, clients) def execute_training_round(self, config): @@ -315,7 +334,8 @@ def execute_training_round(self, config): :return: metadata about the training round. :rtype: dict """ - logger.info("Processing training round, job_id {}".format(config["_job_id"])) + logger.info("Processing training round, job_id {}".format( + config["_job_id"])) data = {} data["config"] = config @@ -324,17 +344,20 @@ def execute_training_round(self, config): # Download model to update and set in temp storage. self.stage_model(config["model_id"]) - provided_functions = self.hook_interface.provided_functions(self.server_functions) + provided_functions = self.hook_interface.provided_functions( + self.server_functions) if provided_functions.get("client_selection", False): - clients = self.hook_interface.client_selection(clients=self.server.get_active_trainers()) + clients = self.hook_interface.client_selection( + clients=self.server.get_active_trainers()) else: clients = self._assign_round_clients(self.server.max_clients) model, meta = self._training_round(config, clients, provided_functions) data["data"] = meta if model is None: - logger.warning("\t Failed to update global model in round {0}!".format(config["round_id"])) + logger.warning( + "\t Failed to update global model in round {0}!".format(config["round_id"])) if model is not None: helper = get_helper(config["helper_type"]) @@ -343,7 +366,8 @@ def execute_training_round(self, config): a.close() data["model_id"] = model_id - logger.info("TRAINING ROUND COMPLETED. Aggregated model id: {}, Job id: {}".format(model_id, config["_job_id"])) + logger.info("TRAINING ROUND COMPLETED. Aggregated model id: {}, Job id: {}".format( + model_id, config["_job_id"])) # Delete temp model self.modelservice.temp_model_storage.delete(config["model_id"]) @@ -369,11 +393,14 @@ def run(self, polling_interval=1.0): session_id = round_config["session_id"] model_id = round_config["model_id"] tic = time.time() - round_meta = self.execute_training_round(round_config) - round_meta["time_exec_training"] = time.time() - tic + round_meta = self.execute_training_round( + round_config) + round_meta["time_exec_training"] = time.time() - \ + tic round_meta["status"] = "Success" round_meta["name"] = self.server.id - self.server.statestore.set_round_combiner_data(round_meta) + self.server.statestore.set_round_combiner_data( + round_meta) elif round_config["task"] == "validation": session_id = round_config["session_id"] model_id = round_config["model_id"] @@ -381,7 +408,8 @@ def run(self, polling_interval=1.0): elif round_config["task"] == "prediction": prediction_id = round_config["prediction_id"] model_id = round_config["model_id"] - self.execute_prediction_round(prediction_id, model_id) + self.execute_prediction_round( + prediction_id, model_id) else: logger.warning("config contains unkown task type.") else: diff --git a/fedn/network/grpc/server.py b/fedn/network/grpc/server.py index a581c16b..7f610932 100644 --- a/fedn/network/grpc/server.py +++ b/fedn/network/grpc/server.py @@ -33,7 +33,6 @@ def __init__(self, servicer, config: ServerConfig): KEEPALIVE_TIMEOUT_MS = 20 * 1000 # max idle time before server terminates the connection (5 minutes) MAX_CONNECTION_IDLE_MS = 5 * 60 * 1000 - MAX_MESSAGE_LENGTH = 100 * 1024 * 1024 self.server = grpc.server( futures.ThreadPoolExecutor(max_workers=350), @@ -42,8 +41,6 @@ def __init__(self, servicer, config: ServerConfig): ("grpc.keepalive_time_ms", KEEPALIVE_TIME_MS), ("grpc.keepalive_timeout_ms", KEEPALIVE_TIMEOUT_MS), ("grpc.max_connection_idle_ms", MAX_CONNECTION_IDLE_MS), - ('grpc.max_send_message_length', MAX_MESSAGE_LENGTH), - ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH), ], ) self.certificate = None