From 512a4ca3625e072e0c609685b56a307cbc33b683 Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Thu, 23 May 2024 17:14:15 +0200 Subject: [PATCH] initial inference + RoundConfig --- .../combiner/aggregators/aggregatorbase.py | 8 +- fedn/network/combiner/combiner.py | 97 +++++++++++------ fedn/network/combiner/interfaces.py | 3 +- fedn/network/combiner/roundhandler.py | 100 ++++++++++++++++-- fedn/network/controller/control.py | 45 +++++++- fedn/network/controller/controlbase.py | 16 +-- .../storage/statestore/mongostatestore.py | 5 +- 7 files changed, 213 insertions(+), 61 deletions(-) diff --git a/fedn/network/combiner/aggregators/aggregatorbase.py b/fedn/network/combiner/aggregators/aggregatorbase.py index 0a9c33f43..f6931645a 100644 --- a/fedn/network/combiner/aggregators/aggregatorbase.py +++ b/fedn/network/combiner/aggregators/aggregatorbase.py @@ -86,9 +86,11 @@ def _validate_model_update(self, model_update): :return: True if the model update is valid, False otherwise. :rtype: bool """ - data = json.loads(model_update.meta)["training_metadata"] - if "num_examples" not in data.keys(): - logger.error("AGGREGATOR({}): Model validation failed, num_examples missing in metadata.".format(self.name)) + try: + data = json.loads(model_update.meta)["training_metadata"] + num_examples = data["num_examples"] + except KeyError as e: + logger.error("AGGREGATOR({}): Invalid model update, missing metadata.".format(self.name)) return False return True diff --git a/fedn/network/combiner/combiner.py b/fedn/network/combiner/combiner.py index 8eacd917e..a053cc3ca 100644 --- a/fedn/network/combiner/combiner.py +++ b/fedn/network/combiner/combiner.py @@ -12,10 +12,11 @@ import fedn.network.grpc.fedn_pb2 as fedn import fedn.network.grpc.fedn_pb2_grpc as rpc -from fedn.common.log_config import logger, set_log_level_from_string, set_log_stream +from fedn.common.log_config import (logger, set_log_level_from_string, + set_log_stream) from fedn.network.combiner.connect import ConnectorCombiner, Status from fedn.network.combiner.modelservice import ModelService -from fedn.network.combiner.roundhandler import RoundHandler +from fedn.network.combiner.roundhandler import RoundConfig, RoundHandler from fedn.network.grpc.server import Server from fedn.network.storage.s3.repository import Repository from fedn.network.storage.statestore.mongostatestore import MongoStateStore @@ -161,7 +162,7 @@ def __whoami(self, client, instance): client.role = role_to_proto_role(instance.role) return client - def request_model_update(self, config, clients=[]): + def request_model_update(self, session_id, model_id, config, clients=[]): """Ask clients to update the current global model. :param config: the model configuration to send to clients @@ -170,32 +171,14 @@ def request_model_update(self, config, clients=[]): :type clients: list """ - # The request to be added to the client queue - request = fedn.TaskRequest() - request.model_id = config["model_id"] - request.correlation_id = str(uuid.uuid4()) - request.timestamp = str(datetime.now()) - request.data = json.dumps(config) - request.type = fedn.StatusType.MODEL_UPDATE - request.session_id = config["session_id"] - - request.sender.name = self.id - request.sender.role = fedn.COMBINER - - if len(clients) == 0: - clients = self.get_active_trainers() - - for client in clients: - request.receiver.name = client - request.receiver.role = fedn.WORKER - self._put_request_to_client_queue(request, fedn.Queue.TASK_QUEUE) + request, clients = self._send_request_type(fedn.StatusType.MODEL_UPDATE, session_id, model_id, config, clients) if len(clients) < 20: logger.info("Sent model update request for model {} to clients {}".format(request.model_id, clients)) else: logger.info("Sent model update request for model {} to {} clients".format(request.model_id, len(clients))) - def request_model_validation(self, model_id, config, clients=[]): + def request_model_validation(self, session_id, model_id, clients=[]): """Ask clients to validate the current global model. :param model_id: the model id to validate @@ -206,30 +189,76 @@ def request_model_validation(self, model_id, config, clients=[]): :type clients: list """ - # The request to be added to the client queue + request, clients = self._send_request_type(fedn.StatusType.MODEL_VALIDATION, session_id, model_id, clients) + + if len(clients) < 20: + logger.info("Sent model validation request for model {} to clients {}".format(request.model_id, clients)) + else: + logger.info("Sent model validation request for model {} to {} clients".format(request.model_id, len(clients))) + + def request_model_inference(self, session_id: str, model_id: str, clients: list=[]) -> None: + """Ask clients to perform inference on the model. + + :param model_id: the model id to perform inference on + :type model_id: str + :param config: the model configuration to send to clients + :type config: dict + :param clients: the clients to send the request to + :type clients: list + + """ + request, clients = self._send_request_type(fedn.StatusType.INFERENCE, session_id, model_id, clients) + + if len(clients) < 20: + logger.info("Sent model inference request for model {} to clients {}".format(request.model_id, clients)) + else: + logger.info("Sent model inference request for model {} to {} clients".format(request.model_id, len(clients))) + + def _send_request_type(self, request_type, session_id, model_id, config=None, clients=[]): + """Send a request of a specific type to clients. + + :param request_type: the type of request + :type request_type: :class:`fedn.network.grpc.fedn_pb2.StatusType` + :param model_id: the model id to send in the request + :type model_id: str + :param config: the model configuration to send to clients + :type config: dict + :param clients: the clients to send the request to + :type clients: list + :return: the request and the clients + :rtype: tuple + """ request = fedn.TaskRequest() request.model_id = model_id request.correlation_id = str(uuid.uuid4()) request.timestamp = str(datetime.now()) - # request.is_inference = (config['task'] == 'inference') - request.type = fedn.StatusType.MODEL_VALIDATION + request.type = request_type + request.session_id = session_id request.sender.name = self.id request.sender.role = fedn.COMBINER - request.session_id = config["session_id"] - if len(clients) == 0: - clients = self.get_active_validators() + if request_type == fedn.StatusType.MODEL_UPDATE: + request.data = json.dumps(config) + if len(clients) == 0: + clients = self.get_active_trainers() + elif request_type == fedn.StatusType.MODEL_VALIDATION: + if len(clients) == 0: + clients = self.get_active_validators() + elif request_type == fedn.StatusType.INFERENCE: + request.data = json.dumps(config) + if len(clients) == 0: + # TODO: add inference clients type + clients = self.get_active_validators() + + # TODO: if inference, request.data should be user-defined data/parameters for client in clients: request.receiver.name = client request.receiver.role = fedn.WORKER self._put_request_to_client_queue(request, fedn.Queue.TASK_QUEUE) - if len(clients) < 20: - logger.info("Sent model validation request for model {} to clients {}".format(request.model_id, clients)) - else: - logger.info("Sent model validation request for model {} to {} clients".format(request.model_id, len(clients))) + return request, clients def get_active_trainers(self): """Get a list of active trainers. @@ -410,7 +439,7 @@ def Start(self, control: fedn.ControlRequest, context): """ logger.info("grpc.Combiner.Start: Starting round") - config = {} + config = RoundConfig() for parameter in control.parameter: config.update({parameter.key: parameter.value}) diff --git a/fedn/network/combiner/interfaces.py b/fedn/network/combiner/interfaces.py index bf10a00f1..935b75442 100644 --- a/fedn/network/combiner/interfaces.py +++ b/fedn/network/combiner/interfaces.py @@ -8,6 +8,7 @@ import fedn.network.grpc.fedn_pb2 as fedn import fedn.network.grpc.fedn_pb2_grpc as rpc +from fedn.network.combiner.roundhandler import RoundConfig class CombinerUnavailableError(Exception): @@ -202,7 +203,7 @@ def set_aggregator(self, aggregator): else: raise - def submit(self, config): + def submit(self, config: RoundConfig): """Submit a compute plan to the combiner. :param config: The job configuration. diff --git a/fedn/network/combiner/roundhandler.py b/fedn/network/combiner/roundhandler.py index 4edc04b6e..6af9366e4 100644 --- a/fedn/network/combiner/roundhandler.py +++ b/fedn/network/combiner/roundhandler.py @@ -4,14 +4,64 @@ import sys import time import uuid +from typing import TypedDict from fedn.common.log_config import logger from fedn.network.combiner.aggregators.aggregatorbase import get_aggregator -from fedn.network.combiner.modelservice import load_model_from_BytesIO, serialize_model_to_BytesIO +from fedn.network.combiner.modelservice import (load_model_from_BytesIO, + serialize_model_to_BytesIO) from fedn.utils.helpers.helpers import get_helper from fedn.utils.parameters import Parameters +class RoundConfig(TypedDict): + """Round configuration. + + :param _job_id: A universally unique identifier for the round. Set by Combiner. + :type _job_id: str + :param committed_at: The time the round was committed. Set by Controller. + :type committed_at: str + :param task: The task to perform in the round. Set by Controller. Supported tasks are "training", "validation", and "inference". + :type task: str + :param round_id: The round identifier as str(int) + :type round_id: str + :param round_timeout: The round timeout in seconds. Set by user interfaces or Controller. + :type round_timeout: str + :param rounds: The number of rounds. Set by user interfaces. + :param model_id: The model identifier. Set by user interfaces or Controller (get_latest_model). + :type model_id: str + :param model_version: The model version. Currently not used. + :type model_version: str + :param model_type: The model type. Currently not used. + :type model_type: str + :param model_size: The size of the model. Currently not used. + :type model_size: int + :param model_parameters: The model parameters. Currently not used. + :type model_parameters: dict + :param model_metadata: The model metadata. Currently not used. + :type model_metadata: dict + :param session_id: The session identifier. Set by (Controller?). + :type session_id: str + :param helper_type: The helper type. + :type helper_type: str + :param aggregator: The aggregator type. + :type aggregator: str + """ + _job_id: str + committed_at: str + task: str + round_id: str + round_timeout: str + rounds: int + model_id: str + model_version: str + model_type: str + model_size: int + model_parameters: dict + model_metadata: dict + session_id: str + helper_type: str + aggregator: str class ModelUpdateError(Exception): pass @@ -42,7 +92,7 @@ def __init__(self, storage, server, modelservice): def set_aggregator(self, aggregator): self.aggregator = get_aggregator(aggregator, self.storage, self.server, self.modelservice, self) - def push_round_config(self, round_config): + def push_round_config(self, round_config: RoundConfig) -> str: """Add a round_config (job description) to the inbox. :param round_config: A dict containing the round configuration (from global controller). @@ -144,8 +194,11 @@ def _training_round(self, config, clients): meta["nr_required_updates"] = int(config["clients_required"]) meta["timeout"] = float(config["round_timeout"]) + session_id = config["session_id"] + model_id = config["model_id"] + # Request model updates from all active clients. - self.server.request_model_update(config, clients=clients) + self.server.request_model_update(session_id=session_id, model_id=model_id, config=config, clients=clients) # If buffer_size is -1 (default), the round terminates when/if all clients have completed. if int(config["buffer_size"]) == -1: @@ -182,7 +235,7 @@ def _training_round(self, config, clients): meta["aggregation_time"] = data return model, meta - def _validation_round(self, config, clients, model_id): + def _validation_round(self, session_id, model_id, clients): """Send model validation requests to clients. :param config: The round config object (passed to the client). @@ -192,7 +245,19 @@ def _validation_round(self, config, clients, model_id): :param model_id: The ID of the model to validate :type model_id: str """ - self.server.request_model_validation(model_id, config, clients) + self.server.request_model_validation(session_id, model_id, clients=clients) + + def _inference_round(self, session_id: str, model_id: str, clients: list): + """Send model inference requests to clients. + + :param config: The round config object (passed to the client). + :type config: dict + :param clients: clients to send inference requests to + :type clients: list + :param model_id: The ID of the model to use for inference + :type model_id: str + """ + self.server.request_model_inference(session_id, model_id, clients=clients) def stage_model(self, model_id, timeout_retry=3, retry=2): """Download a model from persistent storage and set in modelservice. @@ -271,17 +336,28 @@ def _check_nr_round_clients(self, config): logger.info("Too few clients to start round.") return False - def execute_validation_round(self, round_config): + def execute_validation_round(self, session_id, model_id): """Coordinate validation rounds as specified in config. :param round_config: The round config object. :type round_config: dict """ - model_id = round_config["model_id"] logger.info("COMBINER orchestrating validation of model {}".format(model_id)) self.stage_model(model_id) validators = self._assign_round_clients(self.server.max_clients, type="validators") - self._validation_round(round_config, validators, model_id) + self._validation_round(session_id, model_id, validators) + + def execute_inference_round(self, session_id: str, model_id: str) -> None: + """Coordinate inference rounds as specified in config. + + :param round_config: The round config object. + :type round_config: dict + """ + logger.info("COMBINER orchestrating inference using model {}".format(model_id)) + self.stage_model(model_id) + # TODO: Implement inference client type + clients = self._assign_round_clients(self.server.max_clients, type="validators") + self._inference_round(session_id, model_id, clients) def execute_training_round(self, config): """Coordinates clients to execute training tasks. @@ -330,6 +406,8 @@ def run(self, polling_interval=1.0): while True: try: round_config = self.round_configs.get(block=False) + session_id = round_config["session_id"] + model_id = round_config["model_id"] # Check that the minimum allowed number of clients are connected ready = self._check_nr_round_clients(round_config) @@ -343,8 +421,10 @@ def run(self, polling_interval=1.0): round_meta["status"] = "Success" round_meta["name"] = self.server.id self.server.statestore.set_round_combiner_data(round_meta) - elif round_config["task"] == "validation" or round_config["task"] == "inference": - self.execute_validation_round(round_config) + elif round_config["task"] == "validation": + self.execute_validation_round(session_id, model_id) + elif round_config["task"] == "inference": + logger.info("Inference task not yet implemented.") else: logger.warning("config contains unkown task type.") else: diff --git a/fedn/network/controller/control.py b/fedn/network/controller/control.py index 7919fc620..b634e3a9f 100644 --- a/fedn/network/controller/control.py +++ b/fedn/network/controller/control.py @@ -2,12 +2,15 @@ import datetime import time import uuid +from typing import TypedDict -from tenacity import retry, retry_if_exception_type, stop_after_delay, wait_random +from tenacity import (retry, retry_if_exception_type, stop_after_delay, + wait_random) from fedn.common.log_config import logger from fedn.network.combiner.interfaces import CombinerUnavailableError from fedn.network.combiner.modelservice import load_model_from_BytesIO +from fedn.network.combiner.roundhandler import RoundConfig from fedn.network.controller.controlbase import ControlBase from fedn.network.state import ReducerState @@ -78,7 +81,7 @@ def __init__(self, statestore): super().__init__(statestore) self.name = "DefaultControl" - def start_session(self, session_id: str, rounds: int): + def start_session(self, session_id: str, rounds: int) -> None: if self._state == ReducerState.instructing: logger.info("Controller already in INSTRUCTING state. A session is in progress.") return @@ -132,7 +135,7 @@ def start_session(self, session_id: str, rounds: int): self.set_session_status(session_id, "Finished") self._state = ReducerState.idle - def session(self, config): + def session(self, config: RoundConfig) -> None: """Execute a new training session. A session consists of one or several global rounds. All rounds in the same session have the same round_config. @@ -182,8 +185,42 @@ def session(self, config): # TODO: Report completion of session self.set_session_status(config["session_id"], "Finished") self._state = ReducerState.idle + + def inference_session(self, config: RoundConfig) -> None: + """Execute a new inference session. - def round(self, session_config, round_id): + :param config: The round config. + :type config: InferenceConfig + :return: None + """ + + if self._state == ReducerState.instructing: + logger.info("Controller already in INSTRUCTING state. A session is in progress.") + return + + if len(self.network.get_combiners()) < 1: + logger.warning("Inference round cannot start, no combiners connected!") + return + + if not config["model_id"]: + config["model_id"]= self.statestore.get_latest_model() + + config["committed_at"] = datetime.datetime.now() + config["task"] = "inference" + config["rounds"] = str(1) + + participating_combiners = self.get_participating_combiners(config) + + # Check if the policy to start the round is met, Default is number of combiners > 0 + round_start = self.evaluate_round_start_policy(participating_combiners) + + if round_start: + logger.info("Inference round start policy met, {} participating combiners.".format(len(participating_combiners))) + for combiner, _ in participating_combiners: + combiner.submit(config) + logger.info("Inference round submitted to combiner {}".format(combiner)) + + def round(self, session_config: RoundConfig, round_id: str): """Execute one global round. : param session_config: The session config. diff --git a/fedn/network/controller/controlbase.py b/fedn/network/controller/controlbase.py index d667e01c4..ba5c72276 100644 --- a/fedn/network/controller/controlbase.py +++ b/fedn/network/controller/controlbase.py @@ -7,6 +7,7 @@ from fedn.common.log_config import logger from fedn.network.api.network import Network from fedn.network.combiner.interfaces import CombinerUnavailableError +from fedn.network.combiner.roundhandler import RoundConfig from fedn.network.state import ReducerState from fedn.network.storage.s3.repository import Repository @@ -163,7 +164,7 @@ def get_compute_package(self, compute_package=""): else: return None - def create_session(self, config, status="Initialized"): + def create_session(self, config: RoundConfig, status: str="Initialized") -> None: """Initialize a new session in backend db.""" if "session_id" not in config.keys(): session_id = uuid.uuid4() @@ -209,7 +210,7 @@ def set_round_status(self, round_id, status): """ self.statestore.set_round_status(round_id, status) - def set_round_config(self, round_id, round_config): + def set_round_config(self, round_id, round_config: RoundConfig): """Upate round in backend db. :param round_id: The round unique identifier @@ -223,7 +224,7 @@ def request_model_updates(self, combiners): """Ask Combiner server to produce a model update. :param combiners: A list of combiners - :type combiners: tuple (combiner, comboner_round_config) + :type combiners: tuple (combiner, combiner_round_config) """ cl = [] for combiner, combiner_round_config in combiners: @@ -273,22 +274,23 @@ def get_participating_combiners(self, combiner_round_config): self._handle_unavailable_combiner(combiner) continue - is_participating = self.evaluate_round_participation_policy(combiner_round_config, nr_active_clients) + clients_required = int(combiner_round_config["clients_required"]) + is_participating = self.evaluate_round_participation_policy(clients_required, nr_active_clients) if is_participating: combiners.append((combiner, combiner_round_config)) return combiners - def evaluate_round_participation_policy(self, compute_plan, nr_active_clients): + def evaluate_round_participation_policy(self, clients_required: int, nr_active_clients: int) -> bool: """Evaluate policy for combiner round-participation. A combiner participates if it is responsive and reports enough active clients to participate in the round. """ - if int(compute_plan["clients_required"]) <= nr_active_clients: + if clients_required <= nr_active_clients: return True else: return False - def evaluate_round_start_policy(self, combiners): + def evaluate_round_start_policy(self, combiners: list): """Check if the policy to start a round is met. :param combiners: A list of combiners diff --git a/fedn/network/storage/statestore/mongostatestore.py b/fedn/network/storage/statestore/mongostatestore.py index 6bf3be4ff..724077984 100644 --- a/fedn/network/storage/statestore/mongostatestore.py +++ b/fedn/network/storage/statestore/mongostatestore.py @@ -6,6 +6,7 @@ from google.protobuf.json_format import MessageToDict from fedn.common.log_config import logger +from fedn.network.combiner.roundhandler import RoundConfig from fedn.network.state import ReducerStateToString, StringToReducerState @@ -859,7 +860,7 @@ def create_round(self, round_data): # TODO: Add check if round_id already exists self.rounds.insert_one(round_data) - def set_session_config(self, id, config): + def set_session_config(self, id: str, config: RoundConfig) -> None: """Set the session configuration. :param id: The session id @@ -886,7 +887,7 @@ def set_round_combiner_data(self, data): """ self.rounds.update_one({"round_id": str(data["round_id"])}, {"$push": {"combiners": data}}, True) - def set_round_config(self, round_id, round_config): + def set_round_config(self, round_id, round_config: RoundConfig): """Set round configuration. :param round_id: The round unique identifier