From 6e4158314bdb78031ef7f35a79819d7b1947dd1f Mon Sep 17 00:00:00 2001 From: Niklas Date: Fri, 11 Oct 2024 13:27:32 +0200 Subject: [PATCH] Feature/SK-1071 | Initial refactor of client (#722) --- fedn/cli/client_cmd.py | 176 +++++++++++- fedn/network/clients/client_api.py | 321 ++++++++++++++++++++++ fedn/network/clients/client_v2.py | 330 +++++++++++++++++++++++ fedn/network/clients/grpc_handler.py | 345 ++++++++++++++++++++++++ fedn/network/clients/package_runtime.py | 165 ++++++++++++ 5 files changed, 1335 insertions(+), 2 deletions(-) create mode 100644 fedn/network/clients/client_api.py create mode 100644 fedn/network/clients/client_v2.py create mode 100644 fedn/network/clients/grpc_handler.py create mode 100644 fedn/network/clients/package_runtime.py diff --git a/fedn/cli/client_cmd.py b/fedn/cli/client_cmd.py index fb124dcf6..ccf95cff7 100644 --- a/fedn/cli/client_cmd.py +++ b/fedn/cli/client_cmd.py @@ -7,6 +7,8 @@ from fedn.cli.shared import CONTROLLER_DEFAULTS, apply_config, get_api_url, get_token, print_response from fedn.common.exceptions import InvalidClientConfig from fedn.network.clients.client import Client +from fedn.network.clients.client_v2 import Client as ClientV2 +from fedn.network.clients.client_v2 import ClientOptions def validate_client_config(config): @@ -29,7 +31,7 @@ def validate_client_config(config): @main.group("client") @click.pass_context def client_cmd(ctx): - """:param ctx:""" + """- Commands for listing and running clients.""" pass @@ -92,7 +94,7 @@ def list_clients(ctx, protocol: str, host: str, port: str, token: str = None, n_ @click.option("--reconnect-after-missed-heartbeat", required=False, default=30) @click.option("--verbosity", required=False, default="INFO", type=click.Choice(["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"], case_sensitive=False)) @click.pass_context -def client_cmd( +def client_start_cmd( ctx, discoverhost, discoverport, @@ -179,3 +181,173 @@ def client_cmd( client = Client(config) client.run() + +def _validate_client_params(config: dict): + api_url = config["api_url"] + if api_url is None or api_url == "": + click.echo("Error: Missing required parameter: --api_url") + return False + + return True + +def _complement_client_params(config: dict): + api_url = config["api_url"] + if not api_url.startswith("http://") and not api_url.startswith("https://"): + if "localhost" in api_url or "127.0.0.1" in api_url: + config["api_url"] = "http://" + api_url + else: + config["api_url"] = "https://" + api_url + result = config["api_url"] + + click.echo(f"Protocol missing, complementing api_url with protocol: {result}") + +@client_cmd.command("start-v2") +@click.option("-u", "--api_url", required=False, help="Hostname for fedn api.") +@click.option("-p", "--api_port", required=False, help="Port for discovery services (reducer).") +@click.option("--token", required=False, help="Set token provided by reducer if enabled") +@click.option("-n", "--name", required=False, default="client" + str(uuid.uuid4())[:8]) +@click.option("-i", "--client_id", required=False) +@click.option("--local-package", is_flag=True, help="Enable local compute package") +@click.option("-c", "--preferred-combiner", type=str, required=False, default="", help="name of the preferred combiner") +@click.option( + "--combiner", + type=str, + required=False, + default=None, + help="Skip combiner assignment from discover service and attach directly to combiner host." +) +@click.option("--combiner-port", type=str, required=False, default=None, help="Combiner port, need to be used with --combiner") +@click.option("-va", "--validator", required=False, default=True) +@click.option("-tr", "--trainer", required=False, default=True) +@click.option("-h", "--helper_type", required=False, default=None) +@click.option("-in", "--init", required=False, default=None, help="Set to a filename to (re)init client from file state.") +@click.pass_context +def client_start_v2_cmd( + ctx, + api_url: str, + api_port: int, + token: str, + name: str, + client_id: str, + local_package: bool, + preferred_combiner: str, + combiner: str, + combiner_port: int, + validator: bool, + trainer: bool, + helper_type: str, + init: str +): + click.echo( + click.style("\n*** fedn client start-v2 is experimental ***\n", blink=True, bold=True, fg="red") + ) + + package = "local" if local_package else "remote" + + config = { + "api_url": None, + "api_port": None, + "token": None, + "name": None, + "client_id": None, + "preferred_combiner": None, + "combiner": None, + "combiner_port": None, + "validator": None, + "trainer": None, + "package_checksum": None, + "helper_type": None, + # to cater for old inputs + "discover_host": None, + "discover_port": None, + } + + if init: + apply_config(init, config) + click.echo(f"\nClient configuration loaded from file: {init}") + + # to cater for old inputs + if config["discover_host"] is not None: + config["api_url"] = config["discover_host"] + + if config["discover_port"] is not None: + config["api_port"] = config["discover_port"] + + if api_url and api_url != "": + config["api_url"] = api_url + if config["api_url"] and config["api_url"] != "": + click.echo(f"Input param api_url: {api_url} overrides value from file") + + if api_port: + config["api_port"] = api_port + if config["api_port"]: + click.echo(f"Input param api_port: {api_port} overrides value from file") + + if token and token != "": + config["token"] = token + if config["token"]: + click.echo(f"Input param token: {token} overrides value from file") + + if name and name != "": + config["name"] = name + if config["name"]: + click.echo(f"Input param name: {name} overrides value from file") + + if client_id and client_id != "": + config["client_id"] = client_id + if config["client_id"]: + click.echo(f"Input param client_id: {client_id} overrides value from file") + + if preferred_combiner and preferred_combiner != "": + config["preferred_combiner"] = preferred_combiner + if config["preferred_combiner"]: + click.echo(f"Input param preferred_combiner: {preferred_combiner} overrides value from file") + + if combiner and combiner != "": + config["combiner"] = combiner + if config["combiner"]: + click.echo(f"Input param combiner: {combiner} overrides value from file") + + if combiner_port: + config["combiner_port"] = combiner_port + if config["combiner_port"]: + click.echo(f"Input param combiner_port: {combiner_port} overrides value from file") + + if validator is not None: + config["validator"] = validator + if config["validator"] is not None: + click.echo(f"Input param validator: {validator} overrides value from file") + + if trainer is not None: + config["trainer"] = trainer + if config["trainer"] is not None: + click.echo(f"Input param trainer: {trainer} overrides value from file") + + if helper_type and helper_type != "": + config["helper_type"] = helper_type + if config["helper_type"]: + click.echo(f"Input param helper_type: {helper_type} overrides value from file") + + if not _validate_client_params(config): + return + + api_url = _complement_client_params(config) + + client_options = ClientOptions( + name=config["name"], + package=package, + preferred_combiner=config["preferred_combiner"], + id=config["client_id"], + ) + client = ClientV2( + api_url=config["api_url"], + api_port=config["api_port"], + client_obj=client_options, + combiner_host=config["combiner"], + combiner_port=config["combiner_port"], + token=config["token"], + package_checksum=config["package_checksum"], + helper_type=config["helper_type"], + ) + + client.start() diff --git a/fedn/network/clients/client_api.py b/fedn/network/clients/client_api.py new file mode 100644 index 000000000..332331309 --- /dev/null +++ b/fedn/network/clients/client_api.py @@ -0,0 +1,321 @@ +import enum +import os +import time +from io import BytesIO +from typing import Any, Tuple + +import requests + +import fedn.network.grpc.fedn_pb2 as fedn +from fedn.common.config import FEDN_AUTH_SCHEME, FEDN_PACKAGE_EXTRACT_DIR +from fedn.common.log_config import logger +from fedn.network.clients.grpc_handler import GrpcHandler +from fedn.network.clients.package_runtime import PackageRuntime +from fedn.utils.dispatcher import Dispatcher + + +class GrpcConnectionOptions: + def __init__(self, status: str, host: str, fqdn: str, package: str, ip: str, port: int, helper_type: str): + self.status = status + self.host = host + self.fqdn = fqdn + self.package = package + self.ip = ip + self.port = port + self.helper_type = helper_type + + +# Enum for respresenting the result of connecting to the FEDn API +class ConnectToApiResult(enum.Enum): + Assigned = 0 + ComputePackgeMissing = 1 + UnAuthorized = 2 + UnMatchedConfig = 3 + IncorrectUrl = 4 + UnknownError = 5 + + +def get_compute_package_dir_path(): + result = None + + if FEDN_PACKAGE_EXTRACT_DIR: + result = os.path.join(os.getcwd(), FEDN_PACKAGE_EXTRACT_DIR) + else: + dirname = +"compute-package-" + time.strftime("%Y%m%d-%H%M%S") + result = os.path.join(os.getcwd(), dirname) + + if not os.path.exists(result): + os.mkdir(result) + + return result + + +class ClientAPI: + def __init__(self): + self._subscribers = {"train": [], "validation": []} + path = get_compute_package_dir_path() + self._package_runtime = PackageRuntime(path) + + self.dispatcher: Dispatcher = None + self.grpc_handler: GrpcHandler = None + + def subscribe(self, event_type: str, callback: callable): + """Subscribe to a specific event.""" + if event_type in self._subscribers: + self._subscribers[event_type].append(callback) + else: + raise ValueError(f"Unsupported event type: {event_type}") + + def notify_subscribers(self, event_type: str, *args, **kwargs): + """Notify all subscribers about a specific event.""" + if event_type in self._subscribers: + for callback in self._subscribers[event_type]: + callback(*args, **kwargs) + else: + raise ValueError(f"Unsupported event type: {event_type}") + + def train(self, *args, **kwargs): + """Function to be triggered from the server via gRPC.""" + # Perform training logic here + logger.info("Training started") + + # Notify all subscribers about the train event + self.notify_subscribers("train", *args, **kwargs) + + def validate(self, *args, **kwargs): + """Function to be triggered from the server via gRPC.""" + # Perform validation logic here + logger.info("Validation started") + + # Notify all subscribers about the validation event + self.notify_subscribers("validation", *args, **kwargs) + + def connect_to_api(self, url: str, token: str, json: dict) -> Tuple[ConnectToApiResult, Any]: + # TODO: Use new API endpoint (v1) + url_endpoint = f"{url}add_client" + logger.info(f"Connecting to API endpoint: {url_endpoint}") + + try: + response = requests.post( + url=url_endpoint, + json=json, + allow_redirects=True, + headers={"Authorization": f"{FEDN_AUTH_SCHEME} {token}"}, + ) + + if response.status_code == 200: + logger.info("Connect to FEDn Api - Client assinged to controller") + json_response = response.json() + return ConnectToApiResult.Assigned, json_response + elif response.status_code == 203: + json_response = response.json() + logger.info("Connect to FEDn Api - Remote compute package missing.") + return ConnectToApiResult.ComputePackgeMissing, json_response + elif response.status_code == 401: + logger.warning("Connect to FEDn Api - Unauthorized") + return ConnectToApiResult.UnAuthorized, "Unauthorized" + elif response.status_code == 400: + json_response = response.json() + msg = json_response["message"] + logger.warning(f"Connect to FEDn Api - {msg}") + return ConnectToApiResult.UnMatchedConfig, msg + elif response.status_code == 404: + logger.warning("Connect to FEDn Api - Incorrect URL") + return ConnectToApiResult.IncorrectUrl, "Incorrect URL" + except Exception: + pass + + logger.warning("Connect to FEDn Api - Unknown error occurred") + return ConnectToApiResult.UnknownError, "Unknown error occurred" + + def download_compute_package(self, url: str, token: str, name: str = None) -> bool: + """Download compute package from controller + :param host: host of controller + :param port: port of controller + :param token: token for authentication + :param name: name of package + :return: True if download was successful, None otherwise + :rtype: bool + """ + return self._package_runtime.download_compute_package(url, token, name) + + def set_compute_package_checksum(self, url: str, token: str, name: str = None) -> bool: + """Get checksum of compute package from controller + :param host: host of controller + :param port: port of controller + :param token: token for authentication + :param name: name of package + :return: checksum of the compute package + :rtype: str + """ + return self._package_runtime.set_checksum(url, token, name) + + def unpack_compute_package(self) -> Tuple[bool, str]: + result, path = self._package_runtime.unpack_compute_package() + if result: + logger.info(f"Compute package unpacked to: {path}") + else: + logger.error("Error: Could not unpack compute package") + + return result, path + + def validate_compute_package(self, checksum: str) -> bool: + return self._package_runtime.validate(checksum) + + def set_dispatcher(self, path) -> bool: + result = self._package_runtime.get_dispatcher(path) + if result: + self.dispatcher = result + return True + else: + logger.error("Error: Could not set dispatcher") + return False + + def get_or_set_environment(self) -> bool: + try: + logger.info("Initiating Dispatcher with entrypoint set to: startup") + activate_cmd = self.dispatcher._get_or_create_python_env() + self.dispatcher.run_cmd("startup") + except KeyError: + logger.info("No startup command found in package. Continuing.") + return False + except Exception as e: + logger.error(f"Caught exception: {type(e).__name__}") + return False + + if activate_cmd: + logger.info("To activate the virtual environment, run: {}".format(activate_cmd)) + + return True + + # GRPC functions + def init_grpchandler(self, config: GrpcConnectionOptions, client_name: str, token: str): + try: + if "fqdn" in config and config["fqdn"] and len(config["fqdn"]) > 0: + host = config["fqdn"] + port = 443 + else: + host = config["host"] + port = config["port"] + combiner_name = config["host"] + + self.grpc_handler = GrpcHandler(host=host, port=port, name=client_name, token=token, combiner_name=combiner_name) + + logger.info("Successfully initialized GRPC connection") + return True + except Exception: + logger.error("Error: Could not initialize GRPC connection") + return False + + + def send_heartbeats(self, client_name: str, client_id: str, update_frequency: float = 2.0): + self.grpc_handler.send_heartbeats(client_name=client_name, client_id=client_id, update_frequency=update_frequency) + + def listen_to_task_stream(self, client_name: str, client_id: str): + self.grpc_handler.listen_to_task_stream(client_name=client_name, client_id=client_id, callback=self._task_stream_callback) + + def _task_stream_callback(self, request): + if request.type == fedn.StatusType.MODEL_UPDATE: + self.train(request) + elif request.type == fedn.StatusType.MODEL_VALIDATION: + self.validate(request) + + def get_model_from_combiner(self, id: str, client_name: str, timeout: int = 20) -> BytesIO: + return self.grpc_handler.get_model_from_combiner(id=id, client_name=client_name, timeout=timeout) + + def send_model_to_combiner(self, model: BytesIO, id: str): + return self.grpc_handler.send_model_to_combiner(model, id) + + def send_status(self, msg: str, log_level=fedn.Status.INFO, type=None, request=None, sesssion_id: str = None, sender_name: str = None): + return self.grpc_handler.send_status(msg, log_level, type, request, sesssion_id, sender_name) + + def send_model_update(self, + sender_name: str, + sender_role: fedn.Role, + client_id: str, + model_id: str, + model_update_id: str, + receiver_name: str, + receiver_role: fedn.Role, + meta: dict + ) -> bool: + return self.grpc_handler.send_model_update( + sender_name=sender_name, + sender_role=sender_role, + client_id=client_id, + model_id=model_id, + model_update_id=model_update_id, + receiver_name=receiver_name, + receiver_role=receiver_role, + meta=meta + ) + + def send_model_validation(self, + sender_name: str, + receiver_name: str, + receiver_role: fedn.Role, + model_id: str, + metrics: dict, + correlation_id: str, + session_id: str + ) -> bool: + return self.grpc_handler.send_model_validation(sender_name, receiver_name, receiver_role, model_id, metrics, correlation_id, session_id) + + # Init functions + def init_remote_compute_package(self, url: str, token: str, package_checksum: str = None) -> bool: + result: bool = self.download_compute_package(url, token) + if not result: + logger.error("Could not download compute package") + return False + result: bool = self.set_compute_package_checksum(url, token) + if not result: + logger.error("Could not set checksum") + return False + + if package_checksum: + result: bool = self.validate_compute_package(package_checksum) + if not result: + logger.error("Could not validate compute package") + return False + + result, path = self.unpack_compute_package() + + if not result: + logger.error("Could not unpack compute package") + return False + + logger.info(f"Compute package unpacked to: {path}") + + result = self.set_dispatcher(path) + + if not result: + logger.error("Could not set dispatcher") + return False + + logger.info("Dispatcher set") + + result = self.get_or_set_environment() + + if not result: + logger.error("Could not set environment") + return False + + return True + + def init_local_compute_package(self): + path = os.path.join(os.getcwd(), "client") + result = self.set_dispatcher(path) + + if not result: + logger.error("Could not set dispatcher") + return False + + result = self.get_or_set_environment() + + if not result: + logger.error("Could not set environment") + return False + + logger.info("Dispatcher set") + + return True diff --git a/fedn/network/clients/client_v2.py b/fedn/network/clients/client_v2.py new file mode 100644 index 000000000..ab32f6116 --- /dev/null +++ b/fedn/network/clients/client_v2.py @@ -0,0 +1,330 @@ +import io +import json +import os +import threading +import time +import uuid +from typing import Tuple + +import fedn.network.grpc.fedn_pb2 as fedn +from fedn.common.config import FEDN_CUSTOM_URL_PREFIX +from fedn.common.log_config import logger +from fedn.network.clients.client_api import ClientAPI, ConnectToApiResult, GrpcConnectionOptions +from fedn.network.combiner.modelservice import get_tmp_path +from fedn.utils.helpers.helpers import get_helper + + +def get_url(api_url: str, api_port: int) -> str: + return f"{api_url}:{api_port}/{FEDN_CUSTOM_URL_PREFIX}" if api_port else f"{api_url}/{FEDN_CUSTOM_URL_PREFIX}" + + +class ClientOptions: + def __init__(self, name: str, package: str, preferred_combiner: str = None, id: str = None): + # check if name is a string and set. if not raise an error + self._validate(name, package) + self.name = name + self.package = package + self.preferred_combiner = preferred_combiner + self.client_id = id if id else str(uuid.uuid4()) + + def _validate(self, name: str, package: str): + if not isinstance(name, str) or len(name) == 0: + raise ValueError("Name must be a string") + if not isinstance(package, str) or len(package) == 0 or package not in ["local", "remote"]: + raise ValueError("Package must be either 'local' or 'remote'") + + # to json object + def to_json(self): + return { + "name": self.name, + "client_id": self.client_id, + "preferred_combiner": self.preferred_combiner, + "package": self.package, + } + + +class Client: + def __init__(self, + api_url: str, + api_port: int, + client_obj: ClientOptions, + combiner_host: str = None, + combiner_port: int = None, + token: str = None, + package_checksum: str = None, + helper_type: str = None + ): + self.api_url = api_url + self.api_port = api_port + self.combiner_host = combiner_host + self.combiner_port = combiner_port + self.token = token + self.client_obj = client_obj + self.package_checksum = package_checksum + self.helper_type = helper_type + + self.fedn_api_url = get_url(self.api_url, self.api_port) + + self.client_api: ClientAPI = ClientAPI() + + self.helper = None + + def _connect_to_api(self) -> Tuple[bool, dict]: + result = None + + while not result or result == ConnectToApiResult.ComputePackgeMissing: + if result == ConnectToApiResult.ComputePackgeMissing: + logger.info("Retrying in 3 seconds") + time.sleep(3) + result, response = self.client_api.connect_to_api(self.fedn_api_url, self.token, self.client_obj.to_json()) + + if result == ConnectToApiResult.Assigned: + return True, response + + return False, None + + def start(self): + if self.combiner_host and self.combiner_port: + combiner_config = { + "host": self.combiner_host, + "port": self.combiner_port, + } + else: + result, combiner_config = self._connect_to_api() + if not result: + return + + if self.client_obj.package == "remote": + result = self.client_api.init_remote_compute_package(url=self.fedn_api_url, token=self.token, package_checksum=self.package_checksum) + + if not result: + return + else: + result = self.client_api.init_local_compute_package() + + if not result: + return + + self.set_helper(combiner_config) + + result: bool = self.client_api.init_grpchandler(config=combiner_config, client_name=self.client_obj.client_id, token=self.token) + + if not result: + return + + logger.info("-----------------------------") + + threading.Thread( + target=self.client_api.send_heartbeats, kwargs={"client_name": self.client_obj.name, "client_id": self.client_obj.client_id}, daemon=True + ).start() + + self.client_api.subscribe("train", self.on_train) + self.client_api.subscribe("validation", self.on_validation) + + threading.Thread( + target=self.client_api.listen_to_task_stream, kwargs={"client_name": self.client_obj.name, "client_id": self.client_obj.client_id}, daemon=True + ).start() + + stop_event = threading.Event() + try: + stop_event.wait() + except KeyboardInterrupt: + logger.info("Client stopped by user.") + + def set_helper(self, response: GrpcConnectionOptions = None): + helper_type = response.get("helper_type", None) + + helper_type_to_use = self.helper_type or helper_type or "numpyhelper" + + logger.info(f"Setting helper to: {helper_type_to_use}") + + # Priority: helper_type from constructor > helper_type from response > default helper_type + self.helper = get_helper(helper_type_to_use) + + def on_train(self, request): + logger.info("Received train request") + self._process_training_request(request) + + def on_validation(self, request): + logger.info("Received validation request") + self._process_validation_request(request) + + + def _process_training_request(self, request) -> Tuple[str, dict]: + """Process a training (model update) request. + + :param model_id: The model id of the model to be updated. + :type model_id: str + :param session_id: The id of the current session + :type session_id: str + :return: The model id of the updated model, or None if the update failed. And a dict with metadata. + :rtype: tuple + """ + model_id: str = request.model_id + session_id: str = request.session_id + + self.client_api.send_status( + f"\t Starting processing of training request for model_id {model_id}", + sesssion_id=session_id, + sender_name=self.client_obj.name + ) + + try: + meta = {} + tic = time.time() + + model = self.client_api.get_model_from_combiner(id=str(model_id), client_name=self.client_obj.client_id) + + if model is None: + logger.error("Could not retrieve model from combiner. Aborting training request.") + return None, None + + meta["fetch_model"] = time.time() - tic + + inpath = self.helper.get_tmp_path() + + with open(inpath, "wb") as fh: + fh.write(model.getbuffer()) + + outpath = self.helper.get_tmp_path() + + tic = time.time() + + self.client_api.dispatcher.run_cmd("train {} {}".format(inpath, outpath)) + + meta["exec_training"] = time.time() - tic + + tic = time.time() + out_model = None + + with open(outpath, "rb") as fr: + out_model = io.BytesIO(fr.read()) + + # Stream model update to combiner server + updated_model_id = uuid.uuid4() + self.client_api.send_model_to_combiner(out_model, str(updated_model_id)) + meta["upload_model"] = time.time() - tic + + # Read the metadata file + with open(outpath + "-metadata", "r") as fh: + training_metadata = json.loads(fh.read()) + + logger.info("SETTING Training metadata: {}".format(training_metadata)) + meta["training_metadata"] = training_metadata + + os.unlink(inpath) + os.unlink(outpath) + os.unlink(outpath + "-metadata") + + except Exception as e: + logger.error("Could not process training request due to error: {}".format(e)) + updated_model_id = None + meta = {"status": "failed", "error": str(e)} + + if meta is not None: + processing_time = time.time() - tic + meta["processing_time"] = processing_time + meta["config"] = request.data + + if model_id is not None: + # Send model update to combiner + + self.client_api.send_model_update( + sender_name=self.client_obj.name, + sender_role=fedn.WORKER, + client_id=self.client_obj.client_id, + model_id=model_id, + model_update_id=str(updated_model_id), + receiver_name=request.sender.name, + receiver_role=request.sender.role, + meta=meta, + ) + + self.client_api.send_status( + "Model update completed.", + log_level=fedn.Status.AUDIT, + type=fedn.StatusType.MODEL_UPDATE, + request=request, + sesssion_id=session_id, + sender_name=self.client_obj.name + ) + + def _process_validation_request(self, request): + """Process a validation request. + + :param model_id: The model id of the model to be validated. + :type model_id: str + :param session_id: The id of the current session. + :type session_id: str + :return: The validation metrics, or None if validation failed. + :rtype: dict + """ + model_id: str = request.model_id + session_id: str = request.session_id + cmd = "validate" + + self.client_api.send_status(f"Processing {cmd} request for model_id {model_id}", sesssion_id=session_id, sender_name=self.client_obj.name) + + try: + model = self.client_api.get_model_from_combiner(id=str(model_id), client_name=self.client_obj.client_id) + if model is None: + logger.error("Could not retrieve model from combiner. Aborting validation request.") + return + inpath = self.helper.get_tmp_path() + + with open(inpath, "wb") as fh: + fh.write(model.getbuffer()) + + outpath = get_tmp_path() + self.client_api.dispatcher.run_cmd(f"{cmd} {inpath} {outpath}") + + with open(outpath, "r") as fh: + metrics = json.loads(fh.read()) + + os.unlink(inpath) + os.unlink(outpath) + + except Exception as e: + logger.warning("Validation failed with exception {}".format(e)) + + if metrics is not None: + # Send validation + validation = fedn.ModelValidation() + validation.sender.name = self.client_obj.name + validation.sender.role = fedn.WORKER + validation.receiver.name = request.sender.name + validation.receiver.role = request.sender.role + validation.model_id = str(request.model_id) + validation.data = json.dumps(metrics) + validation.timestamp.GetCurrentTime() + validation.correlation_id = request.correlation_id + validation.session_id = request.session_id + + # sender_name: str, sender_role: fedn.Role, model_id: str, model_update_id: str + result: bool = self.client_api.send_model_validation( + sender_name=self.client_obj.name, + receiver_name=request.sender.name, + receiver_role=request.sender.role, + model_id=str(request.model_id), + metrics=json.dumps(metrics), + correlation_id=request.correlation_id, + session_id=request.session_id, + ) + + if result: + self.client_api.send_status( + "Model validation completed.", + log_level=fedn.Status.AUDIT, + type=fedn.StatusType.MODEL_VALIDATION, + request=validation, + sesssion_id=request.session_id, + sender_name=self.client_obj.name + ) + else: + self.client_api.send_status( + "Client {} failed to complete model validation.".format(self.name), + log_level=fedn.Status.WARNING, + request=request, + sesssion_id=request.session_id, + sender_name=self.client_obj.name + ) diff --git a/fedn/network/clients/grpc_handler.py b/fedn/network/clients/grpc_handler.py new file mode 100644 index 000000000..9b8550344 --- /dev/null +++ b/fedn/network/clients/grpc_handler.py @@ -0,0 +1,345 @@ +import json +import os +import socket +import time +from datetime import datetime +from io import BytesIO +from typing import Any, Callable + +import grpc +from cryptography.hazmat.primitives.serialization import Encoding +from google.protobuf.json_format import MessageToJson +from OpenSSL import SSL + +import fedn.network.grpc.fedn_pb2 as fedn +import fedn.network.grpc.fedn_pb2_grpc as rpc +from fedn.common.config import FEDN_AUTH_SCHEME +from fedn.common.log_config import logger +from fedn.network.combiner.modelservice import upload_request_generator + + +class GrpcAuth(grpc.AuthMetadataPlugin): + def __init__(self, key): + self._key = key + + def __call__(self, context, callback): + callback((("authorization", f"{FEDN_AUTH_SCHEME} {self._key}"),), None) + +def _get_ssl_certificate(domain, port=443): + context = SSL.Context(SSL.TLSv1_2_METHOD) + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect((domain, port)) + ssl_sock = SSL.Connection(context, sock) + ssl_sock.set_tlsext_host_name(domain.encode()) + ssl_sock.set_connect_state() + ssl_sock.do_handshake() + cert = ssl_sock.get_peer_certificate() + ssl_sock.close() + sock.close() + cert = cert.to_cryptography().public_bytes(Encoding.PEM).decode() + return cert + +class GrpcHandler: + def __init__(self, host: str, port: int, name: str, token: str, combiner_name: str): + self.metadata = [ + ("client", name), + ("grpc-server", combiner_name), + ] + + if port == 443: + self._init_secure_channel(host, port, token) + else: + self._init_insecure_channel(host, port) + + self.connectorStub = rpc.ConnectorStub(self.channel) + self.combinerStub = rpc.CombinerStub(self.channel) + self.modelStub = rpc.ModelServiceStub(self.channel) + + def _init_secure_channel(self, host: str, port: int, token: str): + url = f"{host}:{port}" + logger.info(f"Connecting (GRPC) to {url}") + + if os.getenv("FEDN_GRPC_ROOT_CERT_PATH"): + logger.info("Using root certificate from environment variable for GRPC channel.") + with open(os.environ["FEDN_GRPC_ROOT_CERT_PATH"], "rb") as f: + credentials = grpc.ssl_channel_credentials(f.read()) + self.channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials) + return + + logger.info(f"Fetching SSL certificate for {host}") + cert = _get_ssl_certificate(host, port) + credentials = grpc.ssl_channel_credentials(cert.encode("utf-8")) + auth_creds = grpc.metadata_call_credentials(GrpcAuth(token)) + self.channel = grpc.secure_channel("{}:{}".format(host, str(port)), grpc.composite_channel_credentials(credentials, auth_creds)) + + def _init_insecure_channel(self, host: str, port: int): + url = f"{host}:{port}" + logger.info(f"Connecting (GRPC) to {url}") + self.channel = grpc.insecure_channel(url) + + def send_heartbeats(self, client_name: str, client_id: str, update_frequency: float = 2.0): + heartbeat = fedn.Heartbeat(sender=fedn.Client(name=client_name, role=fedn.WORKER, client_id=client_id)) + + send_hearbeat = True + while send_hearbeat: + try: + logger.info("Sending heartbeat to combiner") + self.connectorStub.SendHeartbeat(heartbeat, metadata=self.metadata) + except grpc.RpcError as e: + return self._handle_grpc_error(e, "SendHeartbeat", lambda: self.send_heartbeats(client_name, client_id, update_frequency)) + except Exception as e: + logger.error(f"GRPC (SendHeartbeat): An error occurred: {e}") + self._disconnect() + send_hearbeat = False + + time.sleep(update_frequency) + + def listen_to_task_stream(self, client_name: str, client_id: str, callback: Callable[[Any], None]): + """Subscribe to the model update request stream. + + :return: None + :rtype: None + """ + r = fedn.ClientAvailableMessage() + r.sender.name = client_name + r.sender.role = fedn.WORKER + r.sender.client_id = client_id + + try: + logger.info("Listening to task stream.") + for request in self.combinerStub.TaskStream(r, metadata=self.metadata): + if request.sender.role == fedn.COMBINER: + self.send_status( + "Received model update request.", + log_level=fedn.Status.AUDIT, + type=fedn.StatusType.MODEL_UPDATE_REQUEST, + request=request, + sesssion_id=request.session_id, + sender_name=client_name + ) + + logger.info(f"Received task request of type {request.type} for model_id {request.model_id}") + + callback(request) + + except grpc.RpcError as e: + return self._handle_grpc_error(e, "TaskStream", lambda: self.listen_to_task_stream(client_name, client_id, callback)) + except Exception as e: + logger.error(f"GRPC (TaskStream): An error occurred: {e}") + self._disconnect() + + def send_status(self, msg: str, log_level=fedn.Status.INFO, type=None, request=None, sesssion_id: str = None, sender_name: str = None): + """Send status message. + + :param msg: The message to send. + :type msg: str + :param log_level: The log level of the message. + :type log_level: fedn.Status.INFO, fedn.Status.WARNING, fedn.Status.ERROR + :param type: The type of the message. + :type type: str + :param request: The request message. + :type request: fedn.Request + """ + status = fedn.Status() + status.timestamp.GetCurrentTime() + status.sender.name = sender_name + status.sender.role = fedn.WORKER + status.log_level = log_level + status.status = str(msg) + status.session_id = sesssion_id + + if type is not None: + status.type = type + + if request is not None: + status.data = MessageToJson(request) + + try: + logger.info("Sending status message to combiner.") + _ = self.connectorStub.SendStatus(status, metadata=self.metadata) + except grpc.RpcError as e: + return self._handle_grpc_error(e, "SendStatus", lambda: self.send_status(msg, log_level, type, request, sesssion_id, sender_name)) + except Exception as e: + logger.error(f"GRPC (SendStatus): An error occurred: {e}") + self._disconnect() + + def get_model_from_combiner(self, id: str, client_name: str, timeout: int = 20) -> BytesIO: + """Fetch a model from the assigned combiner. + Downloads the model update object via a gRPC streaming channel. + + :param id: The id of the model update object. + :type id: str + :return: The model update object. + :rtype: BytesIO + """ + data = BytesIO() + time_start = time.time() + request = fedn.ModelRequest(id=id) + request.sender.name = client_name + request.sender.role = fedn.WORKER + + try: + logger.info("Downloading model from combiner.") + for part in self.modelStub.Download(request, metadata=self.metadata): + if part.status == fedn.ModelStatus.IN_PROGRESS: + data.write(part.data) + + if part.status == fedn.ModelStatus.OK: + return data + + if part.status == fedn.ModelStatus.FAILED: + return None + + if part.status == fedn.ModelStatus.UNKNOWN: + if time.time() - time_start >= timeout: + return None + continue + except grpc.RpcError as e: + return self._handle_grpc_error(e, "Download", lambda: self.get_model_from_combiner(id, client_name, timeout)) + except Exception as e: + logger.error(f"GRPC (Download): An error occurred: {e}") + self._disconnect() + + return data + + def send_model_to_combiner(self, model: BytesIO, id: str): + """Send a model update to the assigned combiner. + Uploads the model updated object via a gRPC streaming channel, Upload. + + :param model: The model update object. + :type model: BytesIO + :param id: The id of the model update object. + :type id: str + :return: The model update object. + :rtype: BytesIO + """ + if not isinstance(model, BytesIO): + bt = BytesIO() + + for d in model.stream(32 * 1024): + bt.write(d) + else: + bt = model + + bt.seek(0, 0) + + try: + logger.info("Uploading model to combiner.") + result = self.modelStub.Upload(upload_request_generator(bt, id), metadata=self.metadata) + except grpc.RpcError as e: + return self._handle_grpc_error(e, "Upload", lambda: self.send_model_to_combiner(model, id)) + except Exception as e: + logger.error(f"GRPC (Upload): An error occurred: {e}") + self._disconnect() + + return result + + def send_model_update(self, + sender_name: str, + sender_role: fedn.Role, + client_id: str, + model_id: str, + model_update_id: str, + receiver_name: str, + receiver_role: fedn.Role, + meta: dict + ): + update = fedn.ModelUpdate() + update.sender.name = sender_name + update.sender.role = sender_role + update.sender.client_id = client_id + update.receiver.name = receiver_name + update.receiver.role = receiver_role + update.model_id = model_id + update.model_update_id = model_update_id + update.timestamp = str(datetime.now()) + update.meta = json.dumps(meta) + + try: + logger.info("Sending model update to combiner.") + _ = self.combinerStub.SendModelUpdate(update, metadata=self.metadata) + except grpc.RpcError as e: + return self._handle_grpc_error( + e, + "SendModelUpdate", + lambda: self.send_model_update( + sender_name, + sender_role, + model_id, + model_update_id, + receiver_name, + receiver_role, + meta + ) + ) + except Exception as e: + logger.error(f"GRPC (SendModelUpdate): An error occurred: {e}") + self._disconnect() + + return True + + def send_model_validation(self, + sender_name: str, + receiver_name: str, + receiver_role: fedn.Role, + model_id: str, + metrics: str, + correlation_id: str, + session_id: str + ) -> bool: + validation = fedn.ModelValidation() + validation.sender.name = sender_name + validation.sender.role = fedn.WORKER + validation.receiver.name = receiver_name + validation.receiver.role = receiver_role + validation.model_id = model_id + validation.data = metrics + validation.timestamp.GetCurrentTime() + validation.correlation_id = correlation_id + validation.session_id = session_id + + + try: + logger.info("Sending model validation to combiner.") + _ = self.combinerStub.SendModelValidation(validation, metadata=self.metadata) + except grpc.RpcError as e: + return self._handle_grpc_error( + e, + "SendModelValidation", + lambda: self.send_model_validation( + sender_name, + receiver_name, + receiver_role, + model_id, + metrics, + correlation_id, + session_id + ) + ) + except Exception as e: + logger.error(f"GRPC (SendModelValidation): An error occurred: {e}") + self._disconnect() + + return True + + def _handle_grpc_error(self, e, method_name: str, sender_function: Callable): + status_code = e.code() + if status_code == grpc.StatusCode.UNAVAILABLE: + logger.warning(f"GRPC ({method_name}): server unavailable. Retrying in 5 seconds.") + time.sleep(5) + return sender_function() + elif status_code == grpc.StatusCode.CANCELLED: + logger.warning(f"GRPC ({method_name}): connection cancelled. Retrying in 5 seconds.") + time.sleep(5) + return sender_function() + if status_code == grpc.StatusCode.UNAUTHENTICATED: + details = e.details() + if details == "Token expired": + logger.warning(f"GRPC ({method_name}): Token expired.") + self._disconnect() + logger.error(f"GRPC ({method_name}): An error occurred: {e}") + + def _disconnect(self): + """Disconnect from the combiner.""" + self.channel.close() + logger.info("GRPC channel closed.") diff --git a/fedn/network/clients/package_runtime.py b/fedn/network/clients/package_runtime.py new file mode 100644 index 000000000..42006ce01 --- /dev/null +++ b/fedn/network/clients/package_runtime.py @@ -0,0 +1,165 @@ +# This file contains the PackageRuntime class, which is used to download, validate and unpack compute packages. +# +# +import cgi +import os +import tarfile +from typing import Tuple + +import requests + +from fedn.common.config import FEDN_AUTH_SCHEME +from fedn.common.log_config import logger +from fedn.utils.checksum import sha +from fedn.utils.dispatcher import Dispatcher, _read_yaml_file + + +class PackageRuntime: + """PackageRuntime is used to download, validate and unpack compute packages. + + :param package_path: path to compute package + :type package_path: str + :param package_dir: directory to unpack compute package + :type package_dir: str + """ + + def __init__(self, package_path: str): + self.dispatch_config = { + "entry_points": { + "predict": {"command": "python3 predict.py"}, + "train": {"command": "python3 train.py"}, + "validate": {"command": "python3 validate.py"}, + } + } + + self.pkg_path = package_path + self.pkg_name = None + self.checksum = None + + def download_compute_package(self, url: str, token: str, name: str = None) -> bool: + """Download compute package from controller + :param host: host of controller + :param port: port of controller + :param token: token for authentication + :param name: name of package + :return: True if download was successful, None otherwise + :rtype: bool + """ + try: + # TODO: use new endpoint (v1) + path = f"{url}/download_package?name={name}" if name else f"{url}/download_package" + + with requests.get(path, stream=True, verify=False, headers={"Authorization": f"{FEDN_AUTH_SCHEME} {token}"}) as r: + if 200 <= r.status_code < 204: + params = cgi.parse_header(r.headers.get("Content-Disposition", ""))[-1] + try: + self.pkg_name = params["filename"] + except KeyError: + logger.error("No package returned.") + return None + r.raise_for_status() + with open(os.path.join(self.pkg_path, self.pkg_name), "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + + return True + except Exception: + return False + + def set_checksum(self, url: str, token: str, name: str = None) -> bool: + """Get checksum of compute package from controller + :param host: host of controller + :param port: port of controller + :param token: token for authentication + :param name: name of package + :return: checksum of the compute package + :rtype: str + """ + try: + # TODO: use new endpoint (v1) + path = f"{url}/get_package_checksum?name={name}" if name else f"{url}/get_package_checksum" + + with requests.get(path, verify=False, headers={"Authorization": f"{FEDN_AUTH_SCHEME} {token}"}) as r: + if 200 <= r.status_code < 204: + data = r.json() + try: + self.checksum = data["checksum"] + except Exception: + logger.error("Could not extract checksum.") + + return True + except Exception: + return False + + def validate(self, expected_checksum) -> bool: + """Validate the package against the checksum provided by the controller + + :param expected_checksum: checksum provided by the controller + :return: True if checksums match, False otherwise + :rtype: bool + """ + # crosscheck checksum and unpack if security checks are ok. + file_checksum = str(sha(os.path.join(self.pkg_path, self.pkg_name))) + + if self.checksum == expected_checksum == file_checksum: + logger.info("Package validated {}".format(self.checksum)) + return True + else: + return False + + def unpack_compute_package(self) -> Tuple[bool, str]: + """Unpack the compute package + + :return: True if unpacking was successful, False otherwise + :rtype: bool + """ + if self.pkg_name: + f = None + if self.pkg_name.endswith("tar.gz"): + f = tarfile.open(os.path.join(self.pkg_path, self.pkg_name), "r:gz") + if self.pkg_name.endswith(".tgz"): + f = tarfile.open(os.path.join(self.pkg_path, self.pkg_name), "r:gz") + if self.pkg_name.endswith("tar.bz2"): + f = tarfile.open(os.path.join(self.pkg_path, self.pkg_name), "r:bz2") + else: + logger.error("Failed to unpack compute package, no pkg_name set." "Has the reducer been configured with a compute package?") + return False, "" + + try: + if f: + f.extractall(self.pkg_path) + logger.info("Successfully extracted compute package content in {}".format(self.pkg_path)) + # delete the tarball + logger.info("Deleting temporary package tarball file.") + f.close() + os.remove(os.path.join(self.pkg_path, self.pkg_name)) + # search for file fedn.yaml in extracted package + for root, dirs, files in os.walk(self.pkg_path): + if "fedn.yaml" in files: + # Get the path to where fedn.yaml is located + logger.info("Found fedn.yaml file in {}".format(root)) + return True, root + + logger.error("No fedn.yaml file found in extracted package!") + return False, "" + except Exception: + logger.error("Error extracting files.") + # delete the tarball + os.remove(os.path.join(self.pkg_path, self.pkg_name)) + return False, "" + + def get_dispatcher(self, run_path) -> Dispatcher: + """Dispatch the compute package + + :param run_path: path to dispatch the compute package + :type run_path: str + :return: Dispatcher object + :rtype: :class:`fedn.utils.dispatcher.Dispatcher` + """ + try: + self.dispatch_config = _read_yaml_file(os.path.join(run_path, "fedn.yaml")) + dispatcher = Dispatcher(self.dispatch_config, run_path) + + return dispatcher + except Exception: + return None