diff --git a/.ci/tests/examples/wait_for.py b/.ci/tests/examples/wait_for.py index dc3345da0..ccd76859d 100644 --- a/.ci/tests/examples/wait_for.py +++ b/.ci/tests/examples/wait_for.py @@ -18,7 +18,7 @@ def _retry(try_func, **func_args): for _ in range(RETRIES): is_success = try_func(**func_args) if is_success: - _eprint('Sucess.') + _eprint('Success.') return True _eprint(f'Sleeping for {SLEEP}.') sleep(SLEEP) @@ -30,7 +30,7 @@ def _test_rounds(n_rounds): client = pymongo.MongoClient( "mongodb://fedn_admin:password@localhost:6534") collection = client['fedn-network']['control']['rounds'] - query = {'reducer.status': 'Success'} + query = {'status': 'Finished'} n = collection.count_documents(query) client.close() _eprint(f'Succeded rounds: {n}.') @@ -60,7 +60,7 @@ def _test_nodes(n_nodes, node_type, reducer_host='localhost', reducer_port='8092 return count == n_nodes except Exception as e: - _eprint(f'Reques exception econuntered: {e}.') + _eprint(f'Request exception enconuntered: {e}.') return False diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 8a86fd439..bc45dc53b 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -12,3 +12,4 @@ python: install: - method: pip path: ./fedn + - requirements: docs/requirements.txt diff --git a/Dockerfile b/Dockerfile index 67f026d03..fa8c5bd22 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Base image -ARG BASE_IMG=python:3.9-slim +ARG BASE_IMG=python:3.10-slim FROM $BASE_IMG # Requirements (use MNIST Keras as default) diff --git a/docker-compose.yaml b/docker-compose.yaml index c8d3aff15..aa4550c25 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -68,7 +68,7 @@ services: build: context: . args: - BASE_IMG: ${BASE_IMG:-python:3.9-slim} + BASE_IMG: ${BASE_IMG:-python:3.10-slim} working_dir: /app volumes: - ${HOST_REPO_DIR:-.}/fedn:/app/fedn @@ -89,7 +89,7 @@ services: build: context: . args: - BASE_IMG: ${BASE_IMG:-python:3.9-slim} + BASE_IMG: ${BASE_IMG:-python:3.10-slim} working_dir: /app volumes: - ${HOST_REPO_DIR:-.}/fedn:/app/fedn @@ -110,7 +110,7 @@ services: build: context: . args: - BASE_IMG: ${BASE_IMG:-python:3.9-slim} + BASE_IMG: ${BASE_IMG:-python:3.10-slim} working_dir: /app volumes: - ${HOST_REPO_DIR:-.}/fedn:/app/fedn @@ -127,7 +127,7 @@ services: build: context: . args: - BASE_IMG: ${BASE_IMG:-python:3.9-slim} + BASE_IMG: ${BASE_IMG:-python:3.10-slim} working_dir: /app volumes: - ${HOST_REPO_DIR:-.}/fedn:/app/fedn diff --git a/docs/conf.py b/docs/conf.py index 963080333..bd2032b0e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -12,7 +12,7 @@ author = 'Scaleout Systems AB' # The full version, including alpha/beta/rc tags -release = '0.4.1' +release = '0.6.0' # Add any Sphinx extension module names here, as strings extensions = [ diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 000000000..4170c03ef --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1 @@ +sphinx-rtd-theme \ No newline at end of file diff --git a/examples/mnist-keras/bin/build.sh b/examples/mnist-keras/bin/build.sh index 18cdb5128..44eda61df 100755 --- a/examples/mnist-keras/bin/build.sh +++ b/examples/mnist-keras/bin/build.sh @@ -5,4 +5,4 @@ set -e client/entrypoint init_seed # Make compute package -tar -czvf package.tgz client \ No newline at end of file +tar -czvf package.tgz client diff --git a/fedn/fedn/common/storage/s3/miniorepo.py b/fedn/fedn/common/storage/s3/miniorepo.py index 9341704e6..154cea7e9 100644 --- a/fedn/fedn/common/storage/s3/miniorepo.py +++ b/fedn/fedn/common/storage/s3/miniorepo.py @@ -62,11 +62,13 @@ def __init__(self, config): self.create_bucket(self.bucket) def create_bucket(self, bucket_name): - """ + """ Create a new bucket. If bucket exists, do nothing. - :param bucket_name: + :param bucket_name: The name of the bucket + :type bucket_name: str """ found = self.client.bucket_exists(bucket_name) + if not found: try: self.client.make_bucket(bucket_name) diff --git a/fedn/fedn/common/tracer/mongotracer.py b/fedn/fedn/common/tracer/mongotracer.py index 0a3e28cdc..aa5c0810b 100644 --- a/fedn/fedn/common/tracer/mongotracer.py +++ b/fedn/fedn/common/tracer/mongotracer.py @@ -52,18 +52,26 @@ def drop_status(self): if self.status: self.status.drop() - def new_session(self, id=None): - """ Create a new session. """ + def create_session(self, id=None): + """ Create a new session. + + :param id: The ID of the created session. + :type id: uuid, str + + """ if not id: id = uuid.uuid4() data = {'session_id': str(id)} self.sessions.insert_one(data) - def new_round(self, id): - """ Create a new session. """ + def create_round(self, round_data): + """ Create a new round. - data = {'round_id': str(id)} - self.rounds.insert_one(data) + :param round_data: Dictionary with round data. + :type round_data: dict + """ + # TODO: Add check if round_id already exists + self.rounds.insert_one(round_data) def set_session_config(self, id, config): self.sessions.update_one({'session_id': str(id)}, { @@ -72,18 +80,35 @@ def set_session_config(self, id, config): def set_round_combiner_data(self, data): """ - :param round_meta: + :param data: The combiner data + :type data: dict """ self.rounds.update_one({'round_id': str(data['round_id'])}, { '$push': {'combiners': data}}, True) - def set_round_data(self, round_data): + def set_round_config(self, round_id, round_config): + """ + + :param round_meta: + """ + self.rounds.update_one({'round_id': round_id}, { + '$set': {'round_config': round_config}}, True) + + def set_round_status(self, round_id, round_status): + """ + + :param round_meta: + """ + self.rounds.update_one({'round_id': round_id}, { + '$set': {'status': round_status}}, True) + + def set_round_data(self, round_id, round_data): """ :param round_meta: """ - self.rounds.update_one({'round_id': str(round_data['round_id'])}, { - '$push': {'reducer': round_data}}, True) + self.rounds.update_one({'round_id': round_id}, { + '$set': {'round_data': round_data}}, True) def update_client_status(self, client_name, status): """ Update client status in statestore. diff --git a/fedn/fedn/network/api/client.py b/fedn/fedn/network/api/client.py index 0e0a48a52..58fc27304 100644 --- a/fedn/fedn/network/api/client.py +++ b/fedn/fedn/network/api/client.py @@ -1,5 +1,3 @@ -import uuid - import requests __all__ = ['APIClient'] @@ -137,9 +135,6 @@ def start_session(self, session_id=None, round_timeout=180, rounds=5, round_buff :return: A dict with success or failure message and session config. :rtype: dict """ - # If session id is None, generate a random session id. - if session_id is None: - session_id = str(uuid.uuid4()) response = requests.post(self._get_url('start_session'), json={ 'session_id': session_id, 'round_timeout': round_timeout, diff --git a/fedn/fedn/network/api/interface.py b/fedn/fedn/network/api/interface.py index 61095e6ec..e56462493 100644 --- a/fedn/fedn/network/api/interface.py +++ b/fedn/fedn/network/api/interface.py @@ -2,6 +2,7 @@ import copy import os import threading +import uuid from io import BytesIO from flask import jsonify, send_from_directory @@ -707,9 +708,8 @@ def get_round(self, round_id): if round_object is None: return jsonify({"success": False, "message": "Round not found."}) payload = { - "round_id": round_object["round_id"], - "reducer": round_object["reducer"], - "combiners": round_object["combiners"], + 'round_id': round_object['round_id'], + 'combiners': round_object['combiners'], } return jsonify(payload) @@ -864,7 +864,7 @@ def start_session( # Setup session config session_config = { - "session_id": session_id, + "session_id": session_id if session_id else str(uuid.uuid4()), "round_timeout": round_timeout, "buffer_size": round_buffer_size, "model_id": model_id, diff --git a/fedn/fedn/network/clients/client.py b/fedn/fedn/network/clients/client.py index 9851b32ef..e27616925 100644 --- a/fedn/fedn/network/clients/client.py +++ b/fedn/fedn/network/clients/client.py @@ -4,7 +4,7 @@ import os import queue import re -import ssl +import socket import sys import tempfile import threading @@ -15,7 +15,9 @@ from io import BytesIO import grpc +from cryptography.hazmat.primitives.serialization import Encoding from google.protobuf.json_format import MessageToJson +from OpenSSL import SSL import fedn.common.net.grpc.fedn_pb2 as fedn import fedn.common.net.grpc.fedn_pb2_grpc as rpc @@ -127,6 +129,42 @@ def _assign(self): print("Received combiner config: {}".format(client_config), flush=True) return client_config + def _add_grpc_metadata(self, key, value): + """Add metadata for gRPC calls. + + :param key: The key of the metadata. + :type key: str + :param value: The value of the metadata. + :type value: str + """ + # Check if metadata exists and add if not + if not hasattr(self, 'metadata'): + self.metadata = () + + # Check if metadata key already exists and replace value if so + for i, (k, v) in enumerate(self.metadata): + if k == key: + # Replace value + self.metadata = self.metadata[:i] + ((key, value),) + self.metadata[i + 1:] + return + + # Set metadata using tuple concatenation + self.metadata += ((key, value),) + + def _get_ssl_certificate(self, domain, port=443): + context = SSL.Context(SSL.SSLv23_METHOD) + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect((domain, port)) + ssl_sock = SSL.Connection(context, sock) + ssl_sock.set_tlsext_host_name(domain.encode()) + ssl_sock.set_connect_state() + ssl_sock.do_handshake() + cert = ssl_sock.get_peer_certificate() + ssl_sock.close() + sock.close() + cert = cert.to_cryptography().public_bytes(Encoding.PEM).decode() + return cert + def _connect(self, client_config): """Connect to assigned combiner. @@ -137,6 +175,9 @@ def _connect(self, client_config): # TODO use the client_config['certificate'] for setting up secure comms' host = client_config['host'] + # Add host to gRPC metadata + self._add_grpc_metadata('grpc-server', host) + print("CLIENT: Using metadata: {}".format(self.metadata), flush=True) port = client_config['port'] secure = False if client_config['fqdn'] is not None: @@ -161,7 +202,7 @@ def _connect(self, client_config): elif self.config['secure']: secure = True print("CLIENT: using CA certificate for GRPC channel") - cert = ssl.get_server_certificate((host, port)) + cert = self._get_ssl_certificate(host, port=port) credentials = grpc.ssl_channel_credentials(cert.encode('utf-8')) if self.config['token']: @@ -331,7 +372,7 @@ def get_model(self, id): """ data = BytesIO() - for part in self.modelStub.Download(fedn.ModelRequest(id=id)): + for part in self.modelStub.Download(fedn.ModelRequest(id=id), metadata=self.metadata): if part.status == fedn.ModelStatus.IN_PROGRESS: data.write(part.data) @@ -386,7 +427,7 @@ def upload_request_generator(mdl): if not b: break - result = self.modelStub.Upload(upload_request_generator(bt)) + result = self.modelStub.Upload(upload_request_generator(bt), metadata=self.metadata) return result @@ -400,11 +441,12 @@ def _listen_to_model_update_request_stream(self): r = fedn.ClientAvailableMessage() r.sender.name = self.name r.sender.role = fedn.WORKER - metadata = [('client', r.sender.name)] + # Add client to metadata + self._add_grpc_metadata('client', self.name) while True: try: - for request in self.combinerStub.ModelUpdateRequestStream(r, metadata=metadata): + for request in self.combinerStub.ModelUpdateRequestStream(r, metadata=self.metadata): if request.sender.role == fedn.COMBINER: # Process training request self._send_status("Received model update request.", log_level=fedn.Status.AUDIT, @@ -438,7 +480,7 @@ def _listen_to_model_validation_request_stream(self): r.sender.role = fedn.WORKER while True: try: - for request in self.combinerStub.ModelValidationRequestStream(r): + for request in self.combinerStub.ModelValidationRequestStream(r, metadata=self.metadata): # Process validation request _ = request.model_id self._send_status("Recieved model validation request.", log_level=fedn.Status.AUDIT, @@ -589,7 +631,7 @@ def process_request(self): update.correlation_id = request.correlation_id update.meta = json.dumps(meta) # TODO: Check responses - _ = self.combinerStub.SendModelUpdate(update) + _ = self.combinerStub.SendModelUpdate(update, metadata=self.metadata) self._send_status("Model update completed.", log_level=fedn.Status.AUDIT, type=fedn.StatusType.MODEL_UPDATE, request=update) @@ -618,7 +660,7 @@ def process_request(self): validation.timestamp = self.str validation.correlation_id = request.correlation_id _ = self.combinerStub.SendModelValidation( - validation) + validation, metadata=self.metadata) # Set status type if request.is_inference: @@ -655,7 +697,7 @@ def _send_heartbeat(self, update_frequency=2.0): heartbeat = fedn.Heartbeat(sender=fedn.Client( name=self.name, role=fedn.WORKER)) try: - self.connectorStub.SendHeartbeat(heartbeat) + self.connectorStub.SendHeartbeat(heartbeat, metadata=self.metadata) self._missed_heartbeat = 0 except grpc.RpcError as e: status_code = e.code() @@ -694,7 +736,7 @@ def _send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None) self.logs.append( "{} {} LOG LEVEL {} MESSAGE {}".format(str(datetime.now()), status.sender.name, status.log_level, status.status)) - _ = self.connectorStub.SendStatus(status) + _ = self.connectorStub.SendStatus(status, metadata=self.metadata) def run(self): """ Run the client. """ diff --git a/fedn/fedn/network/clients/test_client.py b/fedn/fedn/network/clients/test_client.py new file mode 100644 index 000000000..889c00b94 --- /dev/null +++ b/fedn/fedn/network/clients/test_client.py @@ -0,0 +1,45 @@ +import unittest + +from fedn.network.clients.client import Client + + +class TestClient(unittest.TestCase): + """Test the Client class.""" + + def setUp(self): + self.client = Client() + + def test_add_grpc_metadata(self): + """Test the _add_grpc_metadata method.""" + + # Test adding metadata when it doesn't exist + self.client._add_grpc_metadata('key1', 'value1') + self.assertEqual(self.client.metadata, (('key1', 'value1'),)) + + # Test adding metadata when it already exists + self.client._add_grpc_metadata('key1', 'value2') + self.assertEqual(self.client.metadata, (('key1', 'value2'),)) + + # Test adding multiple metadata + self.client._add_grpc_metadata('key2', 'value3') + self.assertEqual(self.client.metadata, (('key1', 'value2'), ('key2', 'value3'))) + + # Test adding metadata with special characters + self.client._add_grpc_metadata('key3', 'value4!@#$%^&*()') + self.assertEqual(self.client.metadata, (('key1', 'value2'), ('key2', 'value3'), ('key3', 'value4!@#$%^&*()'))) + + # Test adding metadata with empty key + with self.assertRaises(ValueError): + self.client._add_grpc_metadata('', 'value5') + + # Test adding metadata with empty value + with self.assertRaises(ValueError): + self.client._add_grpc_metadata('key4', '') + + # Test adding metadata with None value + with self.assertRaises(ValueError): + self.client._add_grpc_metadata('key5', None) + + +if __name__ == '__main__': + unittest.main() diff --git a/fedn/fedn/network/controller/control.py b/fedn/fedn/network/controller/control.py index a8e32333d..615edb3b5 100644 --- a/fedn/fedn/network/controller/control.py +++ b/fedn/fedn/network/controller/control.py @@ -3,6 +3,9 @@ import time import uuid +from tenacity import (retry, retry_if_exception_type, stop_after_delay, + wait_random) + from fedn.network.combiner.interfaces import CombinerUnavailableError from fedn.network.controller.controlbase import ControlBase from fedn.network.state import ReducerState @@ -48,6 +51,20 @@ def __init__(self, message): super().__init__(self.message) +class CombinersNotDoneException(Exception): + """ Exception class for when model is None """ + + def __init__(self, message): + """ Constructor method. + + :param message: The exception message. + :type message: str + + """ + self.message = message + super().__init__(self.message) + + class Control(ControlBase): """Controller, implementing the overall global training, validation and inference logic. @@ -83,12 +100,10 @@ def session(self, config): return self._state = ReducerState.instructing - - # Must be called to set info in the db config["committed_at"] = datetime.datetime.now().strftime( "%Y-%m-%d %H:%M:%S" ) - self.new_session(config) + self.create_session(config) if not self.statestore.get_latest_model(): print( @@ -106,14 +121,13 @@ def session(self, config): # Execute the rounds in this session for round in range(1, int(config["rounds"] + 1)): # Increment the round number - if last_round: current_round = last_round + round else: current_round = round try: - _, round_data = self.round(config, current_round) + _, round_data = self.round(config, str(current_round)) except TypeError as e: print( "Could not unpack data from round: {0}".format(e), @@ -127,30 +141,27 @@ def session(self, config): flush=True, ) - self.tracer.set_round_data(round_data) - # TODO: Report completion of session self._state = ReducerState.idle def round(self, session_config, round_id): - """Execute a single global round. + """ Execute one global round. + + : param session_config: The session config. + : type session_config: dict + : param round_id: The round id. + : type round_id: str - :param session_config: The session config. - :type session_config: dict - :param round_id: The round id. - :type round_id: str(int) """ - round_data = {"round_id": round_id} + self.create_round({'round_id': round_id, 'status': "Pending"}) if len(self.network.get_combiners()) < 1: - print("REDUCER: No combiners connected!", flush=True) - round_data["status"] = "Failed" - return None, round_data + print("CONTROLLER: Round cannot start, no combiners connected!", flush=True) + self.set_round_status(round_id, 'Failed') + return None, self.statestore.get_round(round_id) - # 1. Assemble round config for this global round, - # and check which combiners are able to participate - # in the round. + # Assemble round config for this global round round_config = copy.deepcopy(session_config) round_config["rounds"] = 1 round_config["round_id"] = round_id @@ -158,94 +169,85 @@ def round(self, session_config, round_id): round_config["model_id"] = self.statestore.get_latest_model() round_config["helper_type"] = self.statestore.get_helper() - combiners = self.get_participating_combiners(round_config) - round_start = self.evaluate_round_start_policy(combiners) + self.set_round_config(round_id, round_config) + + # Get combiners that are able to participate in round, given round_config + participating_combiners = self.get_participating_combiners(round_config) + + # Check if the policy to start the round is met + round_start = self.evaluate_round_start_policy(participating_combiners) if round_start: - print( - "CONTROL: round start policy met, participating combiners {}".format( - combiners - ), - flush=True, - ) + print("CONTROL: round start policy met, {} participating combiners.".format( + len(participating_combiners)), flush=True) else: - print( - "CONTROL: Round start policy not met, skipping round!", - flush=True, - ) - round_data["status"] = "Failed" - return None + print("CONTROL: Round start policy not met, skipping round!", flush=True) + self.set_round_status(round_id, 'Failed') + return None, self.statestore.get_round(round_id) + + # Ask participating combiners to coordinate model updates + _ = self.request_model_updates(participating_combiners) + # TODO: Check response + + # Wait until participating combiners have produced an updated global model, + # or round times out. + def do_if_round_times_out(result): + print("CONTROL: Round timed out!", flush=True) + + @retry(wait=wait_random(min=1.0, max=2.0), + stop=stop_after_delay(session_config['round_timeout']), + retry_error_callback=do_if_round_times_out, + retry=retry_if_exception_type(CombinersNotDoneException)) + def combiners_done(): - round_data["round_config"] = round_config - - # 2. Ask participating combiners to coordinate model updates - _ = self.request_model_updates(combiners) - - # Wait until participating combiners have produced an updated global model. - wait = 0.0 - # dict to store combiners that have successfully produced an updated model - updated = {} - # wait until all combiners have produced an updated model or until round timeout - print( - "CONTROL: Fetching round config (ID: {round_id}) from statestore:".format( - round_id=round_id - ), - flush=True, - ) - while len(updated) < len(combiners): round = self.statestore.get_round(round_id) - if round: - print("CONTROL: Round found!", flush=True) - # For each combiner in the round, check if it has produced an updated model (status == 'Success') - for combiner in round["combiners"]: - print(combiner, flush=True) - if combiner["status"] == "Success": - if combiner["name"] not in updated.keys(): - # Add combiner to updated dict - updated[combiner["name"]] = combiner["model_id"] - # Print combiner status - print( - "CONTROL: Combiner {name} status: {status}".format( - name=combiner["name"], status=combiner["status"] - ), - flush=True, - ) - else: - # Print every 10 seconds based on value of wait - if wait % 10 == 0: - print( - "CONTROL: Waiting for round to complete...", flush=True - ) - if wait >= session_config["round_timeout"]: - print("CONTROL: Round timeout! Exiting round...", flush=True) - break - # Update wait time used for timeout - time.sleep(1.0) - wait += 1.0 - - round_valid = self.evaluate_round_validity_policy(updated) + if 'combiners' not in round: + # TODO: use logger + print("CONTROL: Waiting for combiners to update model...", flush=True) + raise CombinersNotDoneException("Combiners have not yet reported.") + + if len(round['combiners']) < len(participating_combiners): + print("CONTROL: Waiting for combiners to update model...", flush=True) + raise CombinersNotDoneException("All combiners have not yet reported.") + + return True + + combiners_done() + + # Due to the distributed nature of the computation, there might be a + # delay before combiners have reported the round data to the db, + # so we need some robustness here. + @retry(wait=wait_random(min=0.1, max=1.0), + retry=retry_if_exception_type(KeyError)) + def check_combiners_done_reporting(): + round = self.statestore.get_round(round_id) + combiners = round['combiners'] + return combiners + + _ = check_combiners_done_reporting() + + round = self.statestore.get_round(round_id) + round_valid = self.evaluate_round_validity_policy(round) if not round_valid: print("REDUCER CONTROL: Round invalid!", flush=True) - round_data["status"] = "Failed" - return None, round_data + self.set_round_status(round_id, 'Failed') + return None, self.statestore.get_round(round_id) - print("CONTROL: Reducing models from combiners...", flush=True) - # 3. Reduce combiner models into a global model + print("CONTROL: Reducing combiner level models...", flush=True) + # Reduce combiner models into a new global model + round_data = {} try: - model, data = self.reduce(updated) - round_data["reduce"] = data + round = self.statestore.get_round(round_id) + model, data = self.reduce(round['combiners']) + round_data['reduce'] = data print("CONTROL: Done reducing models from combiners!", flush=True) except Exception as e: - print( - "CONTROL: Failed to reduce models from combiners: {}".format( - e - ), - flush=True, - ) - round_data["status"] = "Failed" - return None, round_data + print("CONTROL: Failed to reduce models from combiners: {}".format( + e), flush=True) + self.set_round_status(round_id, 'Failed') + return None, self.statestore.get_round(round_id) - # 6. Commit the global model to model trail + # Commit the new global model to the model trail if model is not None: print( "CONTROL: Committing global model to model trail...", @@ -271,10 +273,10 @@ def round(self, session_config, round_id): ), flush=True, ) - round_data["status"] = "Failed" - return None, round_data + self.set_round_status(round_id, 'Failed') + return None, self.statestore.get_round(round_id) - round_data["status"] = "Success" + self.set_round_status(round_id, 'Success') # 4. Trigger participating combiner nodes to execute a validation round for the current model validate = session_config["validate"] @@ -285,9 +287,8 @@ def round(self, session_config, round_id): combiner_config["task"] = "validation" combiner_config["helper_type"] = self.statestore.get_helper() - validating_combiners = self._select_participating_combiners( - combiner_config - ) + validating_combiners = self.get_participating_combiners( + combiner_config) for combiner, combiner_config in validating_combiners: try: @@ -302,13 +303,15 @@ def round(self, session_config, round_id): self._handle_unavailable_combiner(combiner) pass - return model_id, round_data + self.set_round_data(round_id, round_data) + self.set_round_status(round_id, 'Finished') + return model_id, self.statestore.get_round(round_id) def reduce(self, combiners): """Combine updated models from Combiner nodes into one global model. - :param combiners: dict of combiner names (key) and model IDs (value) to reduce - :type combiners: dict + : param combiners: dict of combiner names(key) and model IDs(value) to reduce + : type combiners: dict """ meta = {} @@ -323,7 +326,9 @@ def reduce(self, combiners): print("REDUCER: No combiners to reduce!", flush=True) return model, meta - for name, model_id in combiners.items(): + for combiner in combiners: + name = combiner['name'] + model_id = combiner['model_id'] # TODO: Handle inactive RPC error in get_model and raise specific error print( "REDUCER: Fetching model ({model_id}) from combiner {name}".format( @@ -333,9 +338,9 @@ def reduce(self, combiners): ) try: tic = time.time() - combiner = self.get_combiner(name) - data = combiner.get_model(model_id) - meta["time_fetch_model"] += time.time() - tic + combiner_interface = self.get_combiner(name) + data = combiner_interface.get_model(model_id) + meta['time_fetch_model'] += (time.time() - tic) except Exception as e: print( "REDUCER: Failed to fetch model from combiner {}: {}".format( @@ -367,7 +372,7 @@ def reduce(self, combiners): def infer_instruct(self, config): """Main entrypoint for executing the inference compute plan. - :param config: configuration for the inference round + : param config: configuration for the inference round """ # Check/set instucting state @@ -395,7 +400,7 @@ def infer_instruct(self, config): def inference_round(self, config): """Execute an inference round. - :param config: configuration for the inference round + : param config: configuration for the inference round """ # Init meta @@ -413,7 +418,8 @@ def inference_round(self, config): combiner_config["helper_type"] = self.statestore.get_framework() # Select combiners - validating_combiners = self._select_round_combiners(combiner_config) + validating_combiners = self.get_participating_combiners( + combiner_config) # Test round start policy round_start = self.check_round_start_policy(validating_combiners) diff --git a/fedn/fedn/network/controller/controlbase.py b/fedn/fedn/network/controller/controlbase.py index 077620c14..fab6a2027 100644 --- a/fedn/fedn/network/controller/controlbase.py +++ b/fedn/fedn/network/controller/controlbase.py @@ -196,8 +196,8 @@ def get_compute_package(self, compute_package=""): else: return None - def new_session(self, config): - """Initialize a new session in backend db.""" + def create_session(self, config): + """ Initialize a new session in backend db. """ if "session_id" not in config.keys(): session_id = uuid.uuid4() @@ -205,11 +205,50 @@ def new_session(self, config): else: session_id = config["session_id"] - self.tracer.new_session(id=session_id) + self.tracer.create_session(id=session_id) self.tracer.set_session_config(session_id, config) + def create_round(self, round_data): + """Initialize a new round in backend db. """ + + self.tracer.create_round(round_data) + + def set_round_data(self, round_id, round_data): + """ Set round data. + + :param round_id: The round unique identifier + :type round_id: str + :param round_data: The status + :type status: dict + """ + self.tracer.set_round_data(round_id, round_data) + + def set_round_status(self, round_id, status): + """ Set the round round stats. + + :param round_id: The round unique identifier + :type round_id: str + :param status: The status + :type status: str + """ + self.tracer.set_round_status(round_id, status) + + def set_round_config(self, round_id, round_config): + """ Upate round in backend db. + + :param round_id: The round unique identifier + :type round_id: str + :param round_config: The round configuration + :type round_config: dict + """ + self.tracer.set_round_config(round_id, round_config) + def request_model_updates(self, combiners): - """Call Combiner server RPC to get a model update.""" + """Ask Combiner server to produce a model update. + + :param combiners: A list of combiners + :type combiners: tuple (combiner, comboner_round_config) + """ cl = [] for combiner, combiner_round_config in combiners: response = combiner.submit(combiner_round_config) @@ -217,7 +256,15 @@ def request_model_updates(self, combiners): return cl def commit(self, model_id, model=None, session_id=None): - """Commit a model to the global model trail. The model commited becomes the lastest consensus model.""" + """Commit a model to the global model trail. The model commited becomes the lastest consensus model. + + :param model_id: Unique identifier for the model to commit. + :type model_id: str (uuid) + :param model: The model object to commit + :type model: BytesIO + :param session_id: Unique identifier for the session + :type session_id: str + """ helper = self.get_helper() if model is not None: @@ -289,45 +336,47 @@ def evaluate_round_participation_policy( return False def evaluate_round_start_policy(self, combiners): - """Check if the policy to start a round is met.""" + """Check if the policy to start a round is met. + + :param combiners: A list of combiners + :type combiners: list + :return: True if the round policy is mer, otherwise False + :rtype: bool + """ if len(combiners) > 0: return True else: return False - def evaluate_round_validity_policy(self, combiners): - """Check if the round should be seen as valid. + def evaluate_round_validity_policy(self, round): + """ Check if the round is valid. At the end of the round, before committing a model to the global model trail, we check if the round validity policy has been met. This can involve e.g. asserting that a certain number of combiners have reported in an updated model, or that criteria on model performance have been met. - """ - if combiners.keys() == []: - return False - else: - return True - def _select_participating_combiners(self, compute_plan): - participating_combiners = [] - for combiner in self.network.get_combiners(): + :param round: The round object + :rtype round: dict + :return: True if the policy is met, otherwise False + :rtype: bool + """ + model_ids = [] + for combiner in round['combiners']: try: - combiner_state = combiner.report() - except CombinerUnavailableError: - self._handle_unavailable_combiner(combiner) - combiner_state = None + model_ids.append(combiner['model_id']) + except KeyError: + pass - if combiner_state: - is_participating = self.evaluate_round_participation_policy( - compute_plan, combiner_state - ) - if is_participating: - participating_combiners.append((combiner, compute_plan)) - return participating_combiners + if len(model_ids) == 0: + return False + + return True def state(self): - """ + """ Get the current state of the controller - :return: + :return: The state + :rype: str """ return self._state diff --git a/fedn/setup.py b/fedn/setup.py index 1dbdb951f..62888ce09 100644 --- a/fedn/setup.py +++ b/fedn/setup.py @@ -2,7 +2,7 @@ setup( name='fedn', - version='0.5.0', + version='0.6.0', description="""Scaleout Federated Learning""", author='Scaleout Systems AB', author_email='contact@scaleoutsystems.com',