Skip to content

Commit

Permalink
set active model
Browse files Browse the repository at this point in the history
  • Loading branch information
niklastheman committed Dec 7, 2023
1 parent c3f3b60 commit 33c4c54
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 0 deletions.
23 changes: 23 additions & 0 deletions fedn/fedn/network/api/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,29 @@ def get_latest_model(self):
{"success": False, "message": "No initial model set."}
)

def set_active_model(self, model_id: str):
"""Set the active model in the statestore.
:param model_id: The model id to set.
:type model_id: str
:return: A json response with success or failure message.
:rtype: :class:`flask.Response`
"""
success = self.statestore.set_active_model(model_id)

if not success:
return (
jsonify(
{
"success": False,
"message": "Failed to set active model.",
}
),
400,
)

return jsonify({"success": True, "message": "Active model set."})

def get_models(self, session_id: str = None, limit: str = None, skip: str = None, include_active: str = None):
result = self.statestore.list_models(session_id, limit, skip)

Expand Down
18 changes: 18 additions & 0 deletions fedn/fedn/network/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,24 @@ def get_latest_model():
return api.get_latest_model()


@app.route("/set_active_model", methods=["PUT"])
def set_active_model():
"""Set the initial model in the statestore and upload to model repository.
Usage with curl:
curl -k -X PUT
-F id=<model-id>
http://localhost:8092/set_initial_model
param: id: The model id to set.
type: id: str
return: boolean.
rtype: json
"""
id = request.args.get("id", None)
if id is None:
return jsonify({"success": False, "message": "Missing model id."}), 400
return api.set_active_model(id)

# Get initial model endpoint


Expand Down
26 changes: 26 additions & 0 deletions fedn/fedn/network/statestore/mongostatestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,32 @@ def get_latest_model(self):
except (KeyError, IndexError):
return None

def set_active_model(self, model_id: str):
"""Set the active model in statestore.
:param model_id: The model id.
:type model_id: str
:return:
"""

try:

committed_at = datetime.now()

existing_model = self.model.find_one({"key": "models", "model": model_id})

if existing_model is not None:

self.model.update_one(
{"key": "active_model"}, {"$set": {"model": model_id, "committed_at": committed_at, "session_id": None}}, True
)

return True
except Exception as e:
print("ERROR: {}".format(e), flush=True)

return False

def get_latest_round(self):
"""Get the id of the most recent round.
Expand Down

0 comments on commit 33c4c54

Please sign in to comment.