From a0645067570744bc839978b8b57044b4310115e8 Mon Sep 17 00:00:00 2001
From: Niklas <niklas@scaleoutsystems.com>
Date: Wed, 18 Oct 2023 15:17:06 +0200
Subject: [PATCH] get_models added to api

---
 fedn/fedn/network/api/interface.py            |  19 +-
 fedn/fedn/network/api/server.py               |  20 ++
 fedn/fedn/network/controller/control.py       | 237 +++++++++++-------
 fedn/fedn/network/controller/controlbase.py   | 124 +++++----
 .../network/statestore/mongostatestore.py     |  51 +++-
 5 files changed, 315 insertions(+), 136 deletions(-)

diff --git a/fedn/fedn/network/api/interface.py b/fedn/fedn/network/api/interface.py
index a86f0feb8..40792da81 100644
--- a/fedn/fedn/network/api/interface.py
+++ b/fedn/fedn/network/api/interface.py
@@ -180,7 +180,6 @@ def get_all_sessions(self, limit=None, skip=None):
         result = {"result": payload, "count": sessions_object["count"]}
 
         return jsonify(result)
-        return jsonify(payload)
 
     def get_session(self, session_id):
         """Get a session from the statestore.
@@ -632,6 +631,24 @@ def get_latest_model(self):
                 {"success": False, "message": "No initial model set."}
             )
 
+    def get_models(self, session_id=None, limit=None, skip=None):
+        result = self.statestore.list_models(session_id, limit, skip)
+
+        if result is None:
+            return (
+                jsonify({"success": False, "message": "No models found."}),
+                404,
+            )
+
+        json_docs = []
+        for doc in result["result"]:
+            json_doc = json.dumps(doc, default=json_util.default)
+            json_docs.append(json_doc)
+
+        json_docs.reverse()
+
+        return jsonify({"result": json_docs, "count": result["count"]})
+
     def get_model_trail(self):
         """Get the model trail for a given session.
 
diff --git a/fedn/fedn/network/api/server.py b/fedn/fedn/network/api/server.py
index 2bdd7aae3..0c49343f8 100644
--- a/fedn/fedn/network/api/server.py
+++ b/fedn/fedn/network/api/server.py
@@ -32,6 +32,26 @@ def get_model_trail():
     return api.get_model_trail()
 
 
+@app.route("/list_models", methods=["GET"])
+def list_models():
+    """Get models from the statestore.
+    param:
+    session_id: The session id to get the model trail for.
+    limit: The maximum number of models to return.
+    type: limit: int
+    param: skip: The number of models to skip.
+    type: skip: int
+    Returns:
+        _type_: json
+    """
+
+    session_id = request.args.get("session_id", None)
+    limit = request.args.get("limit", None)
+    skip = request.args.get("skip", None)
+
+    return api.get_models(session_id, limit, skip)
+
+
 @app.route("/delete_model_trail", methods=["GET", "POST"])
 def delete_model_trail():
     """Delete the model trail for a given session.
diff --git a/fedn/fedn/network/controller/control.py b/fedn/fedn/network/controller/control.py
index 5f5bc6634..a8e32333d 100644
--- a/fedn/fedn/network/controller/control.py
+++ b/fedn/fedn/network/controller/control.py
@@ -9,10 +9,10 @@
 
 
 class UnsupportedStorageBackend(Exception):
-    """ Exception class for when storage backend is not supported. Passes """
+    """Exception class for when storage backend is not supported. Passes"""
 
     def __init__(self, message):
-        """ Constructor method.
+        """Constructor method.
 
         :param message: The exception message.
         :type message: str
@@ -23,46 +23,46 @@ def __init__(self, message):
 
 
 class MisconfiguredStorageBackend(Exception):
-    """ Exception class for when storage backend is misconfigured.
+    """Exception class for when storage backend is misconfigured.
 
     :param message: The exception message.
     :type message: str
     """
 
     def __init__(self, message):
-        """ Constructor method."""
+        """Constructor method."""
         self.message = message
         super().__init__(self.message)
 
 
 class NoModelException(Exception):
-    """ Exception class for when model is None
+    """Exception class for when model is None
 
     :param message: The exception message.
     :type message: str
     """
 
     def __init__(self, message):
-        """ Constructor method."""
+        """Constructor method."""
         self.message = message
         super().__init__(self.message)
 
 
 class Control(ControlBase):
-    """ Controller, implementing the overall global training, validation and inference logic.
+    """Controller, implementing the overall global training, validation and inference logic.
 
     :param statestore: A StateStorage instance.
     :type statestore: class: `fedn.network.statestorebase.StateStorageBase`
     """
 
     def __init__(self, statestore):
-        """ Constructor method."""
+        """Constructor method."""
 
         super().__init__(statestore)
         self.name = "DefaultControl"
 
     def session(self, config):
-        """ Execute a new training session. A session consists of one
+        """Execute a new training session. A session consists of one
             or several global rounds. All rounds in the same session
             have the same round_config.
 
@@ -72,7 +72,10 @@ def session(self, config):
         """
 
         if self._state == ReducerState.instructing:
-            print("Controller already in INSTRUCTING state. A session is in progress.", flush=True)
+            print(
+                "Controller already in INSTRUCTING state. A session is in progress.",
+                flush=True,
+            )
             return
 
         if not self.statestore.get_latest_model():
@@ -82,11 +85,16 @@ def session(self, config):
         self._state = ReducerState.instructing
 
         # Must be called to set info in the db
-        config['committed_at'] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+        config["committed_at"] = datetime.datetime.now().strftime(
+            "%Y-%m-%d %H:%M:%S"
+        )
         self.new_session(config)
 
         if not self.statestore.get_latest_model():
-            print("No model in model chain, please provide a seed model!", flush=True)
+            print(
+                "No model in model chain, please provide a seed model!",
+                flush=True,
+            )
         self._state = ReducerState.monitoring
 
         last_round = int(self.get_latest_round_id())
@@ -96,7 +104,7 @@ def session(self, config):
             combiner.flush_model_update_queue()
 
         # Execute the rounds in this session
-        for round in range(1, int(config['rounds'] + 1)):
+        for round in range(1, int(config["rounds"] + 1)):
             # Increment the round number
 
             if last_round:
@@ -107,10 +115,17 @@ def session(self, config):
             try:
                 _, round_data = self.round(config, current_round)
             except TypeError as e:
-                print("Could not unpack data from round: {0}".format(e), flush=True)
-
-            print("CONTROL: Round completed with status {}".format(
-                round_data['status']), flush=True)
+                print(
+                    "Could not unpack data from round: {0}".format(e),
+                    flush=True,
+                )
+
+            print(
+                "CONTROL: Round completed with status {}".format(
+                    round_data["status"]
+                ),
+                flush=True,
+            )
 
             self.tracer.set_round_data(round_data)
 
@@ -118,7 +133,7 @@ def session(self, config):
         self._state = ReducerState.idle
 
     def round(self, session_config, round_id):
-        """ Execute a single global round.
+        """Execute a single global round.
 
         :param session_config: The session config.
         :type session_config: dict
@@ -126,35 +141,42 @@ def round(self, session_config, round_id):
         :type round_id: str(int)
         """
 
-        round_data = {'round_id': round_id}
+        round_data = {"round_id": round_id}
 
         if len(self.network.get_combiners()) < 1:
             print("REDUCER: No combiners connected!", flush=True)
-            round_data['status'] = 'Failed'
+            round_data["status"] = "Failed"
             return None, round_data
 
         # 1. Assemble round config for this global round,
         # and check which combiners are able to participate
         # in the round.
         round_config = copy.deepcopy(session_config)
-        round_config['rounds'] = 1
-        round_config['round_id'] = round_id
-        round_config['task'] = 'training'
-        round_config['model_id'] = self.statestore.get_latest_model()
-        round_config['helper_type'] = self.statestore.get_helper()
+        round_config["rounds"] = 1
+        round_config["round_id"] = round_id
+        round_config["task"] = "training"
+        round_config["model_id"] = self.statestore.get_latest_model()
+        round_config["helper_type"] = self.statestore.get_helper()
 
         combiners = self.get_participating_combiners(round_config)
         round_start = self.evaluate_round_start_policy(combiners)
 
         if round_start:
-            print("CONTROL: round start policy met, participating combiners {}".format(
-                combiners), flush=True)
+            print(
+                "CONTROL: round start policy met, participating combiners {}".format(
+                    combiners
+                ),
+                flush=True,
+            )
         else:
-            print("CONTROL: Round start policy not met, skipping round!", flush=True)
-            round_data['status'] = 'Failed'
+            print(
+                "CONTROL: Round start policy not met, skipping round!",
+                flush=True,
+            )
+            round_data["status"] = "Failed"
             return None
 
-        round_data['round_config'] = round_config
+        round_data["round_config"] = round_config
 
         # 2. Ask participating combiners to coordinate model updates
         _ = self.request_model_updates(combiners)
@@ -164,27 +186,37 @@ def round(self, session_config, round_id):
         # dict to store combiners that have successfully produced an updated model
         updated = {}
         # wait until all combiners have produced an updated model or until round timeout
-        print("CONTROL: Fetching round config (ID: {round_id}) from statestore:".format(
-            round_id=round_id), flush=True)
+        print(
+            "CONTROL: Fetching round config (ID: {round_id}) from statestore:".format(
+                round_id=round_id
+            ),
+            flush=True,
+        )
         while len(updated) < len(combiners):
             round = self.statestore.get_round(round_id)
             if round:
                 print("CONTROL: Round found!", flush=True)
                 # For each combiner in the round, check if it has produced an updated model (status == 'Success')
-                for combiner in round['combiners']:
+                for combiner in round["combiners"]:
                     print(combiner, flush=True)
-                    if combiner['status'] == 'Success':
-                        if combiner['name'] not in updated.keys():
+                    if combiner["status"] == "Success":
+                        if combiner["name"] not in updated.keys():
                             # Add combiner to updated dict
-                            updated[combiner['name']] = combiner['model_id']
+                            updated[combiner["name"]] = combiner["model_id"]
                     # Print combiner status
-                    print("CONTROL: Combiner {name} status: {status}".format(
-                        name=combiner['name'], status=combiner['status']), flush=True)
+                    print(
+                        "CONTROL: Combiner {name} status: {status}".format(
+                            name=combiner["name"], status=combiner["status"]
+                        ),
+                        flush=True,
+                    )
             else:
                 # Print every 10 seconds based on value of wait
                 if wait % 10 == 0:
-                    print("CONTROL: Waiting for round to complete...", flush=True)
-            if wait >= session_config['round_timeout']:
+                    print(
+                        "CONTROL: Waiting for round to complete...", flush=True
+                    )
+            if wait >= session_config["round_timeout"]:
                 print("CONTROL: Round timeout! Exiting round...", flush=True)
                 break
             # Update wait time used for timeout
@@ -194,53 +226,77 @@ def round(self, session_config, round_id):
         round_valid = self.evaluate_round_validity_policy(updated)
         if not round_valid:
             print("REDUCER CONTROL: Round invalid!", flush=True)
-            round_data['status'] = 'Failed'
+            round_data["status"] = "Failed"
             return None, round_data
 
         print("CONTROL: Reducing models from combiners...", flush=True)
         # 3. Reduce combiner models into a global model
         try:
             model, data = self.reduce(updated)
-            round_data['reduce'] = data
+            round_data["reduce"] = data
             print("CONTROL: Done reducing models from combiners!", flush=True)
         except Exception as e:
-            print("CONTROL: Failed to reduce models from combiners: {}".format(
-                e), flush=True)
-            round_data['status'] = 'Failed'
+            print(
+                "CONTROL: Failed to reduce models from combiners: {}".format(
+                    e
+                ),
+                flush=True,
+            )
+            round_data["status"] = "Failed"
             return None, round_data
 
         # 6. Commit the global model to model trail
         if model is not None:
-            print("CONTROL: Committing global model to model trail...", flush=True)
+            print(
+                "CONTROL: Committing global model to model trail...",
+                flush=True,
+            )
             tic = time.time()
             model_id = uuid.uuid4()
-            self.commit(model_id, model)
-            round_data['time_commit'] = time.time() - tic
-            print("CONTROL: Done committing global model to model trail!", flush=True)
+            session_id = (
+                session_config["session_id"]
+                if "session_id" in session_config
+                else None
+            )
+            self.commit(model_id, model, session_id)
+            round_data["time_commit"] = time.time() - tic
+            print(
+                "CONTROL: Done committing global model to model trail!",
+                flush=True,
+            )
         else:
-            print("REDUCER: failed to update model in round with config {}".format(
-                session_config), flush=True)
-            round_data['status'] = 'Failed'
+            print(
+                "REDUCER: failed to update model in round with config {}".format(
+                    session_config
+                ),
+                flush=True,
+            )
+            round_data["status"] = "Failed"
             return None, round_data
 
-        round_data['status'] = 'Success'
+        round_data["status"] = "Success"
 
         # 4. Trigger participating combiner nodes to execute a validation round for the current model
-        validate = session_config['validate']
+        validate = session_config["validate"]
         if validate:
             combiner_config = copy.deepcopy(session_config)
-            combiner_config['round_id'] = round_id
-            combiner_config['model_id'] = self.statestore.get_latest_model()
-            combiner_config['task'] = 'validation'
-            combiner_config['helper_type'] = self.statestore.get_helper()
+            combiner_config["round_id"] = round_id
+            combiner_config["model_id"] = self.statestore.get_latest_model()
+            combiner_config["task"] = "validation"
+            combiner_config["helper_type"] = self.statestore.get_helper()
 
             validating_combiners = self._select_participating_combiners(
-                combiner_config)
+                combiner_config
+            )
 
             for combiner, combiner_config in validating_combiners:
                 try:
-                    print("CONTROL: Submitting validation round to combiner {}".format(
-                        combiner), flush=True)
+                    print(
+                        "CONTROL: Submitting validation round to combiner {}".format(
+                            combiner
+                        ),
+                        flush=True,
+                    )
                     combiner.submit(combiner_config)
                 except CombinerUnavailableError:
                     self._handle_unavailable_combiner(combiner)
@@ -249,16 +305,16 @@ def round(self, session_config, round_id):
         return model_id, round_data
 
     def reduce(self, combiners):
-        """ Combine updated models from Combiner nodes into one global model.
+        """Combine updated models from Combiner nodes into one global model.
 
         :param combiners: dict of combiner names (key) and model IDs (value) to reduce
         :type combiners: dict
         """
 
         meta = {}
-        meta['time_fetch_model'] = 0.0
-        meta['time_load_model'] = 0.0
-        meta['time_aggregate_model'] = 0.0
+        meta["time_fetch_model"] = 0.0
+        meta["time_load_model"] = 0.0
+        meta["time_aggregate_model"] = 0.0
 
         i = 1
         model = None
@@ -268,18 +324,25 @@ def reduce(self, combiners):
             return model, meta
 
         for name, model_id in combiners.items():
-
             # TODO: Handle inactive RPC error in get_model and raise specific error
-            print("REDUCER: Fetching model ({model_id}) from combiner {name}".format(
-                model_id=model_id, name=name), flush=True)
+            print(
+                "REDUCER: Fetching model ({model_id}) from combiner {name}".format(
+                    model_id=model_id, name=name
+                ),
+                flush=True,
+            )
             try:
                 tic = time.time()
                 combiner = self.get_combiner(name)
                 data = combiner.get_model(model_id)
-                meta['time_fetch_model'] += (time.time() - tic)
+                meta["time_fetch_model"] += time.time() - tic
             except Exception as e:
-                print("REDUCER: Failed to fetch model from combiner {}: {}".format(
-                    name, e), flush=True)
+                print(
+                    "REDUCER: Failed to fetch model from combiner {}: {}".format(
+                        name, e
+                    ),
+                    flush=True,
+                )
                 data = None
 
             if data is not None:
@@ -288,21 +351,21 @@ def reduce(self, combiners):
                     helper = self.get_helper()
                     data.seek(0)
                     model_next = helper.load(data)
-                    meta['time_load_model'] += (time.time() - tic)
+                    meta["time_load_model"] += time.time() - tic
                     tic = time.time()
                     model = helper.increment_average(model, model_next, i, i)
-                    meta['time_aggregate_model'] += (time.time() - tic)
+                    meta["time_aggregate_model"] += time.time() - tic
                 except Exception:
                     tic = time.time()
                     data.seek(0)
                     model = helper.load(data)
-                    meta['time_aggregate_model'] += (time.time() - tic)
+                    meta["time_aggregate_model"] += time.time() - tic
                 i = i + 1
 
         return model, meta
 
     def infer_instruct(self, config):
-        """ Main entrypoint for executing the inference compute plan.
+        """Main entrypoint for executing the inference compute plan.
 
         :param config: configuration for the inference round
         """
@@ -330,7 +393,7 @@ def infer_instruct(self, config):
         self.__state = ReducerState.idle
 
     def inference_round(self, config):
-        """ Execute an inference round.
+        """Execute an inference round.
 
         :param config: configuration for the inference round
         """
@@ -345,21 +408,27 @@ def inference_round(self, config):
 
         # Setup combiner configuration
         combiner_config = copy.deepcopy(config)
-        combiner_config['model_id'] = self.statestore.get_latest_model()
-        combiner_config['task'] = 'inference'
-        combiner_config['helper_type'] = self.statestore.get_framework()
+        combiner_config["model_id"] = self.statestore.get_latest_model()
+        combiner_config["task"] = "inference"
+        combiner_config["helper_type"] = self.statestore.get_framework()
 
         # Select combiners
-        validating_combiners = self._select_round_combiners(
-            combiner_config)
+        validating_combiners = self._select_round_combiners(combiner_config)
 
         # Test round start policy
         round_start = self.check_round_start_policy(validating_combiners)
         if round_start:
-            print("CONTROL: round start policy met, participating combiners {}".format(
-                validating_combiners), flush=True)
+            print(
+                "CONTROL: round start policy met, participating combiners {}".format(
+                    validating_combiners
+                ),
+                flush=True,
+            )
         else:
-            print("CONTROL: Round start policy not met, skipping round!", flush=True)
+            print(
+                "CONTROL: Round start policy not met, skipping round!",
+                flush=True,
+            )
             return None
 
         # Synch combiners with latest model and trigger inference
diff --git a/fedn/fedn/network/controller/controlbase.py b/fedn/fedn/network/controller/controlbase.py
index e38d31e38..077620c14 100644
--- a/fedn/fedn/network/controller/controlbase.py
+++ b/fedn/fedn/network/controller/controlbase.py
@@ -11,7 +11,7 @@
 from fedn.network.state import ReducerState
 
 # Maximum number of tries to connect to statestore and retrieve storage configuration
-MAX_TRIES_BACKEND = os.getenv('MAX_TRIES_BACKEND', 10)
+MAX_TRIES_BACKEND = os.getenv("MAX_TRIES_BACKEND", 10)
 
 
 class UnsupportedStorageBackend(Exception):
@@ -27,7 +27,7 @@ class MisconfiguredHelper(Exception):
 
 
 class ControlBase(ABC):
-    """ Base class and interface for a global controller.
+    """Base class and interface for a global controller.
         Override this class to implement a global training strategy (control).
 
     :param statestore: The statestore object.
@@ -36,7 +36,7 @@ class ControlBase(ABC):
 
     @abstractmethod
     def __init__(self, statestore):
-        """ Constructor. """
+        """Constructor."""
         self._state = ReducerState.setup
 
         self.statestore = statestore
@@ -52,26 +52,36 @@ def __init__(self, statestore):
                     not_ready = False
                 else:
                     print(
-                        "REDUCER CONTROL: Storage backend not configured, waiting...", flush=True)
+                        "REDUCER CONTROL: Storage backend not configured, waiting...",
+                        flush=True,
+                    )
                     sleep(5)
                     tries += 1
                     if tries > MAX_TRIES_BACKEND:
                         raise Exception
         except Exception:
             print(
-                "REDUCER CONTROL: Failed to retrive storage configuration, exiting.", flush=True)
+                "REDUCER CONTROL: Failed to retrive storage configuration, exiting.",
+                flush=True,
+            )
             raise MisconfiguredStorageBackend()
 
-        if storage_config['storage_type'] == 'S3':
-            self.model_repository = S3ModelRepository(storage_config['storage_config'])
+        if storage_config["storage_type"] == "S3":
+            self.model_repository = S3ModelRepository(
+                storage_config["storage_config"]
+            )
         else:
-            print("REDUCER CONTROL: Unsupported storage backend, exiting.", flush=True)
+            print(
+                "REDUCER CONTROL: Unsupported storage backend, exiting.",
+                flush=True,
+            )
             raise UnsupportedStorageBackend()
 
         # The tracer is a helper that manages state in the database backend
         statestore_config = statestore.get_config()
         self.tracer = MongoTracer(
-            statestore_config['mongo_config'], statestore_config['network_id'])
+            statestore_config["mongo_config"], statestore_config["network_id"]
+        )
 
         if self.statestore.is_inited():
             self._state = ReducerState.idle
@@ -89,7 +99,7 @@ def reduce(self, combiners):
         pass
 
     def get_helper(self):
-        """ Get a helper instance from global config.
+        """Get a helper instance from global config.
 
         :return: Helper instance.
         :rtype: :class:`fedn.utils.plugins.helperbase.HelperBase`
@@ -97,11 +107,15 @@ def get_helper(self):
         helper_type = self.statestore.get_helper()
         helper = fedn.utils.helpers.get_helper(helper_type)
         if not helper:
-            raise MisconfiguredHelper("Unsupported helper type {}, please configure compute_package.helper !".format(helper_type))
+            raise MisconfiguredHelper(
+                "Unsupported helper type {}, please configure compute_package.helper !".format(
+                    helper_type
+                )
+            )
         return helper
 
     def get_state(self):
-        """ Get the current state of the controller.
+        """Get the current state of the controller.
 
         :return: The current state.
         :rtype: :class:`fedn.network.state.ReducerState`
@@ -109,7 +123,7 @@ def get_state(self):
         return self._state
 
     def idle(self):
-        """ Check if the controller is idle.
+        """Check if the controller is idle.
 
         :return: True if idle, False otherwise.
         :rtype: bool
@@ -139,7 +153,7 @@ def get_latest_round_id(self):
         if not last_round:
             return 0
         else:
-            return last_round['round_id']
+            return last_round["round_id"]
 
     def get_latest_round(self):
         round = self.statestore.get_latest_round()
@@ -153,27 +167,29 @@ def get_compute_package_name(self):
         definition = self.statestore.get_compute_package()
         if definition:
             try:
-                package_name = definition['filename']
+                package_name = definition["filename"]
                 return package_name
             except (IndexError, KeyError):
                 print(
-                    "No context filename set for compute context definition", flush=True)
+                    "No context filename set for compute context definition",
+                    flush=True,
+                )
                 return None
         else:
             return None
 
     def set_compute_package(self, filename, path):
-        """ Persist the configuration for the compute package. """
+        """Persist the configuration for the compute package."""
         self.model_repository.set_compute_package(filename, path)
         self.statestore.set_compute_package(filename)
 
-    def get_compute_package(self, compute_package=''):
+    def get_compute_package(self, compute_package=""):
         """
 
         :param compute_package:
         :return:
         """
-        if compute_package == '':
+        if compute_package == "":
             compute_package = self.get_compute_package_name()
         if compute_package:
             return self.model_repository.get_compute_package(compute_package)
@@ -181,42 +197,49 @@ def get_compute_package(self, compute_package=''):
             return None
 
     def new_session(self, config):
-        """ Initialize a new session in backend db. """
+        """Initialize a new session in backend db."""
 
         if "session_id" not in config.keys():
             session_id = uuid.uuid4()
-            config['session_id'] = str(session_id)
+            config["session_id"] = str(session_id)
         else:
-            session_id = config['session_id']
+            session_id = config["session_id"]
 
         self.tracer.new_session(id=session_id)
         self.tracer.set_session_config(session_id, config)
 
     def request_model_updates(self, combiners):
-        """Call Combiner server RPC to get a model update. """
+        """Call Combiner server RPC to get a model update."""
         cl = []
         for combiner, combiner_round_config in combiners:
             response = combiner.submit(combiner_round_config)
             cl.append((combiner, response))
         return cl
 
-    def commit(self, model_id, model=None):
-        """ Commit a model to the global model trail. The model commited becomes the lastest consensus model. """
+    def commit(self, model_id, model=None, session_id=None):
+        """Commit a model to the global model trail. The model commited becomes the lastest consensus model."""
 
         helper = self.get_helper()
         if model is not None:
-            print("CONTROL: Saving model file temporarily to disk...", flush=True)
+            print(
+                "CONTROL: Saving model file temporarily to disk...", flush=True
+            )
             outfile_name = helper.save(model)
             print("CONTROL: Uploading model to Minio...", flush=True)
             model_id = self.model_repository.set_model(
-                outfile_name, is_file=True)
+                outfile_name, is_file=True
+            )
 
             print("CONTROL: Deleting temporary model file...", flush=True)
             os.unlink(outfile_name)
 
-        print("CONTROL: Committing model {} to global model trail in statestore...".format(
-            model_id), flush=True)
-        self.statestore.set_latest_model(model_id)
+        print(
+            "CONTROL: Committing model {} to global model trail in statestore...".format(
+                model_id
+            ),
+            flush=True,
+        )
+        self.statestore.set_latest_model(model_id, session_id)
 
     def get_combiner(self, name):
         for combiner in self.network.get_combiners():
@@ -226,7 +249,7 @@ def get_combiner(self, name):
 
     def get_participating_combiners(self, combiner_round_config):
         """Assemble a list of combiners able to participate in a round as
-           descibed by combiner_round_config.
+        descibed by combiner_round_config.
         """
         combiners = []
         for combiner in self.network.get_combiners():
@@ -238,45 +261,47 @@ def get_participating_combiners(self, combiner_round_config):
 
             if combiner_state is not None:
                 is_participating = self.evaluate_round_participation_policy(
-                    combiner_round_config, combiner_state)
+                    combiner_round_config, combiner_state
+                )
                 if is_participating:
                     combiners.append((combiner, combiner_round_config))
         return combiners
 
-    def evaluate_round_participation_policy(self, compute_plan, combiner_state):
-        """ Evaluate policy for combiner round-participation.
-            A combiner participates if it is responsive and reports enough
-            active clients to participate in the round.
+    def evaluate_round_participation_policy(
+        self, compute_plan, combiner_state
+    ):
+        """Evaluate policy for combiner round-participation.
+        A combiner participates if it is responsive and reports enough
+        active clients to participate in the round.
         """
 
-        if compute_plan['task'] == 'training':
-            nr_active_clients = int(combiner_state['nr_active_trainers'])
-        elif compute_plan['task'] == 'validation':
-            nr_active_clients = int(combiner_state['nr_active_validators'])
+        if compute_plan["task"] == "training":
+            nr_active_clients = int(combiner_state["nr_active_trainers"])
+        elif compute_plan["task"] == "validation":
+            nr_active_clients = int(combiner_state["nr_active_validators"])
         else:
             print("Invalid task type!", flush=True)
             return False
 
-        if int(compute_plan['clients_required']) <= nr_active_clients:
+        if int(compute_plan["clients_required"]) <= nr_active_clients:
             return True
         else:
             return False
 
     def evaluate_round_start_policy(self, combiners):
-        """ Check if the policy to start a round is met. """
+        """Check if the policy to start a round is met."""
         if len(combiners) > 0:
-
             return True
         else:
             return False
 
     def evaluate_round_validity_policy(self, combiners):
-        """ Check if the round should be seen as valid.
+        """Check if the round should be seen as valid.
 
-            At the end of the round, before committing a model to the global model trail,
-            we check if the round validity policy has been met. This can involve
-            e.g. asserting that a certain number of combiners have reported in an
-            updated model, or that criteria on model performance have been met.
+        At the end of the round, before committing a model to the global model trail,
+        we check if the round validity policy has been met. This can involve
+        e.g. asserting that a certain number of combiners have reported in an
+        updated model, or that criteria on model performance have been met.
         """
         if combiners.keys() == []:
             return False
@@ -294,7 +319,8 @@ def _select_participating_combiners(self, compute_plan):
 
             if combiner_state:
                 is_participating = self.evaluate_round_participation_policy(
-                    compute_plan, combiner_state)
+                    compute_plan, combiner_state
+                )
                 if is_participating:
                     participating_combiners.append((combiner, compute_plan))
         return participating_combiners
diff --git a/fedn/fedn/network/statestore/mongostatestore.py b/fedn/fedn/network/statestore/mongostatestore.py
index 903d1a208..9262e9f13 100644
--- a/fedn/fedn/network/statestore/mongostatestore.py
+++ b/fedn/fedn/network/statestore/mongostatestore.py
@@ -145,7 +145,7 @@ def get_session(self, session_id):
         """
         return self.sessions.find_one({"session_id": session_id})
 
-    def set_latest_model(self, model_id):
+    def set_latest_model(self, model_id, session_id=None):
         """Set the latest model id.
 
         :param model_id: The model id.
@@ -153,6 +153,17 @@ def set_latest_model(self, model_id):
         :return:
         """
 
+        commited_at = str(datetime.now())
+
+        self.model.insert_one(
+            {
+                "key": "models",
+                "model": model_id,
+                "session_id": session_id,
+                "committed_at": commited_at,
+            }
+        )
+
         self.model.update_one(
             {"key": "current_model"}, {"$set": {"model": model_id}}, True
         )
@@ -161,7 +172,7 @@ def set_latest_model(self, model_id):
             {
                 "$push": {
                     "model": model_id,
-                    "committed_at": str(datetime.now()),
+                    "committed_at": commited_at,
                 }
             },
             True,
@@ -326,6 +337,42 @@ def get_helper(self):
         except (KeyError, IndexError):
             return None
 
+    def list_models(self, session_id=None, limit=None, skip=None):
+        """List all models in the statestore.
+
+        :param session_id: The session id.
+        :type session_id: str
+        :param limit: The maximum number of models to return.
+        :type limit: int
+        :param skip: The number of models to skip.
+        :type skip: int
+        :return: List of models.
+        :rtype: list
+        """
+        result = None
+
+        find_option = (
+            {"key": "models"}
+            if session_id is None
+            else {"key": "models", "session_id": session_id}
+        )
+
+        if limit is not None and skip is not None:
+            limit = int(limit)
+            skip = int(skip)
+
+            result = self.model.find(find_option).limit(limit).skip(skip)
+
+        else:
+            result = self.model.find(find_option)
+
+        count = self.model.count_documents({})
+
+        return {
+            "result": result,
+            "count": count,
+        }
+
     def get_model_trail(self):
         """Get the model trail.