diff --git a/fedn/network/api/interface.py b/fedn/network/api/interface.py index 6031051c1..31b54c205 100644 --- a/fedn/network/api/interface.py +++ b/fedn/network/api/interface.py @@ -120,243 +120,6 @@ def get_combiner(self, combiner_id): return jsonify(payload) - def get_all_sessions(self, limit=None, skip=None): - """Get all sessions from the statestore. - - :return: All sessions as a json response. - :rtype: :class:`flask.Response` - """ - sessions_object = self.statestore.get_sessions(limit, skip) - if sessions_object is None: - return ( - jsonify({"success": False, "message": "No sessions found."}), - 404, - ) - arr = [] - for element in sessions_object["result"]: - obj = element["session_config"][0] - arr.append(obj) - - result = {"result": arr, "count": sessions_object["count"]} - - return jsonify(result) - - def get_session(self, session_id): - """Get a session from the statestore. - - :param session_id: The session id to get. - :type session_id: str - :return: The session info dict as a json response. - :rtype: :class:`flask.Response` - """ - session_object = self.statestore.get_session(session_id) - if session_object is None: - return ( - jsonify( - { - "success": False, - "message": f"Session {session_id} not found.", - } - ), - 404, - ) - payload = {} - id = session_object["session_id"] - info = session_object["session_config"][0] - status = session_object["status"] - payload[id] = info - payload["status"] = status - return jsonify(payload) - - def set_active_compute_package(self, id: str): - success = self.statestore.set_active_compute_package(id) - - if not success: - return ( - jsonify( - { - "success": False, - "message": "Failed to set compute package.", - } - ), - 400, - ) - - return jsonify({"success": True, "message": "Compute package set."}) - - def set_compute_package(self, file, helper_type: str, name: str = None, description: str = None): - """Set the compute package in the statestore. - - :param file: The compute package to set. - :type file: file - :return: A json response with success or failure message. - :rtype: :class:`flask.Response` - """ - if self.control.state() == ReducerState.instructing or self.control.state() == ReducerState.monitoring: - return ( - jsonify( - { - "success": False, - "message": "Reducer is in instructing or monitoring state." "Cannot set compute package.", - } - ), - 400, - ) - - if file is None: - return ( - jsonify( - { - "success": False, - "message": "No file provided.", - } - ), - 404, - ) - - success, extension = self._allowed_file_extension(file.filename) - - if not success: - return ( - jsonify( - { - "success": False, - "message": f"File extension {extension} not allowed.", - } - ), - 404, - ) - - file_name = file.filename - storage_file_name = secure_filename(f"{str(uuid.uuid4())}.{extension}") - - file_path = safe_join(FEDN_COMPUTE_PACKAGE_DIR, storage_file_name) - file.save(file_path) - - self.control.set_compute_package(storage_file_name, file_path) - success = self.statestore.set_compute_package(file_name, storage_file_name, helper_type, name, description) - - if not success: - return ( - jsonify( - { - "success": False, - "message": "Failed to set compute package.", - } - ), - 400, - ) - # Delete the file after it has been saved - os.remove(file_path) - return jsonify({"success": True, "message": "Compute package set."}) - - def _get_compute_package_name(self): - """Get the compute package name from the statestore. - - :return: The compute package name. - :rtype: str - """ - package_objects = self.statestore.get_compute_package() - if package_objects is None: - message = "No compute package found." - return None, message - else: - try: - name = package_objects["storage_file_name"] - except KeyError as e: - message = "No compute package found. Key error." - logger.debug(e) - return None, message - return name, "success" - - def get_compute_package(self): - """Get the compute package from the statestore. - - :return: The compute package as a json response. - :rtype: :class:`flask.Response` - """ - result = self.statestore.get_compute_package() - if result is None or "file_name" not in result: - return ( - jsonify({"success": False, "message": "No compute package found."}), - 404, - ) - - obj = { - "id": result["id"] if "id" in result else "", - "file_name": result["file_name"], - "helper": result["helper"], - "committed_at": result["committed_at"], - "storage_file_name": result["storage_file_name"] if "storage_file_name" in result else "", - "name": result["name"] if "name" in result else "", - "description": result["description"] if "description" in result else "", - } - - return jsonify(obj) - - def list_compute_packages(self, limit: str = None, skip: str = None, include_active: str = None): - """Get paginated list of compute packages from the statestore. - :param limit: The number of compute packages to return. - :type limit: str - :param skip: The number of compute packages to skip. - :type skip: str - :param include_active: Whether to include the active compute package or not. - :type include_active: str - :return: All compute packages as a json response. - :rtype: :class:`flask.Response` - """ - if limit is not None and skip is not None: - limit = int(limit) - skip = int(skip) - - include_active: bool = include_active == "true" - - result = self.statestore.list_compute_packages(limit, skip) - if result is None: - return ( - jsonify({"success": False, "message": "No compute packages found."}), - 404, - ) - - active_package_id: str = None - - if include_active: - active_package = self.statestore.get_compute_package() - - if active_package is not None: - active_package_id = active_package["id"] if "id" in active_package else "" - - if include_active: - arr = [ - { - "id": element["id"] if "id" in element else "", - "file_name": element["file_name"], - "helper": element["helper"], - "committed_at": element["committed_at"], - "storage_file_name": element["storage_file_name"] if "storage_file_name" in element else "", - "name": element["name"] if "name" in element else "", - "description": element["description"] if "description" in element else "", - "active": "id" in element and element["id"] == active_package_id, - } - for element in result["result"] - ] - else: - arr = [ - { - "id": element["id"] if "id" in element else "", - "file_name": element["file_name"], - "helper": element["helper"], - "committed_at": element["committed_at"], - "storage_file_name": element["storage_file_name"] if "storage_file_name" in element else "", - "name": element["name"] if "name" in element else "", - "description": element["description"] if "description" in element else "", - } - for element in result["result"] - ] - - result = {"result": arr, "count": result["count"]} - return jsonify(result) - def download_compute_package(self, name): """Download the compute package. @@ -386,6 +149,25 @@ def download_compute_package(self, name): finally: mutex.release() + def _get_compute_package_name(self): + """Get the compute package name from the statestore. + + :return: The compute package name. + :rtype: str + """ + package_objects = self.statestore.get_compute_package() + if package_objects is None: + message = "No compute package found." + return None, message + else: + try: + name = package_objects["storage_file_name"] + except KeyError as e: + message = "No compute package found. Key error." + logger.debug(e) + return None, message + return name, "success" + def _create_checksum(self, name=None): """Create the checksum of the compute package. @@ -429,59 +211,6 @@ def get_controller_status(self): """ return jsonify({"state": ReducerStateToString(self.control.state())}) - def get_events(self, **kwargs): - """Get the events of the federated network. - - :return: The events as a json object. - :rtype: :py:class:`flask.Response` - """ - response = self.statestore.get_events(**kwargs) - - result = response["result"] - if result is None: - return ( - jsonify({"success": False, "message": "No events found."}), - 404, - ) - - events = [] - for evt in result: - events.append(evt) - - return jsonify({"result": events, "count": response["count"]}) - - def get_all_validations(self, **kwargs): - """Get all validations from the statestore. - - :return: All validations as a json response. - :rtype: :class:`flask.Response` - """ - validations_objects = self.statestore.get_validations(**kwargs) - if validations_objects is None: - return ( - jsonify( - { - "success": False, - "message": "No validations found.", - "filter_used": kwargs, - } - ), - 404, - ) - payload = {} - for object in validations_objects: - id = str(object["_id"]) - info = { - "model_id": object["modelId"], - "data": object["data"], - "timestamp": object["timestamp"], - "meta": object["meta"], - "sender": object["sender"], - "receiver": object["receiver"], - } - payload[id] = info - return jsonify(payload) - def add_combiner(self, combiner_id, secure_grpc, address, remote_addr, fqdn, port): """Add a combiner to the network. @@ -533,7 +262,7 @@ def add_client(self, client_id, preferred_combiner, remote_addr, name, package): ), 203, ) - helper_type = self.control.statestore.get_helper() + helper_type = package_object["helper"] else: # Else package is "local": helper_type = "" @@ -582,16 +311,6 @@ def add_client(self, client_id, preferred_combiner, remote_addr, name, package): } return jsonify(payload) - def get_initial_model(self): - """Get the initial model from the statestore. - - :return: The initial model as a json response. - :rtype: :class:`flask.Response` - """ - model_id = self.statestore.get_initial_model() - payload = {"model_id": model_id} - return jsonify(payload) - def set_initial_model(self, file): """Add an initial model to the network. @@ -627,79 +346,6 @@ def set_initial_model(self, file): return jsonify({"success": True, "message": "Initial model added successfully."}) - def get_latest_model(self): - """Get the latest model from the statestore. - - :return: The initial model as a json response. - :rtype: :class:`flask.Response` - """ - if self.statestore.get_latest_model(): - model_id = self.statestore.get_latest_model() - payload = {"model_id": model_id} - return jsonify(payload) - else: - return jsonify({"success": False, "message": "No initial model set."}) - - def set_current_model(self, model_id: str): - """Set the active model in the statestore. - - :param model_id: The model id to set. - :type model_id: str - :return: A json response with success or failure message. - :rtype: :class:`flask.Response` - """ - success = self.statestore.set_current_model(model_id) - - if not success: - return ( - jsonify( - { - "success": False, - "message": "Failed to set active model.", - } - ), - 400, - ) - - return jsonify({"success": True, "message": "Active model set."}) - - def get_models(self, session_id: str = None, limit: str = None, skip: str = None, include_active: str = None): - result = self.statestore.list_models(session_id, limit, skip) - - if result is None: - return ( - jsonify({"success": False, "message": "No models found."}), - 404, - ) - - include_active: bool = include_active == "true" - - if include_active: - latest_model = self.statestore.get_latest_model() - - arr = [ - { - "committed_at": element["committed_at"], - "model": element["model"], - "session_id": element["session_id"], - "active": element["model"] == latest_model, - } - for element in result["result"] - ] - else: - arr = [ - { - "committed_at": element["committed_at"], - "model": element["model"], - "session_id": element["session_id"], - } - for element in result["result"] - ] - - result = {"result": arr, "count": result["count"]} - - return jsonify(result) - def get_model(self, model_id: str): result = self.statestore.get_model(model_id) @@ -718,134 +364,6 @@ def get_model(self, model_id: str): return jsonify(payload) - def get_model_trail(self): - """Get the model trail for a given session. - - :param session: The session id to get the model trail for. - :type session: str - :return: The model trail for the given session as a json response. - :rtype: :class:`flask.Response` - """ - model_info = self.statestore.get_model_trail() - if model_info: - return jsonify(model_info) - else: - return jsonify({"success": False, "message": "No model trail available."}) - - def get_model_ancestors(self, model_id: str, limit: str = None): - """Get the model ancestors for a given model. - - :param model_id: The model id to get the model ancestors for. - :type model_id: str - :param limit: The number of ancestors to return. - :type limit: str - :return: The model ancestors for the given model as a json response. - :rtype: :class:`flask.Response` - """ - if model_id is None: - return jsonify({"success": False, "message": "No model id provided."}) - - limit: int = int(limit) if limit is not None else 10 # if limit is None, default to 10 - - response = self.statestore.get_model_ancestors(model_id, limit) - if response: - arr: list = [] - - for element in response: - obj = { - "model": element["model"], - "committed_at": element["committed_at"], - "session_id": element["session_id"], - "parent_model": element["parent_model"], - } - arr.append(obj) - - result = {"result": arr} - - return jsonify(result) - else: - return jsonify({"success": False, "message": "No model ancestors available."}) - - def get_model_descendants(self, model_id: str, limit: str = None): - """Get the model descendants for a given model. - - :param model_id: The model id to get the model descendants for. - :type model_id: str - :param limit: The number of descendants to return. - :type limit: str - :return: The model descendants for the given model as a json response. - :rtype: :class:`flask.Response` - """ - if model_id is None: - return jsonify({"success": False, "message": "No model id provided."}) - - limit: int = int(limit) if limit is not None else 10 - - response: list = self.statestore.get_model_descendants(model_id, limit) - - if response: - arr: list = [] - - for element in response: - obj = { - "model": element["model"], - "committed_at": element["committed_at"], - "session_id": element["session_id"], - "parent_model": element["parent_model"], - } - arr.append(obj) - - result = {"result": arr} - - return jsonify(result) - else: - return jsonify({"success": False, "message": "No model descendants available."}) - - def get_all_rounds(self): - """Get all rounds. - - :return: The rounds as json response. - :rtype: :class:`flask.Response` - """ - rounds_objects = self.statestore.get_rounds() - if rounds_objects is None: - jsonify({"success": False, "message": "No rounds available."}) - payload = {} - for object in rounds_objects: - id = object["round_id"] - if "reducer" in object.keys(): - reducer = object["reducer"] - else: - reducer = None - if "combiners" in object.keys(): - combiners = object["combiners"] - else: - combiners = None - - info = { - "reducer": reducer, - "combiners": combiners, - } - payload[id] = info - return jsonify(payload) - - def get_round(self, round_id): - """Get a round. - - :param round_id: The round id to get. - :type round_id: str - :return: The round as json response. - :rtype: :class:`flask.Response` - """ - round_object = self.statestore.get_round(round_id) - if round_object is None: - return jsonify({"success": False, "message": "Round not found."}) - payload = { - "round_id": round_object["round_id"], - "combiners": round_object["combiners"], - } - return jsonify(payload) - def get_client_config(self, checksum=True): """Get the client config. @@ -891,135 +409,3 @@ def list_combiners_data(self, combiners): result = {"result": arr} return jsonify(result) - - def start_session( - self, - session_id, - aggregator="fedavg", - aggregator_kwargs=None, - model_id=None, - rounds=5, - round_timeout=180, - round_buffer_size=-1, - delete_models=True, - validate=True, - helper="", - min_clients=1, - requested_clients=8, - server_functions=None, - ): - """Start a session. - - :param session_id: The session id to start. - :type session_id: str - :param aggregator: The aggregator plugin to use. - :type aggregator: str - :param initial_model: The initial model for the session. - :type initial_model: str - :param rounds: The number of rounds to perform. - :type rounds: int - :param round_timeout: The round timeout to use in seconds. - :type round_timeout: int - :param round_buffer_size: The round buffer size to use. - :type round_buffer_size: int - :param delete_models: Whether to delete models after each round at combiner (save storage). - :type delete_models: bool - :param validate: Whether to validate the model after each round. - :type validate: bool - :param min_clients: The minimum number of clients required. - :type min_clients: int - :param requested_clients: The requested number of clients. - :type requested_clients: int - :return: A json response with success or failure message and session config. - :rtype: :class:`flask.Response` - """ - # Check if session already exists - session = self.statestore.get_session(session_id) - if session: - return jsonify({"success": False, "message": "Session already exists."}) - - # Check if session is running - if self.control.state() == ReducerState.monitoring: - return jsonify({"success": False, "message": "A session is already running."}) - - # Check if compute package is set - package = self.statestore.get_compute_package() - if not package: - return jsonify( - { - "success": False, - "message": "No compute package set. Set compute package before starting session.", - } - ) - if not helper: - # get helper from compute package - helper = package["helper"] - - # Check that initial (seed) model is set - if not self.statestore.get_initial_model(): - return jsonify( - { - "success": False, - "message": "No initial model set. Set initial model before starting session.", - } - ) - - # Check available clients per combiner - clients_available = 0 - for combiner in self.control.network.get_combiners(): - try: - nr_active_clients = len(combiner.list_active_clients()) - clients_available = clients_available + int(nr_active_clients) - except CombinerUnavailableError as e: - # TODO: Handle unavailable combiner, stop session or continue? - logger.error("COMBINER UNAVAILABLE: {}".format(e)) - continue - - if clients_available < min_clients: - return jsonify( - { - "success": False, - "message": "Not enough clients available to start session.", - } - ) - - # Check if validate is string and convert to bool - if isinstance(validate, str): - if validate.lower() == "true": - validate = True - else: - validate = False - - # Get lastest model as initial model for session - if not model_id: - model_id = self.statestore.get_latest_model() - - # Setup session config - session_config = { - "session_id": session_id if session_id else str(uuid.uuid4()), - "aggregator": aggregator, - "aggregator_kwargs": aggregator_kwargs, - "round_timeout": round_timeout, - "buffer_size": round_buffer_size, - "model_id": model_id, - "rounds": rounds, - "delete_models_storage": delete_models, - "clients_required": min_clients, - "clients_requested": requested_clients, - "task": (""), - "validate": validate, - "helper_type": helper, - "server_functions": server_functions, - } - - # Start session - threading.Thread(target=self.control.session, args=(session_config,)).start() - - # Return success response - return jsonify( - { - "success": True, - "message": "Session started successfully.", - "config": session_config, - } - ) diff --git a/fedn/network/api/server.py b/fedn/network/api/server.py index 5055653a6..001d38fba 100644 --- a/fedn/network/api/server.py +++ b/fedn/network/api/server.py @@ -11,7 +11,6 @@ from fedn.network.api.v1.graphql.schema import schema custom_url_prefix = os.environ.get("FEDN_CUSTOM_URL_PREFIX", False) -# statestore_config,modelstorage_config,network_id,control=set_statestore_config() api = API(statestore, control) app = Flask(__name__) for bp in _routes: @@ -55,89 +54,6 @@ def graphql_endpoint(): app.add_url_rule(f"{custom_url_prefix}/api/v1/graphql", view_func=graphql_endpoint, methods=["POST"]) -@app.route("/get_model_trail", methods=["GET"]) -@jwt_auth_required(role="admin") -def get_model_trail(): - """Get the model trail for a given session. - param: session: The session id to get the model trail for. - type: session: str - return: The model trail for the given session as a json object. - rtype: json - """ - return api.get_model_trail() - - -if custom_url_prefix: - app.add_url_rule(f"{custom_url_prefix}/get_model_trail", view_func=get_model_trail, methods=["GET"]) - - -@app.route("/get_model_ancestors", methods=["GET"]) -@jwt_auth_required(role="admin") -def get_model_ancestors(): - """Get the ancestors of a model. - param: model: The model id to get the ancestors for. - type: model: str - param: limit: The maximum number of ancestors to return. - type: limit: int - return: A list of model objects that the model derives from. - rtype: json - """ - model = request.args.get("model", None) - limit = request.args.get("limit", None) - - return api.get_model_ancestors(model, limit) - - -if custom_url_prefix: - app.add_url_rule(f"{custom_url_prefix}/get_model_ancestors", view_func=get_model_ancestors, methods=["GET"]) - - -@app.route("/get_model_descendants", methods=["GET"]) -@jwt_auth_required(role="admin") -def get_model_descendants(): - """Get the ancestors of a model. - param: model: The model id to get the child for. - type: model: str - param: limit: The maximum number of descendants to return. - type: limit: int - return: A list of model objects that are descendents of the provided model id. - rtype: json - """ - model = request.args.get("model", None) - limit = request.args.get("limit", None) - - return api.get_model_descendants(model, limit) - - -if custom_url_prefix: - app.add_url_rule(f"{custom_url_prefix}/get_model_descendants", view_func=get_model_descendants, methods=["GET"]) - - -@app.route("/list_models", methods=["GET"]) -@jwt_auth_required(role="admin") -def list_models(): - """Get models from the statestore. - param: - session_id: The session id to get the model trail for. - limit: The maximum number of models to return. - type: limit: int - param: skip: The number of models to skip. - type: skip: int - Returns: - _type_: json - """ - session_id = request.args.get("session_id", None) - limit = request.args.get("limit", None) - skip = request.args.get("skip", None) - include_active = request.args.get("include_active", None) - - return api.get_models(session_id, limit, skip, include_active) - - -if custom_url_prefix: - app.add_url_rule(f"{custom_url_prefix}/list_models", view_func=list_models, methods=["GET"]) - - @app.route("/get_model", methods=["GET"]) @jwt_auth_required(role="admin") def get_model(): @@ -253,139 +169,6 @@ def get_combiner(): app.add_url_rule(f"{custom_url_prefix}/get_combiner", view_func=get_combiner, methods=["GET"]) -@app.route("/list_rounds", methods=["GET"]) -@jwt_auth_required(role="admin") -def list_rounds(): - """Get all rounds from the statestore. - return: All rounds as a json object. - rtype: json - """ - return api.get_all_rounds() - - -if custom_url_prefix: - app.add_url_rule(f"{custom_url_prefix}/list_rounds", view_func=list_rounds, methods=["GET"]) - - -@app.route("/get_round", methods=["GET"]) -@jwt_auth_required(role="admin") -def get_round(): - """Get a round from the statestore. - param: round_id: The round id to get. - type: round_id: str - return: The round as a json object. - rtype: json - """ - round_id = request.args.get("round_id", None) - if round_id is None: - return jsonify({"success": False, "message": "Missing round id."}), 400 - return api.get_round(round_id) - - -if custom_url_prefix: - app.add_url_rule(f"{custom_url_prefix}/get_round", view_func=get_round, methods=["GET"]) - - -@app.route("/start_session", methods=["GET", "POST"]) -@jwt_auth_required(role="admin") -def start_session(): - """Start a new session. - return: The response from control. - rtype: json - """ - json_data = request.get_json() - return api.start_session(**json_data) - - -if custom_url_prefix: - app.add_url_rule(f"{custom_url_prefix}/start_session", view_func=start_session, methods=["GET", "POST"]) - - -@app.route("/list_sessions", methods=["GET"]) -@jwt_auth_required(role="admin") -def list_sessions(): - """Get all sessions from the statestore. - return: All sessions as a json object. - rtype: json - """ - limit = request.args.get("limit", None) - skip = request.args.get("skip", None) - - return api.get_all_sessions(limit, skip) - - -if custom_url_prefix: - app.add_url_rule(f"{custom_url_prefix}/list_sessions", view_func=list_sessions, methods=["GET"]) - - -@app.route("/get_session", methods=["GET"]) -@jwt_auth_required(role="admin") -def get_session(): - """Get a session from the statestore. - param: session_id: The session id to get. - type: session_id: str - return: The session as a json object. - rtype: json - """ - session_id = request.args.get("session_id", None) - if session_id is None: - return ( - jsonify({"success": False, "message": "Missing session id."}), - 400, - ) - return api.get_session(session_id) - - -if custom_url_prefix: - app.add_url_rule(f"{custom_url_prefix}/get_session", view_func=get_session, methods=["GET"]) - - -@app.route("/set_active_package", methods=["PUT"]) -@jwt_auth_required(role="admin") -def set_active_package(): - id = request.args.get("id", None) - return api.set_active_compute_package(id) - - -if custom_url_prefix: - app.add_url_rule(f"{custom_url_prefix}/set_active_package", view_func=set_active_package, methods=["PUT"]) - - -@app.route("/set_package", methods=["POST"]) -@jwt_auth_required(role="admin") -def set_package(): - """ Set the compute package in the statestore. - Usage with curl: - curl -k -X POST \ - -F file=@package.tgz \ - -F helper="kerashelper" \ - http://localhost:8092/set_package - - param: file: The compute package file to set. - type: file: file - return: The response from the statestore. - rtype: json - """ - helper_type = request.form.get("helper", None) - name = request.form.get("name", None) - description = request.form.get("description", None) - - if helper_type is None: - return ( - jsonify({"success": False, "message": "Missing helper type."}), - 400, - ) - try: - file = request.files["file"] - except KeyError: - return jsonify({"success": False, "message": "Missing file."}), 400 - return api.set_compute_package(file=file, helper_type=helper_type, name=name, description=description) - - -if custom_url_prefix: - app.add_url_rule(f"{custom_url_prefix}/set_package", view_func=set_package, methods=["POST"]) - - @app.route("/get_package", methods=["GET"]) @jwt_auth_required(role="admin") def get_package(): @@ -400,24 +183,6 @@ def get_package(): app.add_url_rule(f"{custom_url_prefix}/get_package", view_func=get_package, methods=["GET"]) -@app.route("/list_compute_packages", methods=["GET"]) -@jwt_auth_required(role="admin") -def list_compute_packages(): - """Get the compute package from the statestore. - return: The compute package as a json object. - rtype: json - """ - limit = request.args.get("limit", None) - skip = request.args.get("skip", None) - include_active = request.args.get("include_active", None) - - return api.list_compute_packages(limit=limit, skip=skip, include_active=include_active) - - -if custom_url_prefix: - app.add_url_rule(f"{custom_url_prefix}/list_compute_packages", view_func=list_compute_packages, methods=["GET"]) - - @app.route("/download_package", methods=["GET"]) @jwt_auth_required(role="client") def download_package(): @@ -444,60 +209,6 @@ def get_package_checksum(): app.add_url_rule(f"{custom_url_prefix}/get_package_checksum", view_func=get_package_checksum, methods=["GET"]) -@app.route("/get_latest_model", methods=["GET"]) -@jwt_auth_required(role="admin") -def get_latest_model(): - """Get the latest model from the statestore. - return: The initial model as a json object. - rtype: json - """ - return api.get_latest_model() - - -if custom_url_prefix: - app.add_url_rule(f"{custom_url_prefix}/get_latest_model", view_func=get_latest_model, methods=["GET"]) - - -@app.route("/set_current_model", methods=["PUT"]) -@jwt_auth_required(role="admin") -def set_current_model(): - """Set the initial model in the statestore and upload to model repository. - Usage with curl: - curl -k -X PUT - -F id= - http://localhost:8092/set_current_model - - param: id: The model id to set. - type: id: str - return: boolean. - rtype: json - """ - id = request.args.get("id", None) - if id is None: - return jsonify({"success": False, "message": "Missing model id."}), 400 - return api.set_current_model(id) - - -if custom_url_prefix: - app.add_url_rule(f"{custom_url_prefix}/set_current_model", view_func=set_current_model, methods=["PUT"]) - -# Get initial model endpoint - - -@app.route("/get_initial_model", methods=["GET"]) -@jwt_auth_required(role="admin") -def get_initial_model(): - """Get the initial model from the statestore. - return: The initial model as a json object. - rtype: json - """ - return api.get_initial_model() - - -if custom_url_prefix: - app.add_url_rule(f"{custom_url_prefix}/get_initial_model", view_func=get_initial_model, methods=["GET"]) - - @app.route("/set_initial_model", methods=["POST"]) @jwt_auth_required(role="admin") def set_initial_model(): @@ -553,39 +264,6 @@ def get_client_config(): app.add_url_rule(f"{custom_url_prefix}/get_client_config", view_func=get_client_config, methods=["GET"]) -@app.route("/get_events", methods=["GET"]) -@jwt_auth_required(role="admin") -def get_events(): - """Get the events from the statestore. - return: The events as a json object. - rtype: json - """ - # TODO: except filter with request.get_json() - kwargs = request.args.to_dict() - - return api.get_events(**kwargs) - - -if custom_url_prefix: - app.add_url_rule(f"{custom_url_prefix}/get_events", view_func=get_events, methods=["GET"]) - - -@app.route("/list_validations", methods=["GET"]) -@jwt_auth_required(role="admin") -def list_validations(): - """Get all validations from the statestore. - return: All validations as a json object. - rtype: json - """ - # TODO: except filter with request.get_json() - kwargs = request.args.to_dict() - return api.get_all_validations(**kwargs) - - -if custom_url_prefix: - app.add_url_rule(f"{custom_url_prefix}/list_validations", view_func=list_validations, methods=["GET"]) - - @app.route("/add_combiner", methods=["POST"]) @jwt_auth_required(role="combiner") def add_combiner(): diff --git a/fedn/network/api/shared.py b/fedn/network/api/shared.py index 9e0e5acbd..818f8f334 100644 --- a/fedn/network/api/shared.py +++ b/fedn/network/api/shared.py @@ -1,10 +1,53 @@ +import pymongo +from pymongo.database import Database + from fedn.common.config import get_modelstorage_config, get_network_config, get_statestore_config from fedn.network.controller.control import Control +from fedn.network.storage.s3.base import RepositoryBase +from fedn.network.storage.s3.miniorepository import MINIORepository +from fedn.network.storage.s3.repository import Repository from fedn.network.storage.statestore.mongostatestore import MongoStateStore +from fedn.network.storage.statestore.stores.client_store import ClientStore +from fedn.network.storage.statestore.stores.combiner_store import CombinerStore +from fedn.network.storage.statestore.stores.model_store import MongoDBModelStore +from fedn.network.storage.statestore.stores.package_store import MongoDBPackageStore +from fedn.network.storage.statestore.stores.round_store import RoundStore +from fedn.network.storage.statestore.stores.session_store import SessionStore +from fedn.network.storage.statestore.stores.status_store import StatusStore +from fedn.network.storage.statestore.stores.validation_store import ValidationStore statestore_config = get_statestore_config() modelstorage_config = get_modelstorage_config() network_id = get_network_config() statestore = MongoStateStore(network_id, statestore_config["mongo_config"]) statestore.set_storage_backend(modelstorage_config) -control = Control(statestore=statestore) + +mc = pymongo.MongoClient(**statestore_config["mongo_config"]) +mc.server_info() +mdb: Database = mc[network_id] + +client_store = ClientStore(mdb, "network.clients") +package_store = MongoDBPackageStore(mdb, "control.package") +session_store = SessionStore(mdb, "control.sessions") +model_store = MongoDBModelStore(mdb, "control.model") +combiner_store = CombinerStore(mdb, "network.combiners") +round_store = RoundStore(mdb, "control.rounds") +status_store = StatusStore(mdb, "control.status") +validation_store = ValidationStore(mdb, "control.validations") + +control = Control(statestore=statestore, session_store=session_store, model_store=model_store, round_store=round_store, package_store=package_store) + +minio_repository: RepositoryBase = None + +if modelstorage_config["storage_type"] == "S3": + minio_repository = MINIORepository(modelstorage_config["storage_config"]) + + +storage_collection = mdb["network.storage"] + +storage_config = storage_collection.find_one({"status": "enabled"}, projection={"_id": False}) + +repository: RepositoryBase = None + +if storage_config["storage_type"] == "S3": + repository = Repository(storage_config["storage_config"]) diff --git a/fedn/network/api/tests.py b/fedn/network/api/tests.py index fbd4f4972..7c405c3f9 100644 --- a/fedn/network/api/tests.py +++ b/fedn/network/api/tests.py @@ -41,50 +41,6 @@ def setUp(self, mock_mongo, mock_control): import fedn.network.api.server self.app = fedn.network.api.server.app.test_client() - def test_get_model_trail(self): - """ Test get_model_trail endpoint. """ - # Mock api.get_model_trail - model_id = "test" - time_stamp = time.time() - return_value = {model_id: time_stamp} - fedn.network.api.server.api.get_model_trail = MagicMock(return_value=return_value) - # Make request - response = self.app.get('/get_model_trail') - # Assert response - self.assertEqual(response.status_code, 200) - self.assertEqual(response.json, return_value) - # Assert api.get_model_trail was called - fedn.network.api.server.api.get_model_trail.assert_called_once_with() - - def test_get_latest_model(self): - """ Test get_latest_model endpoint. """ - # Mock api.get_latest_model - model_id = "test" - time_stamp = time.time() - return_value = {model_id: time_stamp} - fedn.network.api.server.api.get_latest_model = MagicMock(return_value=return_value) - # Make request - response = self.app.get('/get_latest_model') - # Assert response - self.assertEqual(response.status_code, 200) - self.assertEqual(response.json, return_value) - # Assert api.get_latest_model was called - fedn.network.api.server.api.get_latest_model.assert_called_once_with() - - def test_get_initial_model(self): - """ Test get_initial_model endpoint. """ - # Mock api.get_initial_model - model_id = "test" - time_stamp = time.time() - return_value = {model_id: time_stamp} - fedn.network.api.server.api.get_initial_model = MagicMock(return_value=return_value) - # Make request - response = self.app.get('/get_initial_model') - # Assert response - self.assertEqual(response.status_code, 200) - self.assertEqual(response.json, return_value) - # Assert api.get_initial_model was called - fedn.network.api.server.api.get_initial_model.assert_called_once_with() def test_set_initial_model(self): """ Test set_initial_model endpoint. """ @@ -160,44 +116,6 @@ def test_list_combiners(self): # Assert api.get_all_combiners was called fedn.network.api.server.api.get_all_combiners.assert_called_once_with() - def test_list_compute_packages(self): - """ Test list_compute_packages endpoint. """ - # Mock api.list_compute_packages - return_value = {"test": "test"} - fedn.network.api.server.api.list_compute_packages = MagicMock(return_value=return_value) - # Make request - response = self.app.get('/list_combiners') - # Assert response - self.assertEqual(response.status_code, 200) - self.assertEqual(response.json, return_value) - # Assert api.list_compute_packages was called - fedn.network.api.server.api.list_compute_packages.assert_called_once_with() - - def test_list_rounds(self): - """ Test list_rounds endpoint. """ - # Mock api.get_all_rounds - return_value = {"test": "test"} - fedn.network.api.server.api.get_all_rounds = MagicMock(return_value=return_value) - # Make request - response = self.app.get('/list_rounds') - # Assert response - self.assertEqual(response.status_code, 200) - self.assertEqual(response.json, return_value) - # Assert api.get_all_rounds was called - fedn.network.api.server.api.get_all_rounds.assert_called_once_with() - - def test_get_round(self): - """ Test get_round endpoint. """ - # Mock api.get_round - return_value = {"test": "test"} - fedn.network.api.server.api.get_round = MagicMock(return_value=return_value) - # Make request - response = self.app.get('/get_round?round_id=test') - # Assert response - self.assertEqual(response.status_code, 200) - self.assertEqual(response.json, return_value) - # Assert api.get_round was called - fedn.network.api.server.api.get_round.assert_called_once_with("test") def test_get_combiner(self): """ Test get_combiner endpoint. """ @@ -240,19 +158,6 @@ def test_add_combiner(self): fqdn='test', ) - def test_get_events(self): - """ Test get_events endpoint. """ - # Mock api.get_events - return_value = {"test": "test"} - fedn.network.api.server.api.get_events = MagicMock(return_value=return_value) - # Make request - response = self.app.get('/get_events') - # Assert response - self.assertEqual(response.status_code, 200) - self.assertEqual(response.json, return_value) - # Assert api.get_events was called - fedn.network.api.server.api.get_events.assert_called_once() - def test_get_status(self): """ Test get_status endpoint. """ # Mock api.get_status @@ -304,18 +209,6 @@ def test_list_sessions(self): # Assert api.list_sessions was called fedn.network.api.server.api.get_all_sessions.assert_called_once() - def test_list_models(self): - """ Test list_models endpoint. """ - # Mock api.list_models - return_value = {"test": "test"} - fedn.network.api.server.api.get_models = MagicMock(return_value=return_value) - # Make request - response = self.app.get('/list_models') - # Assert response - self.assertEqual(response.status_code, 200) - self.assertEqual(response.json, return_value) - # Assert api.list_models was called - fedn.network.api.server.api.get_models.assert_called_once() def test_get_package(self): """ Test get_package endpoint. """ diff --git a/fedn/network/api/v1/client_routes.py b/fedn/network/api/v1/client_routes.py index e1eb7ef5a..fcdacb395 100644 --- a/fedn/network/api/v1/client_routes.py +++ b/fedn/network/api/v1/client_routes.py @@ -1,7 +1,8 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import api_version, client_store, get_post_data_to_kwargs, get_typed_list_headers +from fedn.network.api.shared import client_store +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers from fedn.network.storage.statestore.stores.shared import EntityNotFound bp = Blueprint("client", __name__, url_prefix=f"/api/{api_version}/clients") diff --git a/fedn/network/api/v1/combiner_routes.py b/fedn/network/api/v1/combiner_routes.py index ce012645e..0bd4b545e 100644 --- a/fedn/network/api/v1/combiner_routes.py +++ b/fedn/network/api/v1/combiner_routes.py @@ -1,7 +1,8 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import api_version, client_store, combiner_store, get_post_data_to_kwargs, get_typed_list_headers +from fedn.network.api.shared import client_store, combiner_store +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers from fedn.network.storage.statestore.stores.shared import EntityNotFound bp = Blueprint("combiner", __name__, url_prefix=f"/api/{api_version}/combiners") diff --git a/fedn/network/api/v1/graphql/schema.py b/fedn/network/api/v1/graphql/schema.py index 0436c73ea..23871067b 100644 --- a/fedn/network/api/v1/graphql/schema.py +++ b/fedn/network/api/v1/graphql/schema.py @@ -1,7 +1,7 @@ import graphene import pymongo -from fedn.network.api.v1.shared import model_store, session_store, status_store, validation_store +from fedn.network.api.shared import model_store, session_store, status_store, validation_store class ActorType(graphene.ObjectType): diff --git a/fedn/network/api/v1/helper_routes.py b/fedn/network/api/v1/helper_routes.py index 03cfed7bb..5a8f7e357 100644 --- a/fedn/network/api/v1/helper_routes.py +++ b/fedn/network/api/v1/helper_routes.py @@ -1,13 +1,13 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import api_version, package_store +from fedn.network.api.shared import package_store +from fedn.network.api.v1.shared import api_version from fedn.network.storage.statestore.stores.shared import EntityNotFound bp = Blueprint("helper", __name__, url_prefix=f"/api/{api_version}/helpers") - @bp.route("/active", methods=["GET"]) @jwt_auth_required(role="admin") def get_active_helper(): @@ -25,7 +25,6 @@ def get_active_helper(): description: An unexpected error occurred """ try: - active_package = package_store.get_active() response = active_package["helper"] @@ -36,6 +35,7 @@ def get_active_helper(): except Exception: return jsonify({"message": "An unexpected error occurred"}), 500 + @bp.route("/active", methods=["PUT"]) @jwt_auth_required(role="admin") def set_active_helper(): diff --git a/fedn/network/api/v1/model_routes.py b/fedn/network/api/v1/model_routes.py index 76e854494..d349289d2 100644 --- a/fedn/network/api/v1/model_routes.py +++ b/fedn/network/api/v1/model_routes.py @@ -4,8 +4,8 @@ from flask import Blueprint, jsonify, request, send_file from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.shared import modelstorage_config -from fedn.network.api.v1.shared import api_version, get_limit, get_post_data_to_kwargs, get_reverse, get_typed_list_headers, minio_repository, model_store +from fedn.network.api.shared import minio_repository, model_store, modelstorage_config +from fedn.network.api.v1.shared import api_version, get_limit, get_post_data_to_kwargs, get_reverse, get_typed_list_headers from fedn.network.storage.statestore.stores.shared import EntityNotFound bp = Blueprint("model", __name__, url_prefix=f"/api/{api_version}/models") diff --git a/fedn/network/api/v1/package_routes.py b/fedn/network/api/v1/package_routes.py index 4ed138369..918d09e8c 100644 --- a/fedn/network/api/v1/package_routes.py +++ b/fedn/network/api/v1/package_routes.py @@ -5,7 +5,8 @@ from fedn.common.config import FEDN_COMPUTE_PACKAGE_DIR from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, package_store, repository +from fedn.network.api.shared import package_store, repository +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers from fedn.network.storage.statestore.stores.shared import EntityNotFound bp = Blueprint("package", __name__, url_prefix=f"/api/{api_version}/packages") diff --git a/fedn/network/api/v1/prediction_routes.py b/fedn/network/api/v1/prediction_routes.py index 0ea34224a..c625fd797 100644 --- a/fedn/network/api/v1/prediction_routes.py +++ b/fedn/network/api/v1/prediction_routes.py @@ -3,16 +3,14 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.shared import control -from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb -from fedn.network.storage.statestore.stores.model_store import ModelStore +from fedn.network.api.shared import control, mdb, model_store +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers from fedn.network.storage.statestore.stores.prediction_store import PredictionStore from fedn.network.storage.statestore.stores.shared import EntityNotFound bp = Blueprint("prediction", __name__, url_prefix=f"/api/{api_version}/predict") prediction_store = PredictionStore(mdb, "control.predictions") -model_store = ModelStore(mdb, "control.model") @bp.route("/start", methods=["POST"]) diff --git a/fedn/network/api/v1/round_routes.py b/fedn/network/api/v1/round_routes.py index 2c0f6cc9a..052cb7b93 100644 --- a/fedn/network/api/v1/round_routes.py +++ b/fedn/network/api/v1/round_routes.py @@ -1,7 +1,8 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, round_store +from fedn.network.api.shared import round_store +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers from fedn.network.storage.statestore.stores.shared import EntityNotFound bp = Blueprint("round", __name__, url_prefix=f"/api/{api_version}/rounds") diff --git a/fedn/network/api/v1/session_routes.py b/fedn/network/api/v1/session_routes.py index 52c68fb63..bf69d3bbd 100644 --- a/fedn/network/api/v1/session_routes.py +++ b/fedn/network/api/v1/session_routes.py @@ -3,8 +3,8 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.shared import control -from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, model_store, session_store +from fedn.network.api.shared import control, model_store, session_store +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers from fedn.network.combiner.interfaces import CombinerUnavailableError from fedn.network.state import ReducerState from fedn.network.storage.statestore.stores.shared import EntityNotFound diff --git a/fedn/network/api/v1/shared.py b/fedn/network/api/v1/shared.py index 0f6267b38..844f31d03 100644 --- a/fedn/network/api/v1/shared.py +++ b/fedn/network/api/v1/shared.py @@ -1,49 +1,8 @@ from typing import Tuple import pymongo -from pymongo.database import Database - -from fedn.network.api.shared import modelstorage_config, network_id, statestore_config -from fedn.network.storage.s3.base import RepositoryBase -from fedn.network.storage.s3.miniorepository import MINIORepository -from fedn.network.storage.s3.repository import Repository -from fedn.network.storage.statestore.stores.client_store import ClientStore -from fedn.network.storage.statestore.stores.combiner_store import CombinerStore -from fedn.network.storage.statestore.stores.model_store import ModelStore -from fedn.network.storage.statestore.stores.package_store import PackageStore -from fedn.network.storage.statestore.stores.round_store import RoundStore -from fedn.network.storage.statestore.stores.session_store import SessionStore -from fedn.network.storage.statestore.stores.status_store import StatusStore -from fedn.network.storage.statestore.stores.validation_store import ValidationStore api_version = "v1" -mc = pymongo.MongoClient(**statestore_config["mongo_config"]) -mc.server_info() -mdb: Database = mc[network_id] - -client_store = ClientStore(mdb, "network.clients") -package_store = PackageStore(mdb, "control.package") -session_store = SessionStore(mdb, "control.sessions") -model_store = ModelStore(mdb, "control.model") -combiner_store = CombinerStore(mdb, "network.combiners") -round_store = RoundStore(mdb, "control.rounds") -status_store = StatusStore(mdb, "control.status") -validation_store = ValidationStore(mdb, "control.validations") - -minio_repository: RepositoryBase = None - -if modelstorage_config["storage_type"] == "S3": - minio_repository = MINIORepository(modelstorage_config["storage_config"]) - - -storage_collection = mdb["network.storage"] - -storage_config = storage_collection.find_one({"status": "enabled"}, projection={"_id": False}) - -repository: RepositoryBase = None - -if storage_config["storage_type"] == "S3": - repository = Repository(storage_config["storage_config"]) def is_positive_integer(s): diff --git a/fedn/network/api/v1/status_routes.py b/fedn/network/api/v1/status_routes.py index 0cb1f8194..712863106 100644 --- a/fedn/network/api/v1/status_routes.py +++ b/fedn/network/api/v1/status_routes.py @@ -1,7 +1,8 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, status_store +from fedn.network.api.shared import status_store +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers from fedn.network.storage.statestore.stores.shared import EntityNotFound bp = Blueprint("status", __name__, url_prefix=f"/api/{api_version}/statuses") diff --git a/fedn/network/api/v1/validation_routes.py b/fedn/network/api/v1/validation_routes.py index 8fd5f2bb7..f0f349097 100644 --- a/fedn/network/api/v1/validation_routes.py +++ b/fedn/network/api/v1/validation_routes.py @@ -1,7 +1,8 @@ from flask import Blueprint, jsonify, request from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, validation_store +from fedn.network.api.shared import validation_store +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers from fedn.network.storage.statestore.stores.shared import EntityNotFound bp = Blueprint("validation", __name__, url_prefix=f"/api/{api_version}/validations") diff --git a/fedn/network/combiner/combiner.py b/fedn/network/combiner/combiner.py index e30860b4e..f10ea42c9 100644 --- a/fedn/network/combiner/combiner.py +++ b/fedn/network/combiner/combiner.py @@ -807,3 +807,4 @@ def run(self): except (KeyboardInterrupt, SystemExit): pass self.server.stop() + self.server.stop() diff --git a/fedn/network/combiner/roundhandler.py b/fedn/network/combiner/roundhandler.py index fa3d83e8f..d628ddec2 100644 --- a/fedn/network/combiner/roundhandler.py +++ b/fedn/network/combiner/roundhandler.py @@ -31,7 +31,7 @@ class RoundConfig(TypedDict): :param round_timeout: The round timeout in seconds. Set by user interfaces or Controller. :type round_timeout: str :param rounds: The number of rounds. Set by user interfaces. - :param model_id: The model identifier. Set by user interfaces or Controller (get_latest_model). + :param model_id: The model identifier. Set by user interfaces or Controller. :type model_id: str :param model_version: The model version. Currently not used. :type model_version: str diff --git a/fedn/network/controller/control.py b/fedn/network/controller/control.py index c2b430fd5..f3f9749ce 100644 --- a/fedn/network/controller/control.py +++ b/fedn/network/controller/control.py @@ -11,6 +11,10 @@ from fedn.network.combiner.roundhandler import RoundConfig from fedn.network.controller.controlbase import ControlBase from fedn.network.state import ReducerState +from fedn.network.storage.statestore.stores.model_store import ModelStore +from fedn.network.storage.statestore.stores.package_store import PackageStore +from fedn.network.storage.statestore.stores.round_store import RoundStore +from fedn.network.storage.statestore.stores.session_store import SessionStore class UnsupportedStorageBackend(Exception): @@ -88,9 +92,9 @@ class Control(ControlBase): :type statestore: class: `fedn.network.statestorebase.StateStorageBase` """ - def __init__(self, statestore): + def __init__(self, statestore, session_store: SessionStore, model_store: ModelStore, round_store: RoundStore, package_store: PackageStore): """Constructor method.""" - super().__init__(statestore) + super().__init__(statestore, session_store, model_store, round_store, package_store) self.name = "DefaultControl" def start_session(self, session_id: str, rounds: int, round_timeout: int) -> None: @@ -98,13 +102,22 @@ def start_session(self, session_id: str, rounds: int, round_timeout: int) -> Non logger.info("Controller already in INSTRUCTING state. A session is in progress.") return - if not self.statestore.get_latest_model(): + model_set: bool = False + + try: + active_model_id = self.model_store.get_active() + if active_model_id not in ["", " "]: + model_set = True + except Exception: + logger.error("Failed to get active model") + + if not model_set: logger.warning("No model in model chain, please provide a seed model!") return self._state = ReducerState.instructing - session = self.statestore.get_session(session_id) + session = self.session_store.get(session_id) if not session: logger.error("Session not found.") @@ -121,7 +134,7 @@ def start_session(self, session_id: str, rounds: int, round_timeout: int) -> Non self._state = ReducerState.monitoring - last_round = int(self.get_latest_round_id()) + last_round = self.get_latest_round_id() aggregator = session_config["aggregator"] @@ -150,7 +163,7 @@ def start_session(self, session_id: str, rounds: int, round_timeout: int) -> Non logger.info("Round completed with status {}".format(round_data["status"])) - session_config["model_id"] = self.statestore.get_latest_model() + session_config["model_id"] = self.model_store.get_active() if self.get_session_status(session_id) == "Started": self.set_session_status(session_id, "Finished") @@ -158,63 +171,6 @@ def start_session(self, session_id: str, rounds: int, round_timeout: int) -> Non self.set_session_config(session_id, session_config) - def session(self, config: RoundConfig) -> None: - """Execute a new training session. A session consists of one - or several global rounds. All rounds in the same session - have the same round_config. - - :param config: The session config. - :type config: dict - - """ - if self._state == ReducerState.instructing: - logger.info("Controller already in INSTRUCTING state. A session is in progress.") - return - - if not self.statestore.get_latest_model(): - logger.warning("No model in model chain, please provide a seed model!") - return - - self._state = ReducerState.instructing - config["committed_at"] = datetime.datetime.now() - - self.create_session(config) - - self._state = ReducerState.monitoring - - last_round = int(self.get_latest_round_id()) - - for combiner in self.network.get_combiners(): - combiner.set_aggregator(config["aggregator"]) - if config["server_functions"] is not None: - combiner.set_server_functions(config["server_functions"]) - - self.set_session_status(config["session_id"], "Started") - # Execute the rounds in this session - for round in range(1, int(config["rounds"] + 1)): - # Increment the round number - if last_round: - current_round = last_round + round - else: - current_round = round - - try: - if self.get_session_status(config["session_id"]) == "Terminated": - logger.info("Session terminated.") - break - _, round_data = self.round(config, str(current_round)) - except TypeError as e: - logger.error("Failed to execute round: {0}".format(e)) - - logger.info("Round completed with status {}".format(round_data["status"])) - - config["model_id"] = self.statestore.get_latest_model() - - # TODO: Report completion of session - if self.get_session_status(config["session_id"]) == "Started": - self.set_session_status(config["session_id"], "Finished") - self._state = ReducerState.idle - def prediction_session(self, config: RoundConfig) -> None: """Execute a new prediction session. @@ -231,7 +187,7 @@ def prediction_session(self, config: RoundConfig) -> None: return if "model_id" not in config.keys(): - config["model_id"] = self.statestore.get_latest_model() + config["model_id"] = self.model_store.get_active() config["committed_at"] = datetime.datetime.now() config["task"] = "prediction" @@ -264,7 +220,7 @@ def round(self, session_config: RoundConfig, round_id: str): if len(self.network.get_combiners()) < 1: logger.warning("Round cannot start, no combiners connected!") self.set_round_status(round_id, "Failed") - return None, self.statestore.get_round(round_id) + return None, self.round_store.get(round_id) # Assemble round config for this global round round_config = copy.deepcopy(session_config) @@ -286,7 +242,7 @@ def round(self, session_config: RoundConfig, round_id: str): else: logger.warning("Round start policy not met, skipping round!") self.set_round_status(round_id, "Failed") - return None, self.statestore.get_round(round_id) + return None, self.round_store.get(round_id) # Ask participating combiners to coordinate model updates _ = self.request_model_updates(participating_combiners) @@ -305,7 +261,7 @@ def do_if_round_times_out(result): retry=retry_if_exception_type(CombinersNotDoneException), ) def combiners_done(): - round = self.statestore.get_round(round_id) + round = self.round_store.get(round_id) session_status = self.get_session_status(session_id) if session_status == "Terminated": self.set_round_status(round_id, "Terminated") @@ -322,38 +278,38 @@ def combiners_done(): combiners_are_done = combiners_done() if not combiners_are_done: - return None, self.statestore.get_round(round_id) + return None, self.round_store.get(round_id) # Due to the distributed nature of the computation, there might be a # delay before combiners have reported the round data to the db, # so we need some robustness here. @retry(wait=wait_random(min=0.1, max=1.0), retry=retry_if_exception_type(KeyError)) def check_combiners_done_reporting(): - round = self.statestore.get_round(round_id) + round = self.round_store.get(round_id) combiners = round["combiners"] return combiners _ = check_combiners_done_reporting() - round = self.statestore.get_round(round_id) + round = self.round_store.get(round_id) round_valid = self.evaluate_round_validity_policy(round) if not round_valid: logger.error("Round failed. Invalid - evaluate_round_validity_policy: False") self.set_round_status(round_id, "Failed") - return None, self.statestore.get_round(round_id) + return None, self.round_store.get(round_id) logger.info("Reducing combiner level models...") # Reduce combiner models into a new global model round_data = {} try: - round = self.statestore.get_round(round_id) + round = self.round_store.get(round_id) model, data = self.reduce(round["combiners"]) round_data["reduce"] = data logger.info("Done reducing models from combiners!") except Exception as e: logger.error("Failed to reduce models from combiners, reason: {}".format(e)) self.set_round_status(round_id, "Failed") - return None, self.statestore.get_round(round_id) + return None, self.round_store.get(round_id) # Commit the new global model to the model trail if model is not None: @@ -367,7 +323,7 @@ def check_combiners_done_reporting(): else: logger.error("Failed to commit model to global model trail.") self.set_round_status(round_id, "Failed") - return None, self.statestore.get_round(round_id) + return None, self.round_store.get(round_id) self.set_round_status(round_id, "Success") @@ -376,9 +332,18 @@ def check_combiners_done_reporting(): if validate: combiner_config = copy.deepcopy(session_config) combiner_config["round_id"] = round_id - combiner_config["model_id"] = self.statestore.get_latest_model() + combiner_config["model_id"] = self.model_store.get_active() combiner_config["task"] = "validation" - combiner_config["helper_type"] = self.statestore.get_helper() + + helper_type: str = None + + try: + active_package = self.package_store.get_active() + helper_type = active_package["helper"] + except Exception: + logger.error("Failed to get active helper") + + combiner_config["helper_type"] = helper_type validating_combiners = self.get_participating_combiners(combiner_config) @@ -392,7 +357,7 @@ def check_combiners_done_reporting(): self.set_round_data(round_id, round_data) self.set_round_status(round_id, "Finished") - return model_id, self.statestore.get_round(round_id) + return model_id, self.round_store.get(round_id) def reduce(self, combiners): """Combine updated models from Combiner nodes into one global model. @@ -483,7 +448,7 @@ def prediction_round(self, config): # Setup combiner configuration combiner_config = copy.deepcopy(config) - combiner_config["model_id"] = self.statestore.get_latest_model() + combiner_config["model_id"] = self.model_store.get_active() combiner_config["task"] = "prediction" combiner_config["helper_type"] = self.statestore.get_framework() diff --git a/fedn/network/controller/controlbase.py b/fedn/network/controller/controlbase.py index 397a117bb..a381476f7 100644 --- a/fedn/network/controller/controlbase.py +++ b/fedn/network/controller/controlbase.py @@ -3,6 +3,8 @@ from abc import ABC, abstractmethod from time import sleep +import pymongo + import fedn.utils.helpers.helpers from fedn.common.log_config import logger from fedn.network.api.network import Network @@ -10,6 +12,10 @@ from fedn.network.combiner.roundhandler import RoundConfig from fedn.network.state import ReducerState from fedn.network.storage.s3.repository import Repository +from fedn.network.storage.statestore.stores.model_store import ModelStore +from fedn.network.storage.statestore.stores.package_store import PackageStore +from fedn.network.storage.statestore.stores.round_store import RoundStore +from fedn.network.storage.statestore.stores.session_store import SessionStore # Maximum number of tries to connect to statestore and retrieve storage configuration MAX_TRIES_BACKEND = os.getenv("MAX_TRIES_BACKEND", 10) @@ -36,10 +42,14 @@ class ControlBase(ABC): """ @abstractmethod - def __init__(self, statestore): + def __init__(self, statestore, session_store: SessionStore, model_store: ModelStore, round_store: RoundStore, package_store: PackageStore): """Constructor.""" self._state = ReducerState.setup + self.session_store = session_store + self.model_store = model_store + self.round_store = round_store + self.package_store = package_store self.statestore = statestore if self.statestore.is_inited(): self.network = Network(self, statestore) @@ -70,10 +80,6 @@ def __init__(self, statestore): if self.statestore.is_inited(): self._state = ReducerState.idle - @abstractmethod - def session(self, config): - pass - @abstractmethod def round(self, config, round_number): pass @@ -88,7 +94,14 @@ def get_helper(self): :return: Helper instance. :rtype: :class:`fedn.utils.plugins.helperbase.HelperBase` """ - helper_type = self.statestore.get_helper() + helper_type: str = None + + try: + active_package = self.package_store.get_active() + helper_type = active_package["helper"] + except Exception: + logger.error("Failed to get active helper") + helper = fedn.utils.helpers.helpers.get_helper(helper_type) if not helper: raise MisconfiguredHelper("Unsupported helper type {}, please configure compute_package.helper !".format(helper_type)) @@ -113,29 +126,17 @@ def idle(self): else: return False - def get_model_info(self): - """:return:""" - return self.statestore.get_model_trail() - - # TODO: remove use statestore.get_events() instead - def get_events(self): - """:return:""" - return self.statestore.get_events() - - def get_latest_round_id(self): - last_round = self.statestore.get_latest_round() - if not last_round: - return 0 + def get_latest_round_id(self) -> int: + response = self.round_store.list(limit=1, skip=0, sort_key="_id", sort_order=pymongo.DESCENDING) + if response and "result" in response and len(response["result"]) > 0: + round_id: str = response["result"][0]["round_id"] + return int(round_id) else: - return last_round["round_id"] - - def get_latest_round(self): - round = self.statestore.get_latest_round() - return round + return 0 def get_compute_package_name(self): """:return:""" - definition = self.statestore.get_compute_package() + definition = self.package_store.get_active() if definition: try: package_name = definition["storage_file_name"] @@ -161,18 +162,6 @@ def get_compute_package(self, compute_package=""): else: return None - def create_session(self, config: RoundConfig, status: str = "Initialized") -> None: - """Initialize a new session in backend db.""" - if "session_id" not in config.keys(): - session_id = uuid.uuid4() - config["session_id"] = str(session_id) - else: - session_id = config["session_id"] - - self.statestore.create_session(id=session_id) - self.statestore.set_session_config(session_id, config) - self.statestore.set_session_status(session_id, status) - def set_session_status(self, session_id, status): """Set the round round stats. @@ -183,7 +172,7 @@ def set_session_status(self, session_id, status): """ self.statestore.set_session_status(session_id, status) - def get_session_status(self, session_id): + def get_session_status(self, session_id: str): """Get the status of a session. :param session_id: The session unique identifier @@ -191,7 +180,8 @@ def get_session_status(self, session_id): :return: The status :rtype: str """ - return self.statestore.get_session_status(session_id) + session = self.session_store.get(session_id) + return session["status"] def set_session_config(self, session_id: str, config: dict): """Set the model id for a session. @@ -352,3 +342,4 @@ def state(self): :rype: str """ return self._state + return self._state diff --git a/fedn/network/storage/statestore/mongostatestore.py b/fedn/network/storage/statestore/mongostatestore.py index 316cd4965..217dad779 100644 --- a/fedn/network/storage/statestore/mongostatestore.py +++ b/fedn/network/storage/statestore/mongostatestore.py @@ -90,19 +90,6 @@ def is_inited(self): """ return self.__inited - def get_config(self): - """Retrive the statestore config. - - :return: The statestore config. - :rtype: dict - """ - data = { - "type": "MongoDB", - "mongo_config": self.config, - "network_id": self.network_id, - } - return data - def state(self): """Get the current state. @@ -111,74 +98,6 @@ def state(self): """ return StringToReducerState(self.state.find_one()["current_state"]) - def transition(self, state): - """Transition to a new state. - - :param state: The new state. - :type state: str - :return: - """ - old_state = self.state.find_one({"state": "current_state"}) - if old_state != state: - return self.state.update_one( - {"state": "current_state"}, - {"$set": {"state": ReducerStateToString(state)}}, - True, - ) - else: - logger.info("Not updating state, already in {}".format(ReducerStateToString(state))) - - def get_sessions(self, limit=None, skip=None, sort_key="_id", sort_order=pymongo.DESCENDING): - """Get all sessions. - - :param limit: The maximum number of sessions to return. - :type limit: int - :param skip: The number of sessions to skip. - :type skip: int - :param sort_key: The key to sort by. - :type sort_key: str - :param sort_order: The sort order. - :type sort_order: pymongo.ASCENDING or pymongo.DESCENDING - :return: Dictionary of sessions in result (array of session objects) and count. - """ - result = None - - if limit is not None and skip is not None: - limit = int(limit) - skip = int(skip) - - result = self.sessions.find().limit(limit).skip(skip).sort(sort_key, sort_order) - else: - result = self.sessions.find().sort(sort_key, sort_order) - - count = self.sessions.count_documents({}) - - return { - "result": result, - "count": count, - } - - def get_session(self, session_id): - """Get session with id. - - :param session_id: The session id. - :type session_id: str - :return: The session. - :rtype: ObjectID - """ - return self.sessions.find_one({"session_id": session_id}) - - def get_session_status(self, session_id): - """Get the session status. - - :param session_id: The session id. - :type session_id: str - :return: The session status. - :rtype: str - """ - session = self.sessions.find_one({"session_id": session_id}) - return session["status"] - def set_latest_model(self, model_id, session_id=None): """Set the latest model id. @@ -218,157 +137,6 @@ def set_latest_model(self, model_id, session_id=None): True, ) - def get_initial_model(self): - """Return model_id for the initial model in the model trail - - :return: The initial model id. None if no model is found. - :rtype: str - """ - result = self.model.find_one({"key": "model_trail"}, sort=[("committed_at", pymongo.ASCENDING)]) - if result is None: - return None - - try: - model_id = result["model"] - if model_id == "" or model_id == " ": - return None - return model_id[0] - except (KeyError, IndexError): - return None - - def get_latest_model(self): - """Return model_id for the latest model in the model_trail - - :return: The latest model id. None if no model is found. - :rtype: str - """ - result = self.model.find_one({"key": "current_model"}) - if result is None: - return None - - try: - model_id = result["model"] - if model_id == "" or model_id == " ": - return None - return model_id - except (KeyError, IndexError): - return None - - def set_current_model(self, model_id: str): - """Set the current model in statestore. - - :param model_id: The model id. - :type model_id: str - :return: - """ - try: - committed_at = datetime.now() - - existing_model = self.model.find_one({"key": "models", "model": model_id}) - - if existing_model is not None: - self.model.update_one({"key": "current_model"}, {"$set": {"model": model_id, "committed_at": committed_at, "session_id": None}}, True) - - return True - except Exception as e: - logger.error("ERROR: {}".format(e)) - - return False - - def get_latest_round(self): - """Get the id of the most recent round. - - :return: The id of the most recent round. - :rtype: ObjectId - """ - return self.rounds.find_one(sort=[("_id", pymongo.DESCENDING)]) - - def get_round(self, id): - """Get round with id. - - :param id: id of round to get - :type id: int - :return: round with id, reducer and combiners - :rtype: ObjectId - """ - return self.rounds.find_one({"round_id": str(id)}) - - def get_rounds(self): - """Get all rounds. - - :return: All rounds. - :rtype: ObjectId - """ - return self.rounds.find() - - def get_validations(self, **kwargs): - """Get validations from the database. - - :param kwargs: query to filter validations - :type kwargs: dict - :return: validations matching query - :rtype: ObjectId - """ - result = self.control.validations.find(kwargs) - return result - - def set_active_compute_package(self, id: str): - """Set the active compute package in statestore. - - :param id: The id of the compute package (not document _id). - :type id: str - :return: True if successful. - :rtype: bool - """ - try: - find = {"id": id} - projection = {"_id": False, "key": False} - - doc = self.control.package.find_one(find, projection) - - if doc is None: - return False - - doc["key"] = "active" - - self.control.package.replace_one({"key": "active"}, doc) - - except Exception as e: - logger.error("ERROR: {}".format(e)) - return False - - return True - - def set_compute_package(self, file_name: str, storage_file_name: str, helper_type: str, name: str = None, description: str = None): - """Set the active compute package in statestore. - - :param file_name: The file_name of the compute package. - :type file_name: str - :return: True if successful. - :rtype: bool - """ - obj = { - "file_name": file_name, - "storage_file_name": storage_file_name, - "helper": helper_type, - "committed_at": datetime.now(), - "name": name, - "description": description, - "id": str(uuid.uuid4()), - } - - self.control.package.update_one( - {"key": "active"}, - {"$set": obj}, - True, - ) - - trail_obj = {**{"key": "package_trail"}, **obj} - - self.control.package.insert_one(trail_obj) - - return True - def get_compute_package(self): """Get the active compute package. @@ -384,185 +152,6 @@ def get_compute_package(self): logger.error("ERROR: {}".format(e)) return None - def list_compute_packages(self, limit: int = None, skip: int = None, sort_key="committed_at", sort_order=pymongo.DESCENDING): - """List compute packages in the statestore (paginated). - - :param limit: The maximum number of compute packages to return. - :type limit: int - :param skip: The number of compute packages to skip. - :type skip: int - :param sort_key: The key to sort by. - :type sort_key: str - :param sort_order: The sort order. - :type sort_order: pymongo.ASCENDING or pymongo.DESCENDING - :return: Dictionary of compute packages in result and count. - :rtype: dict - """ - result = None - count = None - - find_option = {"key": "package_trail"} - projection = {"key": False, "_id": False} - - try: - if limit is not None and skip is not None: - result = self.control.package.find(find_option, projection).limit(limit).skip(skip).sort(sort_key, sort_order) - else: - result = self.control.package.find(find_option, projection).sort(sort_key, sort_order) - - count = self.control.package.count_documents(find_option) - - except Exception as e: - logger.error("ERROR: {}".format(e)) - return None - - return { - "result": result or [], - "count": count or 0, - } - - def set_helper(self, helper): - """Set the active helper package in statestore. - - :param helper: The name of the helper package. See helper.py for available helpers. - :type helper: str - :return: - """ - self.control.package.update_one({"key": "active"}, {"$set": {"helper": helper}}, True) - - def get_helper(self): - """Get the active helper package. - - :return: The active helper set for the package. - :rtype: str - """ - ret = self.control.package.find_one({"key": "active"}) - # if local compute package used, then 'package' is None - # if not ret: - # get framework from round_config instead - # ret = self.control.config.find_one({'key': 'round_config'}) - try: - retcheck = ret["helper"] - if retcheck == "" or retcheck == " ": # ugly check for empty string - return None - return retcheck - except (KeyError, IndexError): - return None - - def list_models( - self, - session_id=None, - limit=None, - skip=None, - sort_key="committed_at", - sort_order=pymongo.DESCENDING, - ): - """List all models in the statestore. - - :param session_id: The session id. - :type session_id: str - :param limit: The maximum number of models to return. - :type limit: int - :param skip: The number of models to skip. - :type skip: int - :return: List of models. - :rtype: list - """ - result = None - - find_option = {"key": "models"} if session_id is None else {"key": "models", "session_id": session_id} - - projection = {"_id": False, "key": False} - - if limit is not None and skip is not None: - limit = int(limit) - skip = int(skip) - - result = self.model.find(find_option, projection).limit(limit).skip(skip).sort(sort_key, sort_order) - - else: - result = self.model.find(find_option, projection).sort(sort_key, sort_order) - - count = self.model.count_documents(find_option) - - return { - "result": result, - "count": count, - } - - def get_model_trail(self): - """Get the model trail. - - :return: dictionary of model_id: committed_at - :rtype: dict - """ - # TODO Make it so that model order from db is preserved. - result = self.model.find_one({"key": "model_trail"}) - try: - if result is not None: - committed_at = result["committed_at"] - model = result["model"] - model_dictionary = dict(zip(model, committed_at)) - return model_dictionary - else: - return None - except (KeyError, IndexError): - return None - - def get_model_ancestors(self, model_id: str, limit: int): - """Get the model ancestors. - - :param model_id: The model id. - :type model_id: str - :param limit: The maximum number of ancestors to return. - :type limit: int - :return: List of model ancestors. - :rtype: list - """ - model = self.model.find_one({"key": "models", "model": model_id}) - current_model_id = model["parent_model"] if model is not None else None - result = [] - - for _ in range(limit): - if current_model_id is None: - break - - model = self.model.find_one({"key": "models", "model": current_model_id}) - - if model is not None: - result.append(model) - current_model_id = model["parent_model"] - - return result - - def get_model_descendants(self, model_id: str, limit: int): - """Get the model descendants. - - :param model_id: The model id. - :type model_id: str - :param limit: The maximum number of descendants to return. - :type limit: int - :return: List of model descendants. - :rtype: list - """ - model: object = self.model.find_one({"key": "models", "model": model_id}) - current_model_id: str = model["model"] if model is not None else None - result: list = [] - - for _ in range(limit): - if current_model_id is None: - break - - model: str = self.model.find_one({"key": "models", "parent_model": current_model_id}) - - if model is not None: - result.append(model) - current_model_id = model["model"] - - result.reverse() - - return result - def get_model(self, model_id): """Get model with id. @@ -573,41 +162,6 @@ def get_model(self, model_id): """ return self.model.find_one({"key": "models", "model": model_id}) - def get_events(self, **kwargs): - """Get events from the database. - - :param kwargs: query to filter events - :type kwargs: dict - :return: events matching query - :rtype: ObjectId - """ - # check if kwargs is empty - - result = None - count = None - projection = {"_id": False} - - if not kwargs: - result = self.control.status.find({}, projection).sort("timestamp", pymongo.DESCENDING) - count = self.control.status.count_documents({}) - else: - limit = kwargs.pop("limit", None) - skip = kwargs.pop("skip", None) - - if limit is not None and skip is not None: - limit = int(limit) - skip = int(skip) - result = self.control.status.find(kwargs, projection).sort("timestamp", pymongo.DESCENDING).limit(limit).skip(skip) - else: - result = self.control.status.find(kwargs, projection).sort("timestamp", pymongo.DESCENDING) - - count = self.control.status.count_documents(kwargs) - - return { - "result": result, - "count": count, - } - def get_storage_backend(self): """Get the storage backend. diff --git a/fedn/network/storage/statestore/statestorebase.py b/fedn/network/storage/statestore/statestorebase.py index 7c6681682..54d8bb0d8 100644 --- a/fedn/network/storage/statestore/statestorebase.py +++ b/fedn/network/storage/statestore/statestorebase.py @@ -12,15 +12,6 @@ def state(self): """Return the current state of the statestore.""" pass - @abstractmethod - def transition(self, state): - """Transition the statestore to a new state. - - :param state: The new state. - :type state: str - """ - pass - @abstractmethod def set_latest_model(self, model_id): """Set the latest model id in the statestore. @@ -30,15 +21,6 @@ def set_latest_model(self, model_id): """ pass - @abstractmethod - def get_latest_model(self): - """Get the latest model id from the statestore. - - :return: The model object. - :rtype: ObjectId - """ - pass - @abstractmethod def is_inited(self): """Check if the statestore is initialized. diff --git a/fedn/network/storage/statestore/stores/model_store.py b/fedn/network/storage/statestore/stores/model_store.py index d6b96121b..37799f3d8 100644 --- a/fedn/network/storage/statestore/stores/model_store.py +++ b/fedn/network/storage/statestore/stores/model_store.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from datetime import datetime from typing import Any, Dict, List, Tuple @@ -5,7 +6,7 @@ from bson import ObjectId from pymongo.database import Database -from fedn.network.storage.statestore.stores.store import MongoDBStore +from fedn.network.storage.statestore.stores.store import MongoDBStore, SQLStore, Store from .shared import EntityNotFound, from_document @@ -20,7 +21,25 @@ def __init__(self, id: str, key: str, model: str, parent_model: str, session_id: self.committed_at = committed_at -class ModelStore(MongoDBStore[Model]): +class ModelStore(Store[Model]): + @abstractmethod + def list_descendants(self, id: str, limit: int) -> List[Model]: + pass + + @abstractmethod + def list_ancestors(self, id: str, limit: int, include_self: bool = False, reverse: bool = False) -> List[Model]: + pass + + @abstractmethod + def get_active(self) -> str: + pass + + @abstractmethod + def set_active(self, id: str) -> bool: + pass + + +class MongoDBModelStore(ModelStore, MongoDBStore[Model]): def __init__(self, database: Database, collection: str): super().__init__(database, collection) @@ -219,3 +238,48 @@ def set_active(self, id: str) -> bool: self.database[self.collection].update_one({"key": "current_model"}, {"$set": {"model": model["model"]}}) return True + + +class SQLModelStore(ModelStore, SQLStore[Model]): + def __init__(self, database: Database, table: str): + super().__init__(database, table) + + def create_table(self): + table_name = super().table_name + if not table_name.isidentifier(): + raise ValueError(f"Invalid table name: {table_name}") + + query = f""" + CREATE TABLE IF NOT EXISTS {table_name} ( + id VARCHAR(255) PRIMARY KEY, + model VARCHAR(255), + parent_model VARCHAR(255), + session_id VARCHAR(255), + committed_at TIMESTAMP + ) + """ + self.cursor.execute(query) + + def update(self, id, item): + super().cursor.execute( + "UPDATE ? SET model = ?, parent_model = ?, session_id = ?, committed_at = ? WHERE id = ?", + (super().table_name, item.model, item.parent_model, item.session_id, item.committed_at, id), + ) + + def add(self, item): + super().cursor.execute( + "INSERT INTO ? (model, parent_model, session_id, committed_at) VALUES (?, ?, ?, ?)", + (super().table_name, item.model, item.parent_model, item.session_id, item.committed_at), + ) + + def list_descendants(self, id: str, limit: int) -> List[Model]: + raise NotImplementedError("List descendants not implemented for SQLModelStore") + + def list_ancestors(self, id: str, limit: int, include_self: bool = False, reverse: bool = False) -> List[Model]: + raise NotImplementedError("List ancestors not implemented for SQLModelStore") + + def get_active(self) -> str: + raise NotImplementedError("Get active not implemented for SQLModelStore") + + def set_active(self, id: str) -> bool: + raise NotImplementedError("Set active not implemented for SQLModelStore") diff --git a/fedn/network/storage/statestore/stores/package_store.py b/fedn/network/storage/statestore/stores/package_store.py index 44dece2ab..410685669 100644 --- a/fedn/network/storage/statestore/stores/package_store.py +++ b/fedn/network/storage/statestore/stores/package_store.py @@ -7,7 +7,7 @@ from pymongo.database import Database from werkzeug.utils import secure_filename -from fedn.network.storage.statestore.stores.store import MongoDBStore +from fedn.network.storage.statestore.stores.store import MongoDBStore, SQLStore, Store from .shared import EntityNotFound @@ -46,7 +46,11 @@ def __init__( self.active = active -class PackageStore(MongoDBStore[Package]): +class PackageStore(Store[Package]): + pass + + +class MongoDBPackageStore(PackageStore, MongoDBStore[Package]): def __init__(self, database: Database, collection: str): super().__init__(database, collection) @@ -238,3 +242,88 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI def count(self, **kwargs) -> int: kwargs["key"] = "package_trail" return super().count(**kwargs) + + +class SQLPackageStore(PackageStore, SQLStore[Package]): + def __init__(self, db_name: str, table_name: str): + super().__init__(db_name=db_name, table_name=table_name) + # self.table_name = table_name + + def _validate(self, item: Package) -> Tuple[bool, str]: + if "file_name" not in item or not item["file_name"]: + return False, "File name is required" + + if not self._allowed_file_extension(item["file_name"]): + return False, "File extension not allowed" + + if "helper" not in item or not item["helper"]: + return False, "Helper is required" + + return True, "" + + def _complement(self, item: Package): + if "committed_at" not in item or item.committed_at is None: + item["committed_at"] = datetime.now() + + extension = item["file_name"].rsplit(".", 1)[1].lower() + + if "storage_file_name" not in item or item.storage_file_name is None: + storage_file_name = secure_filename(f"{str(uuid.uuid4())}.{extension}") + item["storage_file_name"] = storage_file_name + + def create_table(self): + table_name = super().table_name + if not table_name.isidentifier(): + raise ValueError(f"Invalid table name: {table_name}") + + query = """ + CREATE TABLE IF NOT EXISTS ? ( + id VARCHAR(255) PRIMARY KEY, + active BOOLEAN, + committed_at TIMESTAMP, + description VARCHAR(255), + file_name VARCHAR(255), + helper VARCHAR(255), + name VARCHAR(255), + storage_file_name VARCHAR(255) + ) + """ + self.cursor.execute(query, (table_name,)) + + def update(self, id, item): + pass + # super().cursor.execute( + # "UPDATE ? SET model = ?, parent_model = ?, session_id = ?, committed_at = ? WHERE id = ?", + # (super().table_name, item.model, item.parent_model, item.session_id, item.committed_at, id), + # ) + + def add(self, item: Package) -> Tuple[bool, Any]: + try: + valid, message = self._validate(item) + if not valid: + return False, message + + self._complement(item) + + super().cursor.execute( + "INSERT INTO ? (active, committed_at, description, file_name, helper, name, storage_file_name) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + ( + super().table_name, + item["active"], + item["committed_at"], + item["description"], + item["file_name"], + item["helper"], + item["name"], + item["storage_file_name"], + ), + ) + return True, item + except Exception as e: + return False, str(e) + + def get_active(self) -> str: + raise NotImplementedError("Get active not implemented for SQLModelStore") + + def set_active(self, id: str) -> bool: + raise NotImplementedError("Set active not implemented for SQLModelStore") diff --git a/fedn/network/storage/statestore/stores/round_store.py b/fedn/network/storage/statestore/stores/round_store.py index 9148f0c63..873dfd24f 100644 --- a/fedn/network/storage/statestore/stores/round_store.py +++ b/fedn/network/storage/statestore/stores/round_store.py @@ -1,10 +1,13 @@ from typing import Any, Dict, List, Tuple import pymongo +from bson import ObjectId from pymongo.database import Database from fedn.network.storage.statestore.stores.store import MongoDBStore +from .shared import EntityNotFound, from_document + class Round: def __init__(self, id: str, round_id: str, status: str, round_config: dict, combiners: List[dict], round_data: dict): @@ -26,7 +29,19 @@ def get(self, id: str) -> Round: type: str return: The entity """ - return super().get(id) + kwargs = {} + if ObjectId.is_valid(id): + id_obj = ObjectId(id) + kwargs["_id"] = id_obj + else: + kwargs["round_id"] = id + + document = self.database[self.collection].find_one(kwargs) + + if document is None: + raise EntityNotFound(f"Entity with (id | model) {id} not found") + + return from_document(document) def update(self, id: str, item: Round) -> bool: raise NotImplementedError("Update not implemented for RoundStore") @@ -54,3 +69,4 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI return: The entities """ return super().list(limit, skip, sort_key or "round_id", sort_order, **kwargs) + return super().list(limit, skip, sort_key or "round_id", sort_order, **kwargs) diff --git a/fedn/network/storage/statestore/stores/store.py b/fedn/network/storage/statestore/stores/store.py index ec5e4e9be..4bbbba4fb 100644 --- a/fedn/network/storage/statestore/stores/store.py +++ b/fedn/network/storage/statestore/stores/store.py @@ -1,3 +1,4 @@ +import sqlite3 from abc import ABC, abstractmethod from typing import Any, Dict, Generic, List, Tuple, TypeVar @@ -111,3 +112,49 @@ def count(self, **kwargs) -> int: return: The count (int) """ return self.database[self.collection].count_documents(kwargs) + + +class SQLStore(Store[T], Generic[T]): + def __init__(self, db_name: str, table_name: str): + self.connection = sqlite3.connect(db_name) + self.cursor = self.connection.cursor() + self.table_name = table_name + self.create_table() + + @abstractmethod + def create_table(self): + """Create a table for the specific type.""" + pass + + def get(self, id: str) -> T: + self.cursor.execute( + "SELECT * FROM ? WHERE id = ?", + ( + self.table_name, + id, + ), + ) + row = self.cursor.fetchone() + if not row: + raise ValueError(f"Item with id '{id}' not found") + return row + + def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs) -> Dict[int, List[T]]: + sort_order = "ASC" if sort_order == pymongo.ASCENDING else "DESC" + if limit and skip: + self.cursor.execute("SELECT * FROM ? ORDER BY ? ? LIMIT ? OFFSET ?", (self.table_name, sort_key, sort_order, limit, skip)) + else: + self.cursor.execute("SELECT * FROM ? ORDER BY ? ?", (self.table_name, sort_key, sort_order)) + rows = self.cursor.fetchall() + + self.cursor.execute("SELECT COUNT(*) FROM items") + count = self.cursor.fetchone()[0] + + return {"count": count, "result": rows} + + def count(self, **kwargs) -> int: + self.cursor.execute("SELECT COUNT(*) FROM ?", (self.table_name,)) + return self.cursor.fetchone()[0] + + def delete(self, id: str) -> bool: + super().cursor.execute("DELETE FROM ? WHERE id = ?", (super().table_name, id))