From a0645067570744bc839978b8b57044b4310115e8 Mon Sep 17 00:00:00 2001 From: Niklas <niklas@scaleoutsystems.com> Date: Wed, 18 Oct 2023 15:17:06 +0200 Subject: [PATCH] get_models added to api --- fedn/fedn/network/api/interface.py | 19 +- fedn/fedn/network/api/server.py | 20 ++ fedn/fedn/network/controller/control.py | 237 +++++++++++------- fedn/fedn/network/controller/controlbase.py | 124 +++++---- .../network/statestore/mongostatestore.py | 51 +++- 5 files changed, 315 insertions(+), 136 deletions(-) diff --git a/fedn/fedn/network/api/interface.py b/fedn/fedn/network/api/interface.py index a86f0feb8..40792da81 100644 --- a/fedn/fedn/network/api/interface.py +++ b/fedn/fedn/network/api/interface.py @@ -180,7 +180,6 @@ def get_all_sessions(self, limit=None, skip=None): result = {"result": payload, "count": sessions_object["count"]} return jsonify(result) - return jsonify(payload) def get_session(self, session_id): """Get a session from the statestore. @@ -632,6 +631,24 @@ def get_latest_model(self): {"success": False, "message": "No initial model set."} ) + def get_models(self, session_id=None, limit=None, skip=None): + result = self.statestore.list_models(session_id, limit, skip) + + if result is None: + return ( + jsonify({"success": False, "message": "No models found."}), + 404, + ) + + json_docs = [] + for doc in result["result"]: + json_doc = json.dumps(doc, default=json_util.default) + json_docs.append(json_doc) + + json_docs.reverse() + + return jsonify({"result": json_docs, "count": result["count"]}) + def get_model_trail(self): """Get the model trail for a given session. diff --git a/fedn/fedn/network/api/server.py b/fedn/fedn/network/api/server.py index 2bdd7aae3..0c49343f8 100644 --- a/fedn/fedn/network/api/server.py +++ b/fedn/fedn/network/api/server.py @@ -32,6 +32,26 @@ def get_model_trail(): return api.get_model_trail() +@app.route("/list_models", methods=["GET"]) +def list_models(): + """Get models from the statestore. + param: + session_id: The session id to get the model trail for. + limit: The maximum number of models to return. + type: limit: int + param: skip: The number of models to skip. + type: skip: int + Returns: + _type_: json + """ + + session_id = request.args.get("session_id", None) + limit = request.args.get("limit", None) + skip = request.args.get("skip", None) + + return api.get_models(session_id, limit, skip) + + @app.route("/delete_model_trail", methods=["GET", "POST"]) def delete_model_trail(): """Delete the model trail for a given session. diff --git a/fedn/fedn/network/controller/control.py b/fedn/fedn/network/controller/control.py index 5f5bc6634..a8e32333d 100644 --- a/fedn/fedn/network/controller/control.py +++ b/fedn/fedn/network/controller/control.py @@ -9,10 +9,10 @@ class UnsupportedStorageBackend(Exception): - """ Exception class for when storage backend is not supported. Passes """ + """Exception class for when storage backend is not supported. Passes""" def __init__(self, message): - """ Constructor method. + """Constructor method. :param message: The exception message. :type message: str @@ -23,46 +23,46 @@ def __init__(self, message): class MisconfiguredStorageBackend(Exception): - """ Exception class for when storage backend is misconfigured. + """Exception class for when storage backend is misconfigured. :param message: The exception message. :type message: str """ def __init__(self, message): - """ Constructor method.""" + """Constructor method.""" self.message = message super().__init__(self.message) class NoModelException(Exception): - """ Exception class for when model is None + """Exception class for when model is None :param message: The exception message. :type message: str """ def __init__(self, message): - """ Constructor method.""" + """Constructor method.""" self.message = message super().__init__(self.message) class Control(ControlBase): - """ Controller, implementing the overall global training, validation and inference logic. + """Controller, implementing the overall global training, validation and inference logic. :param statestore: A StateStorage instance. :type statestore: class: `fedn.network.statestorebase.StateStorageBase` """ def __init__(self, statestore): - """ Constructor method.""" + """Constructor method.""" super().__init__(statestore) self.name = "DefaultControl" def session(self, config): - """ Execute a new training session. A session consists of one + """Execute a new training session. A session consists of one or several global rounds. All rounds in the same session have the same round_config. @@ -72,7 +72,10 @@ def session(self, config): """ if self._state == ReducerState.instructing: - print("Controller already in INSTRUCTING state. A session is in progress.", flush=True) + print( + "Controller already in INSTRUCTING state. A session is in progress.", + flush=True, + ) return if not self.statestore.get_latest_model(): @@ -82,11 +85,16 @@ def session(self, config): self._state = ReducerState.instructing # Must be called to set info in the db - config['committed_at'] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + config["committed_at"] = datetime.datetime.now().strftime( + "%Y-%m-%d %H:%M:%S" + ) self.new_session(config) if not self.statestore.get_latest_model(): - print("No model in model chain, please provide a seed model!", flush=True) + print( + "No model in model chain, please provide a seed model!", + flush=True, + ) self._state = ReducerState.monitoring last_round = int(self.get_latest_round_id()) @@ -96,7 +104,7 @@ def session(self, config): combiner.flush_model_update_queue() # Execute the rounds in this session - for round in range(1, int(config['rounds'] + 1)): + for round in range(1, int(config["rounds"] + 1)): # Increment the round number if last_round: @@ -107,10 +115,17 @@ def session(self, config): try: _, round_data = self.round(config, current_round) except TypeError as e: - print("Could not unpack data from round: {0}".format(e), flush=True) - - print("CONTROL: Round completed with status {}".format( - round_data['status']), flush=True) + print( + "Could not unpack data from round: {0}".format(e), + flush=True, + ) + + print( + "CONTROL: Round completed with status {}".format( + round_data["status"] + ), + flush=True, + ) self.tracer.set_round_data(round_data) @@ -118,7 +133,7 @@ def session(self, config): self._state = ReducerState.idle def round(self, session_config, round_id): - """ Execute a single global round. + """Execute a single global round. :param session_config: The session config. :type session_config: dict @@ -126,35 +141,42 @@ def round(self, session_config, round_id): :type round_id: str(int) """ - round_data = {'round_id': round_id} + round_data = {"round_id": round_id} if len(self.network.get_combiners()) < 1: print("REDUCER: No combiners connected!", flush=True) - round_data['status'] = 'Failed' + round_data["status"] = "Failed" return None, round_data # 1. Assemble round config for this global round, # and check which combiners are able to participate # in the round. round_config = copy.deepcopy(session_config) - round_config['rounds'] = 1 - round_config['round_id'] = round_id - round_config['task'] = 'training' - round_config['model_id'] = self.statestore.get_latest_model() - round_config['helper_type'] = self.statestore.get_helper() + round_config["rounds"] = 1 + round_config["round_id"] = round_id + round_config["task"] = "training" + round_config["model_id"] = self.statestore.get_latest_model() + round_config["helper_type"] = self.statestore.get_helper() combiners = self.get_participating_combiners(round_config) round_start = self.evaluate_round_start_policy(combiners) if round_start: - print("CONTROL: round start policy met, participating combiners {}".format( - combiners), flush=True) + print( + "CONTROL: round start policy met, participating combiners {}".format( + combiners + ), + flush=True, + ) else: - print("CONTROL: Round start policy not met, skipping round!", flush=True) - round_data['status'] = 'Failed' + print( + "CONTROL: Round start policy not met, skipping round!", + flush=True, + ) + round_data["status"] = "Failed" return None - round_data['round_config'] = round_config + round_data["round_config"] = round_config # 2. Ask participating combiners to coordinate model updates _ = self.request_model_updates(combiners) @@ -164,27 +186,37 @@ def round(self, session_config, round_id): # dict to store combiners that have successfully produced an updated model updated = {} # wait until all combiners have produced an updated model or until round timeout - print("CONTROL: Fetching round config (ID: {round_id}) from statestore:".format( - round_id=round_id), flush=True) + print( + "CONTROL: Fetching round config (ID: {round_id}) from statestore:".format( + round_id=round_id + ), + flush=True, + ) while len(updated) < len(combiners): round = self.statestore.get_round(round_id) if round: print("CONTROL: Round found!", flush=True) # For each combiner in the round, check if it has produced an updated model (status == 'Success') - for combiner in round['combiners']: + for combiner in round["combiners"]: print(combiner, flush=True) - if combiner['status'] == 'Success': - if combiner['name'] not in updated.keys(): + if combiner["status"] == "Success": + if combiner["name"] not in updated.keys(): # Add combiner to updated dict - updated[combiner['name']] = combiner['model_id'] + updated[combiner["name"]] = combiner["model_id"] # Print combiner status - print("CONTROL: Combiner {name} status: {status}".format( - name=combiner['name'], status=combiner['status']), flush=True) + print( + "CONTROL: Combiner {name} status: {status}".format( + name=combiner["name"], status=combiner["status"] + ), + flush=True, + ) else: # Print every 10 seconds based on value of wait if wait % 10 == 0: - print("CONTROL: Waiting for round to complete...", flush=True) - if wait >= session_config['round_timeout']: + print( + "CONTROL: Waiting for round to complete...", flush=True + ) + if wait >= session_config["round_timeout"]: print("CONTROL: Round timeout! Exiting round...", flush=True) break # Update wait time used for timeout @@ -194,53 +226,77 @@ def round(self, session_config, round_id): round_valid = self.evaluate_round_validity_policy(updated) if not round_valid: print("REDUCER CONTROL: Round invalid!", flush=True) - round_data['status'] = 'Failed' + round_data["status"] = "Failed" return None, round_data print("CONTROL: Reducing models from combiners...", flush=True) # 3. Reduce combiner models into a global model try: model, data = self.reduce(updated) - round_data['reduce'] = data + round_data["reduce"] = data print("CONTROL: Done reducing models from combiners!", flush=True) except Exception as e: - print("CONTROL: Failed to reduce models from combiners: {}".format( - e), flush=True) - round_data['status'] = 'Failed' + print( + "CONTROL: Failed to reduce models from combiners: {}".format( + e + ), + flush=True, + ) + round_data["status"] = "Failed" return None, round_data # 6. Commit the global model to model trail if model is not None: - print("CONTROL: Committing global model to model trail...", flush=True) + print( + "CONTROL: Committing global model to model trail...", + flush=True, + ) tic = time.time() model_id = uuid.uuid4() - self.commit(model_id, model) - round_data['time_commit'] = time.time() - tic - print("CONTROL: Done committing global model to model trail!", flush=True) + session_id = ( + session_config["session_id"] + if "session_id" in session_config + else None + ) + self.commit(model_id, model, session_id) + round_data["time_commit"] = time.time() - tic + print( + "CONTROL: Done committing global model to model trail!", + flush=True, + ) else: - print("REDUCER: failed to update model in round with config {}".format( - session_config), flush=True) - round_data['status'] = 'Failed' + print( + "REDUCER: failed to update model in round with config {}".format( + session_config + ), + flush=True, + ) + round_data["status"] = "Failed" return None, round_data - round_data['status'] = 'Success' + round_data["status"] = "Success" # 4. Trigger participating combiner nodes to execute a validation round for the current model - validate = session_config['validate'] + validate = session_config["validate"] if validate: combiner_config = copy.deepcopy(session_config) - combiner_config['round_id'] = round_id - combiner_config['model_id'] = self.statestore.get_latest_model() - combiner_config['task'] = 'validation' - combiner_config['helper_type'] = self.statestore.get_helper() + combiner_config["round_id"] = round_id + combiner_config["model_id"] = self.statestore.get_latest_model() + combiner_config["task"] = "validation" + combiner_config["helper_type"] = self.statestore.get_helper() validating_combiners = self._select_participating_combiners( - combiner_config) + combiner_config + ) for combiner, combiner_config in validating_combiners: try: - print("CONTROL: Submitting validation round to combiner {}".format( - combiner), flush=True) + print( + "CONTROL: Submitting validation round to combiner {}".format( + combiner + ), + flush=True, + ) combiner.submit(combiner_config) except CombinerUnavailableError: self._handle_unavailable_combiner(combiner) @@ -249,16 +305,16 @@ def round(self, session_config, round_id): return model_id, round_data def reduce(self, combiners): - """ Combine updated models from Combiner nodes into one global model. + """Combine updated models from Combiner nodes into one global model. :param combiners: dict of combiner names (key) and model IDs (value) to reduce :type combiners: dict """ meta = {} - meta['time_fetch_model'] = 0.0 - meta['time_load_model'] = 0.0 - meta['time_aggregate_model'] = 0.0 + meta["time_fetch_model"] = 0.0 + meta["time_load_model"] = 0.0 + meta["time_aggregate_model"] = 0.0 i = 1 model = None @@ -268,18 +324,25 @@ def reduce(self, combiners): return model, meta for name, model_id in combiners.items(): - # TODO: Handle inactive RPC error in get_model and raise specific error - print("REDUCER: Fetching model ({model_id}) from combiner {name}".format( - model_id=model_id, name=name), flush=True) + print( + "REDUCER: Fetching model ({model_id}) from combiner {name}".format( + model_id=model_id, name=name + ), + flush=True, + ) try: tic = time.time() combiner = self.get_combiner(name) data = combiner.get_model(model_id) - meta['time_fetch_model'] += (time.time() - tic) + meta["time_fetch_model"] += time.time() - tic except Exception as e: - print("REDUCER: Failed to fetch model from combiner {}: {}".format( - name, e), flush=True) + print( + "REDUCER: Failed to fetch model from combiner {}: {}".format( + name, e + ), + flush=True, + ) data = None if data is not None: @@ -288,21 +351,21 @@ def reduce(self, combiners): helper = self.get_helper() data.seek(0) model_next = helper.load(data) - meta['time_load_model'] += (time.time() - tic) + meta["time_load_model"] += time.time() - tic tic = time.time() model = helper.increment_average(model, model_next, i, i) - meta['time_aggregate_model'] += (time.time() - tic) + meta["time_aggregate_model"] += time.time() - tic except Exception: tic = time.time() data.seek(0) model = helper.load(data) - meta['time_aggregate_model'] += (time.time() - tic) + meta["time_aggregate_model"] += time.time() - tic i = i + 1 return model, meta def infer_instruct(self, config): - """ Main entrypoint for executing the inference compute plan. + """Main entrypoint for executing the inference compute plan. :param config: configuration for the inference round """ @@ -330,7 +393,7 @@ def infer_instruct(self, config): self.__state = ReducerState.idle def inference_round(self, config): - """ Execute an inference round. + """Execute an inference round. :param config: configuration for the inference round """ @@ -345,21 +408,27 @@ def inference_round(self, config): # Setup combiner configuration combiner_config = copy.deepcopy(config) - combiner_config['model_id'] = self.statestore.get_latest_model() - combiner_config['task'] = 'inference' - combiner_config['helper_type'] = self.statestore.get_framework() + combiner_config["model_id"] = self.statestore.get_latest_model() + combiner_config["task"] = "inference" + combiner_config["helper_type"] = self.statestore.get_framework() # Select combiners - validating_combiners = self._select_round_combiners( - combiner_config) + validating_combiners = self._select_round_combiners(combiner_config) # Test round start policy round_start = self.check_round_start_policy(validating_combiners) if round_start: - print("CONTROL: round start policy met, participating combiners {}".format( - validating_combiners), flush=True) + print( + "CONTROL: round start policy met, participating combiners {}".format( + validating_combiners + ), + flush=True, + ) else: - print("CONTROL: Round start policy not met, skipping round!", flush=True) + print( + "CONTROL: Round start policy not met, skipping round!", + flush=True, + ) return None # Synch combiners with latest model and trigger inference diff --git a/fedn/fedn/network/controller/controlbase.py b/fedn/fedn/network/controller/controlbase.py index e38d31e38..077620c14 100644 --- a/fedn/fedn/network/controller/controlbase.py +++ b/fedn/fedn/network/controller/controlbase.py @@ -11,7 +11,7 @@ from fedn.network.state import ReducerState # Maximum number of tries to connect to statestore and retrieve storage configuration -MAX_TRIES_BACKEND = os.getenv('MAX_TRIES_BACKEND', 10) +MAX_TRIES_BACKEND = os.getenv("MAX_TRIES_BACKEND", 10) class UnsupportedStorageBackend(Exception): @@ -27,7 +27,7 @@ class MisconfiguredHelper(Exception): class ControlBase(ABC): - """ Base class and interface for a global controller. + """Base class and interface for a global controller. Override this class to implement a global training strategy (control). :param statestore: The statestore object. @@ -36,7 +36,7 @@ class ControlBase(ABC): @abstractmethod def __init__(self, statestore): - """ Constructor. """ + """Constructor.""" self._state = ReducerState.setup self.statestore = statestore @@ -52,26 +52,36 @@ def __init__(self, statestore): not_ready = False else: print( - "REDUCER CONTROL: Storage backend not configured, waiting...", flush=True) + "REDUCER CONTROL: Storage backend not configured, waiting...", + flush=True, + ) sleep(5) tries += 1 if tries > MAX_TRIES_BACKEND: raise Exception except Exception: print( - "REDUCER CONTROL: Failed to retrive storage configuration, exiting.", flush=True) + "REDUCER CONTROL: Failed to retrive storage configuration, exiting.", + flush=True, + ) raise MisconfiguredStorageBackend() - if storage_config['storage_type'] == 'S3': - self.model_repository = S3ModelRepository(storage_config['storage_config']) + if storage_config["storage_type"] == "S3": + self.model_repository = S3ModelRepository( + storage_config["storage_config"] + ) else: - print("REDUCER CONTROL: Unsupported storage backend, exiting.", flush=True) + print( + "REDUCER CONTROL: Unsupported storage backend, exiting.", + flush=True, + ) raise UnsupportedStorageBackend() # The tracer is a helper that manages state in the database backend statestore_config = statestore.get_config() self.tracer = MongoTracer( - statestore_config['mongo_config'], statestore_config['network_id']) + statestore_config["mongo_config"], statestore_config["network_id"] + ) if self.statestore.is_inited(): self._state = ReducerState.idle @@ -89,7 +99,7 @@ def reduce(self, combiners): pass def get_helper(self): - """ Get a helper instance from global config. + """Get a helper instance from global config. :return: Helper instance. :rtype: :class:`fedn.utils.plugins.helperbase.HelperBase` @@ -97,11 +107,15 @@ def get_helper(self): helper_type = self.statestore.get_helper() helper = fedn.utils.helpers.get_helper(helper_type) if not helper: - raise MisconfiguredHelper("Unsupported helper type {}, please configure compute_package.helper !".format(helper_type)) + raise MisconfiguredHelper( + "Unsupported helper type {}, please configure compute_package.helper !".format( + helper_type + ) + ) return helper def get_state(self): - """ Get the current state of the controller. + """Get the current state of the controller. :return: The current state. :rtype: :class:`fedn.network.state.ReducerState` @@ -109,7 +123,7 @@ def get_state(self): return self._state def idle(self): - """ Check if the controller is idle. + """Check if the controller is idle. :return: True if idle, False otherwise. :rtype: bool @@ -139,7 +153,7 @@ def get_latest_round_id(self): if not last_round: return 0 else: - return last_round['round_id'] + return last_round["round_id"] def get_latest_round(self): round = self.statestore.get_latest_round() @@ -153,27 +167,29 @@ def get_compute_package_name(self): definition = self.statestore.get_compute_package() if definition: try: - package_name = definition['filename'] + package_name = definition["filename"] return package_name except (IndexError, KeyError): print( - "No context filename set for compute context definition", flush=True) + "No context filename set for compute context definition", + flush=True, + ) return None else: return None def set_compute_package(self, filename, path): - """ Persist the configuration for the compute package. """ + """Persist the configuration for the compute package.""" self.model_repository.set_compute_package(filename, path) self.statestore.set_compute_package(filename) - def get_compute_package(self, compute_package=''): + def get_compute_package(self, compute_package=""): """ :param compute_package: :return: """ - if compute_package == '': + if compute_package == "": compute_package = self.get_compute_package_name() if compute_package: return self.model_repository.get_compute_package(compute_package) @@ -181,42 +197,49 @@ def get_compute_package(self, compute_package=''): return None def new_session(self, config): - """ Initialize a new session in backend db. """ + """Initialize a new session in backend db.""" if "session_id" not in config.keys(): session_id = uuid.uuid4() - config['session_id'] = str(session_id) + config["session_id"] = str(session_id) else: - session_id = config['session_id'] + session_id = config["session_id"] self.tracer.new_session(id=session_id) self.tracer.set_session_config(session_id, config) def request_model_updates(self, combiners): - """Call Combiner server RPC to get a model update. """ + """Call Combiner server RPC to get a model update.""" cl = [] for combiner, combiner_round_config in combiners: response = combiner.submit(combiner_round_config) cl.append((combiner, response)) return cl - def commit(self, model_id, model=None): - """ Commit a model to the global model trail. The model commited becomes the lastest consensus model. """ + def commit(self, model_id, model=None, session_id=None): + """Commit a model to the global model trail. The model commited becomes the lastest consensus model.""" helper = self.get_helper() if model is not None: - print("CONTROL: Saving model file temporarily to disk...", flush=True) + print( + "CONTROL: Saving model file temporarily to disk...", flush=True + ) outfile_name = helper.save(model) print("CONTROL: Uploading model to Minio...", flush=True) model_id = self.model_repository.set_model( - outfile_name, is_file=True) + outfile_name, is_file=True + ) print("CONTROL: Deleting temporary model file...", flush=True) os.unlink(outfile_name) - print("CONTROL: Committing model {} to global model trail in statestore...".format( - model_id), flush=True) - self.statestore.set_latest_model(model_id) + print( + "CONTROL: Committing model {} to global model trail in statestore...".format( + model_id + ), + flush=True, + ) + self.statestore.set_latest_model(model_id, session_id) def get_combiner(self, name): for combiner in self.network.get_combiners(): @@ -226,7 +249,7 @@ def get_combiner(self, name): def get_participating_combiners(self, combiner_round_config): """Assemble a list of combiners able to participate in a round as - descibed by combiner_round_config. + descibed by combiner_round_config. """ combiners = [] for combiner in self.network.get_combiners(): @@ -238,45 +261,47 @@ def get_participating_combiners(self, combiner_round_config): if combiner_state is not None: is_participating = self.evaluate_round_participation_policy( - combiner_round_config, combiner_state) + combiner_round_config, combiner_state + ) if is_participating: combiners.append((combiner, combiner_round_config)) return combiners - def evaluate_round_participation_policy(self, compute_plan, combiner_state): - """ Evaluate policy for combiner round-participation. - A combiner participates if it is responsive and reports enough - active clients to participate in the round. + def evaluate_round_participation_policy( + self, compute_plan, combiner_state + ): + """Evaluate policy for combiner round-participation. + A combiner participates if it is responsive and reports enough + active clients to participate in the round. """ - if compute_plan['task'] == 'training': - nr_active_clients = int(combiner_state['nr_active_trainers']) - elif compute_plan['task'] == 'validation': - nr_active_clients = int(combiner_state['nr_active_validators']) + if compute_plan["task"] == "training": + nr_active_clients = int(combiner_state["nr_active_trainers"]) + elif compute_plan["task"] == "validation": + nr_active_clients = int(combiner_state["nr_active_validators"]) else: print("Invalid task type!", flush=True) return False - if int(compute_plan['clients_required']) <= nr_active_clients: + if int(compute_plan["clients_required"]) <= nr_active_clients: return True else: return False def evaluate_round_start_policy(self, combiners): - """ Check if the policy to start a round is met. """ + """Check if the policy to start a round is met.""" if len(combiners) > 0: - return True else: return False def evaluate_round_validity_policy(self, combiners): - """ Check if the round should be seen as valid. + """Check if the round should be seen as valid. - At the end of the round, before committing a model to the global model trail, - we check if the round validity policy has been met. This can involve - e.g. asserting that a certain number of combiners have reported in an - updated model, or that criteria on model performance have been met. + At the end of the round, before committing a model to the global model trail, + we check if the round validity policy has been met. This can involve + e.g. asserting that a certain number of combiners have reported in an + updated model, or that criteria on model performance have been met. """ if combiners.keys() == []: return False @@ -294,7 +319,8 @@ def _select_participating_combiners(self, compute_plan): if combiner_state: is_participating = self.evaluate_round_participation_policy( - compute_plan, combiner_state) + compute_plan, combiner_state + ) if is_participating: participating_combiners.append((combiner, compute_plan)) return participating_combiners diff --git a/fedn/fedn/network/statestore/mongostatestore.py b/fedn/fedn/network/statestore/mongostatestore.py index 903d1a208..9262e9f13 100644 --- a/fedn/fedn/network/statestore/mongostatestore.py +++ b/fedn/fedn/network/statestore/mongostatestore.py @@ -145,7 +145,7 @@ def get_session(self, session_id): """ return self.sessions.find_one({"session_id": session_id}) - def set_latest_model(self, model_id): + def set_latest_model(self, model_id, session_id=None): """Set the latest model id. :param model_id: The model id. @@ -153,6 +153,17 @@ def set_latest_model(self, model_id): :return: """ + commited_at = str(datetime.now()) + + self.model.insert_one( + { + "key": "models", + "model": model_id, + "session_id": session_id, + "committed_at": commited_at, + } + ) + self.model.update_one( {"key": "current_model"}, {"$set": {"model": model_id}}, True ) @@ -161,7 +172,7 @@ def set_latest_model(self, model_id): { "$push": { "model": model_id, - "committed_at": str(datetime.now()), + "committed_at": commited_at, } }, True, @@ -326,6 +337,42 @@ def get_helper(self): except (KeyError, IndexError): return None + def list_models(self, session_id=None, limit=None, skip=None): + """List all models in the statestore. + + :param session_id: The session id. + :type session_id: str + :param limit: The maximum number of models to return. + :type limit: int + :param skip: The number of models to skip. + :type skip: int + :return: List of models. + :rtype: list + """ + result = None + + find_option = ( + {"key": "models"} + if session_id is None + else {"key": "models", "session_id": session_id} + ) + + if limit is not None and skip is not None: + limit = int(limit) + skip = int(skip) + + result = self.model.find(find_option).limit(limit).skip(skip) + + else: + result = self.model.find(find_option) + + count = self.model.count_documents({}) + + return { + "result": result, + "count": count, + } + def get_model_trail(self): """Get the model trail.