diff --git a/.ci/tests/examples/wait_for.py b/.ci/tests/examples/wait_for.py index 7fa75506d..ccd76859d 100644 --- a/.ci/tests/examples/wait_for.py +++ b/.ci/tests/examples/wait_for.py @@ -18,7 +18,7 @@ def _retry(try_func, **func_args): for _ in range(RETRIES): is_success = try_func(**func_args) if is_success: - _eprint('Sucess.') + _eprint('Success.') return True _eprint(f'Sleeping for {SLEEP}.') sleep(SLEEP) @@ -30,28 +30,38 @@ def _test_rounds(n_rounds): client = pymongo.MongoClient( "mongodb://fedn_admin:password@localhost:6534") collection = client['fedn-network']['control']['rounds'] - query = {'reducer.status': 'Success'} + query = {'status': 'Finished'} n = collection.count_documents(query) client.close() _eprint(f'Succeded rounds: {n}.') return n == n_rounds -def _test_nodes(n_nodes, node_type, reducer_host='localhost', reducer_port='8090'): +def _test_nodes(n_nodes, node_type, reducer_host='localhost', reducer_port='8092'): try: - resp = requests.get( - f'http://{reducer_host}:{reducer_port}/netgraph', verify=False) + + endpoint = "list_clients" if node_type == "client" else "list_combiners" + + response = requests.get( + f'http://{reducer_host}:{reducer_port}/{endpoint}', verify=False) + + if response.status_code == 200: + + data = json.loads(response.content) + + count = 0 + if node_type == "client": + arr = data.get('result') + count = sum(element.get('status') == "online" for element in arr) + else: + count = data.get('count') + + _eprint(f'Active {node_type}s: {count}.') + return count == n_nodes + except Exception as e: - _eprint(f'Reques exception econuntered: {e}.') + _eprint(f'Request exception enconuntered: {e}.') return False - if resp.status_code == 200: - gr = json.loads(resp.content) - n = sum(values.get('type') == node_type and values.get( - 'status') == 'active' for values in gr['nodes']) - _eprint(f'Active {node_type}s: {n}.') - return n == n_nodes - _eprint(f'Reducer returned {resp.status_code}.') - return False def rounds(n_rounds=3): diff --git a/README.rst b/README.rst index 42503985a..2afb60ebc 100644 --- a/README.rst +++ b/README.rst @@ -101,7 +101,7 @@ To connect a client that uses the data partition 'data/clients/1/mnist.pt': -v $PWD/data/clients/1:/var/data \ -e ENTRYPOINT_OPTS=--data_path=/var/data/mnist.pt \ --network=fedn_default \ - ghcr.io/scaleoutsystems/fedn/fedn:develop-mnist-pytorch run client -in client.yaml --name client1 + ghcr.io/scaleoutsystems/fedn/fedn:master-mnist-pytorch run client -in client.yaml --name client1 You are now ready to start training the model at http://localhost:8090/control. diff --git a/examples/mnist-keras/bin/build.sh b/examples/mnist-keras/bin/build.sh index 18cdb5128..44eda61df 100755 --- a/examples/mnist-keras/bin/build.sh +++ b/examples/mnist-keras/bin/build.sh @@ -5,4 +5,4 @@ set -e client/entrypoint init_seed # Make compute package -tar -czvf package.tgz client \ No newline at end of file +tar -czvf package.tgz client diff --git a/fedn/fedn/common/storage/s3/miniorepo.py b/fedn/fedn/common/storage/s3/miniorepo.py index 9341704e6..154cea7e9 100644 --- a/fedn/fedn/common/storage/s3/miniorepo.py +++ b/fedn/fedn/common/storage/s3/miniorepo.py @@ -62,11 +62,13 @@ def __init__(self, config): self.create_bucket(self.bucket) def create_bucket(self, bucket_name): - """ + """ Create a new bucket. If bucket exists, do nothing. - :param bucket_name: + :param bucket_name: The name of the bucket + :type bucket_name: str """ found = self.client.bucket_exists(bucket_name) + if not found: try: self.client.make_bucket(bucket_name) diff --git a/fedn/fedn/common/tracer/mongotracer.py b/fedn/fedn/common/tracer/mongotracer.py index 92af569ea..aa5c0810b 100644 --- a/fedn/fedn/common/tracer/mongotracer.py +++ b/fedn/fedn/common/tracer/mongotracer.py @@ -1,4 +1,5 @@ import uuid +from datetime import datetime from google.protobuf.json_format import MessageToDict @@ -18,6 +19,7 @@ def __init__(self, mongo_config, network_id): self.rounds = self.mdb['control.rounds'] self.sessions = self.mdb['control.sessions'] self.validations = self.mdb['control.validations'] + self.clients = self.mdb['network.clients'] except Exception as e: print("FAILED TO CONNECT TO MONGO, {}".format(e), flush=True) self.status = None @@ -50,18 +52,26 @@ def drop_status(self): if self.status: self.status.drop() - def new_session(self, id=None): - """ Create a new session. """ + def create_session(self, id=None): + """ Create a new session. + + :param id: The ID of the created session. + :type id: uuid, str + + """ if not id: id = uuid.uuid4() data = {'session_id': str(id)} self.sessions.insert_one(data) - def new_round(self, id): - """ Create a new session. """ + def create_round(self, round_data): + """ Create a new round. - data = {'round_id': str(id)} - self.rounds.insert_one(data) + :param round_data: Dictionary with round data. + :type round_data: dict + """ + # TODO: Add check if round_id already exists + self.rounds.insert_one(round_data) def set_session_config(self, id, config): self.sessions.update_one({'session_id': str(id)}, { @@ -70,15 +80,46 @@ def set_session_config(self, id, config): def set_round_combiner_data(self, data): """ - :param round_meta: + :param data: The combiner data + :type data: dict """ self.rounds.update_one({'round_id': str(data['round_id'])}, { '$push': {'combiners': data}}, True) - def set_round_data(self, round_data): + def set_round_config(self, round_id, round_config): """ :param round_meta: """ - self.rounds.update_one({'round_id': str(round_data['round_id'])}, { - '$push': {'reducer': round_data}}, True) + self.rounds.update_one({'round_id': round_id}, { + '$set': {'round_config': round_config}}, True) + + def set_round_status(self, round_id, round_status): + """ + + :param round_meta: + """ + self.rounds.update_one({'round_id': round_id}, { + '$set': {'status': round_status}}, True) + + def set_round_data(self, round_id, round_data): + """ + + :param round_meta: + """ + self.rounds.update_one({'round_id': round_id}, { + '$set': {'round_data': round_data}}, True) + + def update_client_status(self, client_name, status): + """ Update client status in statestore. + :param client_name: The client name + :type client_name: str + :param status: The client status + :type status: str + :return: None + """ + datetime_now = datetime.now() + filter_query = {"name": client_name} + + update_query = {"$set": {"last_seen": datetime_now, "status": status}} + self.clients.update_one(filter_query, update_query) diff --git a/fedn/fedn/network/api/interface.py b/fedn/fedn/network/api/interface.py index f05222e87..0821ed176 100644 --- a/fedn/fedn/network/api/interface.py +++ b/fedn/fedn/network/api/interface.py @@ -1,44 +1,41 @@ import base64 import copy -import json import os import threading from io import BytesIO -from bson import json_util from flask import jsonify, send_from_directory from werkzeug.utils import secure_filename from fedn.common.config import get_controller_config, get_network_config from fedn.network.combiner.interfaces import (CombinerInterface, CombinerUnavailableError) +from fedn.network.dashboard.plots import Plot from fedn.network.state import ReducerState, ReducerStateToString from fedn.utils.checksum import sha -__all__ = 'API', +__all__ = ("API",) class API: - """ The API class is a wrapper for the statestore. It is used to expose the statestore to the network API. """ + """The API class is a wrapper for the statestore. It is used to expose the statestore to the network API.""" def __init__(self, statestore, control): self.statestore = statestore self.control = control - self.name = 'api' + self.name = "api" def _to_dict(self): - """ Convert the object to a dict. + """Convert the object to a dict. ::return: The object as a dict. ::rtype: dict """ - data = { - 'name': self.name - } + data = {"name": self.name} return data def _get_combiner_report(self, combiner_id): - """ Get report response from combiner. + """Get report response from combiner. :param combiner_id: The combiner id to get report response from. :type combiner_id: str @@ -50,8 +47,10 @@ def _get_combiner_report(self, combiner_id): report = combiner.report return report - def _allowed_file_extension(self, filename, ALLOWED_EXTENSIONS={'gz', 'bz2', 'tar', 'zip', 'tgz'}): - """ Check if file extension is allowed. + def _allowed_file_extension( + self, filename, ALLOWED_EXTENSIONS={"gz", "bz2", "tar", "zip", "tgz"} + ): + """Check if file extension is allowed. :param filename: The filename to check. :type filename: str @@ -59,32 +58,40 @@ def _allowed_file_extension(self, filename, ALLOWED_EXTENSIONS={'gz', 'bz2', 'ta :rtype: bool """ - return '.' in filename and \ - filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS + return ( + "." in filename + and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS + ) - def get_all_clients(self): - """ Get all clients from the statestore. + def get_clients(self, limit=None, skip=None, status=False): + """Get all clients from the statestore. :return: All clients as a json response. :rtype: :class:`flask.Response` """ # Will return list of ObjectId - clients_objects = self.statestore.list_clients() - payload = {} - for object in clients_objects: - id = object['name'] - info = {"combiner": object['combiner'], - "combiner_preferred": object['combiner_preferred'], - "ip": object['ip'], - "updated_at": object['updated_at'], - "status": object['status'], - } - payload[id] = info + response = self.statestore.list_clients(limit, skip, status) + + arr = [] + + for element in response["result"]: + obj = { + "id": element["name"], + "combiner": element["combiner"], + "combiner_preferred": element["combiner_preferred"], + "ip": element["ip"], + "status": element["status"], + "last_seen": element["last_seen"] if "last_seen" in element else "", + } - return jsonify(payload) + arr.append(obj) + + result = {"result": arr, "count": response["count"]} + + return jsonify(result) def get_active_clients(self, combiner_id): - """ Get all active clients, i.e that are assigned to a combiner. + """Get all active clients, i.e that are assigned to a combiner. A report request to the combiner is neccessary to determine if a client is active or not. :param combiner_id: The combiner id to get active clients for. @@ -95,34 +102,42 @@ def get_active_clients(self, combiner_id): # Get combiner interface object combiner = self.control.network.get_combiner(combiner_id) if combiner is None: - return jsonify({"success": False, "message": f"Combiner {combiner_id} not found."}), 404 + return ( + jsonify( + { + "success": False, + "message": f"Combiner {combiner_id} not found.", + } + ), + 404, + ) response = combiner.list_active_clients() return response - def get_all_combiners(self): - """ Get all combiners from the statestore. + def get_all_combiners(self, limit=None, skip=None): + """Get all combiners from the statestore. :return: All combiners as a json response. :rtype: :class:`flask.Response` """ # Will return list of ObjectId - combiner_objects = self.statestore.get_combiners() - payload = {} - for object in combiner_objects: - id = object['name'] - info = {"address": object['address'], - "fqdn": object['fqdn'], - "parent_reducer": object['parent']["name"], - "port": object['port'], - "report": object['report'], - "updated_at": object['updated_at'], - } - payload[id] = info + projection = {"name": True, "updated_at": True} + response = self.statestore.get_combiners(limit, skip, projection=projection) + arr = [] + for element in response["result"]: + obj = { + "name": element["name"], + "updated_at": element["updated_at"], + } - return jsonify(payload) + arr.append(obj) + + result = {"result": arr, "count": response["count"]} + + return jsonify(result) def get_combiner(self, combiner_id): - """ Get a combiner from the statestore. + """Get a combiner from the statestore. :param combiner_id: The combiner id to get. :type combiner_id: str @@ -132,36 +147,42 @@ def get_combiner(self, combiner_id): # Will return ObjectId object = self.statestore.get_combiner(combiner_id) payload = {} - id = object['name'] - info = {"address": object['address'], - "fqdn": object['fqdn'], - "parent_reducer": object['parent']["name"], - "port": object['port'], - "report": object['report'], - "updated_at": object['updated_at'], - } + id = object["name"] + info = { + "address": object["address"], + "fqdn": object["fqdn"], + "parent_reducer": object["parent"]["name"], + "port": object["port"], + "report": object["report"], + "updated_at": object["updated_at"], + } payload[id] = info return jsonify(payload) - def get_all_sessions(self): - """ Get all sessions from the statestore. + def get_all_sessions(self, limit=None, skip=None): + """Get all sessions from the statestore. :return: All sessions as a json response. :rtype: :class:`flask.Response` """ - sessions_objects = self.statestore.get_sessions() - if sessions_objects is None: - return jsonify({"success": False, "message": "No sessions found."}), 404 - payload = {} - for object in sessions_objects: - id = object['session_id'] - info = object['session_config'][0] - payload[id] = info - return jsonify(payload) + sessions_object = self.statestore.get_sessions(limit, skip) + if sessions_object is None: + return ( + jsonify({"success": False, "message": "No sessions found."}), + 404, + ) + arr = [] + for element in sessions_object["result"]: + obj = element["session_config"][0] + arr.append(obj) + + result = {"result": arr, "count": sessions_object["count"]} + + return jsonify(result) def get_session(self, session_id): - """ Get a session from the statestore. + """Get a session from the statestore. :param session_id: The session id to get. :type session_id: str @@ -170,15 +191,23 @@ def get_session(self, session_id): """ session_object = self.statestore.get_session(session_id) if session_object is None: - return jsonify({"success": False, "message": f"Session {session_id} not found."}), 404 + return ( + jsonify( + { + "success": False, + "message": f"Session {session_id} not found.", + } + ), + 404, + ) payload = {} - id = session_object['session_id'] - info = session_object['session_config'][0] + id = session_object["session_id"] + info = session_object["session_config"][0] payload[id] = info return jsonify(payload) def set_compute_package(self, file, helper_type): - """ Set the compute package in the statestore. + """Set the compute package in the statestore. :param file: The compute package to set. :type file: file @@ -189,24 +218,42 @@ def set_compute_package(self, file, helper_type): if file and self._allowed_file_extension(file.filename): filename = secure_filename(file.filename) # TODO: make configurable, perhaps in config.py or package.py - file_path = os.path.join( - '/app/client/package/', filename) + file_path = os.path.join("/app/client/package/", filename) file.save(file_path) - if self.control.state() == ReducerState.instructing or self.control.state() == ReducerState.monitoring: - return jsonify({"success": False, "message": "Reducer is in instructing or monitoring state." - "Cannot set compute package."}), 400 + if ( + self.control.state() == ReducerState.instructing + or self.control.state() == ReducerState.monitoring + ): + return ( + jsonify( + { + "success": False, + "message": "Reducer is in instructing or monitoring state." + "Cannot set compute package.", + } + ), + 400, + ) self.control.set_compute_package(filename, file_path) self.statestore.set_helper(helper_type) success = self.statestore.set_compute_package(filename) if not success: - return jsonify({"success": False, "message": "Failed to set compute package."}), 400 + return ( + jsonify( + { + "success": False, + "message": "Failed to set compute package.", + } + ), + 400, + ) return jsonify({"success": True, "message": "Compute package set."}) def _get_compute_package_name(self): - """ Get the compute package name from the statestore. + """Get the compute package name from the statestore. :return: The compute package name. :rtype: str @@ -217,32 +264,38 @@ def _get_compute_package_name(self): return None, message else: try: - name = package_objects['filename'] + name = package_objects["filename"] except KeyError as e: message = "No compute package found. Key error." print(e) return None, message - return name, 'success' + return name, "success" def get_compute_package(self): - """ Get the compute package from the statestore. + """Get the compute package from the statestore. :return: The compute package as a json response. :rtype: :class:`flask.Response` """ package_object = self.statestore.get_compute_package() if package_object is None: - return jsonify({"success": False, "message": "No compute package found."}), 404 + return ( + jsonify( + {"success": False, "message": "No compute package found."} + ), + 404, + ) payload = {} - id = str(package_object['_id']) - info = {"filename": package_object['filename'], - "helper": package_object['helper'], - } + id = str(package_object["_id"]) + info = { + "filename": package_object["filename"], + "helper": package_object["helper"], + } payload[id] = info return jsonify(payload) def download_compute_package(self, name): - """ Download the compute package. + """Download the compute package. :return: The compute package as a json object. :rtype: :class:`flask.Response` @@ -255,23 +308,27 @@ def download_compute_package(self, name): mutex = threading.Lock() mutex.acquire() # TODO: make configurable, perhaps in config.py or package.py - return send_from_directory('/app/client/package/', name, as_attachment=True) + return send_from_directory( + "/app/client/package/", name, as_attachment=True + ) except Exception: try: data = self.control.get_compute_package(name) # TODO: make configurable, perhaps in config.py or package.py - file_path = os.path.join('/app/client/package/', name) - with open(file_path, 'wb') as fh: + file_path = os.path.join("/app/client/package/", name) + with open(file_path, "wb") as fh: fh.write(data) # TODO: make configurable, perhaps in config.py or package.py - return send_from_directory('/app/client/package/', name, as_attachment=True) + return send_from_directory( + "/app/client/package/", name, as_attachment=True + ) except Exception: raise finally: mutex.release() def _create_checksum(self, name=None): - """ Create the checksum of the compute package. + """Create the checksum of the compute package. :param name: The name of the compute package. :type name: str @@ -282,17 +339,19 @@ def _create_checksum(self, name=None): if name is None: name, message = self._get_compute_package_name() if name is None: - return False, message, '' - file_path = os.path.join('/app/client/package/', name) # TODO: make configurable, perhaps in config.py or package.py + return False, message, "" + file_path = os.path.join( + "/app/client/package/", name + ) # TODO: make configurable, perhaps in config.py or package.py try: sum = str(sha(file_path)) except FileNotFoundError: - sum = '' - message = 'File not found.' + sum = "" + message = "File not found." return True, message, sum def get_checksum(self, name): - """ Get the checksum of the compute package. + """Get the checksum of the compute package. :param name: The name of the compute package. :type name: str @@ -303,66 +362,75 @@ def get_checksum(self, name): success, message, sum = self._create_checksum(name) if not success: return jsonify({"success": False, "message": message}), 404 - payload = {'checksum': sum} + payload = {"checksum": sum} return jsonify(payload) def get_controller_status(self): - """ Get the status of the controller. + """Get the status of the controller. :return: The status of the controller as a json object. :rtype: :py:class:`flask.Response` """ - return jsonify({'state': ReducerStateToString(self.control.state())}) + return jsonify({"state": ReducerStateToString(self.control.state())}) def get_events(self, **kwargs): - """ Get the events of the federated network. + """Get the events of the federated network. :return: The events as a json object. :rtype: :py:class:`flask.Response` """ - event_objects = self.statestore.get_events(**kwargs) - if event_objects is None: - return jsonify({"success": False, "message": "No events found."}), 404 - json_docs = [] - for doc in self.statestore.get_events(**kwargs): - json_doc = json.dumps(doc, default=json_util.default) - json_docs.append(json_doc) + response = self.statestore.get_events(**kwargs) - json_docs.reverse() - return jsonify({'events': json_docs}) + result = response["result"] + if result is None: + return ( + jsonify({"success": False, "message": "No events found."}), + 404, + ) + + events = [] + for evt in result: + events.append(evt) + + return jsonify({"result": events, "count": response["count"]}) def get_all_validations(self, **kwargs): - """ Get all validations from the statestore. + """Get all validations from the statestore. :return: All validations as a json response. :rtype: :class:`flask.Response` """ validations_objects = self.statestore.get_validations(**kwargs) if validations_objects is None: - return jsonify( - { - "success": False, - "message": "No validations found.", - "filter_used": kwargs - } - ), 404 + return ( + jsonify( + { + "success": False, + "message": "No validations found.", + "filter_used": kwargs, + } + ), + 404, + ) payload = {} for object in validations_objects: - id = str(object['_id']) + id = str(object["_id"]) info = { - 'model_id': object['modelId'], - 'data': object['data'], - 'timestamp': object['timestamp'], - 'meta': object['meta'], - 'sender': object['sender'], - 'receiver': object['receiver'], + "model_id": object["modelId"], + "data": object["data"], + "timestamp": object["timestamp"], + "meta": object["meta"], + "sender": object["sender"], + "receiver": object["receiver"], } payload[id] = info return jsonify(payload) - def add_combiner(self, combiner_id, secure_grpc, address, remote_addr, fqdn, port): - """ Add a combiner to the network. + def add_combiner( + self, combiner_id, secure_grpc, address, remote_addr, fqdn, port + ): + """Add a combiner to the network. :param combiner_id: The combiner id to add. :type combiner_id: str @@ -383,18 +451,20 @@ def add_combiner(self, combiner_id, secure_grpc, address, remote_addr, fqdn, por """ # TODO: Any more required check for config? Formerly based on status: "retry" if not self.control.idle(): - return jsonify({ - 'success': False, - 'status': 'retry', - 'message': 'Conroller is not in idle state, try again later. ' - } + return jsonify( + { + "success": False, + "status": "retry", + "message": "Conroller is not in idle state, try again later. ", + } ) # Check if combiner already exists combiner = self.control.network.get_combiner(combiner_id) if not combiner: - if secure_grpc == 'True': + if secure_grpc == "True": certificate, key = self.certificate_manager.get_or_create( - address).get_keypair_raw() + address + ).get_keypair_raw() _ = base64.b64encode(certificate) _ = base64.b64encode(key) @@ -410,29 +480,32 @@ def add_combiner(self, combiner_id, secure_grpc, address, remote_addr, fqdn, por port=port, certificate=copy.deepcopy(certificate), key=copy.deepcopy(key), - ip=remote_addr) + ip=remote_addr, + ) self.control.network.add_combiner(combiner_interface) # Check combiner now exists combiner = self.control.network.get_combiner(combiner_id) if not combiner: - return jsonify({'success': False, 'message': 'Combiner not added.'}) + return jsonify( + {"success": False, "message": "Combiner not added."} + ) payload = { - 'success': True, - 'message': 'Combiner added successfully.', - 'status': 'added', - 'storage': self.statestore.get_storage_backend(), - 'statestore': self.statestore.get_config(), - 'certificate': combiner.get_certificate(), - 'key': combiner.get_key() + "success": True, + "message": "Combiner added successfully.", + "status": "added", + "storage": self.statestore.get_storage_backend(), + "statestore": self.statestore.get_config(), + "certificate": combiner.get_certificate(), + "key": combiner.get_key(), } return jsonify(payload) def add_client(self, client_id, preferred_combiner, remote_addr): - """ Add a client to the network. + """Add a client to the network. :param client_id: The client id to add. :type client_id: str @@ -444,26 +517,46 @@ def add_client(self, client_id, preferred_combiner, remote_addr): # Check if package has been set package_object = self.statestore.get_compute_package() if package_object is None: - return jsonify({'success': False, 'status': 'retry', 'message': 'No compute package found. Set package in controller.'}), 203 + return ( + jsonify( + { + "success": False, + "status": "retry", + "message": "No compute package found. Set package in controller.", + } + ), + 203, + ) # Assign client to combiner if preferred_combiner: combiner = self.control.network.get_combiner(preferred_combiner) if combiner is None: - return jsonify({'success': False, - 'message': f'Combiner {preferred_combiner} not found or unavailable.'}), 400 + return ( + jsonify( + { + "success": False, + "message": f"Combiner {preferred_combiner} not found or unavailable.", + } + ), + 400, + ) else: combiner = self.control.network.find_available_combiner() if combiner is None: - return jsonify({'success': False, - 'message': 'No combiner available.'}), 400 + return ( + jsonify( + {"success": False, "message": "No combiner available."} + ), + 400, + ) client_config = { - 'name': client_id, - 'combiner_preferred': preferred_combiner, - 'combiner': combiner.name, - 'ip': remote_addr, - 'status': 'available' + "name": client_id, + "combiner_preferred": preferred_combiner, + "combiner": combiner.name, + "ip": remote_addr, + "status": "available", } # Add client to network self.control.network.add_client(client_config) @@ -471,38 +564,36 @@ def add_client(self, client_id, preferred_combiner, remote_addr): # Setup response containing information about the combiner for assinging the client if combiner.certificate: cert_b64 = base64.b64encode(combiner.certificate) - cert = str(cert_b64).split('\'')[1] + cert = str(cert_b64).split("'")[1] else: cert = None payload = { - 'status': 'assigned', - 'host': combiner.address, - 'fqdn': combiner.fqdn, - 'package': 'remote', # TODO: Make this configurable - 'ip': combiner.ip, - 'port': combiner.port, - 'certificate': cert, - 'helper_type': self.control.statestore.get_helper() + "status": "assigned", + "host": combiner.address, + "fqdn": combiner.fqdn, + "package": "remote", # TODO: Make this configurable + "ip": combiner.ip, + "port": combiner.port, + "certificate": cert, + "helper_type": self.control.statestore.get_helper(), } print("Seding payload: ", payload, flush=True) return jsonify(payload) def get_initial_model(self): - """ Get the initial model from the statestore. + """Get the initial model from the statestore. :return: The initial model as a json response. :rtype: :class:`flask.Response` """ model_id = self.statestore.get_initial_model() - payload = { - 'model_id': model_id - } + payload = {"model_id": model_id} return jsonify(payload) def set_initial_model(self, file): - """ Add an initial model to the network. + """Add an initial model to the network. :param file: The initial model to add. :type file: file @@ -520,27 +611,47 @@ def set_initial_model(self, file): self.control.commit(file.filename, model) except Exception as e: print(e, flush=True) - return jsonify({'success': False, 'message': e}) + return jsonify({"success": False, "message": e}) - return jsonify({'success': True, 'message': 'Initial model added successfully.'}) + return jsonify( + {"success": True, "message": "Initial model added successfully."} + ) def get_latest_model(self): - """ Get the latest model from the statestore. + """Get the latest model from the statestore. :return: The initial model as a json response. :rtype: :class:`flask.Response` """ if self.statestore.get_latest_model(): model_id = self.statestore.get_latest_model() - payload = { - 'model_id': model_id - } + payload = {"model_id": model_id} return jsonify(payload) else: - return jsonify({'success': False, 'message': 'No initial model set.'}) + return jsonify( + {"success": False, "message": "No initial model set."} + ) + + def get_models(self, session_id=None, limit=None, skip=None): + result = self.statestore.list_models(session_id, limit, skip) + + if result is None: + return ( + jsonify({"success": False, "message": "No models found."}), + 404, + ) + + arr = [] + + for model in result["result"]: + arr.append(model) + + result = {"result": arr, "count": result["count"]} + + return jsonify(result) def get_model_trail(self): - """ Get the model trail for a given session. + """Get the model trail for a given session. :param session: The session id to get the model trail for. :type session: str @@ -551,38 +662,41 @@ def get_model_trail(self): if model_info: return jsonify(model_info) else: - return jsonify({'success': False, 'message': 'No model trail available.'}) + return jsonify( + {"success": False, "message": "No model trail available."} + ) def get_all_rounds(self): - """ Get all rounds. + """Get all rounds. :return: The rounds as json response. :rtype: :class:`flask.Response` """ rounds_objects = self.statestore.get_rounds() if rounds_objects is None: - jsonify({'success': False, 'message': 'No rounds available.'}) + jsonify({"success": False, "message": "No rounds available."}) payload = {} for object in rounds_objects: - id = object['round_id'] - if 'reducer' in object.keys(): - reducer = object['reducer'] + id = object["round_id"] + if "reducer" in object.keys(): + reducer = object["reducer"] else: reducer = None - if 'combiners' in object.keys(): - combiners = object['combiners'] + if "combiners" in object.keys(): + combiners = object["combiners"] else: combiners = None - info = {'reducer': reducer, - 'combiners': combiners, - } + info = { + "reducer": reducer, + "combiners": combiners, + } payload[id] = info else: return jsonify(payload) def get_round(self, round_id): - """ Get a round. + """Get a round. :param round_id: The round id to get. :type round_id: str @@ -591,38 +705,99 @@ def get_round(self, round_id): """ round_object = self.statestore.get_round(round_id) if round_object is None: - return jsonify({'success': False, 'message': 'Round not found.'}) + return jsonify({"success": False, "message": "Round not found."}) payload = { 'round_id': round_object['round_id'], - 'reducer': round_object['reducer'], 'combiners': round_object['combiners'], } return jsonify(payload) def get_client_config(self, checksum=True): - """ Get the client config. + """Get the client config. :return: The client config as json response. :rtype: :py:class:`flask.Response` """ config = get_controller_config() network_id = get_network_config() - port = config['port'] - host = config['host'] + port = config["port"] + host = config["host"] payload = { - 'network_id': network_id, - 'discover_host': host, - 'discover_port': port, + "network_id": network_id, + "discover_host": host, + "discover_port": port, } if checksum: success, _, checksum_str = self._create_checksum() if success: - payload['checksum'] = checksum_str + payload["checksum"] = checksum_str return jsonify(payload) - def start_session(self, session_id, rounds=5, round_timeout=180, round_buffer_size=-1, delete_models=False, - validate=True, helper='keras', min_clients=1, requested_clients=8): - """ Start a session. + def get_plot_data(self, feature=None): + """Get plot data. + + :return: The plot data as json response. + :rtype: :py:class:`flask.Response` + """ + + plot = Plot(self.control.statestore) + + try: + valid_metrics = plot.fetch_valid_metrics() + feature = feature or valid_metrics[0] + box_plot = plot.create_box_plot(feature) + except Exception as e: + valid_metrics = None + box_plot = None + print(e, flush=True) + + result = { + "valid_metrics": valid_metrics, + "box_plot": box_plot, + } + + return jsonify(result) + + def list_combiners_data(self, combiners): + """Get combiners data. + + :param combiners: The combiners to get data for. + :type combiners: list + :return: The combiners data as json response. + :rtype: :py:class:`flask.Response` + """ + + response = self.statestore.list_combiners_data(combiners) + + arr = [] + + # order list by combiner name + for element in response: + + obj = { + "combiner": element["_id"], + "count": element["count"], + } + + arr.append(obj) + + result = {"result": arr} + + return jsonify(result) + + def start_session( + self, + session_id, + rounds=5, + round_timeout=180, + round_buffer_size=-1, + delete_models=False, + validate=True, + helper="keras", + min_clients=1, + requested_clients=8, + ): + """Start a session. :param session_id: The session id to start. :type session_id: str @@ -646,18 +821,22 @@ def start_session(self, session_id, rounds=5, round_timeout=180, round_buffer_si # Check if session already exists session = self.statestore.get_session(session_id) if session: - return jsonify({'success': False, 'message': 'Session already exists.'}) + return jsonify( + {"success": False, "message": "Session already exists."} + ) # Check if session is running if self.control.state() == ReducerState.monitoring: - return jsonify({'success': False, 'message': 'A session is already running.'}) + return jsonify( + {"success": False, "message": "A session is already running."} + ) # Check available clients per combiner clients_available = 0 for combiner in self.control.network.get_combiners(): try: combiner_state = combiner.report() - nr_active_clients = combiner_state['nr_active_clients'] + nr_active_clients = combiner_state["nr_active_clients"] clients_available = clients_available + int(nr_active_clients) except CombinerUnavailableError as e: # TODO: Handle unavailable combiner, stop session or continue? @@ -665,11 +844,16 @@ def start_session(self, session_id, rounds=5, round_timeout=180, round_buffer_si continue if clients_available < min_clients: - return jsonify({'success': False, 'message': 'Not enough clients available to start session.'}) + return jsonify( + { + "success": False, + "message": "Not enough clients available to start session.", + } + ) # Check if validate is string and convert to bool if isinstance(validate, str): - if validate.lower() == 'true': + if validate.lower() == "true": validate = True else: validate = False @@ -678,22 +862,30 @@ def start_session(self, session_id, rounds=5, round_timeout=180, round_buffer_si model_id = self.statestore.get_latest_model() # Setup session config - session_config = {'session_id': session_id, - 'round_timeout': round_timeout, - 'buffer_size': round_buffer_size, - 'model_id': model_id, - 'rounds': rounds, - 'delete_models_storage': delete_models, - 'clients_required': min_clients, - 'clients_requested': requested_clients, - 'task': (''), - 'validate': validate, - 'helper_type': helper - } + session_config = { + "session_id": session_id, + "round_timeout": round_timeout, + "buffer_size": round_buffer_size, + "model_id": model_id, + "rounds": rounds, + "delete_models_storage": delete_models, + "clients_required": min_clients, + "clients_requested": requested_clients, + "task": (""), + "validate": validate, + "helper_type": helper, + } # Start session - threading.Thread(target=self.control.session, - args=(session_config,)).start() + threading.Thread( + target=self.control.session, args=(session_config,) + ).start() # Return success response - return jsonify({'success': True, 'message': 'Session started successfully.', "config": session_config}) + return jsonify( + { + "success": True, + "message": "Session started successfully.", + "config": session_config, + } + ) diff --git a/fedn/fedn/network/api/network.py b/fedn/fedn/network/api/network.py index 26b366a76..6fcaad053 100644 --- a/fedn/fedn/network/api/network.py +++ b/fedn/fedn/network/api/network.py @@ -46,7 +46,7 @@ def get_combiners(self): """ data = self.statestore.get_combiners() combiners = [] - for c in data: + for c in data["result"]: if c['certificate']: cert = base64.b64decode(c['certificate']) key = base64.b64decode(c['key']) diff --git a/fedn/fedn/network/api/server.py b/fedn/fedn/network/api/server.py index 4e0e93775..cfb91bece 100644 --- a/fedn/fedn/network/api/server.py +++ b/fedn/fedn/network/api/server.py @@ -10,18 +10,16 @@ network_id = get_network_config() modelstorage_config = get_modelstorage_config() statestore = MongoStateStore( - network_id, - statestore_config['mongo_config'], - modelstorage_config + network_id, statestore_config["mongo_config"], modelstorage_config ) control = Control(statestore=statestore) api = API(statestore, control) app = Flask(__name__) -@app.route('/get_model_trail', methods=['GET']) +@app.route("/get_model_trail", methods=["GET"]) def get_model_trail(): - """ Get the model trail for a given session. + """Get the model trail for a given session. param: session: The session id to get the model trail for. type: session: str return: The model trail for the given session as a json object. @@ -30,9 +28,29 @@ def get_model_trail(): return api.get_model_trail() -@app.route('/delete_model_trail', methods=['GET', 'POST']) +@app.route("/list_models", methods=["GET"]) +def list_models(): + """Get models from the statestore. + param: + session_id: The session id to get the model trail for. + limit: The maximum number of models to return. + type: limit: int + param: skip: The number of models to skip. + type: skip: int + Returns: + _type_: json + """ + + session_id = request.args.get("session_id", None) + limit = request.args.get("limit", None) + skip = request.args.get("skip", None) + + return api.get_models(session_id, limit, skip) + + +@app.route("/delete_model_trail", methods=["GET", "POST"]) def delete_model_trail(): - """ Delete the model trail for a given session. + """Delete the model trail for a given session. param: session: The session id to delete the model trail for. type: session: str return: The response from the statestore. @@ -41,78 +59,93 @@ def delete_model_trail(): return jsonify({"message": "Not implemented"}), 501 -@app.route('/list_clients', methods=['GET']) +@app.route("/list_clients", methods=["GET"]) def list_clients(): - """ Get all clients from the statestore. + """Get all clients from the statestore. return: All clients as a json object. rtype: json """ - return api.get_all_clients() + limit = request.args.get("limit", None) + skip = request.args.get("skip", None) + status = request.args.get("status", None) + + return api.get_clients(limit, skip, status) -@app.route('/get_active_clients', methods=['GET']) + +@app.route("/get_active_clients", methods=["GET"]) def get_active_clients(): - """ Get all active clients from the statestore. + """Get all active clients from the statestore. param: combiner_id: The combiner id to get active clients for. type: combiner_id: str return: All active clients as a json object. rtype: json """ - combiner_id = request.args.get('combiner', None) + combiner_id = request.args.get("combiner", None) if combiner_id is None: - return jsonify({"success": False, "message": "Missing combiner id."}), 400 + return ( + jsonify({"success": False, "message": "Missing combiner id."}), + 400, + ) return api.get_active_clients(combiner_id) -@app.route('/list_combiners', methods=['GET']) +@app.route("/list_combiners", methods=["GET"]) def list_combiners(): - """ Get all combiners in the network. + """Get all combiners in the network. return: All combiners as a json object. rtype: json """ - return api.get_all_combiners() + limit = request.args.get("limit", None) + skip = request.args.get("skip", None) + + return api.get_all_combiners(limit, skip) -@app.route('/get_combiner', methods=['GET']) + +@app.route("/get_combiner", methods=["GET"]) def get_combiner(): - """ Get a combiner from the statestore. + """Get a combiner from the statestore. param: combiner_id: The combiner id to get. type: combiner_id: str return: The combiner as a json object. rtype: json """ - combiner_id = request.args.get('combiner', None) + combiner_id = request.args.get("combiner", None) if combiner_id is None: - return jsonify({"success": False, "message": "Missing combiner id."}), 400 + return ( + jsonify({"success": False, "message": "Missing combiner id."}), + 400, + ) return api.get_combiner(combiner_id) -@app.route('/list_rounds', methods=['GET']) +@app.route("/list_rounds", methods=["GET"]) def list_rounds(): - """ Get all rounds from the statestore. + """Get all rounds from the statestore. return: All rounds as a json object. rtype: json """ return api.get_all_rounds() -@app.route('/get_round', methods=['GET']) +@app.route("/get_round", methods=["GET"]) def get_round(): - """ Get a round from the statestore. + """Get a round from the statestore. param: round_id: The round id to get. type: round_id: str return: The round as a json object. rtype: json """ - round_id = request.args.get('round_id', None) + round_id = request.args.get("round_id", None) if round_id is None: return jsonify({"success": False, "message": "Missing round id."}), 400 return api.get_round(round_id) -@app.route('/start_session', methods=['GET', 'POST']) +@app.route("/start_session", methods=["GET", "POST"]) def start_session(): - """ Start a new session. + """Start a new session. return: The response from control. rtype: json """ @@ -120,30 +153,36 @@ def start_session(): return api.start_session(**json_data) -@app.route('/list_sessions', methods=['GET']) +@app.route("/list_sessions", methods=["GET"]) def list_sessions(): - """ Get all sessions from the statestore. + """Get all sessions from the statestore. return: All sessions as a json object. rtype: json """ - return api.get_all_sessions() + limit = request.args.get("limit", None) + skip = request.args.get("skip", None) + + return api.get_all_sessions(limit, skip) -@app.route('/get_session', methods=['GET']) +@app.route("/get_session", methods=["GET"]) def get_session(): - """ Get a session from the statestore. + """Get a session from the statestore. param: session_id: The session id to get. type: session_id: str return: The session as a json object. rtype: json """ - session_id = request.args.get('session_id', None) + session_id = request.args.get("session_id", None) if session_id is None: - return jsonify({"success": False, "message": "Missing session id."}), 400 + return ( + jsonify({"success": False, "message": "Missing session id."}), + 400, + ) return api.get_session(session_id) -@app.route('/set_package', methods=['POST']) +@app.route("/set_package", methods=["POST"]) def set_package(): """ Set the compute package in the statestore. Usage with curl: @@ -157,64 +196,68 @@ def set_package(): return: The response from the statestore. rtype: json """ - helper_type = request.form.get('helper', None) + helper_type = request.form.get("helper", None) if helper_type is None: - return jsonify({"success": False, "message": "Missing helper type."}), 400 + return ( + jsonify({"success": False, "message": "Missing helper type."}), + 400, + ) try: - file = request.files['file'] + file = request.files["file"] except KeyError: return jsonify({"success": False, "message": "Missing file."}), 400 return api.set_compute_package(file=file, helper_type=helper_type) -@app.route('/get_package', methods=['GET']) +@app.route("/get_package", methods=["GET"]) def get_package(): - """ Get the compute package from the statestore. + """Get the compute package from the statestore. return: The compute package as a json object. rtype: json """ return api.get_compute_package() -@app.route('/download_package', methods=['GET']) +@app.route("/download_package", methods=["GET"]) def download_package(): - """ Download the compute package. + """Download the compute package. return: The compute package as a json object. rtype: json """ - name = request.args.get('name', None) + name = request.args.get("name", None) return api.download_compute_package(name) -@app.route('/get_package_checksum', methods=['GET']) +@app.route("/get_package_checksum", methods=["GET"]) def get_package_checksum(): - name = request.args.get('name', None) + name = request.args.get("name", None) return api.get_checksum(name) -@app.route('/get_latest_model', methods=['GET']) +@app.route("/get_latest_model", methods=["GET"]) def get_latest_model(): - """ Get the latest model from the statestore. + """Get the latest model from the statestore. return: The initial model as a json object. rtype: json """ return api.get_latest_model() + # Get initial model endpoint -@app.route('/get_initial_model', methods=['GET']) +@app.route("/get_initial_model", methods=["GET"]) def get_initial_model(): - """ Get the initial model from the statestore. + """Get the initial model from the statestore. return: The initial model as a json object. rtype: json """ return api.get_initial_model() -@app.route('/set_initial_model', methods=['POST']) +@app.route("/set_initial_model", methods=["POST"]) def set_initial_model(): - """ Set the initial model in the statestore and upload to model repository. + """Set the initial model in the statestore and upload to model repository. Usage with curl: curl -k -X POST -F file=@seed.npz @@ -226,45 +269,46 @@ def set_initial_model(): rtype: json """ try: - file = request.files['file'] + file = request.files["file"] except KeyError: return jsonify({"success": False, "message": "Missing file."}), 400 return api.set_initial_model(file) -@app.route('/get_controller_status', methods=['GET']) +@app.route("/get_controller_status", methods=["GET"]) def get_controller_status(): - """ Get the status of the controller. + """Get the status of the controller. return: The status as a json object. rtype: json """ return api.get_controller_status() -@app.route('/get_client_config', methods=['GET']) +@app.route("/get_client_config", methods=["GET"]) def get_client_config(): - """ Get the client configuration. + """Get the client configuration. return: The client configuration as a json object. rtype: json """ - checksum = request.args.get('checksum', True) + checksum = request.args.get("checksum", True) return api.get_client_config(checksum) -@app.route('/get_events', methods=['GET']) +@app.route("/get_events", methods=["GET"]) def get_events(): - """ Get the events from the statestore. + """Get the events from the statestore. return: The events as a json object. rtype: json """ # TODO: except filter with request.get_json() kwargs = request.args.to_dict() + return api.get_events(**kwargs) -@app.route('/list_validations', methods=['GET']) +@app.route("/list_validations", methods=["GET"]) def list_validations(): - """ Get all validations from the statestore. + """Get all validations from the statestore. return: All validations as a json object. rtype: json """ @@ -273,9 +317,9 @@ def list_validations(): return api.get_all_validations(**kwargs) -@app.route('/add_combiner', methods=['POST']) +@app.route("/add_combiner", methods=["POST"]) def add_combiner(): - """ Add a combiner to the network. + """Add a combiner to the network. return: The response from the statestore. rtype: json """ @@ -284,13 +328,13 @@ def add_combiner(): try: response = api.add_combiner(**json_data, remote_addr=remote_addr) except TypeError as e: - return jsonify({'success': False, 'message': str(e)}), 400 + return jsonify({"success": False, "message": str(e)}), 400 return response -@app.route('/add_client', methods=['POST']) +@app.route("/add_client", methods=["POST"]) def add_client(): - """ Add a client to the network. + """Add a client to the network. return: The response from control. rtype: json """ @@ -300,12 +344,45 @@ def add_client(): try: response = api.add_client(**json_data, remote_addr=remote_addr) except TypeError as e: - return jsonify({'success': False, 'message': str(e)}), 400 + return jsonify({"success": False, "message": str(e)}), 400 + return response + + +@app.route("/list_combiners_data", methods=["POST"]) +def list_combiners_data(): + """List data from combiners. + return: The response from control. + rtype: json + """ + + json_data = request.get_json() + + # expects a list of combiner names (strings) in an array + combiners = json_data.get("combiners", None) + + try: + response = api.list_combiners_data(combiners) + except TypeError as e: + return jsonify({"success": False, "message": str(e)}), 400 + return response + + +@app.route("/get_plot_data", methods=["GET"]) +def get_plot_data(): + """Get plot data from the statestore. + rtype: json + """ + + try: + feature = request.args.get("feature", None) + response = api.get_plot_data(feature=feature) + except TypeError as e: + return jsonify({"success": False, "message": str(e)}), 400 return response -if __name__ == '__main__': +if __name__ == "__main__": config = get_controller_config() - port = config['port'] - debug = config['debug'] - app.run(debug=debug, port=port, host='0.0.0.0') + port = config["port"] + debug = config["debug"] + app.run(debug=debug, port=port, host="0.0.0.0") diff --git a/fedn/fedn/network/combiner/server.py b/fedn/fedn/network/combiner/server.py index 11d874ea6..7a9c87ff9 100644 --- a/fedn/fedn/network/combiner/server.py +++ b/fedn/fedn/network/combiner/server.py @@ -98,9 +98,11 @@ def __init__(self, config): break if status == Status.UnAuthorized: print(response, flush=True) + print("Status.UnAuthorized", flush=True) sys.exit("Exiting: Unauthorized") if status == Status.UnMatchedConfig: print(response, flush=True) + print("Status.UnMatchedConfig", flush=True) sys.exit("Exiting: Missing config") cert = announce_config['certificate'] @@ -712,12 +714,16 @@ def ModelUpdateRequestStream(self, response, context): self._send_status(status) + self.tracer.update_client_status(client.name, "online") + while context.is_active(): try: yield q.get(timeout=1.0) except queue.Empty: pass + self.tracer.update_client_status(client.name, "offline") + def ModelValidationStream(self, update, context): """ Model validation stream RPC endpoint. Update status for client is connecting to stream. diff --git a/fedn/fedn/network/controller/control.py b/fedn/fedn/network/controller/control.py index 5f5bc6634..615edb3b5 100644 --- a/fedn/fedn/network/controller/control.py +++ b/fedn/fedn/network/controller/control.py @@ -3,16 +3,19 @@ import time import uuid +from tenacity import (retry, retry_if_exception_type, stop_after_delay, + wait_random) + from fedn.network.combiner.interfaces import CombinerUnavailableError from fedn.network.controller.controlbase import ControlBase from fedn.network.state import ReducerState class UnsupportedStorageBackend(Exception): - """ Exception class for when storage backend is not supported. Passes """ + """Exception class for when storage backend is not supported. Passes""" def __init__(self, message): - """ Constructor method. + """Constructor method. :param message: The exception message. :type message: str @@ -23,46 +26,60 @@ def __init__(self, message): class MisconfiguredStorageBackend(Exception): - """ Exception class for when storage backend is misconfigured. + """Exception class for when storage backend is misconfigured. :param message: The exception message. :type message: str """ def __init__(self, message): - """ Constructor method.""" + """Constructor method.""" self.message = message super().__init__(self.message) class NoModelException(Exception): - """ Exception class for when model is None + """Exception class for when model is None :param message: The exception message. :type message: str """ def __init__(self, message): - """ Constructor method.""" + """Constructor method.""" + self.message = message + super().__init__(self.message) + + +class CombinersNotDoneException(Exception): + """ Exception class for when model is None """ + + def __init__(self, message): + """ Constructor method. + + :param message: The exception message. + :type message: str + + """ self.message = message super().__init__(self.message) class Control(ControlBase): - """ Controller, implementing the overall global training, validation and inference logic. + """Controller, implementing the overall global training, validation and inference logic. :param statestore: A StateStorage instance. :type statestore: class: `fedn.network.statestorebase.StateStorageBase` """ def __init__(self, statestore): - """ Constructor method.""" + """Constructor method.""" super().__init__(statestore) self.name = "DefaultControl" def session(self, config): - """ Execute a new training session. A session consists of one + """Execute a new training session. A session consists of one or several global rounds. All rounds in the same session have the same round_config. @@ -72,7 +89,10 @@ def session(self, config): """ if self._state == ReducerState.instructing: - print("Controller already in INSTRUCTING state. A session is in progress.", flush=True) + print( + "Controller already in INSTRUCTING state. A session is in progress.", + flush=True, + ) return if not self.statestore.get_latest_model(): @@ -80,13 +100,16 @@ def session(self, config): return self._state = ReducerState.instructing - - # Must be called to set info in the db - config['committed_at'] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") - self.new_session(config) + config["committed_at"] = datetime.datetime.now().strftime( + "%Y-%m-%d %H:%M:%S" + ) + self.create_session(config) if not self.statestore.get_latest_model(): - print("No model in model chain, please provide a seed model!", flush=True) + print( + "No model in model chain, please provide a seed model!", + flush=True, + ) self._state = ReducerState.monitoring last_round = int(self.get_latest_round_id()) @@ -96,169 +119,205 @@ def session(self, config): combiner.flush_model_update_queue() # Execute the rounds in this session - for round in range(1, int(config['rounds'] + 1)): + for round in range(1, int(config["rounds"] + 1)): # Increment the round number - if last_round: current_round = last_round + round else: current_round = round try: - _, round_data = self.round(config, current_round) + _, round_data = self.round(config, str(current_round)) except TypeError as e: - print("Could not unpack data from round: {0}".format(e), flush=True) - - print("CONTROL: Round completed with status {}".format( - round_data['status']), flush=True) - - self.tracer.set_round_data(round_data) + print( + "Could not unpack data from round: {0}".format(e), + flush=True, + ) + + print( + "CONTROL: Round completed with status {}".format( + round_data["status"] + ), + flush=True, + ) # TODO: Report completion of session self._state = ReducerState.idle def round(self, session_config, round_id): - """ Execute a single global round. + """ Execute one global round. + + : param session_config: The session config. + : type session_config: dict + : param round_id: The round id. + : type round_id: str - :param session_config: The session config. - :type session_config: dict - :param round_id: The round id. - :type round_id: str(int) """ - round_data = {'round_id': round_id} + self.create_round({'round_id': round_id, 'status': "Pending"}) if len(self.network.get_combiners()) < 1: - print("REDUCER: No combiners connected!", flush=True) - round_data['status'] = 'Failed' - return None, round_data + print("CONTROLLER: Round cannot start, no combiners connected!", flush=True) + self.set_round_status(round_id, 'Failed') + return None, self.statestore.get_round(round_id) - # 1. Assemble round config for this global round, - # and check which combiners are able to participate - # in the round. + # Assemble round config for this global round round_config = copy.deepcopy(session_config) - round_config['rounds'] = 1 - round_config['round_id'] = round_id - round_config['task'] = 'training' - round_config['model_id'] = self.statestore.get_latest_model() - round_config['helper_type'] = self.statestore.get_helper() + round_config["rounds"] = 1 + round_config["round_id"] = round_id + round_config["task"] = "training" + round_config["model_id"] = self.statestore.get_latest_model() + round_config["helper_type"] = self.statestore.get_helper() + + self.set_round_config(round_id, round_config) + + # Get combiners that are able to participate in round, given round_config + participating_combiners = self.get_participating_combiners(round_config) - combiners = self.get_participating_combiners(round_config) - round_start = self.evaluate_round_start_policy(combiners) + # Check if the policy to start the round is met + round_start = self.evaluate_round_start_policy(participating_combiners) if round_start: - print("CONTROL: round start policy met, participating combiners {}".format( - combiners), flush=True) + print("CONTROL: round start policy met, {} participating combiners.".format( + len(participating_combiners)), flush=True) else: print("CONTROL: Round start policy not met, skipping round!", flush=True) - round_data['status'] = 'Failed' - return None + self.set_round_status(round_id, 'Failed') + return None, self.statestore.get_round(round_id) + + # Ask participating combiners to coordinate model updates + _ = self.request_model_updates(participating_combiners) + # TODO: Check response - round_data['round_config'] = round_config + # Wait until participating combiners have produced an updated global model, + # or round times out. + def do_if_round_times_out(result): + print("CONTROL: Round timed out!", flush=True) - # 2. Ask participating combiners to coordinate model updates - _ = self.request_model_updates(combiners) + @retry(wait=wait_random(min=1.0, max=2.0), + stop=stop_after_delay(session_config['round_timeout']), + retry_error_callback=do_if_round_times_out, + retry=retry_if_exception_type(CombinersNotDoneException)) + def combiners_done(): - # Wait until participating combiners have produced an updated global model. - wait = 0.0 - # dict to store combiners that have successfully produced an updated model - updated = {} - # wait until all combiners have produced an updated model or until round timeout - print("CONTROL: Fetching round config (ID: {round_id}) from statestore:".format( - round_id=round_id), flush=True) - while len(updated) < len(combiners): round = self.statestore.get_round(round_id) - if round: - print("CONTROL: Round found!", flush=True) - # For each combiner in the round, check if it has produced an updated model (status == 'Success') - for combiner in round['combiners']: - print(combiner, flush=True) - if combiner['status'] == 'Success': - if combiner['name'] not in updated.keys(): - # Add combiner to updated dict - updated[combiner['name']] = combiner['model_id'] - # Print combiner status - print("CONTROL: Combiner {name} status: {status}".format( - name=combiner['name'], status=combiner['status']), flush=True) - else: - # Print every 10 seconds based on value of wait - if wait % 10 == 0: - print("CONTROL: Waiting for round to complete...", flush=True) - if wait >= session_config['round_timeout']: - print("CONTROL: Round timeout! Exiting round...", flush=True) - break - # Update wait time used for timeout - time.sleep(1.0) - wait += 1.0 - - round_valid = self.evaluate_round_validity_policy(updated) + if 'combiners' not in round: + # TODO: use logger + print("CONTROL: Waiting for combiners to update model...", flush=True) + raise CombinersNotDoneException("Combiners have not yet reported.") + + if len(round['combiners']) < len(participating_combiners): + print("CONTROL: Waiting for combiners to update model...", flush=True) + raise CombinersNotDoneException("All combiners have not yet reported.") + + return True + + combiners_done() + + # Due to the distributed nature of the computation, there might be a + # delay before combiners have reported the round data to the db, + # so we need some robustness here. + @retry(wait=wait_random(min=0.1, max=1.0), + retry=retry_if_exception_type(KeyError)) + def check_combiners_done_reporting(): + round = self.statestore.get_round(round_id) + combiners = round['combiners'] + return combiners + + _ = check_combiners_done_reporting() + + round = self.statestore.get_round(round_id) + round_valid = self.evaluate_round_validity_policy(round) if not round_valid: print("REDUCER CONTROL: Round invalid!", flush=True) - round_data['status'] = 'Failed' - return None, round_data + self.set_round_status(round_id, 'Failed') + return None, self.statestore.get_round(round_id) - print("CONTROL: Reducing models from combiners...", flush=True) - # 3. Reduce combiner models into a global model + print("CONTROL: Reducing combiner level models...", flush=True) + # Reduce combiner models into a new global model + round_data = {} try: - model, data = self.reduce(updated) + round = self.statestore.get_round(round_id) + model, data = self.reduce(round['combiners']) round_data['reduce'] = data print("CONTROL: Done reducing models from combiners!", flush=True) except Exception as e: print("CONTROL: Failed to reduce models from combiners: {}".format( e), flush=True) - round_data['status'] = 'Failed' - return None, round_data + self.set_round_status(round_id, 'Failed') + return None, self.statestore.get_round(round_id) - # 6. Commit the global model to model trail + # Commit the new global model to the model trail if model is not None: - print("CONTROL: Committing global model to model trail...", flush=True) + print( + "CONTROL: Committing global model to model trail...", + flush=True, + ) tic = time.time() model_id = uuid.uuid4() - self.commit(model_id, model) - round_data['time_commit'] = time.time() - tic - print("CONTROL: Done committing global model to model trail!", flush=True) + session_id = ( + session_config["session_id"] + if "session_id" in session_config + else None + ) + self.commit(model_id, model, session_id) + round_data["time_commit"] = time.time() - tic + print( + "CONTROL: Done committing global model to model trail!", + flush=True, + ) else: - print("REDUCER: failed to update model in round with config {}".format( - session_config), flush=True) - round_data['status'] = 'Failed' - return None, round_data + print( + "REDUCER: failed to update model in round with config {}".format( + session_config + ), + flush=True, + ) + self.set_round_status(round_id, 'Failed') + return None, self.statestore.get_round(round_id) - round_data['status'] = 'Success' + self.set_round_status(round_id, 'Success') # 4. Trigger participating combiner nodes to execute a validation round for the current model - validate = session_config['validate'] + validate = session_config["validate"] if validate: combiner_config = copy.deepcopy(session_config) - combiner_config['round_id'] = round_id - combiner_config['model_id'] = self.statestore.get_latest_model() - combiner_config['task'] = 'validation' - combiner_config['helper_type'] = self.statestore.get_helper() + combiner_config["round_id"] = round_id + combiner_config["model_id"] = self.statestore.get_latest_model() + combiner_config["task"] = "validation" + combiner_config["helper_type"] = self.statestore.get_helper() - validating_combiners = self._select_participating_combiners( + validating_combiners = self.get_participating_combiners( combiner_config) for combiner, combiner_config in validating_combiners: try: - print("CONTROL: Submitting validation round to combiner {}".format( - combiner), flush=True) + print( + "CONTROL: Submitting validation round to combiner {}".format( + combiner + ), + flush=True, + ) combiner.submit(combiner_config) except CombinerUnavailableError: self._handle_unavailable_combiner(combiner) pass - return model_id, round_data + self.set_round_data(round_id, round_data) + self.set_round_status(round_id, 'Finished') + return model_id, self.statestore.get_round(round_id) def reduce(self, combiners): - """ Combine updated models from Combiner nodes into one global model. + """Combine updated models from Combiner nodes into one global model. - :param combiners: dict of combiner names (key) and model IDs (value) to reduce - :type combiners: dict + : param combiners: dict of combiner names(key) and model IDs(value) to reduce + : type combiners: dict """ meta = {} - meta['time_fetch_model'] = 0.0 - meta['time_load_model'] = 0.0 - meta['time_aggregate_model'] = 0.0 + meta["time_fetch_model"] = 0.0 + meta["time_load_model"] = 0.0 + meta["time_aggregate_model"] = 0.0 i = 1 model = None @@ -267,19 +326,28 @@ def reduce(self, combiners): print("REDUCER: No combiners to reduce!", flush=True) return model, meta - for name, model_id in combiners.items(): - + for combiner in combiners: + name = combiner['name'] + model_id = combiner['model_id'] # TODO: Handle inactive RPC error in get_model and raise specific error - print("REDUCER: Fetching model ({model_id}) from combiner {name}".format( - model_id=model_id, name=name), flush=True) + print( + "REDUCER: Fetching model ({model_id}) from combiner {name}".format( + model_id=model_id, name=name + ), + flush=True, + ) try: tic = time.time() - combiner = self.get_combiner(name) - data = combiner.get_model(model_id) + combiner_interface = self.get_combiner(name) + data = combiner_interface.get_model(model_id) meta['time_fetch_model'] += (time.time() - tic) except Exception as e: - print("REDUCER: Failed to fetch model from combiner {}: {}".format( - name, e), flush=True) + print( + "REDUCER: Failed to fetch model from combiner {}: {}".format( + name, e + ), + flush=True, + ) data = None if data is not None: @@ -288,23 +356,23 @@ def reduce(self, combiners): helper = self.get_helper() data.seek(0) model_next = helper.load(data) - meta['time_load_model'] += (time.time() - tic) + meta["time_load_model"] += time.time() - tic tic = time.time() model = helper.increment_average(model, model_next, i, i) - meta['time_aggregate_model'] += (time.time() - tic) + meta["time_aggregate_model"] += time.time() - tic except Exception: tic = time.time() data.seek(0) model = helper.load(data) - meta['time_aggregate_model'] += (time.time() - tic) + meta["time_aggregate_model"] += time.time() - tic i = i + 1 return model, meta def infer_instruct(self, config): - """ Main entrypoint for executing the inference compute plan. + """Main entrypoint for executing the inference compute plan. - :param config: configuration for the inference round + : param config: configuration for the inference round """ # Check/set instucting state @@ -330,9 +398,9 @@ def infer_instruct(self, config): self.__state = ReducerState.idle def inference_round(self, config): - """ Execute an inference round. + """Execute an inference round. - :param config: configuration for the inference round + : param config: configuration for the inference round """ # Init meta @@ -345,21 +413,28 @@ def inference_round(self, config): # Setup combiner configuration combiner_config = copy.deepcopy(config) - combiner_config['model_id'] = self.statestore.get_latest_model() - combiner_config['task'] = 'inference' - combiner_config['helper_type'] = self.statestore.get_framework() + combiner_config["model_id"] = self.statestore.get_latest_model() + combiner_config["task"] = "inference" + combiner_config["helper_type"] = self.statestore.get_framework() # Select combiners - validating_combiners = self._select_round_combiners( + validating_combiners = self.get_participating_combiners( combiner_config) # Test round start policy round_start = self.check_round_start_policy(validating_combiners) if round_start: - print("CONTROL: round start policy met, participating combiners {}".format( - validating_combiners), flush=True) + print( + "CONTROL: round start policy met, participating combiners {}".format( + validating_combiners + ), + flush=True, + ) else: - print("CONTROL: Round start policy not met, skipping round!", flush=True) + print( + "CONTROL: Round start policy not met, skipping round!", + flush=True, + ) return None # Synch combiners with latest model and trigger inference diff --git a/fedn/fedn/network/controller/controlbase.py b/fedn/fedn/network/controller/controlbase.py index e38d31e38..fab6a2027 100644 --- a/fedn/fedn/network/controller/controlbase.py +++ b/fedn/fedn/network/controller/controlbase.py @@ -11,7 +11,7 @@ from fedn.network.state import ReducerState # Maximum number of tries to connect to statestore and retrieve storage configuration -MAX_TRIES_BACKEND = os.getenv('MAX_TRIES_BACKEND', 10) +MAX_TRIES_BACKEND = os.getenv("MAX_TRIES_BACKEND", 10) class UnsupportedStorageBackend(Exception): @@ -27,7 +27,7 @@ class MisconfiguredHelper(Exception): class ControlBase(ABC): - """ Base class and interface for a global controller. + """Base class and interface for a global controller. Override this class to implement a global training strategy (control). :param statestore: The statestore object. @@ -36,7 +36,7 @@ class ControlBase(ABC): @abstractmethod def __init__(self, statestore): - """ Constructor. """ + """Constructor.""" self._state = ReducerState.setup self.statestore = statestore @@ -52,26 +52,36 @@ def __init__(self, statestore): not_ready = False else: print( - "REDUCER CONTROL: Storage backend not configured, waiting...", flush=True) + "REDUCER CONTROL: Storage backend not configured, waiting...", + flush=True, + ) sleep(5) tries += 1 if tries > MAX_TRIES_BACKEND: raise Exception except Exception: print( - "REDUCER CONTROL: Failed to retrive storage configuration, exiting.", flush=True) + "REDUCER CONTROL: Failed to retrive storage configuration, exiting.", + flush=True, + ) raise MisconfiguredStorageBackend() - if storage_config['storage_type'] == 'S3': - self.model_repository = S3ModelRepository(storage_config['storage_config']) + if storage_config["storage_type"] == "S3": + self.model_repository = S3ModelRepository( + storage_config["storage_config"] + ) else: - print("REDUCER CONTROL: Unsupported storage backend, exiting.", flush=True) + print( + "REDUCER CONTROL: Unsupported storage backend, exiting.", + flush=True, + ) raise UnsupportedStorageBackend() # The tracer is a helper that manages state in the database backend statestore_config = statestore.get_config() self.tracer = MongoTracer( - statestore_config['mongo_config'], statestore_config['network_id']) + statestore_config["mongo_config"], statestore_config["network_id"] + ) if self.statestore.is_inited(): self._state = ReducerState.idle @@ -89,7 +99,7 @@ def reduce(self, combiners): pass def get_helper(self): - """ Get a helper instance from global config. + """Get a helper instance from global config. :return: Helper instance. :rtype: :class:`fedn.utils.plugins.helperbase.HelperBase` @@ -97,11 +107,15 @@ def get_helper(self): helper_type = self.statestore.get_helper() helper = fedn.utils.helpers.get_helper(helper_type) if not helper: - raise MisconfiguredHelper("Unsupported helper type {}, please configure compute_package.helper !".format(helper_type)) + raise MisconfiguredHelper( + "Unsupported helper type {}, please configure compute_package.helper !".format( + helper_type + ) + ) return helper def get_state(self): - """ Get the current state of the controller. + """Get the current state of the controller. :return: The current state. :rtype: :class:`fedn.network.state.ReducerState` @@ -109,7 +123,7 @@ def get_state(self): return self._state def idle(self): - """ Check if the controller is idle. + """Check if the controller is idle. :return: True if idle, False otherwise. :rtype: bool @@ -139,7 +153,7 @@ def get_latest_round_id(self): if not last_round: return 0 else: - return last_round['round_id'] + return last_round["round_id"] def get_latest_round(self): round = self.statestore.get_latest_round() @@ -153,70 +167,126 @@ def get_compute_package_name(self): definition = self.statestore.get_compute_package() if definition: try: - package_name = definition['filename'] + package_name = definition["filename"] return package_name except (IndexError, KeyError): print( - "No context filename set for compute context definition", flush=True) + "No context filename set for compute context definition", + flush=True, + ) return None else: return None def set_compute_package(self, filename, path): - """ Persist the configuration for the compute package. """ + """Persist the configuration for the compute package.""" self.model_repository.set_compute_package(filename, path) self.statestore.set_compute_package(filename) - def get_compute_package(self, compute_package=''): + def get_compute_package(self, compute_package=""): """ :param compute_package: :return: """ - if compute_package == '': + if compute_package == "": compute_package = self.get_compute_package_name() if compute_package: return self.model_repository.get_compute_package(compute_package) else: return None - def new_session(self, config): + def create_session(self, config): """ Initialize a new session in backend db. """ if "session_id" not in config.keys(): session_id = uuid.uuid4() - config['session_id'] = str(session_id) + config["session_id"] = str(session_id) else: - session_id = config['session_id'] + session_id = config["session_id"] - self.tracer.new_session(id=session_id) + self.tracer.create_session(id=session_id) self.tracer.set_session_config(session_id, config) + def create_round(self, round_data): + """Initialize a new round in backend db. """ + + self.tracer.create_round(round_data) + + def set_round_data(self, round_id, round_data): + """ Set round data. + + :param round_id: The round unique identifier + :type round_id: str + :param round_data: The status + :type status: dict + """ + self.tracer.set_round_data(round_id, round_data) + + def set_round_status(self, round_id, status): + """ Set the round round stats. + + :param round_id: The round unique identifier + :type round_id: str + :param status: The status + :type status: str + """ + self.tracer.set_round_status(round_id, status) + + def set_round_config(self, round_id, round_config): + """ Upate round in backend db. + + :param round_id: The round unique identifier + :type round_id: str + :param round_config: The round configuration + :type round_config: dict + """ + self.tracer.set_round_config(round_id, round_config) + def request_model_updates(self, combiners): - """Call Combiner server RPC to get a model update. """ + """Ask Combiner server to produce a model update. + + :param combiners: A list of combiners + :type combiners: tuple (combiner, comboner_round_config) + """ cl = [] for combiner, combiner_round_config in combiners: response = combiner.submit(combiner_round_config) cl.append((combiner, response)) return cl - def commit(self, model_id, model=None): - """ Commit a model to the global model trail. The model commited becomes the lastest consensus model. """ + def commit(self, model_id, model=None, session_id=None): + """Commit a model to the global model trail. The model commited becomes the lastest consensus model. + + :param model_id: Unique identifier for the model to commit. + :type model_id: str (uuid) + :param model: The model object to commit + :type model: BytesIO + :param session_id: Unique identifier for the session + :type session_id: str + """ helper = self.get_helper() if model is not None: - print("CONTROL: Saving model file temporarily to disk...", flush=True) + print( + "CONTROL: Saving model file temporarily to disk...", flush=True + ) outfile_name = helper.save(model) print("CONTROL: Uploading model to Minio...", flush=True) model_id = self.model_repository.set_model( - outfile_name, is_file=True) + outfile_name, is_file=True + ) print("CONTROL: Deleting temporary model file...", flush=True) os.unlink(outfile_name) - print("CONTROL: Committing model {} to global model trail in statestore...".format( - model_id), flush=True) - self.statestore.set_latest_model(model_id) + print( + "CONTROL: Committing model {} to global model trail in statestore...".format( + model_id + ), + flush=True, + ) + self.statestore.set_latest_model(model_id, session_id) def get_combiner(self, name): for combiner in self.network.get_combiners(): @@ -226,7 +296,7 @@ def get_combiner(self, name): def get_participating_combiners(self, combiner_round_config): """Assemble a list of combiners able to participate in a round as - descibed by combiner_round_config. + descibed by combiner_round_config. """ combiners = [] for combiner in self.network.get_combiners(): @@ -238,70 +308,75 @@ def get_participating_combiners(self, combiner_round_config): if combiner_state is not None: is_participating = self.evaluate_round_participation_policy( - combiner_round_config, combiner_state) + combiner_round_config, combiner_state + ) if is_participating: combiners.append((combiner, combiner_round_config)) return combiners - def evaluate_round_participation_policy(self, compute_plan, combiner_state): - """ Evaluate policy for combiner round-participation. - A combiner participates if it is responsive and reports enough - active clients to participate in the round. + def evaluate_round_participation_policy( + self, compute_plan, combiner_state + ): + """Evaluate policy for combiner round-participation. + A combiner participates if it is responsive and reports enough + active clients to participate in the round. """ - if compute_plan['task'] == 'training': - nr_active_clients = int(combiner_state['nr_active_trainers']) - elif compute_plan['task'] == 'validation': - nr_active_clients = int(combiner_state['nr_active_validators']) + if compute_plan["task"] == "training": + nr_active_clients = int(combiner_state["nr_active_trainers"]) + elif compute_plan["task"] == "validation": + nr_active_clients = int(combiner_state["nr_active_validators"]) else: print("Invalid task type!", flush=True) return False - if int(compute_plan['clients_required']) <= nr_active_clients: + if int(compute_plan["clients_required"]) <= nr_active_clients: return True else: return False def evaluate_round_start_policy(self, combiners): - """ Check if the policy to start a round is met. """ - if len(combiners) > 0: + """Check if the policy to start a round is met. + :param combiners: A list of combiners + :type combiners: list + :return: True if the round policy is mer, otherwise False + :rtype: bool + """ + if len(combiners) > 0: return True else: return False - def evaluate_round_validity_policy(self, combiners): - """ Check if the round should be seen as valid. + def evaluate_round_validity_policy(self, round): + """ Check if the round is valid. - At the end of the round, before committing a model to the global model trail, - we check if the round validity policy has been met. This can involve - e.g. asserting that a certain number of combiners have reported in an - updated model, or that criteria on model performance have been met. - """ - if combiners.keys() == []: - return False - else: - return True + At the end of the round, before committing a model to the global model trail, + we check if the round validity policy has been met. This can involve + e.g. asserting that a certain number of combiners have reported in an + updated model, or that criteria on model performance have been met. - def _select_participating_combiners(self, compute_plan): - participating_combiners = [] - for combiner in self.network.get_combiners(): + :param round: The round object + :rtype round: dict + :return: True if the policy is met, otherwise False + :rtype: bool + """ + model_ids = [] + for combiner in round['combiners']: try: - combiner_state = combiner.report() - except CombinerUnavailableError: - self._handle_unavailable_combiner(combiner) - combiner_state = None + model_ids.append(combiner['model_id']) + except KeyError: + pass - if combiner_state: - is_participating = self.evaluate_round_participation_policy( - compute_plan, combiner_state) - if is_participating: - participating_combiners.append((combiner, compute_plan)) - return participating_combiners + if len(model_ids) == 0: + return False + + return True def state(self): - """ + """ Get the current state of the controller - :return: + :return: The state + :rype: str """ return self._state diff --git a/fedn/fedn/network/dashboard/restservice.py b/fedn/fedn/network/dashboard/restservice.py index 6e44897a1..808bcb272 100644 --- a/fedn/fedn/network/dashboard/restservice.py +++ b/fedn/fedn/network/dashboard/restservice.py @@ -23,8 +23,8 @@ from fedn.network.state import ReducerState, ReducerStateToString from fedn.utils.checksum import sha -UPLOAD_FOLDER = '/app/client/package/' -ALLOWED_EXTENSIONS = {'gz', 'bz2', 'tar', 'zip', 'tgz'} +UPLOAD_FOLDER = "/app/client/package/" +ALLOWED_EXTENSIONS = {"gz", "bz2", "tar", "zip", "tgz"} def allowed_file(filename): @@ -33,8 +33,10 @@ def allowed_file(filename): :param filename: :return: """ - return '.' in filename and \ - filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS + return ( + "." in filename + and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS + ) def encode_auth_token(secret_key): @@ -43,16 +45,17 @@ def encode_auth_token(secret_key): """ try: payload = { - 'exp': datetime.datetime.utcnow() + datetime.timedelta(days=90, seconds=0), - 'iat': datetime.datetime.utcnow(), - 'status': 'Success' + "exp": datetime.datetime.utcnow() + + datetime.timedelta(days=90, seconds=0), + "iat": datetime.datetime.utcnow(), + "status": "Success", } - token = jwt.encode( - payload, - secret_key, - algorithm='HS256' + token = jwt.encode(payload, secret_key, algorithm="HS256") + print( + "\n\n\nSECURE MODE ENABLED, USE TOKEN TO ACCESS REDUCER: **** {} ****\n\n\n".format( + token + ) ) - print('\n\n\nSECURE MODE ENABLED, USE TOKEN TO ACCESS REDUCER: **** {} ****\n\n\n'.format(token)) return token except Exception as e: return e @@ -64,56 +67,49 @@ def decode_auth_token(auth_token, secret): :return: string """ try: - payload = jwt.decode( - auth_token, - secret, - algorithms=['HS256'] - ) + payload = jwt.decode(auth_token, secret, algorithms=["HS256"]) return payload["status"] except jwt.ExpiredSignatureError as e: print(e) - return 'Token has expired.' + return "Token has expired." except jwt.InvalidTokenError as e: print(e) - return 'Invalid token.' + return "Invalid token." class ReducerRestService: - """ - - """ + """ """ def __init__(self, config, control, statestore, certificate_manager): - print("config object!: \n\n\n\n{}".format(config)) - if config['host']: - self.host = config['host'] + if config["host"]: + self.host = config["host"] else: self.host = None - self.name = config['name'] + self.name = config["name"] - self.port = config['port'] - self.network_id = config['name'] + '-network' + self.port = config["port"] + self.network_id = config["name"] + "-network" - if 'token' in config.keys(): + if "token" in config.keys(): self.token_auth_enabled = True else: self.token_auth_enabled = False - if 'secret_key' in config.keys(): - self.SECRET_KEY = config['secret_key'] + if "secret_key" in config.keys(): + self.SECRET_KEY = config["secret_key"] else: self.SECRET_KEY = None - if 'use_ssl' in config.keys(): - self.use_ssl = config['use_ssl'] + if "use_ssl" in config.keys(): + self.use_ssl = config["use_ssl"] self.remote_compute_package = config["remote_compute_package"] if self.remote_compute_package: - self.package = 'remote' + self.package = "remote" else: - self.package = 'local' + self.package = "local" self.control = control self.statestore = statestore @@ -125,9 +121,7 @@ def to_dict(self): :return: """ - data = { - 'name': self.name - } + data = {"name": self.name} return data def check_compute_package(self): @@ -165,24 +159,40 @@ def check_configured_response(self): :rtype: json """ if self.control.state() == ReducerState.setup: - return jsonify({'status': 'retry', - 'package': self.package, - 'msg': "Controller is not configured."}) + return jsonify( + { + "status": "retry", + "package": self.package, + "msg": "Controller is not configured.", + } + ) if not self.check_compute_package(): - return jsonify({'status': 'retry', - 'package': self.package, - 'msg': "Compute package is not configured. Please upload the compute package."}) + return jsonify( + { + "status": "retry", + "package": self.package, + "msg": "Compute package is not configured. Please upload the compute package.", + } + ) if not self.check_initial_model(): - return jsonify({'status': 'retry', - 'package': self.package, - 'msg': "Initial model is not configured. Please upload the model."}) + return jsonify( + { + "status": "retry", + "package": self.package, + "msg": "Initial model is not configured. Please upload the model.", + } + ) if not self.control.idle(): - return jsonify({'status': 'retry', - 'package': self.package, - 'msg': "Controller is not in idle state, try again later. "}) + return jsonify( + { + "status": "retry", + "package": self.package, + "msg": "Controller is not in idle state, try again later. ", + } + ) return None def check_configured(self): @@ -192,17 +202,29 @@ def check_configured(self): :return: Rendered html template or None """ if not self.check_compute_package(): - return render_template('setup.html', client=self.name, state=ReducerStateToString(self.control.state()), - logs=None, refresh=False, - message='Please set the compute package') + return render_template( + "setup.html", + client=self.name, + state=ReducerStateToString(self.control.state()), + logs=None, + refresh=False, + message="Please set the compute package", + ) if self.control.state() == ReducerState.setup: - return render_template('setup.html', client=self.name, state=ReducerStateToString(self.control.state()), - logs=None, refresh=True, - message='Warning. Reducer is not base-configured. please do so with config file.') + return render_template( + "setup.html", + client=self.name, + state=ReducerStateToString(self.control.state()), + logs=None, + refresh=True, + message="Warning. Reducer is not base-configured. please do so with config file.", + ) if not self.check_initial_model(): - return render_template('setup_model.html', message="Please set the initial model.") + return render_template( + "setup_model.html", message="Please set the initial model." + ) return None @@ -216,31 +238,37 @@ def authorize(self, r, secret): """ try: # Get token - if 'Authorization' in r.headers: # header auth - request_token = r.headers.get('Authorization').split()[1] - elif 'token' in r.args: # args auth - request_token = str(r.args.get('token')) - elif 'fedn_token' in r.cookies: - request_token = r.cookies.get('fedn_token') + if "Authorization" in r.headers: # header auth + request_token = r.headers.get("Authorization").split()[1] + elif "token" in r.args: # args auth + request_token = str(r.args.get("token")) + elif "fedn_token" in r.cookies: + request_token = r.cookies.get("fedn_token") else: # no token provided - print('Authorization failed. No token provided.', flush=True) + print("Authorization failed. No token provided.", flush=True) abort(401) # Log token and secret print( - f'Secret: {secret}. Request token: {request_token}.', flush=True) + f"Secret: {secret}. Request token: {request_token}.", + flush=True, + ) # Authenticate status = decode_auth_token(request_token, secret) - if status == 'Success': + if status == "Success": return True else: - print('Authorization failed. Status: "{}"'.format( - status), flush=True) + print( + 'Authorization failed. Status: "{}"'.format(status), + flush=True, + ) abort(401) except Exception as e: - print('Authorization failed. Expection encountered: "{}".'.format( - e), flush=True) + print( + 'Authorization failed. Expection encountered: "{}".'.format(e), + flush=True, + ) abort(401) def run(self): @@ -250,10 +278,10 @@ def run(self): """ app = Flask(__name__) - app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER - app.config['SECRET_KEY'] = self.SECRET_KEY + app.config["UPLOAD_FOLDER"] = UPLOAD_FOLDER + app.config["SECRET_KEY"] = self.SECRET_KEY - @app.route('/') + @app.route("/") def index(): """ @@ -261,7 +289,7 @@ def index(): """ # Token auth if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) # Render template not_configured_template = self.check_configured() @@ -269,29 +297,37 @@ def index(): template = not_configured_template else: events = self.control.get_events() - message = request.args.get('message', None) - message_type = request.args.get('message_type', None) - template = render_template('events.html', client=self.name, state=ReducerStateToString(self.control.state()), - events=events, - logs=None, refresh=True, configured=True, message=message, message_type=message_type) + message = request.args.get("message", None) + message_type = request.args.get("message_type", None) + template = render_template( + "events.html", + client=self.name, + state=ReducerStateToString(self.control.state()), + events=events, + logs=None, + refresh=True, + configured=True, + message=message, + message_type=message_type, + ) # Set token cookie in response if needed response = make_response(template) - if 'token' in request.args: # args auth - response.set_cookie('fedn_token', str(request.args['token'])) + if "token" in request.args: # args auth + response.set_cookie("fedn_token", str(request.args["token"])) # Return response return response - @app.route('/status') + @app.route("/status") def status(): """ :return: """ - return {'state': ReducerStateToString(self.control.state())} + return {"state": ReducerStateToString(self.control.state())} - @app.route('/netgraph') + @app.route("/netgraph") def netgraph(): """ Creates nodes and edges for network graph @@ -299,16 +335,18 @@ def netgraph(): :return: nodes and edges as keys :rtype: dict """ - result = {'nodes': [], 'edges': []} - - result['nodes'].append({ - "id": "reducer", - "label": "Reducer", - "role": 'reducer', - "status": 'active', - "name": 'reducer', # TODO: get real host name - "type": 'reducer', - }) + result = {"nodes": [], "edges": []} + + result["nodes"].append( + { + "id": "reducer", + "label": "Reducer", + "role": "reducer", + "status": "active", + "name": "reducer", # TODO: get real host name + "type": "reducer", + } + ) combiner_info = combiner_status() client_info = client_status() @@ -319,49 +357,55 @@ def netgraph(): for combiner in combiner_info: print("combiner info {}".format(combiner_info), flush=True) try: - result['nodes'].append({ - "id": combiner['name'], # "n{}".format(count), - "label": "Combiner ({} clients)".format(combiner['nr_active_clients']), - "role": 'combiner', - "status": 'active', # TODO: Hard-coded, combiner_info does not contain status - "name": combiner['name'], - "type": 'combiner', - }) + result["nodes"].append( + { + "id": combiner["name"], # "n{}".format(count), + "label": "Combiner ({} clients)".format( + combiner["nr_active_clients"] + ), + "role": "combiner", + "status": "active", # TODO: Hard-coded, combiner_info does not contain status + "name": combiner["name"], + "type": "combiner", + } + ) except Exception as err: print(err) - for client in client_info['active_clients']: + for client in client_info["active_clients"]: try: - if client['status'] != 'offline': - result['nodes'].append({ - "id": str(client['_id']), - "label": "Client", - "role": client['role'], - "status": client['status'], - "name": client['name'], - "combiner": client['combiner'], - "type": 'client', - }) + if client["status"] != "offline": + result["nodes"].append( + { + "id": str(client["_id"]), + "label": "Client", + "role": client["role"], + "status": client["status"], + "name": client["name"], + "combiner": client["combiner"], + "type": "client", + } + ) except Exception as err: print(err) count = 0 - for node in result['nodes']: + for node in result["nodes"]: try: - if node['type'] == 'combiner': - result['edges'].append( + if node["type"] == "combiner": + result["edges"].append( { "id": "e{}".format(count), - "source": node['id'], - "target": 'reducer', + "source": node["id"], + "target": "reducer", } ) - elif node['type'] == 'client': - result['edges'].append( + elif node["type"] == "client": + result["edges"].append( { "id": "e{}".format(count), - "source": node['combiner'], - "target": node['id'], + "source": node["combiner"], + "target": node["id"], } ) except Exception: @@ -369,59 +413,75 @@ def netgraph(): count = count + 1 return result - @app.route('/networkgraph') + @app.route("/networkgraph") def network_graph(): - try: plot = Plot(self.control.statestore) result = netgraph() - df_nodes = pd.DataFrame(result['nodes']) - df_edges = pd.DataFrame(result['edges']) + df_nodes = pd.DataFrame(result["nodes"]) + df_edges = pd.DataFrame(result["edges"]) graph = plot.make_netgraph_plot(df_edges, df_nodes) return json.dumps(json_item(graph, "myplot")) except Exception: raise # return '' - @app.route('/events') + @app.route("/events") def events(): """ :return: """ + response = self.control.get_events() + events = [] + + result = response["result"] + + for evt in result: + events.append(evt) + + return jsonify({"result": events, "count": response["count"]}) + json_docs = [] for doc in self.control.get_events(): json_doc = json.dumps(doc, default=json_util.default) json_docs.append(json_doc) json_docs.reverse() - return {'events': json_docs} - @app.route('/add') + return {"events": json_docs} + + @app.route("/add") def add(): - """ Add a combiner to the network. """ + """Add a combiner to the network.""" print("Adding combiner to network:", flush=True) if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) if self.control.state() == ReducerState.setup: - return jsonify({'status': 'retry'}) - - name = request.args.get('name', None) - address = str(request.args.get('address', None)) - fqdn = str(request.args.get('fqdn', None)) - port = request.args.get('port', None) - secure_grpc = request.args.get('secure', None) - - if port is None or address is None or name is None or secure_grpc is None: + return jsonify({"status": "retry"}) + + name = request.args.get("name", None) + address = str(request.args.get("address", None)) + fqdn = str(request.args.get("fqdn", None)) + port = request.args.get("port", None) + secure_grpc = request.args.get("secure", None) + + if ( + port is None + or address is None + or name is None + or secure_grpc is None + ): return "Please specify correct parameters." # Try to retrieve combiner from db combiner = self.control.network.get_combiner(name) if not combiner: - if secure_grpc == 'True': + if secure_grpc == "True": certificate, key = self.certificate_manager.get_or_create( - address).get_keypair_raw() + address + ).get_keypair_raw() _ = base64.b64encode(certificate) _ = base64.b64encode(key) @@ -437,23 +497,24 @@ def add(): port=port, certificate=copy.deepcopy(certificate), key=copy.deepcopy(key), - ip=request.remote_addr) + ip=request.remote_addr, + ) self.control.network.add_combiner(combiner) combiner = self.control.network.get_combiner(name) ret = { - 'status': 'added', - 'storage': self.control.statestore.get_storage_backend(), - 'statestore': self.control.statestore.get_config(), - 'certificate': combiner.get_certificate(), - 'key': combiner.get_key() + "status": "added", + "storage": self.control.statestore.get_storage_backend(), + "statestore": self.control.statestore.get_config(), + "certificate": combiner.get_certificate(), + "key": combiner.get_key(), } return jsonify(ret) - @app.route('/eula', methods=['GET', 'POST']) + @app.route("/eula", methods=["GET", "POST"]) def eula(): """ @@ -462,9 +523,9 @@ def eula(): for r in request.headers: print("header contains: {}".format(r), flush=True) - return render_template('eula.html', configured=True) + return render_template("eula.html", configured=True) - @app.route('/models', methods=['GET', 'POST']) + @app.route("/models", methods=["GET", "POST"]) def models(): """ @@ -472,13 +533,12 @@ def models(): """ # Token auth if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) - if request.method == 'POST': + if request.method == "POST": # upload seed file - uploaded_seed = request.files['seed'] + uploaded_seed = request.files["seed"] if uploaded_seed: - a = BytesIO() a.seek(0, 0) uploaded_seed.seek(0) @@ -504,23 +564,31 @@ def models(): h_latest_model_id = self.statestore.get_latest_model() model_info = self.control.get_model_info() - return render_template('models.html', box_plot=box_plot, metrics=valid_metrics, h_latest_model_id=h_latest_model_id, seed=True, - model_info=model_info, configured=True) + return render_template( + "models.html", + box_plot=box_plot, + metrics=valid_metrics, + h_latest_model_id=h_latest_model_id, + seed=True, + model_info=model_info, + configured=True, + ) seed = True - return redirect(url_for('models', seed=seed)) + return redirect(url_for("models", seed=seed)) - @app.route('/delete_model_trail', methods=['GET', 'POST']) + @app.route("/delete_model_trail", methods=["GET", "POST"]) def delete_model_trail(): """ :return: """ - if request.method == 'POST': - + if request.method == "POST": statestore_config = self.control.statestore.get_config() self.tracer = MongoTracer( - statestore_config['mongo_config'], statestore_config['network_id']) + statestore_config["mongo_config"], + statestore_config["network_id"], + ) try: self.control.drop_models() except Exception: @@ -528,28 +596,28 @@ def delete_model_trail(): # drop objects in minio self.control.delete_bucket_objects() - return redirect(url_for('models')) + return redirect(url_for("models")) seed = True - return redirect(url_for('models', seed=seed)) + return redirect(url_for("models", seed=seed)) - @app.route('/drop_control', methods=['GET', 'POST']) + @app.route("/drop_control", methods=["GET", "POST"]) def drop_control(): """ :return: """ - if request.method == 'POST': + if request.method == "POST": self.control.statestore.drop_control() - return redirect(url_for('control')) - return redirect(url_for('control')) + return redirect(url_for("control")) + return redirect(url_for("control")) # http://localhost:8090/control?rounds=4&model_id=879fa112-c861-4cb1-a25d-775153e5b548 - @app.route('/control', methods=['GET', 'POST']) + @app.route("/control", methods=["GET", "POST"]) def control(): - """ Main page for round control. Configure, start and stop training sessions. """ + """Main page for round control. Configure, start and stop training sessions.""" # Token auth if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) not_configured = self.check_configured() if not_configured: @@ -560,60 +628,88 @@ def control(): if self.remote_compute_package: try: - self.current_compute_context = self.control.get_compute_package_name() + self.current_compute_context = ( + self.control.get_compute_package_name() + ) except Exception: self.current_compute_context = None else: self.current_compute_context = "None:Local" if self.control.state() == ReducerState.monitoring: return redirect( - url_for('index', state=state, refresh=refresh, message="Reducer is in monitoring state")) - - if request.method == 'POST': + url_for( + "index", + state=state, + refresh=refresh, + message="Reducer is in monitoring state", + ) + ) + + if request.method == "POST": # Get session configuration - round_timeout = float(request.form.get('timeout', 180)) - buffer_size = int(request.form.get('buffer_size', -1)) - rounds = int(request.form.get('rounds', 1)) - delete_models = request.form.get('delete_models', True) - task = (request.form.get('task', '')) - clients_required = request.form.get('clients_required', 1) - clients_requested = request.form.get('clients_requested', 8) + round_timeout = float(request.form.get("timeout", 180)) + buffer_size = int(request.form.get("buffer_size", -1)) + rounds = int(request.form.get("rounds", 1)) + delete_models = request.form.get("delete_models", True) + task = request.form.get("task", "") + clients_required = request.form.get("clients_required", 1) + clients_requested = request.form.get("clients_requested", 8) # checking if there are enough clients connected to start! clients_available = 0 for combiner in self.control.network.get_combiners(): try: combiner_state = combiner.report() - nac = combiner_state['nr_active_clients'] + nac = combiner_state["nr_active_clients"] clients_available = clients_available + int(nac) except Exception: pass if clients_available < clients_required: - return redirect(url_for('index', state=state, - message="Not enough clients available to start rounds! " - "check combiner client capacity", - message_type='warning')) + return redirect( + url_for( + "index", + state=state, + message="Not enough clients available to start rounds! " + "check combiner client capacity", + message_type="warning", + ) + ) - validate = request.form.get('validate', False) - if validate == 'False': + validate = request.form.get("validate", False) + if validate == "False": validate = False - helper_type = request.form.get('helper', 'keras') + helper_type = request.form.get("helper", "keras") # self.control.statestore.set_framework(helper_type) latest_model_id = self.statestore.get_latest_model() - config = {'round_timeout': round_timeout, 'buffer_size': buffer_size, - 'model_id': latest_model_id, 'rounds': rounds, 'delete_models_storage': delete_models, - 'clients_required': clients_required, - 'clients_requested': clients_requested, 'task': task, - 'validate': validate, 'helper_type': helper_type} - - threading.Thread(target=self.control.session, - args=(config,)).start() + config = { + "round_timeout": round_timeout, + "buffer_size": buffer_size, + "model_id": latest_model_id, + "rounds": rounds, + "delete_models_storage": delete_models, + "clients_required": clients_required, + "clients_requested": clients_requested, + "task": task, + "validate": validate, + "helper_type": helper_type, + } + + threading.Thread( + target=self.control.session, args=(config,) + ).start() - return redirect(url_for('index', state=state, refresh=refresh, message="Sent execution plan.", - message_type='SUCCESS')) + return redirect( + url_for( + "index", + state=state, + refresh=refresh, + message="Sent execution plan.", + message_type="SUCCESS", + ) + ) else: seed_model_id = None @@ -624,42 +720,53 @@ def control(): except Exception: pass - return render_template('index.html', latest_model_id=latest_model_id, - compute_package=self.current_compute_context, - seed_model_id=seed_model_id, - helper=self.control.statestore.get_helper(), validate=True, configured=True) - - @app.route('/assign') + return render_template( + "index.html", + latest_model_id=latest_model_id, + compute_package=self.current_compute_context, + seed_model_id=seed_model_id, + helper=self.control.statestore.get_helper(), + validate=True, + configured=True, + ) + + @app.route("/assign") def assign(): - """Handle client assignment requests. """ + """Handle client assignment requests.""" if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) response = self.check_configured_response() if response: return response - name = request.args.get('name', None) - combiner_preferred = request.args.get('combiner', None) + name = request.args.get("name", None) + combiner_preferred = request.args.get("combiner", None) if combiner_preferred: - combiner = self.control.network.get_combiner(combiner_preferred) + combiner = self.control.network.get_combiner( + combiner_preferred + ) else: combiner = self.control.network.find_available_combiner() if combiner is None: - return jsonify({'status': 'retry', - 'package': self.package, - 'msg': "Failed to assign to a combiner, try again later."}) + return jsonify( + { + "status": "retry", + "package": self.package, + "msg": "Failed to assign to a combiner, try again later.", + } + ) client = { - 'name': name, - 'combiner_preferred': combiner_preferred, - 'combiner': combiner.name, - 'ip': request.remote_addr, - 'status': 'available' + "name": name, + "combiner_preferred": combiner_preferred, + "combiner": combiner.name, + "ip": request.remote_addr, + "status": "available", } # Add client to database @@ -668,25 +775,25 @@ def assign(): # Return connection information to client if combiner.certificate: cert_b64 = base64.b64encode(combiner.certificate) - cert = str(cert_b64).split('\'')[1] + cert = str(cert_b64).split("'")[1] else: cert = None response = { - 'status': 'assigned', - 'host': combiner.address, - 'fqdn': combiner.fqdn, - 'package': self.package, - 'ip': combiner.ip, - 'port': combiner.port, - 'certificate': cert, - 'model_type': self.control.statestore.get_helper() + "status": "assigned", + "host": combiner.address, + "fqdn": combiner.fqdn, + "package": self.package, + "ip": combiner.ip, + "port": combiner.port, + "certificate": cert, + "model_type": self.control.statestore.get_helper(), } return jsonify(response) def combiner_status(): - """ Get current status reports from all combiners registered in the network. + """Get current status reports from all combiners registered in the network. :return: """ @@ -711,67 +818,90 @@ def client_status(): all_active_validators = [] for client in combiner_info: - active_trainers_str = client['active_trainers'] - active_validators_str = client['active_validators'] + active_trainers_str = client["active_trainers"] + active_validators_str = client["active_validators"] active_trainers_str = re.sub( - '[^a-zA-Z0-9-:\n\.]', '', active_trainers_str).replace('name:', ' ') # noqa: W605 + "[^a-zA-Z0-9-:\n\.]", "", active_trainers_str # noqa: W605 + ).replace( + "name:", " " + ) active_validators_str = re.sub( - '[^a-zA-Z0-9-:\n\.]', '', active_validators_str).replace('name:', ' ') # noqa: W605 + "[^a-zA-Z0-9-:\n\.]", "", active_validators_str # noqa: W605 + ).replace( + "name:", " " + ) all_active_trainers.extend( - ' '.join(active_trainers_str.split(" ")).split()) + " ".join(active_trainers_str.split(" ")).split() + ) all_active_validators.extend( - ' '.join(active_validators_str.split(" ")).split()) + " ".join(active_validators_str.split(" ")).split() + ) active_trainers_list = [ - client for client in client_info if client['name'] in all_active_trainers] + client + for client in client_info + if client["name"] in all_active_trainers + ] active_validators_list = [ - cl for cl in client_info if cl['name'] in all_active_validators] + cl + for cl in client_info + if cl["name"] in all_active_validators + ] all_clients = [cl for cl in client_info] for client in all_clients: - status = 'offline' - role = 'None' + status = "offline" + role = "None" self.control.network.update_client_data( - client, status, role) + client, status, role + ) - all_active_clients = active_validators_list + active_trainers_list + all_active_clients = ( + active_validators_list + active_trainers_list + ) for client in all_active_clients: - status = 'active' - if client in active_trainers_list and client in active_validators_list: - role = 'trainer-validator' + status = "active" + if ( + client in active_trainers_list + and client in active_validators_list + ): + role = "trainer-validator" elif client in active_trainers_list: - role = 'trainer' + role = "trainer" elif client in active_validators_list: - role = 'validator' + role = "validator" else: - role = 'unknown' + role = "unknown" self.control.network.update_client_data( - client, status, role) - - return {'active_clients': all_clients, - 'active_trainers': active_trainers_list, - 'active_validators': active_validators_list - } + client, status, role + ) + + return { + "active_clients": all_clients, + "active_trainers": active_trainers_list, + "active_validators": active_validators_list, + } except Exception: pass - return {'active_clients': [], - 'active_trainers': [], - 'active_validators': [] - } + return { + "active_clients": [], + "active_trainers": [], + "active_validators": [], + } - @app.route('/metric_type', methods=['GET', 'POST']) + @app.route("/metric_type", methods=["GET", "POST"]) def change_features(): """ :return: """ - feature = request.args['selected'] + feature = request.args["selected"] plot = Plot(self.control.statestore) graphJSON = plot.create_box_plot(feature) return graphJSON - @app.route('/dashboard') + @app.route("/dashboard") def dashboard(): """ @@ -779,7 +909,7 @@ def dashboard(): """ # Token auth if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) not_configured = self.check_configured() if not_configured: @@ -793,16 +923,18 @@ def dashboard(): clients_plot = plot.create_client_plot() client_histogram_plot = plot.create_client_histogram_plot() - return render_template('dashboard.html', show_plot=True, - table_plot=table_plot, - timeline_plot=timeline_plot, - clients_plot=clients_plot, - client_histogram_plot=client_histogram_plot, - combiners_plot=combiners_plot, - configured=True - ) - - @app.route('/network') + return render_template( + "dashboard.html", + show_plot=True, + table_plot=table_plot, + timeline_plot=timeline_plot, + clients_plot=clients_plot, + client_histogram_plot=client_histogram_plot, + combiners_plot=combiners_plot, + configured=True, + ) + + @app.route("/network") def network(): """ @@ -810,7 +942,7 @@ def network(): """ # Token auth if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) not_configured = self.check_configured() if not_configured: @@ -821,17 +953,19 @@ def network(): combiner_info = combiner_status() active_clients = client_status() # print(combiner_info, flush=True) - return render_template('network.html', network_plot=True, - round_time_plot=round_time_plot, - mem_cpu_plot=mem_cpu_plot, - combiner_info=combiner_info, - active_clients=active_clients['active_clients'], - active_trainers=active_clients['active_trainers'], - active_validators=active_clients['active_validators'], - configured=True - ) - - @app.route('/config/download', methods=['GET']) + return render_template( + "network.html", + network_plot=True, + round_time_plot=round_time_plot, + mem_cpu_plot=mem_cpu_plot, + combiner_info=combiner_info, + active_clients=active_clients["active_clients"], + active_trainers=active_clients["active_trainers"], + active_validators=active_clients["active_validators"], + configured=True, + ) + + @app.route("/config/download", methods=["GET"]) def config_download(): """ @@ -839,8 +973,8 @@ def config_download(): """ chk_string = "" name = self.control.get_compute_package_name() - if name is None or name == '': - chk_string = '' + if name is None or name == "": + chk_string = "" else: file_path = os.path.join(UPLOAD_FOLDER, name) print("trying to get {}".format(file_path)) @@ -848,7 +982,7 @@ def config_download(): try: sum = str(sha(file_path)) except FileNotFoundError: - sum = '' + sum = "" chk_string = "checksum: {}".format(sum) network_id = self.network_id @@ -857,20 +991,24 @@ def config_download(): ctx = """network_id: {network_id} discover_host: {discover_host} discover_port: {discover_port} -{chk_string}""".format(network_id=network_id, - discover_host=discover_host, - discover_port=discover_port, - chk_string=chk_string) +{chk_string}""".format( + network_id=network_id, + discover_host=discover_host, + discover_port=discover_port, + chk_string=chk_string, + ) obj = BytesIO() - obj.write(ctx.encode('UTF-8')) + obj.write(ctx.encode("UTF-8")) obj.seek(0) - return send_file(obj, - as_attachment=True, - download_name='client.yaml', - mimetype='application/x-yaml') - - @app.route('/context', methods=['GET', 'POST']) + return send_file( + obj, + as_attachment=True, + download_name="client.yaml", + mimetype="application/x-yaml", + ) + + @app.route("/context", methods=["GET", "POST"]) def context(): """ @@ -878,78 +1016,85 @@ def context(): """ # Token auth if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) # if reset is not empty then allow context re-set - reset = request.args.get('reset', None) + reset = request.args.get("reset", None) if reset: - return render_template('context.html') + return render_template("context.html") - if request.method == 'POST': + if request.method == "POST": + if "file" not in request.files: + flash("No file part") + return redirect(url_for("context")) - if 'file' not in request.files: - flash('No file part') - return redirect(url_for('context')) - - file = request.files['file'] - helper_type = request.form.get('helper', 'kerashelper') + file = request.files["file"] + helper_type = request.form.get("helper", "kerashelper") # if user does not select file, browser also # submit an empty part without filename - if file.filename == '': - flash('No selected file') - return redirect(url_for('context')) + if file.filename == "": + flash("No selected file") + return redirect(url_for("context")) if file and allowed_file(file.filename): filename = secure_filename(file.filename) file_path = os.path.join( - app.config['UPLOAD_FOLDER'], filename) + app.config["UPLOAD_FOLDER"], filename + ) file.save(file_path) - if self.control.state() == ReducerState.instructing or self.control.state() == ReducerState.monitoring: + if ( + self.control.state() == ReducerState.instructing + or self.control.state() == ReducerState.monitoring + ): return "Not allowed to change context while execution is ongoing." self.control.set_compute_package(filename, file_path) self.control.statestore.set_helper(helper_type) - return redirect(url_for('control')) + return redirect(url_for("control")) - name = request.args.get('name', '') + name = request.args.get("name", "") - if name == '': + if name == "": name = self.control.get_compute_package_name() - if name is None or name == '': - return render_template('context.html') + if name is None or name == "": + return render_template("context.html") # There is a potential race condition here, if one client requests a package and at # the same time another one triggers a fetch from Minio and writes to disk. try: mutex = Lock() mutex.acquire() - return send_from_directory(app.config['UPLOAD_FOLDER'], name, as_attachment=True) + return send_from_directory( + app.config["UPLOAD_FOLDER"], name, as_attachment=True + ) except Exception: try: data = self.control.get_compute_package(name) - file_path = os.path.join(app.config['UPLOAD_FOLDER'], name) - with open(file_path, 'wb') as fh: + file_path = os.path.join(app.config["UPLOAD_FOLDER"], name) + with open(file_path, "wb") as fh: fh.write(data) - return send_from_directory(app.config['UPLOAD_FOLDER'], name, as_attachment=True) + return send_from_directory( + app.config["UPLOAD_FOLDER"], name, as_attachment=True + ) except Exception: raise finally: mutex.release() - return render_template('context.html') + return render_template("context.html") - @app.route('/checksum', methods=['GET', 'POST']) + @app.route("/checksum", methods=["GET", "POST"]) def checksum(): """ :return: """ # sum = '' - name = request.args.get('name', None) - if name == '' or name is None: + name = request.args.get("name", None) + if name == "" or name is None: name = self.control.get_compute_package_name() - if name is None or name == '': + if name is None or name == "": return jsonify({}) file_path = os.path.join(UPLOAD_FOLDER, name) @@ -958,13 +1103,13 @@ def checksum(): try: sum = str(sha(file_path)) except FileNotFoundError: - sum = '' + sum = "" - data = {'checksum': sum} + data = {"checksum": sum} return jsonify(data) - @app.route('/infer', methods=['POST']) + @app.route("/infer", methods=["POST"]) def infer(): """ @@ -972,7 +1117,7 @@ def infer(): """ # Token auth if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) # Check configured not_configured = self.check_configured() @@ -982,7 +1127,9 @@ def infer(): # Check compute context if self.remote_compute_context: try: - self.current_compute_context = self.control.get_compute_package() + self.current_compute_context = ( + self.control.get_compute_package() + ) except Exception as e: print(e, flush=True) self.current_compute_context = None @@ -992,27 +1139,43 @@ def infer(): # Redirect if in monitoring state if self.control.state() == ReducerState.monitoring: return redirect( - url_for('index', state=ReducerStateToString(self.control.state()), refresh=True, message="Reducer is in monitoring state")) + url_for( + "index", + state=ReducerStateToString(self.control.state()), + refresh=True, + message="Reducer is in monitoring state", + ) + ) # POST params - timeout = int(request.form.get('timeout', 180)) - helper_type = request.form.get('helper', 'keras') - clients_required = request.form.get('clients_required', 1) - clients_requested = request.form.get('clients_requested', 8) + timeout = int(request.form.get("timeout", 180)) + helper_type = request.form.get("helper", "keras") + clients_required = request.form.get("clients_required", 1) + clients_requested = request.form.get("clients_requested", 8) # Start inference request - config = {'round_timeout': timeout, - 'model_id': self.statestore.get_latest_model(), - 'clients_required': clients_required, - 'clients_requested': clients_requested, - 'task': 'inference', - 'helper_type': helper_type} - threading.Thread(target=self.control.infer_instruct, - args=(config,)).start() + config = { + "round_timeout": timeout, + "model_id": self.statestore.get_latest_model(), + "clients_required": clients_required, + "clients_requested": clients_requested, + "task": "inference", + "helper_type": helper_type, + } + threading.Thread( + target=self.control.infer_instruct, args=(config,) + ).start() # Redirect - return redirect(url_for('index', state=ReducerStateToString(self.control.state()), refresh=True, message="Sent execution plan (inference).", - message_type='SUCCESS')) + return redirect( + url_for( + "index", + state=ReducerStateToString(self.control.state()), + refresh=True, + message="Sent execution plan (inference).", + message_type="SUCCESS", + ) + ) if not self.host: bind = "0.0.0.0" diff --git a/fedn/fedn/network/dashboard/templates/events.html b/fedn/fedn/network/dashboard/templates/events.html index d3c34beb5..1fb5fac74 100644 --- a/fedn/fedn/network/dashboard/templates/events.html +++ b/fedn/fedn/network/dashboard/templates/events.html @@ -3,41 +3,44 @@ {% block content %} -
-
-
Events
-
-
- - - - + + + -
- -
+ }); + +
+
+ -{% endblock %} +{% endblock %} \ No newline at end of file diff --git a/fedn/fedn/network/statestore/mongostatestore.py b/fedn/fedn/network/statestore/mongostatestore.py index f991701d4..19d514f59 100644 --- a/fedn/fedn/network/statestore/mongostatestore.py +++ b/fedn/fedn/network/statestore/mongostatestore.py @@ -10,7 +10,7 @@ class MongoStateStore(StateStoreBase): - """ Statestore implementation using MongoDB. + """Statestore implementation using MongoDB. :param network_id: The network id. :type network_id: str @@ -21,7 +21,7 @@ class MongoStateStore(StateStoreBase): """ def __init__(self, network_id, config, model_storage_config): - """ Constructor.""" + """Constructor.""" self.__inited = False try: self.config = config @@ -29,19 +29,19 @@ def __init__(self, network_id, config, model_storage_config): self.mdb = connect_to_mongodb(self.config, self.network_id) # FEDn network - self.network = self.mdb['network'] - self.reducer = self.network['reducer'] - self.combiners = self.network['combiners'] - self.clients = self.network['clients'] - self.storage = self.network['storage'] + self.network = self.mdb["network"] + self.reducer = self.network["reducer"] + self.combiners = self.network["combiners"] + self.clients = self.network["clients"] + self.storage = self.network["storage"] # Control - self.control = self.mdb['control'] - self.package = self.control['package'] - self.state = self.control['state'] - self.model = self.control['model'] - self.sessions = self.control['sessions'] - self.rounds = self.control['rounds'] + self.control = self.mdb["control"] + self.package = self.control["package"] + self.state = self.control["state"] + self.model = self.control["model"] + self.sessions = self.control["sessions"] + self.rounds = self.control["rounds"] # Logging self.status = self.control["status"] @@ -62,7 +62,7 @@ def __init__(self, network_id, config, model_storage_config): self.__inited = True def is_inited(self): - """ Check if the statestore is intialized. + """Check if the statestore is intialized. :return: True if initialized, else False. :rtype: bool @@ -76,105 +76,160 @@ def get_config(self): :rtype: dict """ data = { - 'type': 'MongoDB', - 'mongo_config': self.config, - 'network_id': self.network_id + "type": "MongoDB", + "mongo_config": self.config, + "network_id": self.network_id, } return data def state(self): - """ Get the current state. + """Get the current state. :return: The current state. :rtype: str """ - return StringToReducerState(self.state.find_one()['current_state']) + return StringToReducerState(self.state.find_one()["current_state"]) def transition(self, state): - """ Transition to a new state. + """Transition to a new state. :param state: The new state. :type state: str :return: """ - old_state = self.state.find_one({'state': 'current_state'}) + old_state = self.state.find_one({"state": "current_state"}) if old_state != state: - return self.state.update_one({'state': 'current_state'}, {'$set': {'state': ReducerStateToString(state)}}, True) + return self.state.update_one( + {"state": "current_state"}, + {"$set": {"state": ReducerStateToString(state)}}, + True, + ) else: - print("Not updating state, already in {}".format( - ReducerStateToString(state))) + print( + "Not updating state, already in {}".format( + ReducerStateToString(state) + ) + ) + + def get_sessions(self, limit=None, skip=None, sort_key="_id", sort_order=pymongo.DESCENDING): + """Get all sessions. + + :param limit: The maximum number of sessions to return. + :type limit: int + :param skip: The number of sessions to skip. + :type skip: int + :param sort_key: The key to sort by. + :type sort_key: str + :param sort_order: The sort order. + :type sort_order: pymongo.ASCENDING or pymongo.DESCENDING + :return: Dictionary of sessions in result (array of session objects) and count. + """ - def get_sessions(self): - """ Get all sessions. + result = None - :return: All sessions. - :rtype: ObjectID - """ - return self.sessions.find() + if limit is not None and skip is not None: + limit = int(limit) + skip = int(skip) + + result = self.sessions.find().limit(limit).skip(skip).sort( + sort_key, sort_order + ) + else: + result = self.sessions.find().sort( + sort_key, sort_order + ) + + count = self.sessions.count_documents({}) + + return { + "result": result, + "count": count, + } def get_session(self, session_id): - """ Get session with id. + """Get session with id. :param session_id: The session id. :type session_id: str :return: The session. :rtype: ObjectID """ - return self.sessions.find_one({'session_id': session_id}) + return self.sessions.find_one({"session_id": session_id}) - def set_latest_model(self, model_id): - """ Set the latest model id. + def set_latest_model(self, model_id, session_id=None): + """Set the latest model id. :param model_id: The model id. :type model_id: str :return: """ - self.model.update_one({'key': 'current_model'}, { - '$set': {'model': model_id}}, True) - self.model.update_one({'key': 'model_trail'}, {'$push': {'model': model_id, 'committed_at': str(datetime.now())}}, - True) + committed_at = datetime.now() + + self.model.insert_one( + { + "key": "models", + "model": model_id, + "session_id": session_id, + "committed_at": committed_at, + } + ) + + self.model.update_one( + {"key": "current_model"}, {"$set": {"model": model_id}}, True + ) + self.model.update_one( + {"key": "model_trail"}, + { + "$push": { + "model": model_id, + "committed_at": str(committed_at), + } + }, + True, + ) def get_initial_model(self): - """ Return model_id for the initial model in the model trail + """Return model_id for the initial model in the model trail :return: The initial model id. None if no model is found. :rtype: str """ - result = self.model.find_one({'key': 'model_trail'}, sort=[ - ("committed_at", pymongo.ASCENDING)]) + result = self.model.find_one( + {"key": "model_trail"}, sort=[("committed_at", pymongo.ASCENDING)] + ) if result is None: return None try: - model_id = result['model'] - if model_id == '' or model_id == ' ': + model_id = result["model"] + if model_id == "" or model_id == " ": return None return model_id[0] except (KeyError, IndexError): return None def get_latest_model(self): - """ Return model_id for the latest model in the model_trail + """Return model_id for the latest model in the model_trail :return: The latest model id. None if no model is found. :rtype: str """ - result = self.model.find_one({'key': 'current_model'}) + result = self.model.find_one({"key": "current_model"}) if result is None: return None try: - model_id = result['model'] - if model_id == '' or model_id == ' ': + model_id = result["model"] + if model_id == "" or model_id == " ": return None return model_id except (KeyError, IndexError): return None def get_latest_round(self): - """ Get the id of the most recent round. + """Get the id of the most recent round. :return: The id of the most recent round. :rtype: ObjectId @@ -183,7 +238,7 @@ def get_latest_round(self): return self.rounds.find_one(sort=[("_id", pymongo.DESCENDING)]) def get_round(self, id): - """ Get round with id. + """Get round with id. :param id: id of round to get :type id: int @@ -191,10 +246,10 @@ def get_round(self, id): :rtype: ObjectId """ - return self.rounds.find_one({'round_id': str(id)}) + return self.rounds.find_one({"round_id": str(id)}) def get_rounds(self): - """ Get all rounds. + """Get all rounds. :return: All rounds. :rtype: ObjectId @@ -203,7 +258,7 @@ def get_rounds(self): return self.rounds.find() def get_validations(self, **kwargs): - """ Get validations from the database. + """Get validations from the database. :param kwargs: query to filter validations :type kwargs: dict @@ -215,7 +270,7 @@ def get_validations(self, **kwargs): return result def set_compute_package(self, filename): - """ Set the active compute package in statestore. + """Set the active compute package in statestore. :param filename: The filename of the compute package. :type filename: str @@ -223,66 +278,139 @@ def set_compute_package(self, filename): :rtype: bool """ self.control.package.update_one( - {'key': 'active'}, {'$set': {'filename': filename, 'committed_at': str(datetime.now())}}, True) - self.control.package.update_one({'key': 'package_trail'}, - {'$push': {'filename': filename, 'committed_at': str(datetime.now())}}, True) + {"key": "active"}, + { + "$set": { + "filename": filename, + "committed_at": str(datetime.now()), + } + }, + True, + ) + self.control.package.update_one( + {"key": "package_trail"}, + { + "$push": { + "filename": filename, + "committed_at": str(datetime.now()), + } + }, + True, + ) return True def get_compute_package(self): - """ Get the active compute package. + """Get the active compute package. :return: The active compute package. :rtype: ObjectID """ - ret = self.control.package.find({'key': 'active'}) + ret = self.control.package.find({"key": "active"}) try: retcheck = ret[0] - if retcheck is None or retcheck == '' or retcheck == ' ': # ugly check for empty string + if ( + retcheck is None or retcheck == "" or retcheck == " " + ): # ugly check for empty string return None return retcheck except (KeyError, IndexError): return None def set_helper(self, helper): - """ Set the active helper package in statestore. + """Set the active helper package in statestore. :param helper: The name of the helper package. See helper.py for available helpers. :type helper: str :return: """ - self.control.package.update_one({'key': 'active'}, - {'$set': {'helper': helper}}, True) + self.control.package.update_one( + {"key": "active"}, {"$set": {"helper": helper}}, True + ) def get_helper(self): - """ Get the active helper package. + """Get the active helper package. :return: The active helper set for the package. :rtype: str """ - ret = self.control.package.find_one({'key': 'active'}) + ret = self.control.package.find_one({"key": "active"}) # if local compute package used, then 'package' is None # if not ret: # get framework from round_config instead # ret = self.control.config.find_one({'key': 'round_config'}) try: - retcheck = ret['helper'] - if retcheck == '' or retcheck == ' ': # ugly check for empty string + retcheck = ret["helper"] + if ( + retcheck == "" or retcheck == " " + ): # ugly check for empty string return None return retcheck except (KeyError, IndexError): return None + def list_models( + self, + session_id=None, + limit=None, + skip=None, + sort_key="committed_at", + sort_order=pymongo.DESCENDING, + ): + """List all models in the statestore. + + :param session_id: The session id. + :type session_id: str + :param limit: The maximum number of models to return. + :type limit: int + :param skip: The number of models to skip. + :type skip: int + :return: List of models. + :rtype: list + """ + result = None + + find_option = ( + {"key": "models"} + if session_id is None + else {"key": "models", "session_id": session_id} + ) + + projection = {"_id": False, "key": False} + + if limit is not None and skip is not None: + limit = int(limit) + skip = int(skip) + + result = ( + self.model.find(find_option, projection) + .limit(limit) + .skip(skip) + .sort(sort_key, sort_order) + ) + + else: + result = self.model.find(find_option, projection).sort( + sort_key, sort_order + ) + + count = self.model.count_documents(find_option) + + return { + "result": result, + "count": count, + } + def get_model_trail(self): - """ Get the model trail. + """Get the model trail. :return: dictionary of model_id: committed_at :rtype: dict """ - result = self.model.find_one({'key': 'model_trail'}) + result = self.model.find_one({"key": "model_trail"}) try: if result is not None: - committed_at = result['committed_at'] - model = result['model'] + committed_at = result["committed_at"] + model = result["model"] model_dictionary = dict(zip(model, committed_at)) return model_dictionary else: @@ -291,7 +419,7 @@ def get_model_trail(self): return None def get_events(self, **kwargs): - """ Get events from the database. + """Get events from the database. :param kwargs: query to filter events :type kwargs: dict @@ -299,51 +427,83 @@ def get_events(self, **kwargs): :rtype: ObjectId """ # check if kwargs is empty + + result = None + count = None + projection = {"_id": False} + if not kwargs: - return self.control.status.find() + result = self.control.status.find({}, projection).sort( + "timestamp", pymongo.DESCENDING + ) + count = self.control.status.count_documents({}) else: - result = self.control.status.find(kwargs) - return result + limit = kwargs.pop("limit", None) + skip = kwargs.pop("skip", None) + + if limit is not None and skip is not None: + limit = int(limit) + skip = int(skip) + result = ( + self.control.status.find(kwargs, projection) + .sort("timestamp", pymongo.DESCENDING) + .limit(limit) + .skip(skip) + ) + else: + result = self.control.status.find(kwargs, projection).sort( + "timestamp", pymongo.DESCENDING + ) + + count = self.control.status.count_documents(kwargs) + + return { + "result": result, + "count": count, + } def get_storage_backend(self): - """ Get the storage backend. + """Get the storage backend. :return: The storage backend. :rtype: ObjectID """ try: ret = self.storage.find( - {'status': 'enabled'}, projection={'_id': False}) + {"status": "enabled"}, projection={"_id": False} + ) return ret[0] except (KeyError, IndexError): return None def set_storage_backend(self, config): - """ Set the storage backend. + """Set the storage backend. :param config: The storage backend configuration. :type config: dict :return: """ config = copy.deepcopy(config) - config['updated_at'] = str(datetime.now()) - config['status'] = 'enabled' + config["updated_at"] = str(datetime.now()) + config["status"] = "enabled" self.storage.update_one( - {'storage_type': config['storage_type']}, {'$set': config}, True) + {"storage_type": config["storage_type"]}, {"$set": config}, True + ) def set_reducer(self, reducer_data): - """ Set the reducer in the statestore. + """Set the reducer in the statestore. :param reducer_data: dictionary of reducer config. :type reducer_data: dict :return: """ - reducer_data['updated_at'] = str(datetime.now()) - self.reducer.update_one({'name': reducer_data['name']}, { - '$set': reducer_data}, True) + reducer_data["updated_at"] = str(datetime.now()) + self.reducer.update_one( + {"name": reducer_data["name"]}, {"$set": reducer_data}, True + ) def get_reducer(self): - """ Get reducer.config. + """Get reducer.config. return: reducer config. rtype: ObjectId @@ -355,67 +515,99 @@ def get_reducer(self): return None def get_combiner(self, name): - """ Get combiner by name. + """Get combiner by name. + :param name: name of combiner to get. + :type name: str :return: The combiner. :rtype: ObjectId """ try: - ret = self.combiners.find_one({'name': name}) + ret = self.combiners.find_one({"name": name}) return ret except Exception: return None - def get_combiners(self): - """ Get all combiners. - - :return: list of combiners. - :rtype: list + def get_combiners(self, limit=None, skip=None, sort_key="updated_at", sort_order=pymongo.DESCENDING, projection={}): + """Get all combiners. + + :param limit: The maximum number of combiners to return. + :type limit: int + :param skip: The number of combiners to skip. + :type skip: int + :param sort_key: The key to sort by. + :type sort_key: str + :param sort_order: The sort order. + :type sort_order: pymongo.ASCENDING or pymongo.DESCENDING + :param projection: The projection. + :type projection: dict + :return: Dictionary of combiners in result and count. + :rtype: dict """ + + result = None + count = None + try: - ret = self.combiners.find() - return list(ret) + if limit is not None and skip is not None: + limit = int(limit) + skip = int(skip) + result = self.combiners.find({}, projection).limit(limit).skip(skip).sort(sort_key, sort_order) + else: + result = self.combiners.find({}, projection).sort(sort_key, sort_order) + + count = self.combiners.count_documents({}) + except Exception: return None + return { + "result": result, + "count": count, + } + def set_combiner(self, combiner_data): - """ Set combiner in statestore. + """Set combiner in statestore. :param combiner_data: dictionary of combiner config :type combiner_data: dict :return: """ - combiner_data['updated_at'] = str(datetime.now()) - self.combiners.update_one({'name': combiner_data['name']}, { - '$set': combiner_data}, True) + combiner_data["updated_at"] = str(datetime.now()) + self.combiners.update_one( + {"name": combiner_data["name"]}, {"$set": combiner_data}, True + ) def delete_combiner(self, combiner): - """ Delete a combiner from statestore. + """Delete a combiner from statestore. :param combiner: name of combiner to delete. :type combiner: str :return: """ try: - self.combiners.delete_one({'name': combiner}) + self.combiners.delete_one({"name": combiner}) except Exception: - print("WARNING, failed to delete combiner: {}".format( - combiner), flush=True) + print( + "WARNING, failed to delete combiner: {}".format(combiner), + flush=True, + ) def set_client(self, client_data): - """ Set client in statestore. + """Set client in statestore. :param client_data: dictionary of client config. :type client_data: dict :return: """ - client_data['updated_at'] = str(datetime.now()) - self.clients.update_one({'name': client_data['name']}, { - '$set': client_data}, True) + client_data["updated_at"] = str(datetime.now()) + self.clients.update_one( + {"name": client_data["name"]}, {"$set": client_data}, True + ) def get_client(self, name): - """ Get client by name. + """Get client by name. :param name: name of client to get. :type name: str @@ -423,7 +615,7 @@ def get_client(self, name): :rtype: ObjectId """ try: - ret = self.clients.find({'key': name}) + ret = self.clients.find({"key": name}) if list(ret) == []: return None else: @@ -431,20 +623,77 @@ def get_client(self, name): except Exception: return None - def list_clients(self): + def list_clients(self, limit=None, skip=None, status=None, sort_key="last_seen", sort_order=pymongo.DESCENDING): """List all clients registered on the network. - :return: list of clients. + :param limit: The maximum number of clients to return. + :type limit: int + :param skip: The number of clients to skip. + :type skip: int + :param status: online | offline + :type status: str + :param sort_key: The key to sort by. + """ + + result = None + count = None + + try: + find = {} if status is None else {"status": status} + projection = {"_id": False, "updated_at": False} + + if limit is not None and skip is not None: + limit = int(limit) + skip = int(skip) + result = self.clients.find(find, projection).limit(limit).skip(skip).sort(sort_key, sort_order) + else: + result = self.clients.find(find, projection).sort(sort_key, sort_order) + + count = self.clients.count_documents(find) + + except Exception as e: + print("ERROR: {}".format(e), flush=True) + + return { + "result": result, + "count": count, + } + + def list_combiners_data(self, combiners, sort_key="count", sort_order=pymongo.DESCENDING): + """List all combiner data. + + :param combiners: list of combiners to get data for. + :type combiners: list + :param sort_key: The key to sort by. + :type sort_key: str + :param sort_order: The sort order. + :type sort_order: pymongo.ASCENDING or pymongo.DESCENDING + :return: list of combiner data. :rtype: list(ObjectId) """ + + result = None + try: - ret = self.clients.find() - return list(ret) - except Exception: - return None + + pipeline = [ + {"$match": {"combiner": {"$in": combiners}, "status": "online"}}, + {"$group": {"_id": "$combiner", "count": {"$sum": 1}}}, + {"$sort": {sort_key: sort_order, "_id": pymongo.ASCENDING}} + ] if combiners is not None else [ + {"$group": {"_id": "$combiner", "count": {"$sum": 1}}}, + {"$sort": {sort_key: sort_order, "_id": pymongo.ASCENDING}} + ] + + result = self.clients.aggregate(pipeline) + + except Exception as e: + print("ERROR: {}".format(e), flush=True) + + return result def update_client_status(self, client_data, status, role): - """ Set or update client status. + """Set or update client status. :param client_data: dictionary of client config. :type client_data: dict @@ -454,10 +703,7 @@ def update_client_status(self, client_data, status, role): :type role: str :return: """ - self.clients.update_one({"name": client_data['name']}, - {"$set": - { - "status": status, - "role": role - } - }) + self.clients.update_one( + {"name": client_data["name"]}, + {"$set": {"status": status, "role": role}}, + ) diff --git a/fedn/setup.py b/fedn/setup.py index f7c9fa0bb..4a149d184 100644 --- a/fedn/setup.py +++ b/fedn/setup.py @@ -2,7 +2,7 @@ setup( name='fedn', - version='0.5.0-dev', + version='0.5.0', description="""Scaleout Federated Learning""", author='Scaleout Systems AB', author_email='contact@scaleoutsystems.com',