Skip to content

Commit

Permalink
moved get and list clients from mongostatestore
Browse files Browse the repository at this point in the history
  • Loading branch information
niklastheman committed Dec 28, 2024
1 parent 08f8583 commit fab412a
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 238 deletions.
52 changes: 0 additions & 52 deletions fedn/network/api/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,33 +48,6 @@ def _allowed_file_extension(self, filename, ALLOWED_EXTENSIONS={"gz", "bz2", "ta

return (False, None)

def get_clients(self, limit=None, skip=None, status=False):
"""Get all clients from the statestore.
:return: All clients as a json response.
:rtype: :class:`flask.Response`
"""
# Will return list of ObjectId
response = self.statestore.list_clients(limit, skip, status)

arr = []

for element in response["result"]:
obj = {
"id": element["name"],
"combiner": element["combiner"],
"combiner_preferred": element["combiner_preferred"],
"ip": element["ip"],
"status": element["status"],
"last_seen": element["last_seen"] if "last_seen" in element else "",
}

arr.append(obj)

result = {"result": arr, "count": response["count"]}

return jsonify(result)

def download_compute_package(self, name):
"""Download the compute package.
Expand Down Expand Up @@ -339,28 +312,3 @@ def get_client_config(self, checksum=True):
if success:
payload["checksum"] = checksum_str
return jsonify(payload)

def list_combiners_data(self, combiners):
"""Get combiners data.
:param combiners: The combiners to get data for.
:type combiners: list
:return: The combiners data as json response.
:rtype: :py:class:`flask.Response`
"""
response = self.statestore.list_combiners_data(combiners)

arr = []

# order list by combiner name
for element in response:
obj = {
"combiner": element["_id"],
"count": element["count"],
}

arr.append(obj)

result = {"result": arr}

return jsonify(result)
17 changes: 10 additions & 7 deletions fedn/network/api/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,19 @@ def add_client(self, client):
logger.info("adding client {}".format(client["client_id"]))
self.client_store.upsert(client)

def get_client(self, name):
"""Get client by name.
def get_client(self, client_id: str):
"""Get client by client_id.
:param name: name of client
:type name: str
:param client_id: client_id of client
:type client_id: str
:return: The client instance object
:rtype: ObjectId
"""
ret = self.statestore.get_client(name)
return ret
try:
client = self.client_store.get(client_id)
return client
except Exception:
return None

def update_client_data(self, client_data, status, role):
"""Update client status in statestore.
Expand All @@ -129,4 +132,4 @@ def get_client_info(self):
:return: list of client objects
:rtype: list(ObjectId)
"""
return self.statestore.list_clients()
return self.client_store.list(limit=0, skip=0, sort_key=None)
43 changes: 0 additions & 43 deletions fedn/network/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,24 +90,6 @@ def delete_model_trail():
app.add_url_rule(f"{custom_url_prefix}/delete_model_trail", view_func=delete_model_trail, methods=["GET", "POST"])


@app.route("/list_clients", methods=["GET"])
@jwt_auth_required(role="admin")
def list_clients():
"""Get all clients from the statestore.
return: All clients as a json object.
rtype: json
"""
limit = request.args.get("limit", None)
skip = request.args.get("skip", None)
status = request.args.get("status", None)

return api.get_clients(limit, skip, status)


if custom_url_prefix:
app.add_url_rule(f"{custom_url_prefix}/list_clients", view_func=list_clients, methods=["GET"])


@app.route("/get_active_clients", methods=["GET"])
@jwt_auth_required(role="admin")
def get_active_clients():
Expand Down Expand Up @@ -271,31 +253,6 @@ def add_client():
app.add_url_rule(f"{custom_url_prefix}/add_client", view_func=add_client, methods=["POST"])


@app.route("/list_combiners_data", methods=["POST"])
@jwt_auth_required(role="admin")
def list_combiners_data():
"""List data from combiners.
return: The response from control.
rtype: json
"""
json_data = request.get_json()

# expects a list of combiner names (strings) in an array
combiners = json_data.get("combiners", None)

try:
response = api.list_combiners_data(combiners)
except TypeError:
return jsonify({"success": False, "message": "Invalid data provided"}), 400
except Exception:
return jsonify({"success": False, "message": "An unexpected error occurred"}), 500
return response


if custom_url_prefix:
app.add_url_rule(f"{custom_url_prefix}/list_combiners_data", view_func=list_combiners_data, methods=["POST"])


def start_server_api():
config = get_controller_config()
port = config["port"]
Expand Down
13 changes: 0 additions & 13 deletions fedn/network/api/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,6 @@ def test_set_initial_model(self):
# Assert api.set_initial_model was called
fedn.network.api.server.api.set_initial_model.assert_called_once()

def test_list_clients(self):
""" Test list_clients endpoint. """
# Mock api.get_all_clients
return_value = {"test": "test"}
fedn.network.api.server.api.get_all_clients = MagicMock(return_value=return_value)
# Make request
response = self.app.get('/list_clients')
# Assert response
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json, return_value)
# Assert api.get_all_clients was called
fedn.network.api.server.api.get_all_clients.assert_called_once_with()

def test_get_active_clients(self):
""" Test get_active_clients endpoint. """
# Mock api.get_active_clients
Expand Down
123 changes: 0 additions & 123 deletions fedn/network/storage/statestore/mongostatestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,129 +186,6 @@ def set_storage_backend(self, config):
config["status"] = "enabled"
self.storage.update_one({"storage_type": config["storage_type"]}, {"$set": config}, True)

def get_client(self, client_id):
"""Get client by client_id.
:param client_id: client_id of client to get.
:type client_id: str
:return: The client. None if not found.
:rtype: ObjectId
"""
try:
ret = self.clients.find({"key": client_id})
if list(ret) == []:
return None
else:
return ret
except Exception:
return None

def list_clients(self, limit=None, skip=None, status=None, sort_key="last_seen", sort_order=pymongo.DESCENDING):
"""List all clients registered on the network.
:param limit: The maximum number of clients to return.
:type limit: int
:param skip: The number of clients to skip.
:type skip: int
:param status: online | offline
:type status: str
:param sort_key: The key to sort by.
"""
result = None
count = None

try:
find = {} if status is None else {"status": status}
projection = {"_id": False, "updated_at": False}

if limit is not None and skip is not None:
limit = int(limit)
skip = int(skip)
result = self.clients.find(find, projection).limit(limit).skip(skip).sort(sort_key, sort_order)
else:
result = self.clients.find(find, projection).sort(sort_key, sort_order)

count = self.clients.count_documents(find)

except Exception as e:
logger.error("{}".format(e))

return {
"result": result,
"count": count,
}

def list_combiners_data(self, combiners, sort_key="count", sort_order=pymongo.DESCENDING):
"""List all combiner data.
:param combiners: list of combiners to get data for.
:type combiners: list
:param sort_key: The key to sort by.
:type sort_key: str
:param sort_order: The sort order.
:type sort_order: pymongo.ASCENDING or pymongo.DESCENDING
:return: list of combiner data.
:rtype: list(ObjectId)
"""
result = None

try:
pipeline = (
[
{"$match": {"combiner": {"$in": combiners}, "status": "online"}},
{"$group": {"_id": "$combiner", "count": {"$sum": 1}}},
{"$sort": {sort_key: sort_order, "_id": pymongo.ASCENDING}},
]
if combiners is not None
else [{"$group": {"_id": "$combiner", "count": {"$sum": 1}}}, {"$sort": {sort_key: sort_order, "_id": pymongo.ASCENDING}}]
)

result = self.clients.aggregate(pipeline)

except Exception as e:
logger.error(e)

return result

def report_status(self, msg):
"""Write status message to the database.
:param msg: The status message.
:type msg: str
"""
data = MessageToDict(msg)

if self.status is not None:
self.status.insert_one(data)

def report_validation(self, validation):
"""Write model validation to database.
:param validation: The model validation.
:type validation: dict
"""
data = MessageToDict(validation)

if self.validations is not None:
self.validations.insert_one(data)

def drop_status(self):
"""Drop the status collection."""
if self.status:
self.status.drop()

def create_session(self, id=None):
"""Create a new session object.
:param id: The ID of the created session.
:type id: uuid, str
"""
if not id:
id = uuid.uuid4()
data = {"session_id": str(id)}
self.sessions.insert_one(data)

def create_round(self, round_data):
"""Create a new round.
Expand Down

0 comments on commit fab412a

Please sign in to comment.